-
Notifications
You must be signed in to change notification settings - Fork 26.3k
JIT pass for add relu fusion. #39343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit b2628b4 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 53 times. |
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8a90fef Pull Request resolved: #39343
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 55d97dc Pull Request resolved: #39343
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
| a = a * -10 | ||
| a = a + 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was out of habit for quantized stuff. I was trying to bias values of tensor rather than uniformly distributed between 0 and 1 to be able to catch error which are not easily canceled out. In this specific instance it may not matter.
test/test_jit.py
Outdated
| new_res = m(a, b) | ||
| FileCheck().check_not("aten::add(") \ | ||
| .check_not("aten::relu_(") \ | ||
| .check("aten::add_relu_(") \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this add_relu_? Shouldn't it be add_relu, since the original operand was not modified? Doesn't the optimized model mutate a, whereas the original did not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so add is not inplace but relu_ is. So it made sense to replace add + relu_ with add_relu_.
test/test_jit.py
Outdated
| .check_not("aten::relu(") \ | ||
| .check("aten::add_relu_(") \ | ||
| .run(m.graph) | ||
| torch.testing.assert_allclose(orig_res, new_res) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should also test that both models properly mutate a. Above, should probably test that they both don't mutate.
torch/csrc/jit/passes/fuse_relu.cpp
Outdated
| // inplace. since in the graph, it does not seem | ||
| // that output of add will be really used anywhere else. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't agree with this, since you're now mutating one of the addends, rather than the sum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. I think you are right. I need to check a graph like this:
def m(a, b, c):
tmp = torch.add(a, b)
x = self.relu_(tmp)
d = torch.add(a, c)
return x + d
If add + relu_ is replaced with add_relu_ then a is mutated and if a.add(c) is executed afterwards that is wrong. This transformation adds a new alias. In the original graph tmp = torch.add(a, b) a and tmp dont' alias but replacing that with tmp = torch.add_relu_(a, b) now I have made a and tmp alias each other.
I will test and fix this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dreiss I fixed this and also added exactly this test that I mentioned above. Without fixing it, the test fails, after the fix it passes.
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f35fc72 Pull Request resolved: #39343
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
Summary: Building on top of previous PR that adds fused add_relu op, this PR adds a JIT pass to transform input graph to find all fusable instancs of add + relu and fuses them. Test Plan: python test/test_jit.py TestJit.test_add_relu_fusion Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21822396](https://our.internmc.facebook.com/intern/diff/D21822396) [ghstack-poisoned]
|
This pull request has been merged in c5dcf05. |
Stack from ghstack:
Summary:
Building on top of previous PR that adds fused add_relu op, this PR adds
a JIT pass to transform input graph to find all fusable instancs of add
Test Plan:
python test/test_jit.py TestJit.test_add_relu_fusion
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D21822396