-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add hop for additional control dependencies #164568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164568
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 48c663d with merge base 7617b11 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies. This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap. There's definitely some similarity with the `with_effects` hop. Talked with angelayi - when richard is back we will figure out how we want to consolidate. The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects` breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](https://github.com/pytorch/pytorch/blob/aed66248a01d309eb2ac1149b5f51310545b0783/torch/fx/experimental/proxy_tensor.py#L1246-L1249). By maintaining the node with its original calling structure in subgraph this all works. Example transformation: Before: ``` %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {}) %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {}) ``` After: ``` add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1) mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1) mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add) ``` The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
| from torch.utils._ordered_set import OrderedSet | ||
|
|
||
|
|
||
| class ControlDeps(HigherOrderOperator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this live with the other HOPs? Put it in torch/_higher_order_ops/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we defer on this when we actually have discussions on consolidation and such with @zou3519
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should live in torch/_higher_order_ops -- it's easier to keep track of HOPs that way.
Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies. This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap. There's definitely some similarity with the `with_effects` hop. Talked with angelayi - when zou3519 is back we will figure out how we want to consolidate. The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects` breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](https://github.com/pytorch/pytorch/blob/aed66248a01d309eb2ac1149b5f51310545b0783/torch/fx/experimental/proxy_tensor.py#L1246-L1249). By maintaining the node with its original calling structure in subgraph this all works. Example transformation: Before: ``` %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {}) %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {}) ``` After: ``` add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1) mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1) mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add) ``` The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
|
|
||
| for add_dep in V.graph.additional_buffer_deps[buf.get_name()]: | ||
| add_user(add_dep, node, is_weak=True) | ||
| node.add_fake_dep(WeakDep(add_dep, node.get_name())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You turn these into weak deps. That means they don't impede DCE. Is it important that FX level DCE is also able to DCE past control deps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is intentional. they are just additional dependencies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for now - could consider a non-dce version with consolidation.
| raise TypeError( | ||
| f"subgraph must be GraphModule or callable, got {type(subgraph).__name__}" | ||
| ) | ||
| return super().__call__(additional_deps, subgraph, *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a dumb, noobie question about HOPs. How exactly does mutation inside the HOP work here? I know for many HOPs we have to go through great lengths to assert there's no mutation in the HOP. For this case where you are just wrapping a single op it seems feasible to handle a mutable op. But how exactly does this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only being inserted in a post grad pass. so it does not need to handle mutation and other things as invoke_subgraph does.
| ordered_node.name = original_name # PRESERVE ORIGINAL NAME | ||
|
|
||
| # Track the replacement for future dependencies | ||
| replacements[dependent_node] = ordered_node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if I have a replacement from n1 -> n2, and then later I do n2 -> n3, how do I ensure n1 -> n3? Or is this impossible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any node such only have a single set of additional deps, so it cant be replaced twice.
| subgraph_module = _create_subgraph_for_node(graph, dependent_node) | ||
|
|
||
| subgraph_attr_name = f"subgraph_{original_name}" | ||
| setattr(graph.owning_module, subgraph_attr_name, subgraph_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you assert the subgraph name doesn't already exist. But it would be better to have some way of just allocating a fresh name here
|
Will this hop added in post_grad prevent other fusions (e.g. asynctp ag+mm, mm+rs that looks for mm), as pass looks for torch.ops.aten.mm. |
Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies. This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap. There's definitely some similarity with the `with_effects` hop. Talked with angelayi - when zou3519 is back we will figure out how we want to consolidate. The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects` breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](https://github.com/pytorch/pytorch/blob/aed66248a01d309eb2ac1149b5f51310545b0783/torch/fx/experimental/proxy_tensor.py#L1246-L1249). By maintaining the node with its original calling structure in subgraph this all works. Example transformation: Before: ``` %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {}) %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {}) ``` After: ``` add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1) mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1) mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add) ``` The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
|
@IvanKobzarev since we're only adding these deps for nodes that are hidden, it shouldn't affect those fusions, which are targeting exposed collectives. But, potentially, in the future we could unwrap the subgraphs, do other passes, and add them back as needed. |
| updated_dep_nodes = [replacements.get(dep, dep) for dep in dep_nodes] | ||
|
|
||
| # Create a subgraph that preserves the exact original operation | ||
| subgraph_module = _create_subgraph_for_node(graph, dependent_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on the topic of "better aot_eager perf": I imagine that we probably want one of these outcomes in the situation where we are running bucketing/reordering and only using the aot_eager backend:
(1) we have the reordering pass not bother inserting these HOPs/subgraph if we know we are not going to inductor
(2) or we insert the dependencies unconditionally, but if we do we may need to make sure that the aot_eager runtime code doesn't do anything too slow with them? no action needed, mostly thinking out loud (maybe calling a subgraph in the fx interpreter is fast enough already, or if not maybe we make the control_deps HOP have a faster eager impl)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 - we have a config for this, we'd set it false for aot_eager.
2 - if it turns out we do need to insert these for another reordering pass, we'd probably want to inline it at the end of invocation. (Similarly, we'd potentially want to inline invoke_subgraph)
|
Starting merge as part of PR stack under #164569 |
Now that we have a hop to add implicit deps - use those deps for comm/compute overlap. Pull Request resolved: #164569 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: #164568
When we are looking if two nodes are dependent, limit path search within the bounds of their node idxs. Pull Request resolved: #164581 Approved by: https://github.com/ezyang ghstack dependencies: #164568, #164569
Original work by @ShatianWang, with lints applied. I am going to a few changes and add tests in subsequent prs but I want to preserve original commit first. Pull Request resolved: #164738 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #164568, #164569, #164581
|
Discussed a bit with @bdhirsh, there are some things that might be problems. NB: I did not read this PR carefully, it might already handle these cases.
|
I can submit a pr to address this. For now, we do not trace this, it's only inserted post tracing, so i'm not sure it's relevant. Nor is there a way for pattern matcher to target this.
This is probably because of more general, tracing, functionalization issues which are not relevant here since this is only applied as a post grad pass, post functionalization.
I think this is only reinplacing today - which uses Fake Storages, and which this maintains. We should make sure other passes use Fake Storages as well. I do agree we need to be careful about Fake Tensor Updater - once #159523 lands i can add a test. |
Sorry I think I was confused. To check, the additional_deps aren't being returned as outputs of |
|
@zou3519 yes, correct (today at least) |
ghstack-source-id: 1cc8bad Pull Request resolved: pytorch#164568
Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies. This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap. There's definitely some similarity with the `with_effects` hop. Talked with @angelayi - when @zou3519 is back we will figure out how we want to consolidate. The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects` breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](https://github.com/pytorch/pytorch/blob/aed66248a01d309eb2ac1149b5f51310545b0783/torch/fx/experimental/proxy_tensor.py#L1246-L1249). By maintaining the node with its original calling structure in subgraph this all works. Example transformation: Before: ``` %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {}) %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {}) ``` After: ``` add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1) mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1) mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add) ``` The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs. Pull Request resolved: pytorch#164568 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
Now that we have a hop to add implicit deps - use those deps for comm/compute overlap. Pull Request resolved: pytorch#164569 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: pytorch#164568
When we are looking if two nodes are dependent, limit path search within the bounds of their node idxs. Pull Request resolved: pytorch#164581 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#164568, pytorch#164569
Original work by @ShatianWang, with lints applied. I am going to a few changes and add tests in subsequent prs but I want to preserve original commit first. Pull Request resolved: pytorch#164738 Approved by: https://github.com/IvanKobzarev ghstack dependencies: pytorch#164568, pytorch#164569, pytorch#164581
Stack from ghstack (oldest at bottom):
Adds control_deps higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.
This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap.
There's definitely some similarity with the
with_effectshop. Talked with @angelayi - when @zou3519 is back we will figure out how we want to consolidate.The implementation needs to be a subgraph (as opposed to
with_effects) because inductor relies onV.graph.current_node. Changing the signature of the node withwith_effectsbreaks this, and additionally, also breaks striding constraints on the wrapped node - see this TODO. By maintaining the node with its original calling structure in subgraph this all works.Example transformation:
Before:
After:
The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben