|
10 | 10 | - run() function is the main entry point for running inference benchmarks.
|
11 | 11 | """
|
12 | 12 |
|
| 13 | +import os |
13 | 14 | from copy import deepcopy
|
14 | 15 | from pathlib import Path
|
15 | 16 |
|
16 | 17 | import torch
|
17 | 18 |
|
18 | 19 | from benchmarks.microbenchmarks.profiler import (
|
| 20 | + generate_memory_profile, |
19 | 21 | generate_model_profile,
|
| 22 | + visualize_memory_profile, |
20 | 23 | )
|
21 | 24 | from benchmarks.microbenchmarks.utils import (
|
22 | 25 | BenchmarkConfig,
|
@@ -98,11 +101,49 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
|
98 | 101 | if config.enable_profiler:
|
99 | 102 | print("Running profiler...")
|
100 | 103 | try:
|
101 |
| - result.profiler_json_path = generate_model_profile( |
102 |
| - m_copy, input_data, config.profiler_file_name |
| 104 | + profiler_json_path = generate_model_profile( |
| 105 | + model=m_copy, |
| 106 | + input_data=input_data, |
| 107 | + profile_file_path=os.path.join( |
| 108 | + config.output_dir, |
| 109 | + "profiler", |
| 110 | + f"{config._file_name}_profile.json", |
| 111 | + ), |
103 | 112 | )
|
| 113 | + result.profiler_json_path = profiler_json_path |
104 | 114 | except Exception as e:
|
105 |
| - print(f"Error running profiler for {config.name} with error: {e}") |
| 115 | + print(f"Error running profiler: {e}") |
| 116 | + |
| 117 | + # Run memory profiler if enabled |
| 118 | + if config.enable_memory_profiler: |
| 119 | + print("Running memory profiler...") |
| 120 | + try: |
| 121 | + result.memory_profile_path, result.memory_stats = ( |
| 122 | + generate_memory_profile( |
| 123 | + model=m_copy, |
| 124 | + input_data=input_data, |
| 125 | + profile_file_path=os.path.join( |
| 126 | + config.output_dir, |
| 127 | + "memory_profiler/pickle", |
| 128 | + f"{config._file_name}_memory_profile.pickle", |
| 129 | + ), |
| 130 | + ) |
| 131 | + ) |
| 132 | + |
| 133 | + if result.memory_profile_path: |
| 134 | + result.memory_visualization_path = visualize_memory_profile( |
| 135 | + result.memory_profile_path |
| 136 | + ) |
| 137 | + except ValueError as e: |
| 138 | + if "not enough values to unpack" in e: |
| 139 | + print( |
| 140 | + "Failed due to existing bugs, re-run the code to generate memory profile. Please raise an issue if it persists." |
| 141 | + ) |
| 142 | + except Exception as e: |
| 143 | + print(f"Error running memory profiler: {e}") |
| 144 | + import traceback |
| 145 | + |
| 146 | + traceback.print_exc() |
106 | 147 |
|
107 | 148 | return result
|
108 | 149 | except Exception as e:
|
|
0 commit comments