Skip to content

Commit a77d633

Browse files
neginraooffacebook-github-bot
authored andcommitted
[ONNX] Fix view for dynamic input shape (#43558)
Summary: Export of view op with dynamic input shape is broken when using tensors with a 0-dim. This fix removes symbolic use of static input size to fix this issue. Pull Request resolved: #43558 Reviewed By: ailzhang Differential Revision: D23965090 Pulled By: bzinodev fbshipit-source-id: 628e9d7ee5d53375f25052340ca6feabf7ba7c53
1 parent 5d1fee2 commit a77d633

File tree

5 files changed

+40
-36
lines changed

5 files changed

+40
-36
lines changed

test/onnx/expect/TestOperators.test_view.expect

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,26 @@ producer_name: "pytorch"
33
producer_version: "CURRENT_VERSION"
44
graph {
55
node {
6-
input: "0"
76
output: "1"
8-
name: "Flatten_0"
9-
op_type: "Flatten"
7+
name: "Constant_0"
8+
op_type: "Constant"
109
attribute {
11-
name: "axis"
12-
i: 1
13-
type: INT
10+
name: "value"
11+
t {
12+
dims: 2
13+
data_type: 7
14+
raw_data: "\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000"
15+
}
16+
type: TENSOR
1417
}
1518
}
19+
node {
20+
input: "0"
21+
input: "1"
22+
output: "2"
23+
name: "Reshape_1"
24+
op_type: "Reshape"
25+
}
1626
name: "torch-jit-export"
1727
input {
1828
name: "0"
@@ -28,7 +38,7 @@ graph {
2838
}
2939
}
3040
output {
31-
name: "1"
41+
name: "2"
3242
type {
3343
tensor_type {
3444
elem_type: 1

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2649,6 +2649,17 @@ def forward(self, input, other):
26492649
shape = torch.randn(6, 4)
26502650
self.run_test(ViewModel(), (x, shape))
26512651

2652+
def test_view_dynamic_zero_dim(self):
2653+
class ViewModel(torch.nn.Module):
2654+
def forward(self, input):
2655+
input = input.view(-1, 2)
2656+
return input.view(1, -1)
2657+
2658+
x = torch.ones(2)
2659+
another_x = torch.empty((0,))
2660+
self.run_test(ViewModel(), x, test_with_inputs=[another_x],
2661+
input_names=['input_1'], dynamic_axes={'input_1': [0, ]})
2662+
26522663
def test_view_as(self):
26532664
class ViewModel(torch.nn.Module):
26542665
def forward(self, input, other):

torch/onnx/symbolic_opset12.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def celu(g, self, alpha):
6565
def argmax(g, input, dim, keepdim):
6666
if sym_help._is_none(dim):
6767
from torch.onnx.symbolic_opset9 import reshape
68-
flattened = reshape(g, input, (-1,))
68+
flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1])))
6969
return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False, select_last_index_i=False)
7070
else:
7171
dim = _parse_arg(dim, 'i')
@@ -76,7 +76,7 @@ def argmax(g, input, dim, keepdim):
7676
def argmin(g, input, dim, keepdim):
7777
if sym_help._is_none(dim):
7878
from torch.onnx.symbolic_opset9 import reshape
79-
flattened = reshape(g, input, (-1,))
79+
flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1])))
8080
return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False, select_last_index_i=False)
8181
else:
8282
dim = _parse_arg(dim, 'i')

torch/onnx/symbolic_opset8.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -182,20 +182,6 @@ def addmm(g, self, mat1, mat2, beta, alpha):
182182
return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha))
183183

184184

185-
def view(g, self, size):
186-
size = sym_help._maybe_get_const(size, 'is')
187-
if sym_help._is_value(size):
188-
shape = size
189-
else:
190-
if self.isCompleteTensor():
191-
self_sizes = self.type().sizes()
192-
if self_sizes and len(size) == 2 and self_sizes[0] == size[0]:
193-
old_type, self = _try_cast_integer_to_float(g, self)
194-
return _cast_to_type(g, g.op("Flatten", self, axis_i=1), old_type)
195-
shape = g.op("Constant", value_t=torch.LongTensor(size))
196-
return g.op("Reshape", self, shape)
197-
198-
199185
def flatten(g, input, start_dim, end_dim):
200186
start_dim_i = sym_help._get_const(start_dim, 'i', 'start_dim')
201187
end_dim_i = sym_help._get_const(end_dim, 'i', 'end_dim')
@@ -290,5 +276,5 @@ def repeat(g, self, repeats):
290276
sizes = self.type().sizes()
291277
diff_dims = repeat_size_len - len(sizes)
292278
if diff_dims > 0:
293-
self = sym_opset9.view(g, self, [1] * diff_dims + sizes)
279+
self = sym_opset9.view(g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes)))
294280
return g.op("Tile", self, repeats)

torch/onnx/symbolic_opset9.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def expand(g, self, size, implicit):
374374
# Expand with -1 dim value means dim is unchanged.
375375
# Since onnx::expand supports two-way broadcasting,
376376
# -1 dim value can be exported to onnx as 1
377-
size = view(g, stack(g, size, 0), [-1])
377+
size = view(g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])))
378378
dtype = 4 # dim type is int64
379379
ones = ones_like(g, size, dtype)
380380
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
@@ -461,13 +461,10 @@ def view(g, self, size):
461461
if sym_help._is_value(size):
462462
shape = size
463463
else:
464-
if self.isCompleteTensor():
465-
self_sizes = self.type().sizes()
466-
if self_sizes and len(size) == 2 and self_sizes[0] == size[0]:
467-
return g.op("Flatten", self, axis_i=1)
468464
shape = g.op("Constant", value_t=torch.LongTensor(size))
469465
return g.op("Reshape", self, shape)
470466

467+
471468
def view_as(g, self, other):
472469
shape = g.op("Shape", other)
473470
return g.op("Reshape", self, shape)
@@ -1783,12 +1780,12 @@ def pixel_shuffle(g, self, upscale_factor):
17831780
if len(dims) != 4:
17841781
return _unimplemented("pixel_shuffle", "only support 4d input")
17851782
output_channel = dims[1] // upscale_factor // upscale_factor
1786-
after_view = view(g, self, [-1, output_channel, upscale_factor, upscale_factor,
1787-
dims[2], dims[3]])
1783+
after_view = view(g, self, g.op("Constant", value_t=torch.tensor([-1, output_channel, upscale_factor,
1784+
upscale_factor, dims[2], dims[3]])))
17881785
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
17891786
return view(g, after_transpose,
1790-
[-1, output_channel, dims[2] * upscale_factor, dims[3] *
1791-
upscale_factor])
1787+
g.op("Constant", value_t=torch.tensor([-1, output_channel, dims[2] * upscale_factor,
1788+
dims[3] * upscale_factor])))
17921789

17931790

17941791
def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
@@ -2136,7 +2133,7 @@ def narrow(g, input, dim, start, length):
21362133

21372134
def argmax(g, input, dim, keepdim):
21382135
if sym_help._is_none(dim):
2139-
flattened = reshape(g, input, (-1,))
2136+
flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1])))
21402137
return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False)
21412138
else:
21422139
dim = _parse_arg(dim, 'i')
@@ -2146,7 +2143,7 @@ def argmax(g, input, dim, keepdim):
21462143

21472144
def argmin(g, input, dim, keepdim):
21482145
if sym_help._is_none(dim):
2149-
flattened = reshape(g, input, (-1,))
2146+
flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1])))
21502147
return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False)
21512148
else:
21522149
dim = _parse_arg(dim, 'i')
@@ -2453,7 +2450,7 @@ def baddbmm(g, self, batch1, batch2, beta, alpha):
24532450

24542451

24552452
def meshgrid(g, tensor_list):
2456-
tensors = [view(g, t, torch.LongTensor([-1])) for t in sym_help._unpack_list(tensor_list)]
2453+
tensors = [view(g, t, g.op("Constant", value_t=torch.LongTensor([-1]))) for t in sym_help._unpack_list(tensor_list)]
24572454
tensors_shape = [g.op("Shape", t) for t in tensors]
24582455
out_shape = g.op("Concat", *tensors_shape, axis_i=0)
24592456
out = []

0 commit comments

Comments
 (0)