Skip to content

Commit ab2c874

Browse files
committed
Add __and__, __or__ onnx export
nit nit refactor wrapper merge fix flake8: line too long fix typo fix regression on not supporting boolean and/or rename wrapped func
1 parent 6741471 commit ab2c874

File tree

3 files changed

+51
-11
lines changed

3 files changed

+51
-11
lines changed

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,6 +1651,24 @@ def forward(self, input):
16511651
return view_by_prim_shape(input)
16521652
self.run_model_test(PrimShapeModel(), train=False, input=x, batch_size=BATCH_SIZE)
16531653

1654+
def test_and(self):
1655+
class AndModel(torch.nn.Module):
1656+
def forward(self, x, y):
1657+
return x & y
1658+
1659+
x = torch.randint(0, 1, (3, 5))
1660+
y = torch.randint(0, 1, (3, 5))
1661+
self.run_model_test(AndModel(), train=False, input=(x, y), batch_size=BATCH_SIZE)
1662+
1663+
def test_or(self):
1664+
class OrModel(torch.nn.Module):
1665+
def forward(self, x, y):
1666+
return x | y
1667+
1668+
x = torch.randint(0, 1, (3, 5))
1669+
y = torch.randint(0, 1, (3, 5))
1670+
self.run_model_test(OrModel(), train=False, input=(x, y), batch_size=BATCH_SIZE)
1671+
16541672
# a bit of metaprogramming to set up all the rnn tests
16551673

16561674

torch/onnx/symbolic_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ def _set_opset_version(opset_version):
247247
'Int': torch.onnx.TensorProtoDataType.INT32,
248248
'Long': torch.onnx.TensorProtoDataType.INT64,
249249
'Short': torch.onnx.TensorProtoDataType.INT16,
250+
'Bool': torch.onnx.TensorProtoDataType.BOOL,
250251
}
251252

252-
253253
scalar_name_to_pytorch = {
254254
'uint8_t': 'Byte',
255255
'int8_t': 'Char',

torch/onnx/symbolic_opset9.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -681,10 +681,22 @@ def upsample_bilinear2d(g, input, output_size, align_corners):
681681
mode_s="linear")
682682

683683

684-
def wrap_logical_op_with_cast_to_uint8(func):
685-
def wrap_with_cast(g, input, other):
686-
return g.op("Cast", func(g, input, other), to_i=sym_help.cast_pytorch_to_onnx['Byte'])
687-
return wrap_with_cast
684+
def wrap_logical_op_with_cast_to(to_type):
685+
def decorator(fn):
686+
def wrap_with_cast(g, input, other):
687+
return g.op("Cast", fn(g, input, other), to_i=sym_help.cast_pytorch_to_onnx[to_type])
688+
return wrap_with_cast
689+
return decorator
690+
691+
692+
def wrap_logical_op_with_cast_to_and_from(to_type):
693+
def decorator(fn):
694+
def wrap_with_cast(g, input, other):
695+
to_cast_func = globals()['_cast_{}'.format(to_type)]
696+
from_cast_func = wrap_logical_op_with_cast_to(input.type().scalarType())(fn)
697+
return from_cast_func(g, to_cast_func(g, input, False), to_cast_func(g, other, False))
698+
return wrap_with_cast
699+
return decorator
688700

689701

690702
def wrap_logical_op_with_negation(func):
@@ -693,18 +705,18 @@ def wrap_with_not(g, input, other):
693705
return wrap_with_not
694706

695707

696-
@wrap_logical_op_with_cast_to_uint8
708+
@wrap_logical_op_with_cast_to('Byte')
697709
def eq(g, self, other):
698710
return g.op("Equal", self, other)
699711

700712

701-
@wrap_logical_op_with_cast_to_uint8
713+
@wrap_logical_op_with_cast_to('Byte')
702714
@wrap_logical_op_with_negation
703715
def ne(g, self, other):
704716
return g.op("Equal", self, other)
705717

706718

707-
@wrap_logical_op_with_cast_to_uint8
719+
@wrap_logical_op_with_cast_to('Byte')
708720
def gt(g, input, other):
709721
return gt_impl(g, input, other)
710722

@@ -714,7 +726,7 @@ def gt_impl(g, input, other):
714726
return g.op("Greater", input, sym_help._if_scalar_type_as(g, other, input))
715727

716728

717-
@wrap_logical_op_with_cast_to_uint8
729+
@wrap_logical_op_with_cast_to('Byte')
718730
def lt(g, input, other):
719731
return lt_impl(g, input, other)
720732

@@ -724,20 +736,30 @@ def lt_impl(g, input, other):
724736
return g.op("Less", input, sym_help._if_scalar_type_as(g, other, input))
725737

726738

727-
@wrap_logical_op_with_cast_to_uint8
739+
@wrap_logical_op_with_cast_to('Byte')
728740
@wrap_logical_op_with_negation
729741
def ge(g, input, other):
730742
other = sym_help._maybe_get_scalar(other)
731743
return lt_impl(g, input, sym_help._if_scalar_type_as(g, other, input))
732744

733745

734-
@wrap_logical_op_with_cast_to_uint8
746+
@wrap_logical_op_with_cast_to('Byte')
735747
@wrap_logical_op_with_negation
736748
def le(g, input, other):
737749
other = sym_help._maybe_get_scalar(other)
738750
return gt_impl(g, input, sym_help._if_scalar_type_as(g, other, input))
739751

740752

753+
@wrap_logical_op_with_cast_to_and_from('Bool')
754+
def __and_(g, input, other):
755+
return g.op('And', input, other)
756+
757+
758+
@wrap_logical_op_with_cast_to_and_from('Bool')
759+
def __or_(g, input, other):
760+
return g.op('Or', input, other)
761+
762+
741763
def where(g, condition, self, other):
742764
return g.op("Where", condition, self, other)
743765

0 commit comments

Comments
 (0)