Skip to content

Commit 67256d5

Browse files
desertfirepytorchmergebot
authored andcommitted
[aotinductor] Solves a problem where a tensor is returned more than once (#112177)
Pull Request resolved: #112177 Approved by: https://github.com/zhxchen17
1 parent 7180357 commit 67256d5

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,18 @@ def forward(self, x):
10111011
x = torch.randn(5, device=self.device)
10121012
self.check_model(Model(self.device), (x,))
10131013

1014+
def test_repeat_output(self):
1015+
class Model(torch.nn.Module):
1016+
def __init__(self):
1017+
super().__init__()
1018+
1019+
def forward(self, x):
1020+
y = torch.sin(x)
1021+
return y, y
1022+
1023+
example_inputs = (torch.randn(3, 10, device=self.device),)
1024+
self.check_model(Model(), example_inputs)
1025+
10141026

10151027
class AOTInductorTestABICompatibleCpu(TestCase):
10161028
device = "cpu"
@@ -1036,6 +1048,8 @@ class AOTInductorTestABICompatibleCpu(TestCase):
10361048
"test_freezing": TestFailure(("abi_compatible_cpu",), is_skip=True),
10371049
"test_normal_functional": TestFailure(("abi_compatible_cpu",)),
10381050
"test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
1051+
# There is a double-free issue which will be fixed in another PR
1052+
"test_repeat_output": TestFailure(("abi_compatible_cpu",), is_skip=True),
10391053
"test_sdpa": TestFailure(("abi_compatible_cpu",)),
10401054
"test_sdpa_2": TestFailure(("abi_compatible_cpu",)),
10411055
"test_simple_dynamic": TestFailure(("abi_compatible_cpu",)),
@@ -1058,6 +1072,8 @@ class AOTInductorTestABICompatibleCuda(TestCase):
10581072
{
10591073
"test_dup_unbacked_sym_decl": TestFailure(("abi_compatible_cuda",)),
10601074
"test_normal_functional": TestFailure(("abi_compatible_cuda",)),
1075+
# There is a double-free issue which will be fixed in another PR
1076+
"test_repeat_output": TestFailure(("abi_compatible_cuda",), is_skip=True),
10611077
},
10621078
)
10631079

torch/_export/exported_program.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict, buf
7171
# Step 2: Find the all the buffers that were mutated and update them
7272
if node.op == "output":
7373
user_output_nodes = []
74-
for return_node in node.all_input_nodes:
74+
# In the case that the same node is returned multiple times,
75+
# node.all_input_nodes will only iterate that node once
76+
for return_node in pytree.tree_flatten(node.args)[0]:
7577
return_node_name = return_node.name
7678
# we found a param/buffer mutation
7779
if return_node_name in buffers_to_mutate:

0 commit comments

Comments
 (0)