@@ -125,7 +125,6 @@ def foreach_reduce(
125125 orig_dtype : torch .dtype ,
126126 reduce_dtype : Optional [torch .dtype ],
127127 device : torch .device ,
128- divide_factors : Union [Tuple [None , None ], Tuple [float , float ]],
129128 all_reduce_group : Optional [dist .ProcessGroup ],
130129 all_reduce_stream : torch .cuda .Stream ,
131130) -> torch .cuda .Event :
@@ -142,7 +141,9 @@ def foreach_reduce(
142141 )
143142 grad_dtype = unsharded_grads [0 ].dtype
144143 reduce_dtype = reduce_dtype or grad_dtype
145- predivide_factor , postdivide_factor = divide_factors
144+ predivide_factor , postdivide_factor = _get_gradient_divide_factors (
145+ reduce_scatter_group , all_reduce_group , reduce_dtype
146+ )
146147 world_size = reduce_scatter_group .size ()
147148 padded_unsharded_sizes = tuple (
148149 _get_dim0_padded_size (grad .size (), world_size ) for grad in unsharded_grads
@@ -166,18 +167,22 @@ def foreach_reduce(
166167 (reduce_scatter_output_numel ,)
167168 )
168169 _div_if_needed (reduce_scatter_input , predivide_factor )
169- _reduce_scatter (
170- post_reduce_output ,
171- reduce_scatter_input ,
172- reduce_scatter_group ,
173- divide_factors ,
170+ dist . reduce_scatter_tensor (
171+ output = post_reduce_output ,
172+ input = reduce_scatter_input ,
173+ group = reduce_scatter_group ,
174+ op = ReduceOp . AVG if predivide_factor is None else ReduceOp . SUM ,
174175 )
175176 view_out_stream = reduce_scatter_stream
176177 if all_reduce_group is not None :
177178 view_out_stream = all_reduce_stream
178179 all_reduce_stream .wait_stream (reduce_scatter_stream )
179180 with torch .cuda .stream (all_reduce_stream ):
180- _all_reduce (post_reduce_output , all_reduce_group , divide_factors )
181+ dist .all_reduce (
182+ post_reduce_output ,
183+ group = all_reduce_group ,
184+ op = ReduceOp .AVG if predivide_factor is None else ReduceOp .SUM ,
185+ )
181186 with torch .cuda .stream (view_out_stream ):
182187 _div_if_needed (post_reduce_output , postdivide_factor )
183188 post_reduce_output = _to_dtype_if_needed (post_reduce_output , orig_dtype )
@@ -257,30 +262,27 @@ def _get_all_gather_input_metadatas(
257262 )
258263
259264
260- def _reduce_scatter (
261- output : torch .Tensor ,
262- input : torch .Tensor ,
263- group : dist .ProcessGroup ,
264- divide_factors : Union [Tuple [None , None ], Tuple [float , float ]],
265- ) -> None :
266- if divide_factors [0 ]:
267- dist .reduce_scatter_tensor (output , input , group = group )
268- else :
269- # Using NCCL's reduce-scatter to do the division by world size saves
270- # extra memory read/write from a separate division kernel
271- dist .reduce_scatter_tensor (output , input , op = ReduceOp .AVG , group = group )
272-
273-
274- def _all_reduce (
275- tensor : torch .Tensor ,
276- group : dist .ProcessGroup ,
277- divide_factors : Union [Tuple [None , None ], Tuple [float , float ]],
278- ) -> None :
279- if divide_factors [0 ]:
280- dist .all_reduce (tensor , group = group )
281- else :
282- # saves extra memory read/write from a separate division kernel
283- dist .all_reduce (tensor , op = ReduceOp .AVG , group = group )
265+ def _get_gradient_divide_factors (
266+ reduce_scatter_group : dist .ProcessGroup ,
267+ all_reduce_group : Optional [dist .ProcessGroup ],
268+ reduce_dtype : torch .dtype ,
269+ ) -> Union [Tuple [None , None ], Tuple [float , float ]]:
270+ # For fp32/bf16, we do not need to worry about overflow/underflow, so we
271+ # use NCCL's built-in division to avoid separate div kernels
272+ if reduce_dtype in (torch .float32 , torch .bfloat16 ):
273+ return None , None
274+ data_parallel_size = reduce_scatter_group .size ()
275+ if all_reduce_group is not None :
276+ data_parallel_size *= all_reduce_group .size ()
277+ # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
278+ # overflow/underflow. For N data parallel workers, each worker computes
279+ # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
280+ # overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
281+ factor : int = 1
282+ while data_parallel_size % factor == 0 and data_parallel_size / factor > factor :
283+ factor *= 2
284+ factor = float (factor )
285+ return (factor , data_parallel_size / factor )
284286
285287
286288def _div_if_needed (tensor : torch .Tensor , div_factor : Optional [float ]) -> None :
0 commit comments