Skip to content

Commit 490bdf6

Browse files
committed
Merge remote-tracking branch 'pytorch/main' into parq
2 parents 9e70d6d + e3db2b2 commit 490bdf6

40 files changed

+986
-452
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# See https://pre-commit.com/hooks.html for more hooks
33
repos:
44
- repo: https://github.com/pre-commit/pre-commit-hooks
5-
rev: v4.4.0
5+
rev: v5.0.0
66
hooks:
77
- id: trailing-whitespace
88
- id: end-of-file-fixer

benchmarks/microbenchmarks/README.md

+13-2
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@ python -m benchmarks.microbenchmarks.benchmark_runner --config path/to/config.ym
3030

3131
```yaml
3232
# Sample configuration for inference benchmarks
33+
benchmark_mode: "inference"
3334
quantization_config_recipe_names:
3435
- "baseline"
3536
- "int8wo"
36-
- "int4wo-128"
37-
- "int4wo-128-hqq"
37+
- "float8wo"
38+
- "float8dq-tensor"
3839

3940
output_dir: "benchmarks/microbenchmarks/results"
4041

@@ -50,10 +51,20 @@ model_params:
5051
compile: "max-autotune" # Options: "default", "max-autotune", "false"
5152
device: "cuda" # Options: "cuda", "mps", "xpu", "cpu"
5253
model_type: "linear" # Options: "linear", "ln_linear_sigmoid"
54+
enable_profiler: true # Enable standard profiling
55+
enable_memory_profiler: true # Enable CUDA memory profiling
5356
```
5457
5558
## Configuration Options
5659
60+
### Profiling Options
61+
- `enable_profiler`: Enable standard PyTorch profiling (default: false)
62+
- `enable_memory_profiler`: Enable CUDA memory profiling (default: false)
63+
- Only works when device is set to "cuda"
64+
- Generates memory snapshots before and after inference
65+
- Creates visualizations of memory usage
66+
- Outputs are saved in the memory_profiler subdirectory
67+
5768
### Quantization Methods
5869
Currently, quantization string is in same format as the one being passed in llama/generate.py.
5970
- `baseline`: No quantization

benchmarks/microbenchmarks/benchmark_inference.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
- run() function is the main entry point for running inference benchmarks.
1111
"""
1212

13+
import os
1314
from copy import deepcopy
1415
from pathlib import Path
1516

1617
import torch
1718

1819
from benchmarks.microbenchmarks.profiler import (
20+
generate_memory_profile,
1921
generate_model_profile,
22+
visualize_memory_profile,
2023
)
2124
from benchmarks.microbenchmarks.utils import (
2225
BenchmarkConfig,
@@ -98,11 +101,49 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
98101
if config.enable_profiler:
99102
print("Running profiler...")
100103
try:
101-
result.profiler_json_path = generate_model_profile(
102-
m_copy, input_data, config.profiler_file_name
104+
profiler_json_path = generate_model_profile(
105+
model=m_copy,
106+
input_data=input_data,
107+
profile_file_path=os.path.join(
108+
config.output_dir,
109+
"profiler",
110+
f"{config._file_name}_profile.json",
111+
),
103112
)
113+
result.profiler_json_path = profiler_json_path
104114
except Exception as e:
105-
print(f"Error running profiler for {config.name} with error: {e}")
115+
print(f"Error running profiler: {e}")
116+
117+
# Run memory profiler if enabled
118+
if config.enable_memory_profiler:
119+
print("Running memory profiler...")
120+
try:
121+
result.memory_profile_path, result.memory_stats = (
122+
generate_memory_profile(
123+
model=m_copy,
124+
input_data=input_data,
125+
profile_file_path=os.path.join(
126+
config.output_dir,
127+
"memory_profiler/pickle",
128+
f"{config._file_name}_memory_profile.pickle",
129+
),
130+
)
131+
)
132+
133+
if result.memory_profile_path:
134+
result.memory_visualization_path = visualize_memory_profile(
135+
result.memory_profile_path
136+
)
137+
except ValueError as e:
138+
if "not enough values to unpack" in e:
139+
print(
140+
"Failed due to existing bugs, re-run the code to generate memory profile. Please raise an issue if it persists."
141+
)
142+
except Exception as e:
143+
print(f"Error running memory profiler: {e}")
144+
import traceback
145+
146+
traceback.print_exc()
106147

107148
return result
108149
except Exception as e:

0 commit comments

Comments
 (0)