Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,24 @@ def forward(self, input):
return view_by_prim_shape(input)
self.run_model_test(PrimShapeModel(), train=False, input=x, batch_size=BATCH_SIZE)

def test_and(self):
class AndModel(torch.nn.Module):
def forward(self, x, y):
return x & y

x = torch.randint(0, 1, (3, 5))
y = torch.randint(0, 1, (3, 5))
self.run_model_test(AndModel(), train=False, input=(x, y), batch_size=BATCH_SIZE)

def test_or(self):
class OrModel(torch.nn.Module):
def forward(self, x, y):
return x | y

x = torch.randint(0, 1, (3, 5))
y = torch.randint(0, 1, (3, 5))
self.run_model_test(OrModel(), train=False, input=(x, y), batch_size=BATCH_SIZE)

# a bit of metaprogramming to set up all the rnn tests


Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ def _set_opset_version(opset_version):
'Int': torch.onnx.TensorProtoDataType.INT32,
'Long': torch.onnx.TensorProtoDataType.INT64,
'Short': torch.onnx.TensorProtoDataType.INT16,
'Bool': torch.onnx.TensorProtoDataType.BOOL,
}


scalar_name_to_pytorch = {
'uint8_t': 'Byte',
'int8_t': 'Char',
Expand Down
42 changes: 32 additions & 10 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,10 +681,22 @@ def upsample_bilinear2d(g, input, output_size, align_corners):
mode_s="linear")


def wrap_logical_op_with_cast_to_uint8(func):
def wrap_with_cast(g, input, other):
return g.op("Cast", func(g, input, other), to_i=sym_help.cast_pytorch_to_onnx['Byte'])
return wrap_with_cast
def wrap_logical_op_with_cast_to(to_type):
def decorator(fn):
def wrap_with_cast(g, input, other):
return g.op("Cast", fn(g, input, other), to_i=sym_help.cast_pytorch_to_onnx[to_type])
return wrap_with_cast
return decorator


def wrap_logical_op_with_cast_to_and_from(to_type):
def decorator(fn):
def wrap_with_cast(g, input, other):
to_cast_func = globals()['_cast_{}'.format(to_type)]
from_cast_func = wrap_logical_op_with_cast_to(input.type().scalarType())(fn)
return from_cast_func(g, to_cast_func(g, input, False), to_cast_func(g, other, False))
return wrap_with_cast
return decorator


def wrap_logical_op_with_negation(func):
Expand All @@ -693,18 +705,18 @@ def wrap_with_not(g, input, other):
return wrap_with_not


@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_cast_to('Byte')
def eq(g, self, other):
return g.op("Equal", self, other)


@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_cast_to('Byte')
@wrap_logical_op_with_negation
def ne(g, self, other):
return g.op("Equal", self, other)


@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_cast_to('Byte')
def gt(g, input, other):
return gt_impl(g, input, other)

Expand All @@ -714,7 +726,7 @@ def gt_impl(g, input, other):
return g.op("Greater", input, sym_help._if_scalar_type_as(g, other, input))


@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_cast_to('Byte')
def lt(g, input, other):
return lt_impl(g, input, other)

Expand All @@ -724,20 +736,30 @@ def lt_impl(g, input, other):
return g.op("Less", input, sym_help._if_scalar_type_as(g, other, input))


@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_cast_to('Byte')
@wrap_logical_op_with_negation
def ge(g, input, other):
other = sym_help._maybe_get_scalar(other)
return lt_impl(g, input, sym_help._if_scalar_type_as(g, other, input))


@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_cast_to('Byte')
@wrap_logical_op_with_negation
def le(g, input, other):
other = sym_help._maybe_get_scalar(other)
return gt_impl(g, input, sym_help._if_scalar_type_as(g, other, input))


@wrap_logical_op_with_cast_to_and_from('Bool')
def __and_(g, input, other):
return g.op('And', input, other)


@wrap_logical_op_with_cast_to_and_from('Bool')
def __or_(g, input, other):
return g.op('Or', input, other)


def where(g, condition, self, other):
return g.op("Where", condition, self, other)

Expand Down