@@ -96,6 +96,7 @@ def __init__(
9696 master_weights : bool = True ,
9797 extra_dp_group : Optional [ProcessGroup ] = None ,
9898 verbose : bool = False ,
99+ enable_async_reduce : bool = True ,
99100 ) -> None :
100101 assert mixed_precision in (torch .float16 , torch .bfloat16 )
101102 reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@@ -178,6 +179,7 @@ def __init__(
178179 if is_ddp_ignored (p ):
179180 continue
180181 if p .requires_grad :
182+ assert not hasattr (p , "_grad_handle" )
181183 p ._grad_handle = p .register_hook (
182184 partial (
183185 GeminiDDP .grad_handle ,
@@ -187,6 +189,7 @@ def __init__(
187189 master_weights = self .master_weights ,
188190 enable_gradient_accumulation = self .enable_gradient_accumulation ,
189191 p = p ,
192+ async_reduce = enable_async_reduce ,
190193 )
191194 )
192195
@@ -334,6 +337,11 @@ def _pre_backward(self):
334337 setattr (param , "_gemini_reduced" , False )
335338
336339 def _post_backward (self ):
340+ for param in self .param2name :
341+ if hasattr (param , "_release_grad_chunk_cb" ):
342+ param ._release_grad_chunk_cb ()
343+ delattr (param , "_release_grad_chunk_cb" )
344+
337345 if self .chunk_manager .accessed_mem != 0 :
338346 error_params = ["Reduction failed at followed parameters:" ]
339347 for param in self .param2name :
@@ -371,6 +379,7 @@ def grad_handle(
371379 master_weights : bool ,
372380 enable_gradient_accumulation : bool ,
373381 p : nn .Parameter ,
382+ async_reduce : bool ,
374383 ):
375384 setattr (p , "_gemini_reduced" , True )
376385 empty_grad = torch .empty_like (grad )
@@ -406,31 +415,57 @@ def grad_handle(
406415 grad_chunk .copy_tensor_to_chunk_slice (p , grad , update_ptr = chunk_manager .reuse_fp16_chunk )
407416 else :
408417 grad_chunk .add_tensor_to_chunk_slice (p , grad )
409- reduced = chunk_manager .reduce_chunk (grad_chunk )
410- if reduced :
411- if not chunk_manager .reuse_fp16_chunk :
412- if chunk .keep_gathered :
413- chunk_manager .fake_release_chunk (chunk )
414- else :
415- chunk_manager .release_chunk (chunk )
416- if grad_chunk .is_gathered :
417- grad_chunk .cuda_global_chunk .div_ (chunk .pg_size )
418- if chunk .extra_dp_group is not None :
419- grad_chunk .cuda_global_chunk .div_ (chunk .extra_dp_size )
418+ reduced = chunk_manager .reduce_chunk (grad_chunk , async_op = async_reduce )
419+ if reduced : # if not async, can release immediately, else release in when work finished
420+ if async_reduce :
421+ # dirty fix by installing callback
422+ assert not hasattr (p , "_release_grad_chunk_cb" )
423+
424+ def _release_grad_chunk_cb ():
425+ grad_chunk .wait_async_reduce ()
426+ GeminiDDP .release_grad_chunk_handle (
427+ chunk_manager ,
428+ grads_device ,
429+ master_weights ,
430+ enable_gradient_accumulation ,
431+ p ,
432+ chunk ,
433+ grad_chunk ,
434+ )
435+
436+ p ._release_grad_chunk_cb = _release_grad_chunk_cb
420437 else :
421- grad_chunk .cuda_shard .div_ (chunk .pg_size )
422- if chunk .extra_dp_group is not None :
423- grad_chunk .cuda_shard .div_ (chunk .extra_dp_size )
424- # check overflow elements
425- chunk_manager .overflow_counter += grad_chunk .has_inf_or_nan
426- # record l2 norm for gradient clipping. flag is bound to fp16 chunk
427- if chunk .l2_norm_flag :
428- grad_chunk .set_l2_norm ()
429- chunk_manager .move_chunk (grad_chunk , grads_device [p ], force_copy = True )
430- if not (master_weights ) or (enable_gradient_accumulation ):
431- chunk_manager .move_chunk (chunk , grads_device [p ], force_copy = True )
438+ GeminiDDP .release_grad_chunk_handle (
439+ chunk_manager , grads_device , master_weights , enable_gradient_accumulation , p , chunk , grad_chunk
440+ )
432441 return empty_grad
433442
443+ @staticmethod
444+ def release_grad_chunk_handle (
445+ chunk_manager , grads_device , master_weights , enable_gradient_accumulation , p , chunk , grad_chunk
446+ ):
447+ if not chunk_manager .reuse_fp16_chunk :
448+ if chunk .keep_gathered :
449+ chunk_manager .fake_release_chunk (chunk )
450+ else :
451+ chunk_manager .release_chunk (chunk )
452+ if grad_chunk .is_gathered :
453+ grad_chunk .cuda_global_chunk .div_ (chunk .pg_size )
454+ if chunk .extra_dp_group is not None :
455+ grad_chunk .cuda_global_chunk .div_ (chunk .extra_dp_size )
456+ else :
457+ grad_chunk .cuda_shard .div_ (chunk .pg_size )
458+ if chunk .extra_dp_group is not None :
459+ grad_chunk .cuda_shard .div_ (chunk .extra_dp_size )
460+ # check overflow elements
461+ chunk_manager .overflow_counter += grad_chunk .has_inf_or_nan
462+ # record l2 norm for gradient clipping. flag is bound to fp16 chunk
463+ if chunk .l2_norm_flag :
464+ grad_chunk .set_l2_norm ()
465+ chunk_manager .move_chunk (grad_chunk , grads_device [p ], force_copy = True )
466+ if not (master_weights ) or (enable_gradient_accumulation ):
467+ chunk_manager .move_chunk (chunk , grads_device [p ], force_copy = True )
468+
434469 def zero_grad (self , set_to_none : bool = False ) -> None :
435470 self .module .zero_grad (set_to_none = True )
436471
0 commit comments