Skip to content

Commit cdddab8

Browse files
Update benchmark script, add Dockerfile
1 parent ef843a0 commit cdddab8

File tree

3 files changed

+105
-48
lines changed

3 files changed

+105
-48
lines changed

scripts/Dockerfile

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
FROM nvcr.io/nvidia/pytorch:22.11-py3
2+
3+
ENV PYTHONDONTWRITEBYTECODE 1
4+
ENV PYTHONUNBUFFERED 1
5+
6+
RUN pip install --pre xformers
7+
RUN pip install diffusers==0.11.0 accelerate transformers
8+
9+
WORKDIR /workspace
10+
11+
COPY benchmark.py /workspace/benchmark.py
12+
RUN (printf '#!/bin/bash\npython benchmark.py \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
13+
ENTRYPOINT ["/entry.sh"]

scripts/Makefile

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
bench:
2+
docker build -t sd-bench .
3+
docker run \
4+
--rm -it \
5+
--gpus all \
6+
--shm-size=128g \
7+
--net=host \
8+
-v $(PWD):/workspace/results \
9+
sd-bench \
10+
--steps 30 \
11+
--samples 1,2,4,8,16,32,64,128 \
12+
--autocast no \
13+
--xformers yes \
14+
--output_file /workspace/results/results.csv
15+
16+
clean:
17+
rm results.csv

scripts/benchmark.py

Lines changed: 75 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pathlib
66
import csv
77
from contextlib import nullcontext
8+
import itertools
89
import torch
910
from torch import autocast
1011
from diffusers import StableDiffusionPipeline, StableDiffusionOnnxPipeline
@@ -13,12 +14,18 @@
1314

1415
prompt = "a photo of an astronaut riding a horse on mars"
1516

17+
def make_bool(yes_or_no):
18+
if yes_or_no.lower() == "yes":
19+
return True
20+
elif yes_or_no.lower() == "no":
21+
return False
22+
else:
23+
raise ValueError(f"unrecognised input {yes_or_no}")
1624

1725
def get_inference_pipeline(precision, backend):
1826
"""
1927
returns HuggingFace diffuser pipeline
2028
cf https://github.com/huggingface/diffusers#text-to-image-generation-with-stable-diffusion
21-
note: could not download from CompVis/stable-diffusion-v1-4 (access restricted)
2229
"""
2330

2431
assert precision in ("half", "single"), "precision in ['half', 'single']"
@@ -28,7 +35,6 @@ def get_inference_pipeline(precision, backend):
2835
pipe = StableDiffusionPipeline.from_pretrained(
2936
"CompVis/stable-diffusion-v1-4",
3037
revision="main" if precision == "single" else "fp16",
31-
use_auth_token=os.environ["ACCESS_TOKEN"],
3238
torch_dtype=torch.float32 if precision == "single" else torch.float16,
3339
)
3440
pipe = pipe.to(device)
@@ -103,9 +109,9 @@ def get_inference_memory(pipe, n_samples, use_autocast, num_inference_steps):
103109
mem = torch.cuda.memory_reserved()
104110
return round(mem / 1e9, 2)
105111

106-
112+
@torch.inference_mode()
107113
def run_benchmark(
108-
n_repeats, n_samples, precision, use_autocast, backend, num_inference_steps
114+
n_repeats, n_samples, precision, use_autocast, xformers, backend, num_inference_steps
109115
):
110116
"""
111117
* n_repeats: nb datapoints for inference latency benchmark
@@ -116,7 +122,14 @@ def run_benchmark(
116122
dict like {'memory usage': 17.70, 'latency': 86.71'}
117123
"""
118124

125+
print(f"n_samples: {n_samples}\tprecision: {precision}\tautocast: {use_autocast}\txformers: {xformers}\tbackend: {backend}")
126+
119127
pipe = get_inference_pipeline(precision, backend)
128+
if xformers:
129+
pipe.enable_xformers_memory_efficient_attention()
130+
131+
if n_samples>16:
132+
pipe.enable_vae_slicing()
120133

121134
logs = {
122135
"memory": 0.00
@@ -128,8 +141,8 @@ def run_benchmark(
128141
pipe, n_samples, n_repeats, use_autocast, num_inference_steps
129142
),
130143
}
131-
print(f"n_samples: {n_samples}\tprecision: {precision}\tautocast: {use_autocast}\tbackend: {backend}")
132144
print(logs, "\n")
145+
print("============================")
133146
return logs
134147

135148

@@ -148,7 +161,7 @@ def get_device_description():
148161
return torch.cuda.get_device_name()
149162

150163

151-
def run_benchmark_grid(grid, n_repeats, num_inference_steps):
164+
def run_benchmark_grid(grid, n_repeats, num_inference_steps, csv_fpath):
152165
"""
153166
* grid : dict like
154167
{
@@ -159,13 +172,13 @@ def run_benchmark_grid(grid, n_repeats, num_inference_steps):
159172
* n_repeats: nb datapoints for inference latency benchmark
160173
"""
161174

162-
csv_fpath = pathlib.Path(__file__).parent.parent / "benchmark_tmp.csv"
163175
# create benchmark.csv if not exists
164176
if not os.path.isfile(csv_fpath):
165177
header = [
166178
"device",
167179
"precision",
168180
"autocast",
181+
"xformers"
169182
"runtime",
170183
"n_samples",
171184
"latency",
@@ -179,45 +192,44 @@ def run_benchmark_grid(grid, n_repeats, num_inference_steps):
179192
with open(csv_fpath, "a") as f:
180193
writer = csv.writer(f)
181194
device_desc = get_device_description()
182-
for n_samples in grid["n_samples"]:
183-
for precision in grid["precision"]:
184-
use_autocast = False
185-
if precision == "half":
186-
for autocast in grid["autocast"]:
187-
if autocast == "yes":
188-
use_autocast = True
189-
for backend in grid["backend"]:
190-
try:
191-
new_log = run_benchmark(
192-
n_repeats=n_repeats,
193-
n_samples=n_samples,
194-
precision=precision,
195-
use_autocast=use_autocast,
196-
backend=backend,
197-
num_inference_steps=num_inference_steps,
198-
)
199-
except Exception as e:
200-
if "CUDA out of memory" in str(
201-
e
202-
) or "Failed to allocate memory" in str(e):
203-
print(str(e))
204-
torch.cuda.empty_cache()
205-
new_log = {"latency": -1.00, "memory": -1.00}
206-
else:
207-
raise e
208-
209-
latency = new_log["latency"]
210-
memory = new_log["memory"]
211-
new_row = [
212-
device_desc,
213-
precision,
214-
autocast,
215-
backend,
216-
n_samples,
217-
latency,
218-
memory,
219-
]
220-
writer.writerow(new_row)
195+
for trial in itertools.product(*grid.values()):
196+
197+
n_samples, precision, use_autocast, xformers, backend = trial
198+
use_autocast = make_bool(use_autocast)
199+
xformers = make_bool(xformers)
200+
201+
try:
202+
new_log = run_benchmark(
203+
n_repeats=n_repeats,
204+
n_samples=n_samples,
205+
precision=precision,
206+
use_autocast=use_autocast,
207+
xformers=xformers,
208+
backend=backend,
209+
num_inference_steps=num_inference_steps,
210+
)
211+
except Exception as e:
212+
if "CUDA out of memory" in str(
213+
e
214+
) or "Failed to allocate memory" in str(e):
215+
print(str(e))
216+
torch.cuda.empty_cache()
217+
new_log = {"latency": -1.00, "memory": -1.00}
218+
else:
219+
raise e
220+
221+
latency = new_log["latency"]
222+
memory = new_log["memory"]
223+
new_row = [
224+
device_desc,
225+
precision,
226+
use_autocast,
227+
backend,
228+
n_samples,
229+
latency,
230+
memory,
231+
]
232+
writer.writerow(new_row)
221233

222234

223235
if __name__ == "__main__":
@@ -249,6 +261,20 @@ def run_benchmark_grid(grid, n_repeats, num_inference_steps):
249261
help="If 'yes', will perform additional runs with autocast activated for half precision inferences",
250262
)
251263

264+
parser.add_argument(
265+
"--xformers",
266+
default="yes",
267+
type=str,
268+
help="If 'yes', will use xformers flash attention",
269+
)
270+
271+
parser.add_argument(
272+
"--output_file",
273+
default="results.py",
274+
type=str,
275+
help="Path to output csv file to write",
276+
)
277+
252278
args = parser.parse_args()
253279

254280
grid = {
@@ -257,10 +283,11 @@ def run_benchmark_grid(grid, n_repeats, num_inference_steps):
257283
# Remove autocast won't help. Ref:
258284
# https://github.com/CompVis/stable-diffusion/issues/307
259285
"precision": ("single",) if device.type == "cpu" else ("single", "half"),
260-
"autocast": ("no",) if args.autocast == "no" else ("yes", "no"),
286+
"autocast": args.autocast.split(","),
287+
"xformers": args.xformers.split(","),
261288
# Only use onnx for cpu, until issues are fixed by upstreams. Ref:
262289
# https://github.com/huggingface/diffusers/issues/489#issuecomment-1261577250
263290
# https://github.com/huggingface/diffusers/pull/440
264291
"backend": ("pytorch", "onnx") if device.type == "cpu" else ("pytorch",),
265292
}
266-
run_benchmark_grid(grid, n_repeats=args.repeats, num_inference_steps=args.steps)
293+
run_benchmark_grid(grid, n_repeats=args.repeats, num_inference_steps=args.steps, csv_fpath=args.output_file)

0 commit comments

Comments
 (0)