-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ONNX] Add pass that fuses Conv and BatchNormalization #40547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 2b1f1f8 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 77 times. |
|
@KsenijaS Could you please resolve the conflicts and take a looks at Caffe2 and model test failures? |
| return func(self) | ||
| return wrapper | ||
|
|
||
| def skipIfNoEmbed(func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad parameters for onnx and pytorch won't match for resnet model when fusing Conv and Batchnorm in the eval mode, and debug_embed_params test will fail because of that . One solution is to create a wrapper to skip debug_embed_params test. What do you think of this solution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you point me to where it fails?
neginraoof
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
Would be great if @houseroad could take a look at the update in Caffe2 tests.
| auto origconvNode = *it; | ||
| auto epsilon = bnNode->f(attr::epsilon); | ||
| auto w_conv_value = getValues(origconvNode, valsToParamsMap); | ||
| if (w_conv_value.size() < 1 || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic here could be stated more clearly.
If inputs are 3, size of w_conv_value should be 2
If inputs are 2, size of w_conv_value should be 1
Continue if none of the above.
|
@houseroad can you please take a look? Thanks |
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall it looks good to me, didn't carefully check the fuse logic. Left some inline comments.
| return func(self) | ||
| return wrapper | ||
|
|
||
| def skipIfNoEmbed(func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you point me to where it fails?
| for node in graph.nodes(): | ||
| assert node.kind() != "onnx::BatchNormalization" | ||
|
|
||
| def test_conv_bn(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we want to add model execution test here? The test_utility_funs.py suppose to only contain checks on the model structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are comparing ort outputs with and without training (without pass/ with pass) and in pytorch_onnx_onnxruntime we are comparing pytorch outputs with ort outputs
| ort_outs2 = ort_sess.run(None, ort_inputs) | ||
| [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in zip(ort_outs1, ort_outs2)] | ||
|
|
||
| def test_multiple_conv_bn(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, can we move them to test_pytorch_onnx_onnxruntime.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as before we are comparing ort outputs with pass with ort outputs without the pass.
|
@houseroad Traceback for test resnet with fuse conv-bn pass is: Traceback (most recent call last): x and y nan location mismatch: |
|
@houseroad do you have more questions? |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@bzinodev can this PR be merged? Thanks! |
Add pass that fuses Conv and Batchnormalization nodes into one node Conv.
This pass is only applied in inference mode (training is None or TrainingMode.Eval).
Since this pass needs access to param_dict it is written outside peephole file where these kind of passes (fusing multiple nodes into one) is usually placed.
This PR also adds wrapper skipIfNoEmbed to skip debug_embed_params test:
Pass that fuses Conv and Batchnorm changes the params of resnet model and parameters of onnx and pytorch model won't match. Since parameters are not matching, debug_embed_params test for test_resnet will fail and that is expected, therefore debug_embed_params test for test_resnet should be skipped.