Skip to content

Commit 6abca6a

Browse files
kwen2501pytorchmergebot
authored andcommitted
[export][unflatten] More strictly respect scope when removing inputs (#127607)
Code snippet from TorchTitan (LLaMa): ``` for layer in self.layers.values(): h = layer(h, self.freqs_cis) ``` `self.freqs_cis` is a buffer of root module (`self`). It is also an explicit arg in the call signature of original `layer` modules. If not respecting scope -- `freqs_cis`'s scope only corresponds to root -- `_sink_param` can remove `freqs_cis` from `layer`'s call signature, resulting in runtime error. There are two fixes in this PR: 1. We filter out the `inputs_to_state` corresponding to the current scope, using existing code that does prefix matching. 2. We delay the removal of param inputs from `call_module` nodes' `args`, till `_sink_param` call on that submodule returns. The return now returns information on which input is actually removed by the submodule, thus more accurate than just doing: ``` for node in call_module_nodes: node.args = tuple(filter(lambda n: n.name not in inputs_to_state, node.args)) ``` Before the PR: ![Screenshot 2024-05-31 at 1 40 24 AM](https://github.com/pytorch/pytorch/assets/6676466/a2e06b18-44d5-40ca-b242-0edab45075b7) After the PR: ![Screenshot 2024-05-31 at 1 43 41 AM](https://github.com/pytorch/pytorch/assets/6676466/b72afb94-cdfa-420d-b88b-29a92bf2a0c0) Pull Request resolved: #127607 Approved by: https://github.com/pianpwk
1 parent e216df4 commit 6abca6a

File tree

2 files changed

+106
-32
lines changed

2 files changed

+106
-32
lines changed

test/export/test_unflatten.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,28 @@ def forward(self, x):
747747
unep = unflatten(ep)
748748
self.assertTrue(torch.allclose(unep(*inps), m(*inps)))
749749

750+
def test_attr_as_submod_input(self):
751+
class layer(torch.nn.Module):
752+
def forward(self, x, const) -> torch.Tensor:
753+
return x + const
754+
755+
class M(torch.nn.Module):
756+
def __init__(self) -> None:
757+
super().__init__()
758+
self.register_buffer("const", torch.ones(4, 8))
759+
self.layers = torch.nn.ModuleList([layer() for _ in range(2)])
760+
761+
def forward(self, x: torch.Tensor) -> torch.Tensor:
762+
for layer in self.layers:
763+
x = layer(x, self.const)
764+
return x
765+
766+
mod = M()
767+
x = torch.randn(4, 8)
768+
ep = export(mod, (x,))
769+
unflattened = unflatten(ep)
770+
torch.testing.assert_close(unflattened(x), mod(x))
771+
750772

751773
if __name__ == "__main__":
752774
run_tests()

torch/export/unflatten.py

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,35 @@ def add_to_consts_map(obj_id, node_name, target_name):
337337
inputs_to_state[n] = targets
338338

339339
_sink_params(self, inputs_to_state, [])
340-
# Check all input nodes has been processed.
341-
for name, module in self.named_modules():
342-
if not hasattr(module, "graph"):
343-
continue
344-
for node in module.graph.nodes:
345-
if node.op != "placeholder":
346-
continue
347-
assert (
348-
node.name not in inputs_to_state
349-
), f"{node.name} was not sunk into the module {name} which has the graph: {module.graph}"
340+
341+
# Helper function to check input nodes of `module` has been processed.
342+
def check_module_inputs(module, scope):
343+
if hasattr(module, "graph"):
344+
for node in module.graph.nodes:
345+
# sink_params() should turn placeholders into get_attr nodes
346+
# for attributes that are within scope of the current
347+
# module. We allow attributes to remain as placeholders if
348+
# they are inputs in the original module signature, meaning
349+
# they are a parent module's attribute, and therefore out of
350+
# scope of the current module.
351+
if (
352+
node.op == "placeholder"
353+
and node.name in inputs_to_state
354+
and any(
355+
fqn.split(".")[: len(scope)] == scope
356+
for fqn in inputs_to_state[node.name]
357+
) # matching scope to avoid wrong assert
358+
):
359+
raise AssertionError(
360+
f"{node.name} was not sunk into the module {scope} which has the graph: {module.graph}"
361+
)
362+
# Recursively check the submodules.
363+
for name, submod in module.named_children():
364+
scope.append(name)
365+
check_module_inputs(submod, scope)
366+
367+
# Recurively check all input nodes have been processed.
368+
check_module_inputs(self, [])
350369

351370
# Cache so we don't have to compute this every time.
352371
# NOTE: this needs to be kept in sync with the placeholders in
@@ -1010,14 +1029,23 @@ def _sink_params(
10101029
scope: tracks where we are in the module hierarchy, so that we can emit the
10111030
right `getattr(self, "foo.bar")` calls, etc.
10121031
"""
1032+
# This dict records inputs removed by child modules.
1033+
# Maps the module object id to the list of placeholder node names
1034+
# in the child module that were removed.
1035+
module_id_to_inputs_removed: Dict[int, List[str]] = defaultdict(list)
1036+
10131037
# We need to use _modules here instead of named_children(), because we
10141038
# explicitly want duplicate modules to show up in the traversal.
10151039
for name, submodule in module._modules.items():
1016-
_sink_params(cast(torch.nn.Module, submodule), inputs_to_state, scope + [name])
1040+
submod_id_to_inputs_removed = _sink_params(
1041+
cast(torch.nn.Module, submodule), inputs_to_state, scope + [name]
1042+
)
1043+
for k, v in submod_id_to_inputs_removed.items():
1044+
module_id_to_inputs_removed[k].extend(v)
10171045

10181046
if not hasattr(module, "graph"):
10191047
# Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList)
1020-
return
1048+
return module_id_to_inputs_removed
10211049

10221050
graph = module.graph
10231051
inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes))
@@ -1026,32 +1054,49 @@ def _sink_params(
10261054
# Also remove from call_module nodes
10271055
call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes)
10281056
for node in call_module_nodes:
1029-
node.args = tuple(filter(lambda n: n.name not in inputs_to_state, node.args))
1057+
submodule = _recursive_getattr(module, node.target.split("."))
1058+
# remove placeholder from call_module node arguments, only if we've
1059+
# erased the placeholder node in the corresponding _sink_params() call
1060+
if submodule is not None and id(submodule) in module_id_to_inputs_removed:
1061+
node.args = tuple(
1062+
filter(
1063+
lambda n: n.name not in module_id_to_inputs_removed[id(submodule)],
1064+
node.args,
1065+
)
1066+
)
10301067

1068+
# Filter out inputs_to_state corresponding to current scope.
1069+
inputs_to_state_of_scope: Dict[torch.fx.Node, list[str]] = {}
10311070
for node in inputs:
10321071
if node.name not in inputs_to_state:
10331072
continue
10341073

1035-
if len(node.users) > 0:
1036-
state_name = None
1037-
for sn in inputs_to_state[node.name]:
1038-
sn_split = sn.split(".")
1039-
if sn_split[: len(scope)] == scope:
1040-
state_name = sn_split
1041-
break
1042-
1043-
# If there's a mismatch beteewn scope name and state name, then
1044-
# there must be multuple scopes pointing to the same state name,
1045-
# meaning some modules are shared. In such case, we can simply skip
1046-
# updating the current node because another later iteration will
1047-
# take care of this input node when the unique match between scope
1048-
# and state name occurs. To make sure this always happen, we should
1049-
# enforce the invariant that no placeholder node in the unflattened
1050-
# graph appears in inputs_to_state dict, which means all the extra
1051-
# input nodes have been handled.
1052-
if state_name is None:
1053-
continue
1074+
state_name = None
1075+
for sn in inputs_to_state[node.name]:
1076+
sn_split = sn.split(".")
1077+
if sn_split[: len(scope)] == scope:
1078+
state_name = sn_split
1079+
break
1080+
1081+
# If there's a mismatch beteewn scope name and state name, then
1082+
# there must be multuple scopes pointing to the same state name,
1083+
# meaning some modules are shared. In such case, we can simply skip
1084+
# updating the current node because another later iteration will
1085+
# take care of this input node when the unique match between scope
1086+
# and state name occurs. To make sure this always happen, we should
1087+
# enforce the invariant that no placeholder node in the unflattened
1088+
# graph appears in inputs_to_state dict, which means all the extra
1089+
# input nodes have been handled.
1090+
if state_name is None:
1091+
continue
1092+
1093+
inputs_to_state_of_scope[node] = state_name
1094+
1095+
# Record name of remove inputs for return purpose.
1096+
inputs_removed: List[str] = []
10541097

1098+
for node, state_name in inputs_to_state_of_scope.items():
1099+
if len(node.users) > 0:
10551100
attr_path = state_name[len(scope) :]
10561101
state_attr = _recursive_getattr(module, attr_path)
10571102
assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject))
@@ -1061,13 +1106,20 @@ def _sink_params(
10611106
new_node = graph.create_node("get_attr", ".".join(attr_path))
10621107

10631108
node.replace_all_uses_with(new_node, propagate_meta=True)
1109+
10641110
graph.erase_node(node)
1111+
inputs_removed.append(node.name)
1112+
10651113
if isinstance(module, InterpreterModule):
10661114
module.finalize()
10671115

1116+
return {id(module): inputs_removed}
1117+
10681118

10691119
def _recursive_getattr(obj, attr_path):
10701120
for attr in attr_path:
1121+
if not hasattr(obj, attr):
1122+
return None
10711123
obj = getattr(obj, attr)
10721124

10731125
return obj

0 commit comments

Comments
 (0)