Skip to content

Commit 438f12d

Browse files
SherlockNoMadpytorchmergebot
authored andcommitted
Rewrite some decomps to allow producing aten ops (#93099)
This introduces a new stop to the decomposition train. Before reaching prims.view_of, it will stop at aten.alias. Export path wants to get off the train at aten ops. Pull Request resolved: #93099 Approved by: https://github.com/ngimel
1 parent 332d55d commit 438f12d

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,6 @@ aten::addmv.out
559559
aten::addr_
560560
aten::affine_grid_generator
561561
aten::affine_grid_generator.out
562-
aten::alias
563562
aten::alias_copy
564563
aten::alias_copy.out
565564
aten::allclose

torch/_prims_common/wrappers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,9 @@ def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None:
3434

3535
# TODO: implement ref.cast with an option to enforce safe casting
3636
def _maybe_convert_to_dtype(a, dtype):
37-
import torch._prims as prims
3837
if isinstance(a, TensorLike):
3938
if a.dtype != dtype:
40-
# NOTE: this is incorrect on the CPU
41-
# See https://github.com/pytorch/pytorch/issues/77553
42-
return prims.convert_element_type(a, dtype)
39+
return a.to(dtype)
4340
return a
4441
if isinstance(a, Number):
4542
return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type]

torch/_refs/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3905,12 +3905,17 @@ def T(a: TensorLikeType) -> TensorLikeType:
39053905
return a.t()
39063906

39073907

3908+
@register_decomposition(aten.alias)
3909+
def alias(a: TensorLikeType) -> TensorLikeType:
3910+
return prims.view_of(a)
3911+
3912+
39083913
@register_decomposition(aten.transpose)
39093914
def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
39103915
_dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc]
39113916

39123917
if a.ndim <= 1 or dim0 == dim1:
3913-
return prims.view_of(a)
3918+
return aten.alias.default(a)
39143919

39153920
_permutation = list(range(0, a.ndim))
39163921
_permutation[_dim0] = _dim1

0 commit comments

Comments
 (0)