Skip to content

Commit 3e8dc56

Browse files
BowenBaofacebook-github-bot
authored andcommitted
Bug fix: ONNX export full operator (#21669)
Summary: Fix an obvious bug. Pull Request resolved: #21669 Reviewed By: zrphercule Differential Revision: D15806614 Pulled By: houseroad fbshipit-source-id: d0f6e934252e0057f3dbcc7f160236ee6f4497ac
1 parent 4b45f08 commit 3e8dc56

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,16 @@ def forward(self, x):
12501250
self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
12511251
use_gpu=False, example_outputs=(torch.ones(x.size()),))
12521252

1253+
def test_full_script(self):
1254+
class FullClass(torch.jit.ScriptModule):
1255+
@torch.jit.script_method
1256+
def forward(self, x):
1257+
return torch.full((4, 5), x, dtype=torch.long)
1258+
1259+
x = torch.tensor(12)
1260+
self.run_model_test(FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE,
1261+
use_gpu=False, example_outputs=FullClass()(x))
1262+
12531263
def test_where_functional(self):
12541264
class WhereFunctional(torch.nn.Module):
12551265
def forward(self, x):

torch/onnx/symbolic_opset9.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,8 +1132,8 @@ def ones_like(g, input, dtype, layout, device, pin_memory=False):
11321132
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
11331133
const_value = sym_help._maybe_get_const(value, 't')
11341134
if sym_help._is_value(const_value):
1135-
tmp = zeros(sizes, dtype, layout, device)
1136-
return add(tmp, value, g.op("Constant", value_t=torch.tensor(1)))
1135+
tmp = zeros(g, sizes, dtype, layout, device)
1136+
return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
11371137
else:
11381138
dtype = sym_help._get_const(dtype, 'i', 'dtype')
11391139
return g.op("ConstantOfShape", sizes,

0 commit comments

Comments
 (0)