Skip to content

Commit f446b82

Browse files
anderspapittocolesbury
authored andcommitted
introduce shape_as_tensor and reshape_from_variable_shape (#5824)
1 parent 99b1f6c commit f446b82

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

torch/onnx/operators.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
import torch.onnx
3+
4+
5+
def _shape_as_tensor(g, input):
6+
return g.op('Shape', input)
7+
8+
9+
@torch.onnx.symbolic_override(_shape_as_tensor)
10+
def shape_as_tensor(x):
11+
return torch.LongTensor(tuple(x.shape))
12+
13+
14+
def _reshape_from_tensor_shape(g, input, shape):
15+
return g.op('Reshape', input, shape)
16+
17+
18+
@torch.onnx.symbolic_override(_reshape_from_tensor_shape)
19+
def reshape_from_tensor_shape(x, shape):
20+
return x.reshape(shape.tolist())

0 commit comments

Comments
 (0)