Skip to content

Commit df143a8

Browse files
pythonomar22Simon Guo
andauthored
rewriting deprecated modal scaledownwindow and adding pydra config (#81)
Co-authored-by: Simon Guo <[email protected]>
1 parent 89544df commit df143a8

File tree

1 file changed

+118
-93
lines changed

1 file changed

+118
-93
lines changed

scripts/generate_baseline_time_modal.py

Lines changed: 118 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import multiprocessing as mp
1616
import time
1717
import einops
18+
import pydra
19+
from pydra import Config, REQUIRED
1820

1921
"""
2022
Generate baseline time for KernelBench
@@ -48,6 +50,28 @@
4850

4951
TIMING_DIR = os.path.join(REPO_TOP_PATH, "results", "timing")
5052

53+
54+
class BaselineConfig(Config):
55+
def __init__(self):
56+
# Problem level to generate baseline for
57+
self.level = REQUIRED
58+
59+
# GPU type for Modal (L40S, H100, A100, A100-80GB, L4, T4, A10G)
60+
self.gpu = REQUIRED
61+
62+
# Hardware name for saving results
63+
self.hardware_name = REQUIRED
64+
65+
# Batch size for parallel processing
66+
self.batch_size = 10
67+
68+
# Timeout for each batch in seconds
69+
self.timeout = 1800
70+
71+
# Number of trials for timing
72+
self.num_trials = 100
73+
74+
5175
# Modal Infra
5276
import modal
5377
app = modal.App("generate_baseline_modal")
@@ -127,7 +151,7 @@ def fetch_ref_arch_from_dataset(dataset: list[str],
127151
ref_arch_name = ref_arch_path.split("/")[-1]
128152
return (ref_arch_path, ref_arch_name, ref_arch_src)
129153

130-
@app.cls(image=image, container_idle_timeout=5)
154+
@app.cls(image=image, scaledown_window=5)
131155
class EvalFunc:
132156

133157
@modal.method()
@@ -188,121 +212,122 @@ def measure_program_time(
188212
except Exception as e:
189213
print(f"[Eval] Error in Measuring Performance: {e}")
190214

191-
def measure_program_time_wrapper(*args, **kwargs):
215+
def measure_program_time_wrapper(gpu_type, *args, **kwargs):
192216
with app.run():
193-
return EvalFunc.with_options(gpu=gpu)().measure_program_time.remote(*args, **kwargs)
217+
return EvalFunc.with_options(gpu=gpu_type)().measure_program_time.remote(*args, **kwargs)
194218

195-
def record_baseline_times(use_torch_compile: bool = False,
196-
torch_compile_backend: str="inductor",
219+
def record_baseline_times(config: BaselineConfig,
220+
use_torch_compile: bool = False,
221+
torch_compile_backend: str="inductor",
197222
torch_compile_options: str="default",
198223
file_name: str="baseline_time.json"):
199224
"""
200-
Generate baseline time for KernelBench,
225+
Generate baseline time for KernelBench,
201226
configure profiler options for PyTorch
202227
save to specified file
203228
"""
204229
json_results = []
205230

206-
for level in [1, 2, 3]:
207-
PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level))
208-
dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR)
209-
num_problems = len(dataset)
210-
total_work = [(i, *fetch_ref_arch_from_dataset(dataset, i)) for i in list(range(1, num_problems + 1))]
211-
212-
with tqdm(total=len(total_work), desc="Processing batches") as pbar:
213-
while len(total_work) > 0:
214-
curr_work_batch = total_work[:batch_size]
215-
total_work = total_work[batch_size:] # pop the first batch_size elements
216-
217-
with mp.Pool() as pool:
218-
219-
work_args = [
220-
(
221-
ref_arch_name,
222-
ref_arch_src,
223-
100,
224-
use_torch_compile,
225-
torch_compile_backend,
226-
torch_compile_options,
227-
torch.device(f"cuda:0"),
228-
False # do not print
229-
)
230-
for i, (p_id, ref_arch_path, ref_arch_name, ref_arch_src) in enumerate(curr_work_batch)
231-
]
231+
level = config.level
232+
PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level))
233+
dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR)
234+
num_problems = len(dataset)
235+
total_work = [(i, *fetch_ref_arch_from_dataset(dataset, i)) for i in list(range(1, num_problems + 1))]
236+
237+
with tqdm(total=len(total_work), desc="Processing batches") as pbar:
238+
while len(total_work) > 0:
239+
curr_work_batch = total_work[:config.batch_size]
240+
total_work = total_work[config.batch_size:] # pop the first batch_size elements
241+
242+
with mp.Pool() as pool:
243+
244+
work_args = [
245+
(
246+
config.gpu,
247+
ref_arch_name,
248+
ref_arch_src,
249+
config.num_trials,
250+
use_torch_compile,
251+
torch_compile_backend,
252+
torch_compile_options,
253+
torch.device(f"cuda:0"),
254+
False # do not print
255+
)
256+
for i, (p_id, ref_arch_path, ref_arch_name, ref_arch_src) in enumerate(curr_work_batch)
257+
]
258+
259+
start_time = time.time()
232260

233-
start_time = time.time()
261+
async_results = []
262+
for work_arg in work_args:
263+
async_results.append(
264+
pool.apply_async(measure_program_time_wrapper, work_arg)
265+
)
234266

235-
async_results = []
236-
for work_arg in work_args:
237-
async_results.append(
238-
pool.apply_async(measure_program_time_wrapper, work_arg)
267+
batch_timeout = config.timeout
268+
for i, async_result in enumerate(async_results):
269+
problem_id, _, ref_arch_name, _ = curr_work_batch[i]
270+
271+
try:
272+
elapsed_time = time.time() - start_time
273+
remaining_time = max(0, batch_timeout - elapsed_time)
274+
result = async_result.get(timeout=remaining_time)
275+
json_results.append((f"level{level}", ref_arch_name, result))
276+
277+
except mp.TimeoutError:
278+
print(
279+
f"[WARNING] Evaluation TIMED OUT for Problem ID: {problem_id}"
239280
)
281+
json_results.append((f"level{level}", ref_arch_name, None))
240282

241-
batch_timeout = timeout
242-
for i, async_result in enumerate(async_results):
243-
problem_id, _, ref_arch_name, _ = curr_work_batch[i]
244-
245-
try:
246-
elapsed_time = time.time() - start_time
247-
remaining_time = max(0, batch_timeout - elapsed_time)
248-
result = async_result.get(timeout=remaining_time)
249-
json_results.append((f"level{level}", ref_arch_name, result))
250-
251-
except mp.TimeoutError:
252-
print(
253-
f"[WARNING] Evaluation TIMED OUT for Problem ID: {problem_id}"
254-
)
255-
json_results.append((f"level{level}", ref_arch_name, None))
256-
257-
except Exception as e:
258-
print(
259-
f"[ERROR] Evaluation FAILED for Problem ID: {problem_id}: {str(e)}"
260-
)
261-
json_results.append((f"level{level}", ref_arch_name, None))
283+
except Exception as e:
284+
print(
285+
f"[ERROR] Evaluation FAILED for Problem ID: {problem_id}: {str(e)}"
286+
)
287+
json_results.append((f"level{level}", ref_arch_name, None))
262288

263-
pbar.update(len(curr_work_batch))
289+
pbar.update(len(curr_work_batch))
264290

265291
save_path = os.path.join(TIMING_DIR, file_name)
266292
write_batch_to_json(json_results, save_path)
267293
return json_results
268294

269295

270-
if __name__ == "__main__":
271-
# DEBUG and simple testing
272-
# test_measure_particular_program(2, 28)
273-
gpu = "A10G"
274-
# Replace this with whatever hardware you are running on
275-
hardware_name = f"{gpu}_modal"
276-
print(f"Generating baseline time for {hardware_name}")
277-
# input(f"You are about to start recording baseline time for {hardware_name}, press Enter to continue...")
278-
# # Systematic recording of baseline time
279-
280-
# if os.path.exists(os.path.join(TIMING_DIR, hardware_name)):
281-
# input(f"Directory {hardware_name} already exists, Are you sure you want to overwrite? Enter to continue...")
296+
@pydra.main(base=BaselineConfig)
297+
def main(config: BaselineConfig):
298+
"""
299+
Generate baseline time for KernelBench problems using Modal GPUs
300+
"""
301+
print(f"Generating baseline time for level {config.level} on {config.gpu} Modal")
302+
print(f"Hardware name: {config.hardware_name}")
303+
print(f"Batch size: {config.batch_size}, Timeout: {config.timeout}s, Num trials: {config.num_trials}")
282304

283305
# 1. Record Torch Eager
284-
record_baseline_times(use_torch_compile=False,
285-
torch_compile_backend=None,
286-
torch_compile_options=None,
287-
file_name=f"{hardware_name}/baseline_time_torch.json")
288-
289-
record_baseline_times(use_torch_compile=True,
290-
torch_compile_backend="inductor",
291-
torch_compile_options="default",
292-
file_name=f"{hardware_name}/baseline_time_torch_compile_inductor_default.json")
293-
294-
# 2. Record Torch Compile using Inductor
295-
# for torch_compile_mode in ["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]:
296-
# record_baseline_times(use_torch_compile=True,
297-
# torch_compile_backend="inductor",
298-
# torch_compile_options=torch_compile_mode,
299-
# file_name=f"{hardware_name}/baseline_time_torch_compile_inductor_{torch_compile_mode}.json")
300-
301-
# 3. Record Torch Compile using cudagraphs
302-
# record_baseline_times(use_torch_compile=True,
303-
# torch_compile_backend="cudagraphs",
304-
# torch_compile_options=None,
305-
# file_name=f"{hardware_name}/baseline_time_torch_compile_cudagraphs.json")
306+
print("\n[1/2] Recording baseline times with PyTorch Eager execution...")
307+
record_baseline_times(
308+
config=config,
309+
use_torch_compile=False,
310+
torch_compile_backend=None,
311+
torch_compile_options=None,
312+
file_name=f"{config.hardware_name}/baseline_time_torch.json"
313+
)
314+
315+
# 2. Record Torch Compile using Inductor (default mode)
316+
print("\n[2/2] Recording baseline times with Torch Compile (inductor, default mode)...")
317+
record_baseline_times(
318+
config=config,
319+
use_torch_compile=True,
320+
torch_compile_backend="inductor",
321+
torch_compile_options="default",
322+
file_name=f"{config.hardware_name}/baseline_time_torch_compile_inductor_default.json"
323+
)
324+
325+
print(f"\n✓ Baseline time generation complete!")
326+
print(f"Results saved to: {os.path.join(TIMING_DIR, config.hardware_name)}")
327+
328+
329+
if __name__ == "__main__":
330+
main()
306331

307332

308333

0 commit comments

Comments
 (0)