Skip to content

Commit da70976

Browse files
BowenBaofacebook-github-bot
authored andcommitted
[ONNX] Add support for operator add between tensor list (#41888)
Summary: E.g. ```python outs = [] outs += [torch.randn(3,4)] outs = outs + [torch.randn(4,5), torch.randn(5,6)] ``` Pull Request resolved: #41888 Reviewed By: houseroad Differential Revision: D23172880 Pulled By: bzinodev fbshipit-source-id: 93865106e3de5908a993e0cfa82f626ba94dab7e
1 parent c64594f commit da70976

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2700,11 +2700,15 @@ def forward(self, x):
27002700
res1 = []
27012701
arr = x.split([3, 4, 1, 1, 2, 3, 2], 0)
27022702
res2 = torch.zeros(3, 4, dtype=torch.long)
2703+
res3 = []
2704+
res4 = []
27032705
for i in range(len(arr)):
27042706
res = res.append(arr[i].sum(0, False))
27052707
res1 = res1.append(arr[-1 - i].sum(0, False))
27062708
res2 += 1
2707-
return torch.stack(res), torch.stack(res1), res2
2709+
res3 = res3 + [arr[i].sum(0, False)]
2710+
res4 += [arr[-1 - i].sum(0, False)]
2711+
return torch.stack(res), torch.stack(res1), res2, torch.stack(res3), torch.stack(res4)
27082712

27092713
model = ListLoopModel()
27102714
inputs = torch.randn(16)
@@ -2723,6 +2727,8 @@ def forward(self, x):
27232727

27242728
res.insert(0, tensors[1])
27252729
res.append(tensors[2])
2730+
res += [tensors[3], tensors[4]]
2731+
res = res + [tensors[5]]
27262732
return torch.ones(len(res))
27272733

27282734
model = ListModel()

torch/onnx/symbolic_opset11.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,19 @@ def append(g, self, tensor):
284284
return g.op("SequenceInsert", self, tensor)
285285

286286

287+
def add(g, self, other, alpha=None):
288+
if sym_help._is_value(self) and sym_help._is_tensor_list(self):
289+
tensor_list_node = other.node()
290+
if tensor_list_node.kind() != "prim::ListConstruct":
291+
return _unimplemented("add", "does not support adding dynamic tensor list to another")
292+
tensors = sym_help._unpack_list(other)
293+
l = self
294+
for t in tensors:
295+
l = g.op("SequenceInsert", l, t)
296+
return l
297+
298+
return torch.onnx.symbolic_opset9.add(g, self, other, alpha)
299+
287300
def insert(g, self, pos, tensor):
288301
return g.op("SequenceInsert", self, tensor, pos)
289302

torch/onnx/symbolic_opset9.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def reshape_as(g, self, other):
8383

8484

8585
def add(g, self, other, alpha=None):
86+
if sym_help._is_value(self) and sym_help._is_tensor_list(self):
87+
return sym_help._onnx_opset_unsupported_detailed('Add', 9, 11, 'Add between list of tensors not supported')
88+
8689
# default alpha arg is to allow no-alpha add (aten add st overload no alpha)
8790
if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
8891
return _unimplemented("add", "alpha != 1")

0 commit comments

Comments
 (0)