Skip to content

Commit 52de340

Browse files
BowenBaofacebook-github-bot
authored andcommitted
Export torch.masked_fill with onnx::where
Summary: Pull Request resolved: #22521 Reviewed By: zrphercule Differential Revision: D16155168 Pulled By: houseroad fbshipit-source-id: 5d419f08213324d474b839ba1ae13c799aeee92a
1 parent 6c99753 commit 52de340

File tree

4 files changed

+43
-0
lines changed

4 files changed

+43
-0
lines changed

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,6 +1956,24 @@ def forward(self, x):
19561956
inputs = torch.randn(3, 2, 1)
19571957
self.run_model_test(model, train=False, input=(inputs, ), batch_size=BATCH_SIZE)
19581958

1959+
1960+
@skipIfUnsupportedMinOpsetVersion(9)
1961+
def test_masked_fill(self):
1962+
class MaskedFillModel(torch.nn.Module):
1963+
def forward(self, x):
1964+
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8)
1965+
return x.masked_fill(mask, 2)
1966+
1967+
x = torch.zeros(4, 2, 3, requires_grad=True)
1968+
self.run_model_test(MaskedFillModel(), input=(x, ), train=False, batch_size=BATCH_SIZE)
1969+
1970+
class MaskedFillModel2(torch.nn.Module):
1971+
def forward(self, x):
1972+
return x.masked_fill(x > 3, -1)
1973+
1974+
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
1975+
self.run_model_test(MaskedFillModel2(), input=(x, ), train=False, batch_size=BATCH_SIZE)
1976+
19591977
# a bit of metaprogramming to set up all the rnn tests
19601978

19611979

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,23 @@ def forward(self, x):
364364
x = torch.randn(2, 3, 4)
365365
self.run_test(TensorFactory(), x)
366366

367+
@skipIfUnsupportedMinOpsetVersion(9)
368+
def test_masked_fill(self):
369+
class MaskedFillModel(torch.nn.Module):
370+
def forward(self, x):
371+
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8)
372+
return x.masked_fill(mask, 2)
373+
374+
x = torch.zeros(4, 2, 3, requires_grad=True)
375+
self.run_test(MaskedFillModel(), x)
376+
377+
class MaskedFillModel2(torch.nn.Module):
378+
def forward(self, x):
379+
return x.masked_fill(x > 3, -1)
380+
381+
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
382+
self.run_test(MaskedFillModel2(), x)
383+
367384

368385
# opset 7 tests
369386
TestONNXRuntime_opset7 = type(str("TestONNXRuntime_opset7"),
@@ -375,6 +392,7 @@ def forward(self, x):
375392
(unittest.TestCase,),
376393
dict(TestONNXRuntime.__dict__, opset_version=8))
377394

395+
378396
# opset 10 tests
379397
TestONNXRuntime_opset10 = type(str("TestONNXRuntime_opset10"),
380398
(unittest.TestCase,),

torch/onnx/symbolic_opset8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
black_listed_operators = [
4141
"nonzero", "where", "scatter", "scatter_add", "erf", "sign", "isnan", "gather",
42+
"masked_fill"
4243
]
4344

4445
for black_listed_op in black_listed_operators:

torch/onnx/symbolic_opset9.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,3 +1664,9 @@ def gather(g, self, dim, index, sparse_grad=False):
16641664
@parse_args('v', 'is', 'i')
16651665
def logsumexp(g, input, dim, keepdim):
16661666
return g.op('ReduceLogSumExp', input, axes_i=dim, keepdims_i=keepdim)
1667+
1668+
1669+
def masked_fill(g, self, mask, value):
1670+
mask = _cast_Bool(g, mask, False)
1671+
value = sym_help._maybe_get_scalar(value)
1672+
return g.op('Where', mask, sym_help._if_scalar_type_as(g, value, self), self)

0 commit comments

Comments
 (0)