Skip to content

Commit a68ba0d

Browse files
Updated code
1 parent 4c185b4 commit a68ba0d

File tree

2 files changed

+89
-58
lines changed

2 files changed

+89
-58
lines changed

open_instruct/grpo_fast.py

Lines changed: 16 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@
119119
RayProcess,
120120
_z3_params_to_fetch,
121121
calibrate_checkpoint_state_dir,
122-
check_calculation,
123122
clean_last_n_checkpoints_deepspeed,
124123
download_latest_checkpoint_from_gs,
125124
get_beaker_whoami,
@@ -1512,68 +1511,27 @@ def calculate_utilization_metrics(
15121511
f"Expected {len(prompt_lengths) * samples_per_prompt} response lengths, got {len(response_lengths)}"
15131512
)
15141513

1515-
# Calculate FLOPs and memory bytes for inference
1516-
actor_total_flops = model_dims.flops(prompt_lengths, response_lengths, samples_per_prompt=samples_per_prompt)
1517-
actor_total_memory_bytes = model_dims.memory_bytes(
1518-
prompt_lengths, num_engines, response_lengths=response_lengths, samples_per_prompt=samples_per_prompt
1519-
)
1520-
1521-
# Calculate MFU and MBU accounting for multiple GPUs
1522-
flops_per_second = actor_total_flops / total_generation_time
1523-
bytes_per_second = actor_total_memory_bytes / total_generation_time
1524-
# Scale device capabilities by number of GPUs
1525-
total_device_flops = model_dims.device_flops * num_inference_gpus
1526-
total_device_bandwidth = model_dims.device_memory_bandwidth * num_inference_gpus
1527-
actor_mfu = 100 * flops_per_second / total_device_flops
1528-
actor_mbu = 100 * bytes_per_second / total_device_bandwidth
1529-
1530-
check_calculation(
1531-
actor_mfu,
1532-
"Actor MFU",
1533-
model_dims,
1534-
total_generation_time,
1535-
prompt_lengths,
1536-
response_lengths,
1537-
samples_per_prompt,
1538-
num_inference_gpus,
1539-
)
1540-
1541-
check_calculation(
1542-
actor_mbu,
1543-
"Actor MBU",
1544-
model_dims,
1545-
total_generation_time,
1546-
prompt_lengths,
1547-
response_lengths,
1548-
samples_per_prompt,
1549-
num_inference_gpus,
1514+
actor_metrics = model_dims.calculate_actor_utilization(
1515+
prompt_lengths=prompt_lengths,
1516+
response_lengths=response_lengths,
1517+
total_generation_time=total_generation_time,
1518+
samples_per_prompt=samples_per_prompt,
1519+
num_inference_gpus=num_inference_gpus,
1520+
num_engines=num_engines,
15501521
)
15511522

1552-
# Calculate learner/training metrics
1553-
# For training, we need to use total sequence lengths (prompt + response) since training
1554-
# processes the full sequences, not separate prefill/decode operations
1555-
total_sequence_lengths = [
1556-
prompt_lengths[i // samples_per_prompt] + response_len for i, response_len in enumerate(response_lengths)
1557-
]
1558-
1559-
# For training FLOPs, pass total sequence lengths as prompt_lengths with response_lengths=None
1560-
training_flops = model_dims.flops(
1561-
prompt_lengths=total_sequence_lengths,
1562-
response_lengths=None,
1563-
samples_per_prompt=1, # Already expanded in total_sequence_lengths
1564-
is_training=True,
1523+
learner_metrics = model_dims.calculate_learner_utilization(
1524+
prompt_lengths=prompt_lengths,
1525+
response_lengths=response_lengths,
1526+
training_time=training_time,
1527+
samples_per_prompt=samples_per_prompt,
1528+
num_training_gpus=num_training_gpus,
15651529
)
15661530

1567-
# Calculate training MFU
1568-
training_flops_per_second = training_flops / training_time
1569-
total_training_device_flops = model_dims.device_flops * num_training_gpus
1570-
learner_mfu = 100 * training_flops_per_second / total_training_device_flops
1571-
1572-
check_calculation(
1573-
learner_mfu, "Learner MFU", model_dims, training_time, total_sequence_lengths, None, 1, num_training_gpus
1574-
)
1531+
utilization_metrics = {f"actor_{k}": v for k, v in actor_metrics.items()}
1532+
utilization_metrics["learner_mfu"] = learner_metrics["mfu"]
15751533

1576-
return {"actor_mfu": actor_mfu, "actor_mbu": actor_mbu, "learner_mfu": learner_mfu}
1534+
return utilization_metrics
15771535

15781536

15791537
def accumulate_inference_batches(

open_instruct/utils.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,6 +1945,79 @@ def memory_bytes(
19451945

19461946
return int(total_bytes)
19471947

1948+
def calculate_actor_utilization(
1949+
self,
1950+
prompt_lengths: list[int],
1951+
response_lengths: list[int],
1952+
total_generation_time: float,
1953+
samples_per_prompt: int,
1954+
num_inference_gpus: int,
1955+
num_engines: int,
1956+
) -> dict[str, float]:
1957+
actor_total_flops = self.flops(prompt_lengths, response_lengths, samples_per_prompt=samples_per_prompt)
1958+
actor_total_memory_bytes = self.memory_bytes(
1959+
prompt_lengths, num_engines, response_lengths=response_lengths, samples_per_prompt=samples_per_prompt
1960+
)
1961+
1962+
flops_per_second = actor_total_flops / total_generation_time
1963+
bytes_per_second = actor_total_memory_bytes / total_generation_time
1964+
1965+
total_device_flops = self.device_flops * num_inference_gpus
1966+
total_device_bandwidth = self.device_memory_bandwidth * num_inference_gpus
1967+
1968+
actor_mfu = 100 * flops_per_second / total_device_flops
1969+
actor_mbu = 100 * bytes_per_second / total_device_bandwidth
1970+
1971+
check_calculation(
1972+
actor_mfu,
1973+
"Actor MFU",
1974+
self,
1975+
total_generation_time,
1976+
prompt_lengths,
1977+
response_lengths,
1978+
samples_per_prompt,
1979+
num_inference_gpus,
1980+
)
1981+
1982+
check_calculation(
1983+
actor_mbu,
1984+
"Actor MBU",
1985+
self,
1986+
total_generation_time,
1987+
prompt_lengths,
1988+
response_lengths,
1989+
samples_per_prompt,
1990+
num_inference_gpus,
1991+
)
1992+
1993+
return {"mfu": actor_mfu, "mbu": actor_mbu}
1994+
1995+
def calculate_learner_utilization(
1996+
self,
1997+
prompt_lengths: list[int],
1998+
response_lengths: list[int],
1999+
training_time: float,
2000+
samples_per_prompt: int,
2001+
num_training_gpus: int,
2002+
) -> dict[str, float]:
2003+
total_sequence_lengths = [
2004+
prompt_lengths[i // samples_per_prompt] + response_len for i, response_len in enumerate(response_lengths)
2005+
]
2006+
2007+
training_flops = self.flops(
2008+
prompt_lengths=total_sequence_lengths, response_lengths=None, samples_per_prompt=1, is_training=True
2009+
)
2010+
2011+
training_flops_per_second = training_flops / training_time
2012+
total_training_device_flops = self.device_flops * num_training_gpus
2013+
learner_mfu = 100 * training_flops_per_second / total_training_device_flops
2014+
2015+
check_calculation(
2016+
learner_mfu, "Learner MFU", self, training_time, total_sequence_lengths, None, 1, num_training_gpus
2017+
)
2018+
2019+
return {"mfu": learner_mfu}
2020+
19482021

19492022
def get_device_name(device_name: str) -> str:
19502023
"""Normalize a GPU device name to a standard key used in GPU_SPECS.

0 commit comments

Comments
 (0)