Skip to content

Commit cd7c86e

Browse files
IvanYashchukpytorchmergebot
authored andcommitted
Add prims.clone (#86705)
This simple PR adds `clone` as a primitive. Current implementation of `clone` is not supported with nvFuser executor because of `empty_like` + `copy_to`. Pull Request resolved: #86705 Approved by: https://github.com/mruberry
1 parent 3356d03 commit cd7c86e

File tree

4 files changed

+42
-5
lines changed

4 files changed

+42
-5
lines changed

torch/_prims/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@
159159
#
160160
# Data conversion and movement prims
161161
#
162+
"clone",
162163
"convert_element_type",
163164
"device_put",
164165
"item",
@@ -595,6 +596,40 @@ def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType:
595596
return_type=RETURN_TYPE.NEW,
596597
)
597598

599+
600+
def _clone_meta(
601+
input: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
602+
) -> TensorLikeType:
603+
if memory_format != torch.preserve_format:
604+
return torch.empty(
605+
input.shape,
606+
dtype=input.dtype,
607+
layout=input.layout,
608+
device=input.device,
609+
requires_grad=input.requires_grad,
610+
memory_format=memory_format,
611+
)
612+
613+
# memory_format == torch.preserve_format
614+
strides = utils.compute_elementwise_output_strides(input)
615+
return torch.empty_strided(
616+
input.shape,
617+
strides,
618+
dtype=input.dtype,
619+
layout=input.layout,
620+
device=input.device,
621+
requires_grad=input.requires_grad,
622+
)
623+
624+
625+
clone = _make_prim(
626+
schema="clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
627+
meta=_clone_meta,
628+
impl_aten=torch.clone,
629+
doc="Returns the copy of a tensor",
630+
return_type=RETURN_TYPE.NEW,
631+
)
632+
598633
digamma = _make_elementwise_unary_prim(
599634
"digamma",
600635
impl_aten=torch.digamma,

torch/_prims/nvfuser_prims.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"atanh",
4343
"cos",
4444
"cosh",
45+
"clone",
4546
"bitwise_not",
4647
"ceil",
4748
"erf",
@@ -322,9 +323,14 @@ def _amin_nvfuser(
322323
return fd.ops.min(a, dims, keep_dims)
323324

324325

326+
def _clone_nvfuser(fd: Any, input: TensorLikeType, *, memory_format=None):
327+
return fd.ops.set(input)
328+
329+
325330
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
326331
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
327332
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
333+
_nvfuser_impls["clone"] = _clone_nvfuser
328334
_nvfuser_impls["transpose"] = _transpose_nvfuser
329335
_nvfuser_impls["squeeze"] = _squeeze_nvfuser
330336
_nvfuser_impls["view_of"] = _view_of_nvfuser

torch/_refs/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,10 +1675,7 @@ def where(
16751675
def clone(
16761676
a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
16771677
) -> TensorLikeType:
1678-
result = torch.empty_like(
1679-
a, requires_grad=a.requires_grad, memory_format=memory_format
1680-
)
1681-
copy_to(result, a)
1678+
result = prims.clone(a, memory_format=memory_format)
16821679
return result
16831680

16841681

torch/testing/_internal/common_methods_invocations.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17578,7 +17578,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1757817578
PythonRefInfo(
1757917579
"_refs.clone",
1758017580
torch_opinfo_name="clone",
17581-
supports_nvfuser=False,
1758217581
),
1758317582
#
1758417583
# View & Shape OpInfos

0 commit comments

Comments
 (0)