-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ONNX] Update ONNX constant folding to support opset 10. #22515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
spandantiwari
wants to merge
2
commits into
pytorch:master
from
spandantiwari:spandantiwari/constant_folding_opset10
Closed
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,47 +64,151 @@ void eraseUnusedBlockInputs(Block* b) { | |
| } | ||
| } | ||
|
|
||
| c10::optional<at::Tensor> runTorchBackendForOnnx( | ||
| const Node* node, | ||
| std::vector<at::Tensor>& inputTensorValues) { | ||
| at::Tensor updated_val; | ||
| if (node->kind() == onnx::Slice) { | ||
| assert(inputTensorValues.size() == 1); | ||
| if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) { | ||
| void handleNegativeStartEndIndex(int64_t& start, int64_t& end, int64_t& axis, | ||
| c10::IntArrayRef tensorSizes) { | ||
| if (start < 0) { | ||
| start = tensorSizes[axis] + start; | ||
| } | ||
| if (end < 0) { | ||
| end = tensorSizes[axis] + end; | ||
| } | ||
| // index higher than dimension is treated as the end. | ||
| if (end > tensorSizes[axis]) { | ||
| end = tensorSizes[axis]; | ||
| } | ||
| } | ||
|
|
||
| c10::optional<at::Tensor> runTorchSlice_opset9(const Node* node, | ||
| std::vector<at::Tensor>& inputTensorValues) { | ||
| assert(inputTensorValues.size() == 1); | ||
| if (inputTensorValues.size() != 1) { | ||
| std::cerr << "Warning: Constant folding - Invalid number of inputs found for opset 9 onnx::Slice op. " | ||
| << "Constant folding not applied." << std::endl; | ||
| return c10::nullopt; | ||
| } | ||
| if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) { | ||
| return c10::nullopt; | ||
| } | ||
| auto startsAttr = node->is(attr::starts); | ||
| auto endsAttr = node->is(attr::ends); | ||
| if (startsAttr.size() != endsAttr.size()) { | ||
| return c10::nullopt; | ||
| } | ||
| std::vector<int64_t> axesAttr; | ||
| if (node->hasAttributeS("axes")) { | ||
| axesAttr = node->is(attr::axes); | ||
| } else { | ||
| axesAttr.resize(startsAttr.size()); | ||
| std::iota(axesAttr.begin(), axesAttr.end(), 0); | ||
| } | ||
| auto updated_val = inputTensorValues[0]; | ||
| for (size_t i = 0; i < axesAttr.size(); ++i) { | ||
| // ONNX slice accepts negative starts and ends values. | ||
| int64_t axis = axesAttr[i], start = startsAttr[i], end = endsAttr[i]; | ||
| handleNegativeStartEndIndex(start, end, axis, updated_val.sizes()); | ||
| int64_t length = end - start; | ||
| if (length < 0 || start > updated_val.sizes()[axis] - length) | ||
| return c10::nullopt; | ||
| updated_val = at::narrow(updated_val, axis, start, length); | ||
| } | ||
| return c10::optional<at::Tensor>(updated_val); | ||
| } | ||
|
|
||
| c10::optional<at::Tensor> runTorchSlice_opset10(const Node* node, | ||
| std::vector<at::Tensor>& inputTensorValues) { | ||
| if (inputTensorValues.size() < 3 || inputTensorValues.size() > 5) { | ||
| std::cerr << "Warning: Constant folding - Invalid number of inputs found for opset 10 onnx::Slice op. " | ||
| << "Constant folding not applied." << std::endl; | ||
| return c10::nullopt; | ||
| } | ||
| // Checking validity of 'starts' and 'ends' input | ||
| if (inputTensorValues[1].sizes().size() != 1 || inputTensorValues[2].sizes().size() != 1) { | ||
| std::cerr << "Warning: Constant folding - Invalid 'starts' or 'ends' inputs found for opset 10 onnx::Slice op. " | ||
| << "Constant folding not applied." << std::endl; | ||
| return c10::nullopt; | ||
| } | ||
| if (inputTensorValues[1].sizes()[0] != inputTensorValues[2].sizes()[0] ) { | ||
| // Number of elements of 'starts' and 'ends' 1-D input tensors should be the same | ||
| return c10::nullopt; | ||
| } | ||
| // Checking 'axes' input, if available. | ||
| std::vector<int64_t> axes; | ||
| if (inputTensorValues.size() > 3) { | ||
| if (inputTensorValues[3].sizes().size() != 1) { | ||
| std::cerr << "Warning: Constant folding - Invalid 'axes' input found for opset 10 onnx::Slice op. " | ||
| << "Constant folding not applied." << std::endl; | ||
| return c10::nullopt; | ||
| } | ||
| auto startsAttr = node->is(attr::starts); | ||
| auto endsAttr = node->is(attr::ends); | ||
| if (startsAttr.size() != endsAttr.size()) { | ||
| if (inputTensorValues[3].sizes()[0] != inputTensorValues[1].sizes()[0] ) { | ||
| // Number of elements of 'axes' and 'ends' 1-D input tensors should be the same | ||
| std::cerr << "Warning: Constant folding - Invalid 'axes' or 'ends' inputs found for opset 10 onnx::Slice op. " | ||
| << "Constant folding not applied." << std::endl; | ||
| return c10::nullopt; | ||
| } | ||
| std::vector<int64_t> axesAttr; | ||
| if (node->hasAttributeS("axes")) { | ||
| axesAttr = node->is(attr::axes); | ||
| } else { | ||
| axesAttr.resize(startsAttr.size()); | ||
| std::iota(axesAttr.begin(), axesAttr.end(), 0); | ||
| auto axes_a = inputTensorValues[3].accessor<int64_t, 1>(); | ||
| axes.reserve(inputTensorValues[3].sizes()[0]); | ||
| for (size_t i = 0; i < inputTensorValues[3].sizes()[0]; ++i) { | ||
| axes[i] = axes_a[i]; | ||
| } | ||
| updated_val = inputTensorValues[0]; | ||
| for (size_t i = 0; i < axesAttr.size(); ++i) { | ||
| // ONNX slice accepts negative starts and ends values. | ||
| int64_t axis = axesAttr[i], start = startsAttr[i], end = endsAttr[i]; | ||
| if (start < 0) { | ||
| start = updated_val.sizes()[axis] + start; | ||
| } | ||
| if (end < 0) { | ||
| end = updated_val.sizes()[axis] + end; | ||
| } | ||
| // index higher than dimension is treated as the end. | ||
| if (end > updated_val.sizes()[axis]) { | ||
| end = updated_val.sizes()[axis]; | ||
| } | ||
| int64_t length = end - start; | ||
| if (length < 0 || start > updated_val.sizes()[axis] - length) | ||
| } | ||
| else { | ||
| axes = std::vector<int64_t>(inputTensorValues[1].sizes()[0], 0); | ||
| } | ||
| // Checking 'steps' input, if available. | ||
| if (inputTensorValues.size() > 4) { | ||
| if (inputTensorValues[4].sizes().size() != 1) { | ||
| std::cerr << "Warning: Constant folding - Invalid 'steps' input found for opset 10 onnx::Slice op. " | ||
| << "Constant folding not applied." << std::endl; | ||
| return c10::nullopt; | ||
| } | ||
| if (inputTensorValues[4].sizes()[0] != inputTensorValues[1].sizes()[0] ) { | ||
| // Number of elements of 'steps' and 'ends' 1-D input tensors should be the same | ||
| std::cerr << "Warning: Constant folding - Invalid 'steps' or 'ends' inputs found for opset 10 onnx::Slice op. " | ||
| << "Constant folding not applied." << std::endl; | ||
| return c10::nullopt; | ||
| } | ||
| auto steps_a = inputTensorValues[4].accessor<int64_t, 1>(); | ||
| for (size_t i = 0; i < inputTensorValues[4].sizes()[0]; ++i) { | ||
| // Only steps == 1 are supported for constant-folding. | ||
| if (steps_a[i] != 1) { | ||
| std::cerr << "Warning: Constant folding - Only steps=1 can be constant folded for opset 10 onnx::Slice op. " | ||
| << "Constant folding not applied." << std::endl; | ||
| return c10::nullopt; | ||
| updated_val = at::narrow(updated_val, axis, start, length); | ||
| } | ||
| } | ||
| } | ||
| auto starts_a = inputTensorValues[1].accessor<int64_t, 1>(); | ||
| auto ends_a = inputTensorValues[2].accessor<int64_t, 1>(); | ||
| auto updated_val = inputTensorValues[0]; | ||
| for (size_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) { | ||
| // ONNX slice accepts negative starts and ends values. | ||
| int64_t start = starts_a[i], end = ends_a[i], axis = axes[i]; | ||
| handleNegativeStartEndIndex(start, end, axis, updated_val.sizes()); | ||
| int64_t length = end - start; | ||
| if (length < 0 || start > updated_val.sizes()[axis] - length) | ||
| return c10::nullopt; | ||
| updated_val = at::narrow(updated_val, axis, start, length); | ||
| } | ||
| return c10::optional<at::Tensor>(updated_val); | ||
| } | ||
|
|
||
| c10::optional<at::Tensor> runTorchBackendForOnnx( | ||
| const Node* node, | ||
| std::vector<at::Tensor>& inputTensorValues, | ||
| int opset_version) { | ||
| at::Tensor updated_val; | ||
| if (node->kind() == onnx::Slice) { | ||
| if (opset_version == 9) { | ||
| return runTorchSlice_opset9(node, inputTensorValues); | ||
| } | ||
| else if (opset_version == 10) { | ||
| return runTorchSlice_opset10(node, inputTensorValues); | ||
| } | ||
| else { | ||
| std::cerr << "Warning: Constant folding - unsupported opset version. " | ||
| << "Constant folding not applied." << std::endl; | ||
| return c10::nullopt; | ||
| } | ||
| return c10::optional<at::Tensor>(updated_val); | ||
| } else if (node->kind() == onnx::Concat) { | ||
| if (!node->hasAttributeS("axis")) { | ||
| return c10::nullopt; | ||
|
|
@@ -218,7 +322,13 @@ std::vector<Node*> getOnnxConstParentsToRemove(Node* node) { | |
|
|
||
| // This method updates the block in-place to fold all the one-time | ||
| // constant-based computations/ops into an initializer node. | ||
| void ConstantFoldONNX(Block* b, ParamMap& paramsDict) { | ||
| void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) { | ||
| if (opset_version != 9 && opset_version != 10) { | ||
| // Number of elements of 'axes' and 'ends' 1-D input tensors should be the same | ||
| std::cerr << "Warning: Constant folding supported for only opsets 9 and 10. " | ||
| << "Constant folding not applied." << std::endl; | ||
| return; | ||
| } | ||
| AT_ASSERT(b->param_node()); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may also want to add a check on
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Done. |
||
| auto valsToParamsMap = buildValueToParamsMap(b, paramsDict); | ||
| // Only the root block is constant-folded. Folding nested blocks is | ||
|
|
@@ -234,13 +344,14 @@ void ConstantFoldONNX(Block* b, ParamMap& paramsDict) { | |
| // onnx::Constant, then skip this node. | ||
| continue; | ||
| } | ||
|
|
||
| auto inputTensorValues = getValues(node, valsToParamsMap); | ||
| if (inputTensorValues.empty()) { | ||
| // This is a terminal node with no inputs, such as onnx::Constant. Skip | ||
| // it. | ||
| continue; | ||
| } | ||
| auto updatedValWrapped = runTorchBackendForOnnx(node, inputTensorValues); | ||
| auto updatedValWrapped = runTorchBackendForOnnx(node, inputTensorValues, opset_version); | ||
| if (updatedValWrapped == c10::nullopt) { | ||
| // Constant folding is not supported for this op. Skip it. | ||
| continue; | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just return nullopt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Fixed.