Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion caffe2/onnx/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ Caffe2Backend::get_special_operators() const {
{"DynamicSlice", &Caffe2Backend::CreateDynamicSlice},
{"RandomNormal", &Caffe2Backend::CreateRandomNormal},
{"RandomNormalLike", &Caffe2Backend::CreateRandomNormal},
{"Where", &Caffe2Backend::CreateWhereOp}};
{"Where", &Caffe2Backend::CreateWhereOp},
{"NonZero", &Caffe2Backend::CreateNonZeroOp}};
return kSpecialOperators;
}

Expand Down Expand Up @@ -598,6 +599,30 @@ Caffe2Ops Caffe2Backend::CreateWhereOp(
return CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
}

Caffe2Ops Caffe2Backend::CreateNonZeroOp(
OnnxNode* onnx_node,
const ConversionContext& ctx) {
// Native Caffe2 doesn't support NonZero, fallback to ATen.
// ATen nonzero is equivalent to Transpose(ONNX::NonZero).
const auto& node = onnx_node->node;

onnx::NodeProto converted;
converted.CopyFrom(onnx_node->node);

auto nonzero_output = dummy_->NewDummyName();
converted.set_output(0, nonzero_output);
converted.set_op_type("ATen");
onnx::AttributeProto* attr = converted.add_attribute();
attr->set_name("operator");
attr->set_s("nonzero");
OnnxNode new_node(converted);
auto ret = CommonOnnxNodeToCaffe2Ops(&new_node, ctx);

auto* c2_transpose = ret.ops.Add();
BuildOperator(c2_transpose, "Transpose", {nonzero_output}, {onnx_node->node.output(0)});
return ret;
}

Caffe2Ops Caffe2Backend::CreateReciprocal(
OnnxNode* onnx_node,
const ConversionContext& ctx) {
Expand Down
2 changes: 2 additions & 0 deletions caffe2/onnx/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ class CAFFE2_API Caffe2Backend {

Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx);

Caffe2Ops CreateNonZeroOp(OnnxNode* onnx_node, const ConversionContext& ctx);

Caffe2Ops CreateBatchNormalization(
OnnxNode* onnx_node,
const ConversionContext& ctx);
Expand Down
13 changes: 5 additions & 8 deletions test/onnx/test_onnx_opset.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,15 @@ def __init__(self):
super(MyModule, self).__init__()

def forward(self, x):
return torch._dim_arange(x, 1)
return x - 1

module = MyModule()
ops_8 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
ops_8 = [{"op_name" : "Constant"},
{"op_name" : "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
{"op_name" : "Range"}]
ops_9 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
{"op_name" : "Range"}]
{"op_name" : "Sub"}]
ops_9 = [{"op_name" : "Constant"}, {"op_name" : "Sub"}]
ops = {8 : ops_8, 9 : ops_9}
x = torch.ones(5, 6)
x = torch.ones(5, 6, dtype=torch.long)
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])

def test_slice(self):
Expand Down
55 changes: 55 additions & 0 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,7 @@ def forward(self, input):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_model_test(CeilModel(), train=False, input=x, batch_size=BATCH_SIZE)

@skipIfUnsupportedMinOpsetVersion(9)
def test__dim_arange(self):
class DimArange(torch.nn.Module):
def forward(self, input):
Expand All @@ -1790,6 +1791,60 @@ def forward(self, input):
x = torch.ones(5, 6)
self.run_model_test(DimArange(), train=False, input=x, batch_size=BATCH_SIZE)

@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a

x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))

class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a

self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE)

@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a

x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))

class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a

self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE)

@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end_step(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a

x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))

class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a

self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE)

def test_log2(self):
class Log2Model(torch.nn.Module):
def forward(self, input):
Expand Down
60 changes: 60 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,66 @@ def forward(self, input, other):
y = torch.randn(4, 1, requires_grad=True)
self.run_test(model, (x, y))

@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a

x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_test(ArangeScript(), x)

class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a

self.run_test(ArangeModel(), x)

@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a

x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_test(ArangeScript(), x)

class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a

self.run_test(ArangeModel(), x)

@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end_step(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a

x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_test(ArangeScript(), x)

class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a

self.run_test(ArangeModel(), x)

@skipIfUnsupportedMinOpsetVersion(9)
def test__dim_arange(self):
class DimArange(torch.nn.Module):
def forward(self, input):
return torch._dim_arange(input, 1)

x = torch.ones(5, 6)
self.run_test(DimArange(), x)

def test_gt(self):
class GreaterModel(torch.nn.Module):
def forward(self, input, other):
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/symbolic_opset8.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

black_listed_operators = [
"nonzero", "where", "scatter", "scatter_add", "erf", "sign", "isnan", "gather",
"masked_fill"
"arange", "masked_fill"
]

for black_listed_op in black_listed_operators:
Expand Down
44 changes: 43 additions & 1 deletion torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,11 @@ def symbolic(g, *args):
def _dim_arange(g, like, dim):
like_shape = g.op('Shape', like)
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
return g.op("_caffe2::Range", stop)
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("_caffe2::Range", stop)
else:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
return arange(g, stop, 4, None, None, None)


def detach(g, input):
Expand Down Expand Up @@ -1666,6 +1670,44 @@ def logsumexp(g, input, dim, keepdim):
return g.op('ReduceLogSumExp', input, axes_i=dim, keepdims_i=keepdim)


def arange(g, *args):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", *args, operator_s="arange")

def _get_arange_dtype(dtype):
dtype = sym_help._maybe_get_const(dtype, 'i')
if sym_help._is_value(dtype):
dtype = 4 # default to int64
return dtype

if len(args) == 5:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[1])
end = g.op("Unsqueeze", args[0], axes_i=[0])
arange_tensor = g.op("Squeeze", nonzero(g, ones(g, end, dtype, *(args[2:]))), axes_i=[1])
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 6:
# aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[2])
end = g.op("Unsqueeze", args[1], axes_i=[0])
start = g.op("Unsqueeze", args[0], axes_i=[0])
range_tensor = g.op("Sub", end, start)
arange_tensor = g.op("Add", g.op("Squeeze", nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), axes_i=[1]), start)
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 7:
# aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[3])
step = g.op("Unsqueeze", args[2], axes_i=[0])
end = g.op("Unsqueeze", args[1], axes_i=[0])
start = g.op("Unsqueeze", args[0], axes_i=[0])
range_tensor = g.op("Div", g.op("Sub", end, start), step)
arange_tensor = g.op("Squeeze", nonzero(g, ones(g, range_tensor, dtype, *(args[4:]))), axes_i=[1])
arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start)
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
else:
raise NotImplementedError("Unknown aten::arange signature taking " + str(len(args)) + " arguments.")


def masked_fill(g, self, mask, value):
mask = _cast_Bool(g, mask, False)
value = sym_help._maybe_get_scalar(value)
Expand Down