Skip to content

Conversation

@KsenijaS
Copy link
Contributor

@KsenijaS KsenijaS commented Jun 25, 2020

Add pass that fuses Conv and Batchnormalization nodes into one node Conv.
This pass is only applied in inference mode (training is None or TrainingMode.Eval).
Since this pass needs access to param_dict it is written outside peephole file where these kind of passes (fusing multiple nodes into one) is usually placed.

This PR also adds wrapper skipIfNoEmbed to skip debug_embed_params test:
Pass that fuses Conv and Batchnorm changes the params of resnet model and parameters of onnx and pytorch model won't match. Since parameters are not matching, debug_embed_params test for test_resnet will fail and that is expected, therefore debug_embed_params test for test_resnet should be skipped.

@KsenijaS KsenijaS requested a review from apaszke as a code owner June 25, 2020 01:27
@dr-ci
Copy link

dr-ci bot commented Jun 25, 2020

💊 CI failures summary and remediations

As of commit 2b1f1f8 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 77 times.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jun 25, 2020
@gchanan gchanan requested a review from houseroad June 25, 2020 20:19
@gchanan gchanan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 25, 2020
@neginraoof
Copy link
Contributor

@KsenijaS Could you please resolve the conflicts and take a looks at Caffe2 and model test failures?
Thanks

return func(self)
return wrapper

def skipIfNoEmbed(func):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad parameters for onnx and pytorch won't match for resnet model when fusing Conv and Batchnorm in the eval mode, and debug_embed_params test will fail because of that . One solution is to create a wrapper to skip debug_embed_params test. What do you think of this solution?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you point me to where it fails?

Copy link
Contributor

@neginraoof neginraoof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.
Would be great if @houseroad could take a look at the update in Caffe2 tests.

auto origconvNode = *it;
auto epsilon = bnNode->f(attr::epsilon);
auto w_conv_value = getValues(origconvNode, valsToParamsMap);
if (w_conv_value.size() < 1 ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here could be stated more clearly.
If inputs are 3, size of w_conv_value should be 2
If inputs are 2, size of w_conv_value should be 1
Continue if none of the above.

@KsenijaS
Copy link
Contributor Author

@houseroad can you please take a look? Thanks

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall it looks good to me, didn't carefully check the fuse logic. Left some inline comments.

return func(self)
return wrapper

def skipIfNoEmbed(func):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you point me to where it fails?

for node in graph.nodes():
assert node.kind() != "onnx::BatchNormalization"

def test_conv_bn(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to add model execution test here? The test_utility_funs.py suppose to only contain checks on the model structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are comparing ort outputs with and without training (without pass/ with pass) and in pytorch_onnx_onnxruntime we are comparing pytorch outputs with ort outputs

ort_outs2 = ort_sess.run(None, ort_inputs)
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in zip(ort_outs1, ort_outs2)]

def test_multiple_conv_bn(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, can we move them to test_pytorch_onnx_onnxruntime.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as before we are comparing ort outputs with pass with ort outputs without the pass.

@KsenijaS
Copy link
Contributor Author

KsenijaS commented Jul 16, 2020

@houseroad Traceback for test resnet with fuse conv-bn pass is:
FAIL: test_resnet (main.TestCaffe2Backend_opset9)


Traceback (most recent call last):
File "test/onnx/test_pytorch_onnx_caffe2.py", line 518, in test_resnet
state_dict=state_dict, atol=1e-5)
File "test/onnx/test_pytorch_onnx_caffe2.py", line 219, in run_model_test
operator_export_type=operator_export_type)
File "test/onnx/test_pytorch_onnx_caffe2.py", line 170, in run_debug_test
np.testing.assert_almost_equal(x.data.cpu().numpy(), y, decimal=3)
File "/home/ksenija/anaconda3/envs/pytorch/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 600, in assert_almost_equal
return assert_array_almost_equal(actual, desired, decimal, err_msg)
File "/home/ksenija/anaconda3/envs/pytorch/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1064, in assert_array_almost_equal
precision=decimal)
File "/home/ksenija/anaconda3/envs/pytorch/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 785, in assert_array_compare
flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
File "/home/ksenija/anaconda3/envs/pytorch/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 761, in func_assert_same_pos
raise AssertionError(msg)
AssertionError:
Arrays are not almost equal to 3 decimals

x and y nan location mismatch:
x: array([[-0.979, 0.093, -0.469, ..., -2.597, -0.888, 1.386],
[-0.582, -0.08 , -0.224, ..., -2.605, -0.825, 1.359]],
dtype=float32)
y: array([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], dtype=float32)

@KsenijaS
Copy link
Contributor Author

@houseroad do you have more questions?

@bzinodev bzinodev self-requested a review July 20, 2020 18:38
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@KsenijaS
Copy link
Contributor Author

@bzinodev can this PR be merged? Thanks!

@facebook-github-bot
Copy link
Contributor

@bzinodev merged this pull request in af5d0bf.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants