3333"""
3434
3535
36+ import pathwaysutils
37+
3638import datetime
3739import time
3840import os
@@ -586,8 +588,6 @@ def generate_completions(
586588 worker_tokenizer_model ,
587589 worker_config_inference ,
588590 worker_config_train ,
589- worker_data_buffer ,
590- worker_data_buffer_lock ,
591591 worker_input_data_shardings ,
592592 engine_lock ,
593593):
@@ -604,8 +604,6 @@ def generate_completions(
604604 worker_tokenizer_model: The tokenizer model.
605605 worker_config_inference: The configuration for the inference process.
606606 worker_config_train: The main training configuration.
607- worker_data_buffer: A list acting as a shared buffer to store generated data.
608- worker_data_buffer_lock: A lock to ensure thread-safe access to the buffer.
609607 worker_input_data_shardings: Sharding specifications for the data.
610608 engine_lock: A lock to ensure thread-safe use of the inference engine.
611609 """
@@ -615,7 +613,7 @@ def generate_completions(
615613 thread_example_batch_trimmed = jax .tree_util .tree_map (
616614 lambda arr : arr [
617615 : int (
618- worker_config_inference .per_device_batch_size
616+ ( worker_config_inference .per_device_batch_size // worker_config_inference . num_generations )
619617 * worker_config_train .inference_replicas
620618 * worker_config_train .inference_devices_per_replica
621619 )
@@ -626,13 +624,7 @@ def generate_completions(
626624 worker_config_inference , worker_tokenizer_model , worker_inference_engine , thread_example_batch_trimmed
627625 )
628626 processed_batch = jax .device_put (processed_batch , worker_input_data_shardings )
629- with worker_data_buffer_lock :
630- if not worker_data_buffer :
631- worker_data_buffer .append (processed_batch )
632- else :
633- worker_data_buffer [0 ] = jax .tree_util .tree_map (
634- lambda a , b : np .concatenate ([a , b ], axis = 0 ), worker_data_buffer [0 ], processed_batch
635- )
627+ return processed_batch
636628
637629
638630def train_loop (config , config_inference , recorder , state = None ):
@@ -705,6 +697,7 @@ def train_loop(config, config_inference, recorder, state=None):
705697
706698 start_step = get_first_step (state ) # this is the start_step for training
707699 prof = profiler .Profiler (config , offset_step = start_step )
700+ inference_prof = profiler .Profiler (config_inference , offset_step = start_step )
708701 data_loader = DataLoader (config_inference , inference_mesh , data_iterator , recorder )
709702 metric_logger = MetricLogger (config = config , learning_rate_schedule = learning_rate_schedule )
710703
@@ -721,6 +714,7 @@ def generation_worker_fn(
721714 worker_input_data_shardings ,
722715 engine_lock ,
723716 stop_event ,
717+ profiler_object ,
724718 ):
725719 """The target function for the data generation worker thread.
726720
@@ -738,21 +732,40 @@ def generation_worker_fn(
738732 worker_input_data_shardings: Sharding specs for the generated data.
739733 engine_lock: A lock for thread-safe inference engine access.
740734 stop_event: A threading.Event to signal when the worker should stop.
735+ profiling_event: a threading.Event to signal when to profile.
741736 """
737+ worker_step = 0
738+ is_profiling = False
742739 while not stop_event .is_set ():
743740 try :
744- with jax .profiler .StepTraceAnnotation ("inference" ):
745- generate_completions (
741+ if worker_step == profiler_object .start_initial_profile_step and not is_profiling :
742+ profiler_object .activate ()
743+ is_profiling = True
744+ elif worker_step == profiler_object .finished_initial_profile_step and is_profiling :
745+ profiler_object .deactivate ()
746+ is_profiling = False
747+ with jax .profiler .StepTraceAnnotation ("inference" , step_num = worker_step ):
748+ processed_batch = generate_completions (
746749 data_loader ,
747750 worker_inference_engine ,
748751 worker_tokenizer_model ,
749752 worker_config_inference ,
750753 worker_config_train ,
751- worker_data_buffer ,
752- worker_data_buffer_lock ,
753754 worker_input_data_shardings ,
754755 engine_lock ,
755756 )
757+ jax .block_until_ready (processed_batch )
758+
759+ with worker_data_buffer_lock :
760+ if not worker_data_buffer :
761+ worker_data_buffer .append (processed_batch )
762+ else :
763+ worker_data_buffer [0 ] = jax .tree_util .tree_map (
764+ lambda a , b : np .concatenate ([a , b ], axis = 0 ),
765+ worker_data_buffer [0 ],
766+ processed_batch ,
767+ )
768+ worker_step += 1
756769 except StopIteration :
757770 max_logging .log ("Data iterator exhausted in generation worker. Stopping." )
758771 break
@@ -764,19 +777,6 @@ def generation_worker_fn(
764777 stop_event = threading .Event ()
765778 inference_engine_lock = threading .Lock ()
766779
767- max_logging .log ("Inference Rollout" )
768- generate_completions (
769- data_loader ,
770- inference_engine ,
771- tokenizer_model ,
772- config_inference ,
773- config ,
774- data_buffer ,
775- data_buffer_lock ,
776- data_sharding ,
777- inference_engine_lock ,
778- )
779-
780780 required_batch_size = int (config .per_device_batch_size * config .num_generations * mesh .size )
781781 generation_thread = threading .Thread (
782782 target = generation_worker_fn ,
@@ -790,6 +790,7 @@ def generation_worker_fn(
790790 data_sharding , # Sharding for the data put into the buffer
791791 inference_engine_lock ,
792792 stop_event ,
793+ inference_prof , # profiler object
793794 ),
794795 daemon = True , # So it exits when the main thread exits
795796 )
@@ -830,8 +831,10 @@ def generation_worker_fn(
830831 {"params" : state .params ["params" ]},
831832 {"params" : state_mesh_shardings .params ["params" ]},
832833 mesh ,
833- inference_state_mesh_shardings ,
834+ { "params" : inference_state_mesh_shardings . params [ "params" ]} ,
834835 )
836+ with data_buffer_lock :
837+ data_buffer .clear ()
835838
836839 step_time_delta = datetime .datetime .now () - last_step_completion
837840 last_step_completion = datetime .datetime .now ()
@@ -895,6 +898,7 @@ def main(argv: Sequence[str]) -> None:
895898 training and inference, sets up system environment variables, and launches
896899 the `train_loop`.
897900 """
901+ pathwaysutils .initialize ()
898902 jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
899903 # TF allocates extraneous GPU memory when using TFDS data
900904 # this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
@@ -923,6 +927,7 @@ def main(argv: Sequence[str]) -> None:
923927 f"with { jax .device_count ()} devices"
924928 )
925929 config_inference = pyconfig .initialize (configs_argv [1 ])
930+
926931 if config .per_device_batch_size < 1.0 or config_inference .per_device_batch_size < 1.0 :
927932 raise ValueError ("GRPO does not support setting per_device_batch_size < 1.0" )
928933 jax .config .update ("jax_use_shardy_partitioner" , config .shardy )
0 commit comments