Skip to content

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Oct 3, 2025

Stack from ghstack (oldest at bottom):

Adds control_deps higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.

This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap.

There's definitely some similarity with the with_effects hop. Talked with @angelayi - when @zou3519 is back we will figure out how we want to consolidate.

The implementation needs to be a subgraph (as opposed to with_effects) because inductor relies on V.graph.current_node. Changing the signature of the node with with_effects breaks this, and additionally, also breaks striding constraints on the wrapped node - see this TODO. By maintaining the node with its original calling structure in subgraph this all works.

Example transformation:

Before:

%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})

After:

add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)

The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164568

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 48c663d with merge base 7617b11 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
@eellison eellison added the topic: not user facing topic category label Oct 3, 2025
Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.

This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap. 

There's definitely some similarity with the `with_effects` hop. Talked with angelayi  - when richard is back we will figure out how we want to consolidate.

The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects`  breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](https://github.com/pytorch/pytorch/blob/aed66248a01d309eb2ac1149b5f51310545b0783/torch/fx/experimental/proxy_tensor.py#L1246-L1249). By maintaining the node with its original calling structure in subgraph this all works.

Example transformation:

Before:
```
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
```
After:
```
add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)
```

The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
from torch.utils._ordered_set import OrderedSet


class ControlDeps(HigherOrderOperator):
Copy link
Contributor

@ezyang ezyang Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this live with the other HOPs? Put it in torch/_higher_order_ops/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we defer on this when we actually have discussions on consolidation and such with @zou3519

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should live in torch/_higher_order_ops -- it's easier to keep track of HOPs that way.

Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.

This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap. 

There's definitely some similarity with the `with_effects` hop. Talked with angelayi  - when zou3519  is back we will figure out how we want to consolidate.

The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects`  breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](https://github.com/pytorch/pytorch/blob/aed66248a01d309eb2ac1149b5f51310545b0783/torch/fx/experimental/proxy_tensor.py#L1246-L1249). By maintaining the node with its original calling structure in subgraph this all works.

Example transformation:

Before:
```
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
```
After:
```
add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)
```

The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]

for add_dep in V.graph.additional_buffer_deps[buf.get_name()]:
add_user(add_dep, node, is_weak=True)
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You turn these into weak deps. That means they don't impede DCE. Is it important that FX level DCE is also able to DCE past control deps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is intentional. they are just additional dependencies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now - could consider a non-dce version with consolidation.

raise TypeError(
f"subgraph must be GraphModule or callable, got {type(subgraph).__name__}"
)
return super().__call__(additional_deps, subgraph, *args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a dumb, noobie question about HOPs. How exactly does mutation inside the HOP work here? I know for many HOPs we have to go through great lengths to assert there's no mutation in the HOP. For this case where you are just wrapping a single op it seems feasible to handle a mutable op. But how exactly does this work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only being inserted in a post grad pass. so it does not need to handle mutation and other things as invoke_subgraph does.

ordered_node.name = original_name # PRESERVE ORIGINAL NAME

# Track the replacement for future dependencies
replacements[dependent_node] = ordered_node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if I have a replacement from n1 -> n2, and then later I do n2 -> n3, how do I ensure n1 -> n3? Or is this impossible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any node such only have a single set of additional deps, so it cant be replaced twice.

subgraph_module = _create_subgraph_for_node(graph, dependent_node)

subgraph_attr_name = f"subgraph_{original_name}"
setattr(graph.owning_module, subgraph_attr_name, subgraph_module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you assert the subgraph name doesn't already exist. But it would be better to have some way of just allocating a fresh name here

@IvanKobzarev
Copy link
Contributor

Will this hop added in post_grad prevent other fusions (e.g. asynctp ag+mm, mm+rs that looks for mm), as pass looks for torch.ops.aten.mm.
Or do we apply all other fx fusions first and then apply additional deps?

Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.

This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap. 

There's definitely some similarity with the `with_effects` hop. Talked with angelayi  - when zou3519  is back we will figure out how we want to consolidate.

The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects`  breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](https://github.com/pytorch/pytorch/blob/aed66248a01d309eb2ac1149b5f51310545b0783/torch/fx/experimental/proxy_tensor.py#L1246-L1249). By maintaining the node with its original calling structure in subgraph this all works.

Example transformation:

Before:
```
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
```
After:
```
add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)
```

The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

eellison commented Oct 3, 2025

@IvanKobzarev since we're only adding these deps for nodes that are hidden, it shouldn't affect those fusions, which are targeting exposed collectives.

But, potentially, in the future we could unwrap the subgraphs, do other passes, and add them back as needed.

updated_dep_nodes = [replacements.get(dep, dep) for dep in dep_nodes]

# Create a subgraph that preserves the exact original operation
subgraph_module = _create_subgraph_for_node(graph, dependent_node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the topic of "better aot_eager perf": I imagine that we probably want one of these outcomes in the situation where we are running bucketing/reordering and only using the aot_eager backend:

(1) we have the reordering pass not bother inserting these HOPs/subgraph if we know we are not going to inductor

(2) or we insert the dependencies unconditionally, but if we do we may need to make sure that the aot_eager runtime code doesn't do anything too slow with them? no action needed, mostly thinking out loud (maybe calling a subgraph in the fx interpreter is fast enough already, or if not maybe we make the control_deps HOP have a faster eager impl)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 - we have a config for this, we'd set it false for aot_eager.

2 - if it turns out we do need to insert these for another reordering pass, we'd probably want to inline it at the end of invocation. (Similarly, we'd potentially want to inline invoke_subgraph)

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #164569

@eellison eellison mentioned this pull request Oct 6, 2025
pytorchmergebot pushed a commit that referenced this pull request Oct 6, 2025
Now that we have a hop to add implicit deps - use those deps for comm/compute overlap.

Pull Request resolved: #164569
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164568
pytorchmergebot pushed a commit that referenced this pull request Oct 6, 2025
When we are looking if two nodes are dependent, limit path search within the bounds of their node idxs.

Pull Request resolved: #164581
Approved by: https://github.com/ezyang
ghstack dependencies: #164568, #164569
pytorchmergebot pushed a commit that referenced this pull request Oct 7, 2025
Original work by @ShatianWang, with lints applied. I am going to a few changes and add tests in subsequent prs but I want to preserve original commit first.

Pull Request resolved: #164738
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #164568, #164569, #164581
@zou3519
Copy link
Contributor

zou3519 commented Oct 7, 2025

Discussed a bit with @bdhirsh, there are some things that might be problems. NB: I did not read this PR carefully, it might already handle these cases.

  1. the HOP should not return the inputs directly as outputs. Instead it should return aliases. This is because bad things happen when you make_fx any operation that returns the input directly as an output (we used to have some clobbering happen for inplace operations, I'm not sure how we avoided it). make_fx might happen at the Inductor level, because the pattern matcher does use make_fx to trace patterns
  2. there are passes in inductor that analyze views and aliases. Those need to be updated to understand that this HOP returns aliases. In general HOPs don't support aliases and we may have baked that assumption into various place (we might have also used this information in invoke_subgraph caching somewhere?). An example is the FakeTensorUpdator: FakeTensorUpdator doesn't support HOPs yet, but assuming it does, it needs to understand that this new HOP returns views. The view information is important to prevent silent incorrectness in reinplacing (this has bit us a couple of times).

@eellison
Copy link
Contributor Author

eellison commented Oct 7, 2025

is because bad things happen when you make_fx any operation that returns the input directly as an output (we used to have some clobbering happen for inplace operations, I'm not sure how we avoided it). make_fx might happen at the Inductor level, because the pattern matcher does use make_fx to trace patterns

I can submit a pr to address this. For now, we do not trace this, it's only inserted post tracing, so i'm not sure it's relevant. Nor is there a way for pattern matcher to target this.

there are passes in inductor that analyze views and aliases. Those need to be updated to understand that this HOP returns aliases. In general HOPs don't support aliases

This is probably because of more general, tracing, functionalization issues which are not relevant here since this is only applied as a post grad pass, post functionalization.

there are passes in inductor that analyze views and aliases.

I think this is only reinplacing today - which uses Fake Storages, and which this maintains. We should make sure other passes use Fake Storages as well.

I do agree we need to be careful about Fake Tensor Updater - once #159523 lands i can add a test.

@zou3519
Copy link
Contributor

zou3519 commented Oct 7, 2025

  1. the HOP should not return the inputs directly as outputs. Instead it should return aliases. This is because bad things happen when you make_fx any operation that returns the input directly as an output (we used to have some clobbering happen for inplace operations, I'm not sure how we avoided it). make_fx might happen at the Inductor level, because the pattern matcher does use make_fx to trace patterns

Sorry I think I was confused. To check, the additional_deps aren't being returned as outputs of control_deps?

@eellison
Copy link
Contributor Author

eellison commented Oct 7, 2025

@zou3519 yes, correct (today at least)

eellison added a commit to eellison/pytorch that referenced this pull request Oct 11, 2025
ghstack-source-id: 1cc8bad
Pull Request resolved: pytorch#164568
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.

This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap.

There's definitely some similarity with the `with_effects` hop. Talked with @angelayi  - when @zou3519  is back we will figure out how we want to consolidate.

The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects`  breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](https://github.com/pytorch/pytorch/blob/aed66248a01d309eb2ac1149b5f51310545b0783/torch/fx/experimental/proxy_tensor.py#L1246-L1249). By maintaining the node with its original calling structure in subgraph this all works.

Example transformation:

Before:
```
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
```
After:
```
add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)
```

The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.

Pull Request resolved: pytorch#164568
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Now that we have a hop to add implicit deps - use those deps for comm/compute overlap.

Pull Request resolved: pytorch#164569
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: pytorch#164568
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
When we are looking if two nodes are dependent, limit path search within the bounds of their node idxs.

Pull Request resolved: pytorch#164581
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#164568, pytorch#164569
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Original work by @ShatianWang, with lints applied. I am going to a few changes and add tests in subsequent prs but I want to preserve original commit first.

Pull Request resolved: pytorch#164738
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: pytorch#164568, pytorch#164569, pytorch#164581
@github-actions github-actions bot deleted the gh/eellison/836/head branch November 7, 2025 02:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants