Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions test/expect/TestScript.test_export_dynamic_slice.expect
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ModelProto {
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x", type:Tensor dims: 3 4 5}]
outputs: [{name: "7", type:Tensor dims: 4 5}]
outputs: [{name: "8", type:Tensor dims: 4 5}]
initializers: []
nodes: [
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Expand All @@ -15,19 +15,21 @@ ModelProto {
Node {type: "Gather", inputs: [x,2], outputs: [4], attributes: [{ name: 'axis', type: int, value: 0}]},
Node {type: "Shape", inputs: [x], outputs: [5], attributes: []},
Node {type: "Gather", inputs: [5,3], outputs: [6], attributes: [{ name: 'axis', type: int, value: 0}]},
Node {type: "Loop", inputs: [6,1,4], outputs: [7], attributes: [{ name: 'body', type: graph, value:
Node {type: "Cast", inputs: [1], outputs: [7], attributes: [{ name: 'to', type: int, value: 9}]},
Node {type: "Loop", inputs: [6,7,4], outputs: [8], attributes: [{ name: 'body', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "10", type:Tensor dims: }]
outputs: [{name: "1", type:Tensor dims: },{name: "16", type:Tensor dims: }]
inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "11", type:Tensor dims: }]
outputs: [{name: "18", type:Tensor dims: },{name: "17", type:Tensor dims: }]
initializers: []
nodes: [
Node {type: "Unsqueeze", inputs: [2], outputs: [11], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [i], outputs: [12], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [2], outputs: [13], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "DynamicSlice", inputs: [x,11,12,13], outputs: [14], attributes: []},
Node {type: "ReduceSum", inputs: [14], outputs: [15], attributes: [{ name: 'axes', type: ints, values: [0]},{ name: 'keepdims', type: int, value: 0}]},
Node {type: "Add", inputs: [10,15], outputs: [16], attributes: []}
Node {type: "Unsqueeze", inputs: [2], outputs: [12], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [i], outputs: [13], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [2], outputs: [14], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "DynamicSlice", inputs: [x,12,13,14], outputs: [15], attributes: []},
Node {type: "ReduceSum", inputs: [15], outputs: [16], attributes: [{ name: 'axes', type: ints, values: [0]},{ name: 'keepdims', type: int, value: 0}]},
Node {type: "Add", inputs: [11,16], outputs: [17], attributes: []},
Node {type: "Cast", inputs: [1], outputs: [18], attributes: [{ name: 'to', type: int, value: 9}]}
]
}

Expand Down
22 changes: 13 additions & 9 deletions test/expect/TestScript.test_onnx_export_script_module_loop.expect
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,35 @@ ModelProto {
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x.1", type:Tensor dims: 1 2 3}]
outputs: [{name: "4", type:Tensor dims: 1 2 3}]
outputs: [{name: "5", type:Tensor dims: 1 2 3}]
initializers: []
nodes: [
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Loop", inputs: [2,1,x.1], outputs: [4], attributes: [{ name: 'body', type: graph, value:
Node {type: "Cast", inputs: [1], outputs: [4], attributes: [{ name: 'to', type: int, value: 9}]},
Node {type: "Loop", inputs: [2,4,x.1], outputs: [5], attributes: [{ name: 'body', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: [{name: "5", type:Tensor dims: },{name: "cond.1", type:Tensor dims: },{name: "7", type:Tensor dims: }]
outputs: [{name: "1", type:Tensor dims: },{name: "8", type:Tensor dims: }]
inputs: [{name: "6", type:Tensor dims: },{name: "cond.1", type:Tensor dims: },{name: "8", type:Tensor dims: }]
outputs: [{name: "16", type:Tensor dims: },{name: "10", type:Tensor dims: }]
initializers: []
nodes: [
Node {type: "Loop", inputs: [3,1,7], outputs: [8], attributes: [{ name: 'body', type: graph, value:
Node {type: "Cast", inputs: [1], outputs: [9], attributes: [{ name: 'to', type: int, value: 9}]},
Node {type: "Loop", inputs: [3,9,8], outputs: [10], attributes: [{ name: 'body', type: graph, value:
GraphProto {
name: "torch-jit-export2"
inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "11", type:Tensor dims: }]
outputs: [{name: "1", type:Tensor dims: },{name: "12", type:Tensor dims: }]
inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "13", type:Tensor dims: }]
outputs: [{name: "15", type:Tensor dims: },{name: "14", type:Tensor dims: }]
initializers: []
nodes: [
Node {type: "Add", inputs: [11,i], outputs: [12], attributes: []}
Node {type: "Add", inputs: [13,i], outputs: [14], attributes: []},
Node {type: "Cast", inputs: [1], outputs: [15], attributes: [{ name: 'to', type: int, value: 9}]}
]
}

}]}
}]},
Node {type: "Cast", inputs: [1], outputs: [16], attributes: [{ name: 'to', type: int, value: 9}]}
]
}

Expand Down
88 changes: 88 additions & 0 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,6 +1735,94 @@ def forward(self, x, y):
y = torch.randint(0, 1, (3, 5))
self.run_model_test(OrModel(), train=False, input=(x, y), batch_size=BATCH_SIZE)

def test_while(self):
class WhileModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
a = 0
while a < 4:
a += 1
return x + a

model = WhileModel()
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))

def test_while_cond(self):
class WhileModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, a):
b = (a < 4)
while b:
a += b.to(torch.long)
b = (a < 4)
return x + a

model = WhileModel()
x = torch.zeros(1, 2, 3, dtype=torch.long)
a = torch.tensor([0], dtype=torch.long)
outputs = model(x, a)
self.run_model_test(model, train=False, input=(x, a), batch_size=BATCH_SIZE,
example_outputs=(outputs,))

def test_loop(self):
class LoopModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
for i in range(5):
x = x + i
return x

model = LoopModel()
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))

def test_dynamic_loop(self):
class LoopModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
for i in range(x.size(2)):
x = x + i
return x

model = LoopModel()
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))

def test_nested_loops(self):
class NestedLoopsModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
for i in range(5):
a = 0
while a < 4:
a += 1
for j in range(a):
x = x + j
x = x + a
return x

model = NestedLoopsModel()
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))

def test_select(self):
class SelectModel(torch.nn.Module):
def forward(self, x):
return torch.select(x, 0, 1)

model = SelectModel()
inputs = torch.randn(3, 2, 1)
self.run_model_test(model, train=False, input=(inputs, ), batch_size=BATCH_SIZE)

# a bit of metaprogramming to set up all the rnn tests


Expand Down
15 changes: 9 additions & 6 deletions torch/csrc/jit/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
return onnx::TensorProto_DataType_INT32;
case at::kLong:
return onnx::TensorProto_DataType_INT64;
case at::kBool:
return onnx::TensorProto_DataType_BOOL;
default:
AT_ERROR("unexpected tensor scalar type");
}
Expand All @@ -206,19 +208,20 @@ void EncoderBase::EncodeValueInfo(
onnx::ValueInfoProto* v,
const Value* n) {
v->set_name(n->uniqueName());
onnx::TypeProto* t = v->mutable_type();
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();

onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
if (CompleteTensorTypePtr node_type = n->type()->cast<CompleteTensorType>()) {
onnx::TypeProto* t = v->mutable_type();
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
const std::vector<std::int64_t>& sizes = node_type->sizes();
for (size_t i = 0; i < sizes.size(); i++) {
shape->add_dim();
shape->mutable_dim(i)->set_dim_value(sizes[i]);
}
tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
} else {
tensor_type->set_elem_type(onnx::TensorProto_DataType_UNDEFINED);
} else if (BoolTypePtr node_type = n->type()->cast<BoolType>()) {
onnx::TypeProto* t = v->mutable_type();
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
tensor_type->set_elem_type(ATenTypeToOnnxType(at::kBool));
}
}

Expand Down
55 changes: 52 additions & 3 deletions torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,61 @@
namespace torch {
namespace jit {

namespace onnx{
using namespace ::c10::onnx;
}

Node* CreateCastToBoolNode(Value* val, Graph* graph) {
Node* cast_node = graph->create(onnx::Cast);
cast_node->addInput(val);
cast_node->i_(attr::to, /*Bool*/9);
return cast_node;
}

Node* InsertCastForCond(Value* cond_val, Graph* graph, Node* consumer_node) {
// prev: cond_val -> consumer_node
// after: cond_val -> cast -> consumer_node
// NOTE: The cast is required because operators like PyTorch Greater/Less return tensor
// in type torch.uint8. However the type for condition input in ONNX Loop must be Bool.
Node* cast_node = CreateCastToBoolNode(cond_val, graph);
cast_node->insertBefore(consumer_node);

consumer_node->replaceInputWith(cond_val, cast_node->output());
return cast_node;
}

bool IsCondCastRequired(Value* cond_val) {
const auto& type = cond_val->type();
if (type->isSubclass(TypeKind::DimensionedTensorType)) {
return type->expect<DimensionedTensorType>()->scalarType() != c10::kBool;
}
return !type->isSubclass(TypeKind::BoolType);
}

void FixupONNXLoops(Block* block) {
for (auto* node : block->nodes()) {
if (node->kind() == ::c10::onnx::Loop) {
AT_ASSERT(node->blocks().size() == 1);
auto* sub_block = node->blocks()[0];
sub_block->insertInput(1, "cond");
auto* loop_node = node;
auto* graph = loop_node->owningGraph();

// add cast to condition input outside the loop.
Value* cond_val = loop_node->inputs()[1];
if (IsCondCastRequired(cond_val))
InsertCastForCond(cond_val, graph, loop_node);

// Setup Loop input cond and i.
TORCH_INTERNAL_ASSERT(loop_node->blocks().size() == 1);
auto* sub_block = loop_node->blocks()[0];
Value* cond = sub_block->insertInput(1, "cond");
cond->setType(BoolType::create());

Value* i = sub_block->inputs()[0];
i->setType(CompleteTensorType::fromNumberType(IntType::get()));

// add cast to condition input inside the loop.
Value* next_cond_val = sub_block->outputs()[0];
if (IsCondCastRequired(next_cond_val))
InsertCastForCond(next_cond_val, graph, sub_block->return_node());
}
for (Block* block : node->blocks()) {
FixupONNXLoops(block);
Expand Down