Skip to content

Commit cc9a196

Browse files
Merge pull request #2482 from AI-Hypercomputer:mohit/new_resharding
PiperOrigin-RevId: 820471060
2 parents 8c57202 + af9d942 commit cc9a196

File tree

10 files changed

+635
-105
lines changed

10 files changed

+635
-105
lines changed

end_to_end/tpu/test_grpo.sh

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
# External users can update pre-trained model checkpoint GCS path (gs://) to your accessible locations.
66
# Usage:
77
HF_TOKEN=<huggingface access token> \
8-
MODEL=llama3.3-70b TOKENIZER=meta-llama/Llama-3.3-70B \
9-
NUM_SAMPLERS=4 DEVICES_PER_SAMPLER=8 \
8+
MODEL=llama3.1-8b TOKENIZER=meta-llama/Llama-3.1-8B-Instruct \
9+
NUM_SAMPLERS=2 DEVICES_PER_SAMPLER=8 \
1010
TRAINING_PER_DEVICE_BATCH_SIZE=1 \
11-
INFERENCE_PER_DEVICE_BATCH_SIZE=4 \
11+
INFERENCE_PER_DEVICE_BATCH_SIZE=1 \
12+
TRAINING_SUBSLICE=2,8 \
13+
INFERENCE_SUBSLICE=2,8 \
14+
MAX_PREFILL_LENGTH=128 \
15+
MAX_TARGET_LENGTH=256 \
1216
STEPS=20 \
1317
bash end_to_end/tpu/test_grpo.sh
1418
'
@@ -23,10 +27,11 @@ JAX_BACKEND_TARGET=grpc://127.0.0.1:29000
2327
ENABLE_PATHWAYS_PERSISTENCE='1'
2428
HF_TOKEN=${HF_TOKEN}
2529

26-
MAX_PREFILL_LENGTH=128
27-
MAX_TARGET_LENGTH=256
30+
MAX_PREFILL_LENGTH=${MAX_PREFILL_LENGTH:-128}
31+
MAX_TARGET_LENGTH=${MAX_TARGET_LENGTH:-256}
2832
NUM_GENERATIONS=2
2933

34+
INFERENCE_PER_DEVICE_BS=$((${INFERENCE_PER_DEVICE_BATCH_SIZE} * ${NUM_GENERATIONS}))
3035

3136
COMMON_ARGS="model_name=${MODEL} base_output_directory=${BASE_OUTPUT_DIRECTORY} \
3237
max_prefill_predict_length=${MAX_PREFILL_LENGTH} max_target_length=${MAX_TARGET_LENGTH} \
@@ -35,19 +40,22 @@ tokenizer_type=huggingface tokenizer_path=${TOKENIZER} \
3540
dataset_type=hf hf_path='trl-lib/tldr' \
3641
enable_single_controller=true \
3742
dtype=bfloat16 weight_dtype=bfloat16 \
38-
allow_split_physical_axes=true enable_goodput_recording=false monitor_goodput=false \
39-
profiler=xplane skip_first_n_steps_for_profiler=10 profiler_steps=5"
43+
allow_split_physical_axes=true enable_goodput_recording=false monitor_goodput=false"
4044

4145
TRAINING_ARGS="run_name=${RUN_NAME} scan_layers=true \
42-
inference_replicas=${NUM_SAMPLERS} inference_devices_per_replica=${DEVICES_PER_SAMPLER} \
43-
inference_rollouts=5 \
44-
per_device_batch_size=${TRAINING_PER_DEVICE_BATCH_SIZE} num_generations=${NUM_GENERATIONS} steps=${STEPS}"
46+
inference_replicas=${NUM_SAMPLERS} inference_devices_per_replica=${DEVICES_PER_SAMPLER} subslice_shape=${TRAINING_SUBSLICE} \
47+
inference_rollouts=1 \
48+
per_device_batch_size=${TRAINING_PER_DEVICE_BATCH_SIZE} num_generations=${NUM_GENERATIONS} steps=${STEPS} \
49+
profiler=xplane skip_first_n_steps_for_profiler=5 profiler_steps=3"
4550

51+
# Make sure profiles on inference TPUs are not captured while profiling trainers TPUs
52+
# Set a small number for profiler_steps during inference as the profiles turn out large in size
4653
INFERENCE_ARGS="run_name=grpo scan_layers=false \
47-
per_device_batch_size=${INFERENCE_PER_DEVICE_BATCH_SIZE} \
48-
ici_data_parallelism=${NUM_SAMPLERS} ici_tensor_parallelism=${DEVICES_PER_SAMPLER}"
54+
per_device_batch_size=${INFERENCE_PER_DEVICE_BS} num_generations=${NUM_GENERATIONS} \
55+
ici_data_parallelism=${NUM_SAMPLERS} ici_tensor_parallelism=${DEVICES_PER_SAMPLER} subslice_shape=${INFERENCE_SUBSLICE} \
56+
profiler=xplane skip_first_n_steps_for_profiler=10 profiler_steps=2"
4957

5058
JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
51-
python3 -m MaxText.experimental.rl.grpo_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/experimental/rl/grpo.yml \
52-
${COMMON_ARGS} ${TRAINING_ARGS} ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/experimental/rl/grpo_inference.yml \
59+
python3 src/MaxText/experimental/rl/grpo_trainer.py src/MaxText/experimental/rl/grpo.yml \
60+
${COMMON_ARGS} ${TRAINING_ARGS} src/MaxText/experimental/rl/grpo_inference.yml \
5361
${COMMON_ARGS} ${INFERENCE_ARGS}

src/MaxText/experimental/rl/grpo_trainer.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
"""
3434

3535

36+
import pathwaysutils
37+
3638
import datetime
3739
import time
3840
import os
@@ -586,8 +588,6 @@ def generate_completions(
586588
worker_tokenizer_model,
587589
worker_config_inference,
588590
worker_config_train,
589-
worker_data_buffer,
590-
worker_data_buffer_lock,
591591
worker_input_data_shardings,
592592
engine_lock,
593593
):
@@ -604,8 +604,6 @@ def generate_completions(
604604
worker_tokenizer_model: The tokenizer model.
605605
worker_config_inference: The configuration for the inference process.
606606
worker_config_train: The main training configuration.
607-
worker_data_buffer: A list acting as a shared buffer to store generated data.
608-
worker_data_buffer_lock: A lock to ensure thread-safe access to the buffer.
609607
worker_input_data_shardings: Sharding specifications for the data.
610608
engine_lock: A lock to ensure thread-safe use of the inference engine.
611609
"""
@@ -615,7 +613,7 @@ def generate_completions(
615613
thread_example_batch_trimmed = jax.tree_util.tree_map(
616614
lambda arr: arr[
617615
: int(
618-
worker_config_inference.per_device_batch_size
616+
(worker_config_inference.per_device_batch_size // worker_config_inference.num_generations)
619617
* worker_config_train.inference_replicas
620618
* worker_config_train.inference_devices_per_replica
621619
)
@@ -626,13 +624,7 @@ def generate_completions(
626624
worker_config_inference, worker_tokenizer_model, worker_inference_engine, thread_example_batch_trimmed
627625
)
628626
processed_batch = jax.device_put(processed_batch, worker_input_data_shardings)
629-
with worker_data_buffer_lock:
630-
if not worker_data_buffer:
631-
worker_data_buffer.append(processed_batch)
632-
else:
633-
worker_data_buffer[0] = jax.tree_util.tree_map(
634-
lambda a, b: np.concatenate([a, b], axis=0), worker_data_buffer[0], processed_batch
635-
)
627+
return processed_batch
636628

637629

638630
def train_loop(config, config_inference, recorder, state=None):
@@ -705,6 +697,7 @@ def train_loop(config, config_inference, recorder, state=None):
705697

706698
start_step = get_first_step(state) # this is the start_step for training
707699
prof = profiler.Profiler(config, offset_step=start_step)
700+
inference_prof = profiler.Profiler(config_inference, offset_step=start_step)
708701
data_loader = DataLoader(config_inference, inference_mesh, data_iterator, recorder)
709702
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
710703

@@ -721,6 +714,7 @@ def generation_worker_fn(
721714
worker_input_data_shardings,
722715
engine_lock,
723716
stop_event,
717+
profiler_object,
724718
):
725719
"""The target function for the data generation worker thread.
726720
@@ -738,21 +732,40 @@ def generation_worker_fn(
738732
worker_input_data_shardings: Sharding specs for the generated data.
739733
engine_lock: A lock for thread-safe inference engine access.
740734
stop_event: A threading.Event to signal when the worker should stop.
735+
profiling_event: a threading.Event to signal when to profile.
741736
"""
737+
worker_step = 0
738+
is_profiling = False
742739
while not stop_event.is_set():
743740
try:
744-
with jax.profiler.StepTraceAnnotation("inference"):
745-
generate_completions(
741+
if worker_step == profiler_object.start_initial_profile_step and not is_profiling:
742+
profiler_object.activate()
743+
is_profiling = True
744+
elif worker_step == profiler_object.finished_initial_profile_step and is_profiling:
745+
profiler_object.deactivate()
746+
is_profiling = False
747+
with jax.profiler.StepTraceAnnotation("inference", step_num=worker_step):
748+
processed_batch = generate_completions(
746749
data_loader,
747750
worker_inference_engine,
748751
worker_tokenizer_model,
749752
worker_config_inference,
750753
worker_config_train,
751-
worker_data_buffer,
752-
worker_data_buffer_lock,
753754
worker_input_data_shardings,
754755
engine_lock,
755756
)
757+
jax.block_until_ready(processed_batch)
758+
759+
with worker_data_buffer_lock:
760+
if not worker_data_buffer:
761+
worker_data_buffer.append(processed_batch)
762+
else:
763+
worker_data_buffer[0] = jax.tree_util.tree_map(
764+
lambda a, b: np.concatenate([a, b], axis=0),
765+
worker_data_buffer[0],
766+
processed_batch,
767+
)
768+
worker_step += 1
756769
except StopIteration:
757770
max_logging.log("Data iterator exhausted in generation worker. Stopping.")
758771
break
@@ -764,19 +777,6 @@ def generation_worker_fn(
764777
stop_event = threading.Event()
765778
inference_engine_lock = threading.Lock()
766779

767-
max_logging.log("Inference Rollout")
768-
generate_completions(
769-
data_loader,
770-
inference_engine,
771-
tokenizer_model,
772-
config_inference,
773-
config,
774-
data_buffer,
775-
data_buffer_lock,
776-
data_sharding,
777-
inference_engine_lock,
778-
)
779-
780780
required_batch_size = int(config.per_device_batch_size * config.num_generations * mesh.size)
781781
generation_thread = threading.Thread(
782782
target=generation_worker_fn,
@@ -790,6 +790,7 @@ def generation_worker_fn(
790790
data_sharding, # Sharding for the data put into the buffer
791791
inference_engine_lock,
792792
stop_event,
793+
inference_prof, # profiler object
793794
),
794795
daemon=True, # So it exits when the main thread exits
795796
)
@@ -830,8 +831,10 @@ def generation_worker_fn(
830831
{"params": state.params["params"]},
831832
{"params": state_mesh_shardings.params["params"]},
832833
mesh,
833-
inference_state_mesh_shardings,
834+
{"params": inference_state_mesh_shardings.params["params"]},
834835
)
836+
with data_buffer_lock:
837+
data_buffer.clear()
835838

836839
step_time_delta = datetime.datetime.now() - last_step_completion
837840
last_step_completion = datetime.datetime.now()
@@ -895,6 +898,7 @@ def main(argv: Sequence[str]) -> None:
895898
training and inference, sets up system environment variables, and launches
896899
the `train_loop`.
897900
"""
901+
pathwaysutils.initialize()
898902
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
899903
# TF allocates extraneous GPU memory when using TFDS data
900904
# this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
@@ -923,6 +927,7 @@ def main(argv: Sequence[str]) -> None:
923927
f"with {jax.device_count()} devices"
924928
)
925929
config_inference = pyconfig.initialize(configs_argv[1])
930+
926931
if config.per_device_batch_size < 1.0 or config_inference.per_device_batch_size < 1.0:
927932
raise ValueError("GRPO does not support setting per_device_batch_size < 1.0")
928933
jax.config.update("jax_use_shardy_partitioner", config.shardy)

0 commit comments

Comments
 (0)