@@ -217,18 +217,21 @@ def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None:
217217 # TODO (awgu): Since every module has at most one handle in the
218218 # current implementation, this should never raise the error.
219219 assert self .world_size is not None # mypy
220- for (r1 , n1 ), (r2 , n2 ) in itertools .combinations (
221- (
222- (rank , world_num_valid_indices [rank ])
223- for rank in range (self .world_size )
224- ),
225- 2 ,
226- ):
227- if n1 != n2 :
228- raise RuntimeError (
229- f"{ msg_prefix } rank { r1 } is all-gathering { n1 } parameters "
230- f"while rank { r2 } is all-gathering { n2 } parameters"
231- )
220+ if not torch .distributed ._functional_collectives .is_torchdynamo_compiling ():
221+ # TODO(voz): Don't graph break on this - dynamo hates the n1 != n2
222+ # tensor comparison control flow.
223+ for (r1 , n1 ), (r2 , n2 ) in itertools .combinations (
224+ (
225+ (rank , world_num_valid_indices [rank ])
226+ for rank in range (self .world_size )
227+ ),
228+ 2 ,
229+ ):
230+ if n1 != n2 :
231+ raise RuntimeError (
232+ f"{ msg_prefix } rank { r1 } is all-gathering { n1 } parameters "
233+ f"while rank { r2 } is all-gathering { n2 } parameters"
234+ )
232235 world_indices = torch .zeros ( # type: ignore[call-overload]
233236 self .world_size * num_valid_indices , ** tensor_kwargs
234237 )
@@ -239,26 +242,31 @@ def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None:
239242 # Copy entire tensor from D2H once to avoid per element D2H copies
240243 world_indices = world_indices .cpu ()
241244 # Check that all ranks plan to all-gather the same index parameters
242- for (r1 , i1 ), (r2 , i2 ) in itertools .combinations (
243- (
245+ if not torch .distributed ._functional_collectives .is_torchdynamo_compiling ():
246+ # TODO(voz): Don't graph break on this - dynamo hates the i1 != i2
247+ # tensor comparison control flow.
248+ for (r1 , i1 ), (r2 , i2 ) in itertools .combinations (
244249 (
245- rank ,
246- world_indices [
247- rank * num_valid_indices : (rank + 1 ) * num_valid_indices
248- ],
249- )
250- for rank in range (self .world_size )
251- ),
252- 2 ,
253- ):
254- if i1 != i2 :
255- r1_param_names = self ._get_names_from_handle_indices (i1 )
256- r2_param_names = self ._get_names_from_handle_indices (i2 )
257- raise RuntimeError (
258- f"{ msg_prefix } rank { r1 } is all-gathering parameters "
259- f"for { r1_param_names } while rank { r2 } is all-gathering "
260- f"parameters for { r2_param_names } "
261- )
250+ (
251+ rank ,
252+ world_indices [
253+ rank
254+ * num_valid_indices : (rank + 1 )
255+ * num_valid_indices
256+ ],
257+ )
258+ for rank in range (self .world_size )
259+ ),
260+ 2 ,
261+ ):
262+ if i1 != i2 :
263+ r1_param_names = self ._get_names_from_handle_indices (i1 )
264+ r2_param_names = self ._get_names_from_handle_indices (i2 )
265+ raise RuntimeError (
266+ f"{ msg_prefix } rank { r1 } is all-gathering parameters "
267+ f"for { r1_param_names } while rank { r2 } is all-gathering "
268+ f"parameters for { r2_param_names } "
269+ )
262270 elif self ._checking_order :
263271 # Only issue warnings on the first deviating iteration and stop
264272 # checking thereafter to avoid flooding the console
0 commit comments