@@ -337,16 +337,35 @@ def add_to_consts_map(obj_id, node_name, target_name):
337337 inputs_to_state [n ] = targets
338338
339339 _sink_params (self , inputs_to_state , [])
340- # Check all input nodes has been processed.
341- for name , module in self .named_modules ():
342- if not hasattr (module , "graph" ):
343- continue
344- for node in module .graph .nodes :
345- if node .op != "placeholder" :
346- continue
347- assert (
348- node .name not in inputs_to_state
349- ), f"{ node .name } was not sunk into the module { name } which has the graph: { module .graph } "
340+
341+ # Helper function to check input nodes of `module` has been processed.
342+ def check_module_inputs (module , scope ):
343+ if hasattr (module , "graph" ):
344+ for node in module .graph .nodes :
345+ # sink_params() should turn placeholders into get_attr nodes
346+ # for attributes that are within scope of the current
347+ # module. We allow attributes to remain as placeholders if
348+ # they are inputs in the original module signature, meaning
349+ # they are a parent module's attribute, and therefore out of
350+ # scope of the current module.
351+ if (
352+ node .op == "placeholder"
353+ and node .name in inputs_to_state
354+ and any (
355+ fqn .split ("." )[: len (scope )] == scope
356+ for fqn in inputs_to_state [node .name ]
357+ ) # matching scope to avoid wrong assert
358+ ):
359+ raise AssertionError (
360+ f"{ node .name } was not sunk into the module { scope } which has the graph: { module .graph } "
361+ )
362+ # Recursively check the submodules.
363+ for name , submod in module .named_children ():
364+ scope .append (name )
365+ check_module_inputs (submod , scope )
366+
367+ # Recurively check all input nodes have been processed.
368+ check_module_inputs (self , [])
350369
351370 # Cache so we don't have to compute this every time.
352371 # NOTE: this needs to be kept in sync with the placeholders in
@@ -1010,14 +1029,23 @@ def _sink_params(
10101029 scope: tracks where we are in the module hierarchy, so that we can emit the
10111030 right `getattr(self, "foo.bar")` calls, etc.
10121031 """
1032+ # This dict records inputs removed by child modules.
1033+ # Maps the module object id to the list of placeholder node names
1034+ # in the child module that were removed.
1035+ module_id_to_inputs_removed : Dict [int , List [str ]] = defaultdict (list )
1036+
10131037 # We need to use _modules here instead of named_children(), because we
10141038 # explicitly want duplicate modules to show up in the traversal.
10151039 for name , submodule in module ._modules .items ():
1016- _sink_params (cast (torch .nn .Module , submodule ), inputs_to_state , scope + [name ])
1040+ submod_id_to_inputs_removed = _sink_params (
1041+ cast (torch .nn .Module , submodule ), inputs_to_state , scope + [name ]
1042+ )
1043+ for k , v in submod_id_to_inputs_removed .items ():
1044+ module_id_to_inputs_removed [k ].extend (v )
10171045
10181046 if not hasattr (module , "graph" ):
10191047 # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList)
1020- return
1048+ return module_id_to_inputs_removed
10211049
10221050 graph = module .graph
10231051 inputs = list (filter (lambda n : n .op == "placeholder" , graph .nodes ))
@@ -1026,32 +1054,49 @@ def _sink_params(
10261054 # Also remove from call_module nodes
10271055 call_module_nodes = filter (lambda n : n .op == "call_module" , graph .nodes )
10281056 for node in call_module_nodes :
1029- node .args = tuple (filter (lambda n : n .name not in inputs_to_state , node .args ))
1057+ submodule = _recursive_getattr (module , node .target .split ("." ))
1058+ # remove placeholder from call_module node arguments, only if we've
1059+ # erased the placeholder node in the corresponding _sink_params() call
1060+ if submodule is not None and id (submodule ) in module_id_to_inputs_removed :
1061+ node .args = tuple (
1062+ filter (
1063+ lambda n : n .name not in module_id_to_inputs_removed [id (submodule )],
1064+ node .args ,
1065+ )
1066+ )
10301067
1068+ # Filter out inputs_to_state corresponding to current scope.
1069+ inputs_to_state_of_scope : Dict [torch .fx .Node , list [str ]] = {}
10311070 for node in inputs :
10321071 if node .name not in inputs_to_state :
10331072 continue
10341073
1035- if len (node .users ) > 0 :
1036- state_name = None
1037- for sn in inputs_to_state [node .name ]:
1038- sn_split = sn .split ("." )
1039- if sn_split [: len (scope )] == scope :
1040- state_name = sn_split
1041- break
1042-
1043- # If there's a mismatch beteewn scope name and state name, then
1044- # there must be multuple scopes pointing to the same state name,
1045- # meaning some modules are shared. In such case, we can simply skip
1046- # updating the current node because another later iteration will
1047- # take care of this input node when the unique match between scope
1048- # and state name occurs. To make sure this always happen, we should
1049- # enforce the invariant that no placeholder node in the unflattened
1050- # graph appears in inputs_to_state dict, which means all the extra
1051- # input nodes have been handled.
1052- if state_name is None :
1053- continue
1074+ state_name = None
1075+ for sn in inputs_to_state [node .name ]:
1076+ sn_split = sn .split ("." )
1077+ if sn_split [: len (scope )] == scope :
1078+ state_name = sn_split
1079+ break
1080+
1081+ # If there's a mismatch beteewn scope name and state name, then
1082+ # there must be multuple scopes pointing to the same state name,
1083+ # meaning some modules are shared. In such case, we can simply skip
1084+ # updating the current node because another later iteration will
1085+ # take care of this input node when the unique match between scope
1086+ # and state name occurs. To make sure this always happen, we should
1087+ # enforce the invariant that no placeholder node in the unflattened
1088+ # graph appears in inputs_to_state dict, which means all the extra
1089+ # input nodes have been handled.
1090+ if state_name is None :
1091+ continue
1092+
1093+ inputs_to_state_of_scope [node ] = state_name
1094+
1095+ # Record name of remove inputs for return purpose.
1096+ inputs_removed : List [str ] = []
10541097
1098+ for node , state_name in inputs_to_state_of_scope .items ():
1099+ if len (node .users ) > 0 :
10551100 attr_path = state_name [len (scope ) :]
10561101 state_attr = _recursive_getattr (module , attr_path )
10571102 assert isinstance (state_attr , (torch .Tensor , torch .ScriptObject ))
@@ -1061,13 +1106,20 @@ def _sink_params(
10611106 new_node = graph .create_node ("get_attr" , "." .join (attr_path ))
10621107
10631108 node .replace_all_uses_with (new_node , propagate_meta = True )
1109+
10641110 graph .erase_node (node )
1111+ inputs_removed .append (node .name )
1112+
10651113 if isinstance (module , InterpreterModule ):
10661114 module .finalize ()
10671115
1116+ return {id (module ): inputs_removed }
1117+
10681118
10691119def _recursive_getattr (obj , attr_path ):
10701120 for attr in attr_path :
1121+ if not hasattr (obj , attr ):
1122+ return None
10711123 obj = getattr (obj , attr )
10721124
10731125 return obj
0 commit comments