Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


class TestUtilityFuns(TestCase):
opset_version = 9

def test_is_in_onnx_export(self):
test_self = self
Expand All @@ -26,7 +27,7 @@ def forward(self, x):
x = torch.randn(3, 4)
f = io.BytesIO()
try:
torch.onnx.export(MyModule(), x, f)
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
except ValueError:
self.assertFalse(torch.onnx.is_in_onnx_export())

Expand All @@ -37,7 +38,7 @@ def forward(self, x):
b = torch.transpose(a, 1, 0)
return b + x

_set_opset_version(9)
_set_opset_version(self.opset_version)
x = torch.ones(3, 2)
graph, _, __ = utils._model_to_graph(TransposeModule(), (x, ),
do_constant_folding=True,
Expand All @@ -55,7 +56,7 @@ def forward(self, x):
b = torch.narrow(a, 0, 0, 1)
return b + x

_set_opset_version(9)
_set_opset_version(self.opset_version)
x = torch.ones(1, 3)
graph, _, __ = utils._model_to_graph(NarrowModule(), (x, ),
do_constant_folding=True,
Expand All @@ -73,7 +74,7 @@ def forward(self, x):
b = a[1:10] # index exceeds dimension
return b + x

_set_opset_version(9)
_set_opset_version(self.opset_version)
x = torch.ones(1, 3)
graph, _, __ = utils._model_to_graph(SliceIndexExceedsDimModule(), (x, ),
do_constant_folding=True,
Expand All @@ -92,7 +93,7 @@ def forward(self, x):
b = a[0:-1] # index relative to the end
return b + x

_set_opset_version(9)
_set_opset_version(self.opset_version)
x = torch.ones(1, 3)
graph, _, __ = utils._model_to_graph(SliceNegativeIndexModule(), (x, ),
do_constant_folding=True,
Expand All @@ -110,7 +111,7 @@ def forward(self, x):
b = torch.unsqueeze(a, 0)
return b + x

_set_opset_version(9)
_set_opset_version(self.opset_version)
x = torch.ones(1, 2, 3)
graph, _, __ = utils._model_to_graph(UnsqueezeModule(), (x, ),
do_constant_folding=True,
Expand All @@ -129,7 +130,7 @@ def forward(self, x):
c = torch.cat((a, b), 0)
return b + c

_set_opset_version(9)
_set_opset_version(self.opset_version)
x = torch.ones(2, 3)
graph, _, __ = utils._model_to_graph(ConcatModule(), (x, ),
do_constant_folding=True,
Expand All @@ -149,7 +150,7 @@ def __init__(self):
def forward(self, input, initial_state):
return self.mygru(input, initial_state)

_set_opset_version(9)
_set_opset_version(self.opset_version)
input = torch.randn(5, 3, 7)
h0 = torch.randn(1, 3, 3)
graph, _, __ = utils._model_to_graph(GruNet(), (input, h0),
Expand All @@ -169,7 +170,7 @@ def __init__(self):
def forward(self, A):
return torch.matmul(A, torch.transpose(self.B, -1, -2))

_set_opset_version(9)
_set_opset_version(self.opset_version)
A = torch.randn(2, 3)
graph, _, __ = utils._model_to_graph(MatMulNet(), (A),
do_constant_folding=True)
Expand All @@ -185,9 +186,10 @@ def forward(self, input):

def is_model_stripped(f, strip_doc_string=None):
if strip_doc_string is None:
torch.onnx.export(MyModule(), x, f)
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
else:
torch.onnx.export(MyModule(), x, f, strip_doc_string=strip_doc_string)
torch.onnx.export(MyModule(), x, f, strip_doc_string=strip_doc_string,
opset_version=self.opset_version)
model = onnx.load(io.BytesIO(f.getvalue()))
model_strip = copy.copy(model)
onnx.helper.strip_doc_string(model_strip)
Expand All @@ -198,5 +200,10 @@ def is_model_stripped(f, strip_doc_string=None):
# test strip_doc_string=False
self.assertFalse(is_model_stripped(io.BytesIO(), False))

# opset 10 tests
TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=10))

if __name__ == '__main__':
run_tests()
5 changes: 3 additions & 2 deletions torch/csrc/jit/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ void initJITBindings(PyObject* module) {
.def(
"_jit_pass_onnx_constant_fold",
[](std::shared_ptr<Graph>& graph,
std::map<std::string, at::Tensor>& paramsDict) {
ConstantFoldONNX(graph->block(), paramsDict); // overload resolution
std::map<std::string, at::Tensor>& paramsDict,
int opset_version) {
ConstantFoldONNX(graph->block(), paramsDict, opset_version); // overload resolution
return paramsDict;
},
pybind11::return_value_policy::move)
Expand Down
183 changes: 147 additions & 36 deletions torch/csrc/jit/passes/onnx/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,47 +64,151 @@ void eraseUnusedBlockInputs(Block* b) {
}
}

c10::optional<at::Tensor> runTorchBackendForOnnx(
const Node* node,
std::vector<at::Tensor>& inputTensorValues) {
at::Tensor updated_val;
if (node->kind() == onnx::Slice) {
assert(inputTensorValues.size() == 1);
if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) {
void handleNegativeStartEndIndex(int64_t& start, int64_t& end, int64_t& axis,
c10::IntArrayRef tensorSizes) {
if (start < 0) {
start = tensorSizes[axis] + start;
}
if (end < 0) {
end = tensorSizes[axis] + end;
}
// index higher than dimension is treated as the end.
if (end > tensorSizes[axis]) {
end = tensorSizes[axis];
}
}

c10::optional<at::Tensor> runTorchSlice_opset9(const Node* node,
std::vector<at::Tensor>& inputTensorValues) {
assert(inputTensorValues.size() == 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just return nullopt?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Fixed.

if (inputTensorValues.size() != 1) {
std::cerr << "Warning: Constant folding - Invalid number of inputs found for opset 9 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) {
return c10::nullopt;
}
auto startsAttr = node->is(attr::starts);
auto endsAttr = node->is(attr::ends);
if (startsAttr.size() != endsAttr.size()) {
return c10::nullopt;
}
std::vector<int64_t> axesAttr;
if (node->hasAttributeS("axes")) {
axesAttr = node->is(attr::axes);
} else {
axesAttr.resize(startsAttr.size());
std::iota(axesAttr.begin(), axesAttr.end(), 0);
}
auto updated_val = inputTensorValues[0];
for (size_t i = 0; i < axesAttr.size(); ++i) {
// ONNX slice accepts negative starts and ends values.
int64_t axis = axesAttr[i], start = startsAttr[i], end = endsAttr[i];
handleNegativeStartEndIndex(start, end, axis, updated_val.sizes());
int64_t length = end - start;
if (length < 0 || start > updated_val.sizes()[axis] - length)
return c10::nullopt;
updated_val = at::narrow(updated_val, axis, start, length);
}
return c10::optional<at::Tensor>(updated_val);
}

c10::optional<at::Tensor> runTorchSlice_opset10(const Node* node,
std::vector<at::Tensor>& inputTensorValues) {
if (inputTensorValues.size() < 3 || inputTensorValues.size() > 5) {
std::cerr << "Warning: Constant folding - Invalid number of inputs found for opset 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
// Checking validity of 'starts' and 'ends' input
if (inputTensorValues[1].sizes().size() != 1 || inputTensorValues[2].sizes().size() != 1) {
std::cerr << "Warning: Constant folding - Invalid 'starts' or 'ends' inputs found for opset 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
if (inputTensorValues[1].sizes()[0] != inputTensorValues[2].sizes()[0] ) {
// Number of elements of 'starts' and 'ends' 1-D input tensors should be the same
return c10::nullopt;
}
// Checking 'axes' input, if available.
std::vector<int64_t> axes;
if (inputTensorValues.size() > 3) {
if (inputTensorValues[3].sizes().size() != 1) {
std::cerr << "Warning: Constant folding - Invalid 'axes' input found for opset 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
auto startsAttr = node->is(attr::starts);
auto endsAttr = node->is(attr::ends);
if (startsAttr.size() != endsAttr.size()) {
if (inputTensorValues[3].sizes()[0] != inputTensorValues[1].sizes()[0] ) {
// Number of elements of 'axes' and 'ends' 1-D input tensors should be the same
std::cerr << "Warning: Constant folding - Invalid 'axes' or 'ends' inputs found for opset 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
std::vector<int64_t> axesAttr;
if (node->hasAttributeS("axes")) {
axesAttr = node->is(attr::axes);
} else {
axesAttr.resize(startsAttr.size());
std::iota(axesAttr.begin(), axesAttr.end(), 0);
auto axes_a = inputTensorValues[3].accessor<int64_t, 1>();
axes.reserve(inputTensorValues[3].sizes()[0]);
for (size_t i = 0; i < inputTensorValues[3].sizes()[0]; ++i) {
axes[i] = axes_a[i];
}
updated_val = inputTensorValues[0];
for (size_t i = 0; i < axesAttr.size(); ++i) {
// ONNX slice accepts negative starts and ends values.
int64_t axis = axesAttr[i], start = startsAttr[i], end = endsAttr[i];
if (start < 0) {
start = updated_val.sizes()[axis] + start;
}
if (end < 0) {
end = updated_val.sizes()[axis] + end;
}
// index higher than dimension is treated as the end.
if (end > updated_val.sizes()[axis]) {
end = updated_val.sizes()[axis];
}
int64_t length = end - start;
if (length < 0 || start > updated_val.sizes()[axis] - length)
}
else {
axes = std::vector<int64_t>(inputTensorValues[1].sizes()[0], 0);
}
// Checking 'steps' input, if available.
if (inputTensorValues.size() > 4) {
if (inputTensorValues[4].sizes().size() != 1) {
std::cerr << "Warning: Constant folding - Invalid 'steps' input found for opset 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
if (inputTensorValues[4].sizes()[0] != inputTensorValues[1].sizes()[0] ) {
// Number of elements of 'steps' and 'ends' 1-D input tensors should be the same
std::cerr << "Warning: Constant folding - Invalid 'steps' or 'ends' inputs found for opset 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
auto steps_a = inputTensorValues[4].accessor<int64_t, 1>();
for (size_t i = 0; i < inputTensorValues[4].sizes()[0]; ++i) {
// Only steps == 1 are supported for constant-folding.
if (steps_a[i] != 1) {
std::cerr << "Warning: Constant folding - Only steps=1 can be constant folded for opset 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
updated_val = at::narrow(updated_val, axis, start, length);
}
}
}
auto starts_a = inputTensorValues[1].accessor<int64_t, 1>();
auto ends_a = inputTensorValues[2].accessor<int64_t, 1>();
auto updated_val = inputTensorValues[0];
for (size_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
// ONNX slice accepts negative starts and ends values.
int64_t start = starts_a[i], end = ends_a[i], axis = axes[i];
handleNegativeStartEndIndex(start, end, axis, updated_val.sizes());
int64_t length = end - start;
if (length < 0 || start > updated_val.sizes()[axis] - length)
return c10::nullopt;
updated_val = at::narrow(updated_val, axis, start, length);
}
return c10::optional<at::Tensor>(updated_val);
}

c10::optional<at::Tensor> runTorchBackendForOnnx(
const Node* node,
std::vector<at::Tensor>& inputTensorValues,
int opset_version) {
at::Tensor updated_val;
if (node->kind() == onnx::Slice) {
if (opset_version == 9) {
return runTorchSlice_opset9(node, inputTensorValues);
}
else if (opset_version == 10) {
return runTorchSlice_opset10(node, inputTensorValues);
}
else {
std::cerr << "Warning: Constant folding - unsupported opset version. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
return c10::optional<at::Tensor>(updated_val);
} else if (node->kind() == onnx::Concat) {
if (!node->hasAttributeS("axis")) {
return c10::nullopt;
Expand Down Expand Up @@ -218,7 +322,13 @@ std::vector<Node*> getOnnxConstParentsToRemove(Node* node) {

// This method updates the block in-place to fold all the one-time
// constant-based computations/ops into an initializer node.
void ConstantFoldONNX(Block* b, ParamMap& paramsDict) {
void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) {
if (opset_version != 9 && opset_version != 10) {
// Number of elements of 'axes' and 'ends' 1-D input tensors should be the same
std::cerr << "Warning: Constant folding supported for only opsets 9 and 10. "
<< "Constant folding not applied." << std::endl;
return;
}
AT_ASSERT(b->param_node());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also want to add a check on opset_version here as well. If it's not supported, skipping should be enough.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Done.

auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
// Only the root block is constant-folded. Folding nested blocks is
Expand All @@ -234,13 +344,14 @@ void ConstantFoldONNX(Block* b, ParamMap& paramsDict) {
// onnx::Constant, then skip this node.
continue;
}

auto inputTensorValues = getValues(node, valsToParamsMap);
if (inputTensorValues.empty()) {
// This is a terminal node with no inputs, such as onnx::Constant. Skip
// it.
continue;
}
auto updatedValWrapped = runTorchBackendForOnnx(node, inputTensorValues);
auto updatedValWrapped = runTorchBackendForOnnx(node, inputTensorValues, opset_version);
if (updatedValWrapped == c10::nullopt) {
// Constant folding is not supported for this op. Skip it.
continue;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/onnx/constant_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace torch {
namespace jit {

void ConstantFoldONNX(Block* b, std::map<std::string, at::Tensor>& paramDict);
void ConstantFoldONNX(Block* b, std::map<std::string, at::Tensor>& paramDict, int opset_version);

}
} // namespace torch
5 changes: 3 additions & 2 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,9 @@ def _model_to_graph(model, args, verbose=False, training=False,
param_names = input_and_param_names[len(input_and_param_names) - len(params):]
params_dict = dict(zip(param_names, params))

if do_constant_folding and _export_onnx_opset_version == 9:
params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict)
if do_constant_folding and _export_onnx_opset_version in [9, 10]:
params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
_export_onnx_opset_version)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)

if verbose:
Expand Down