Skip to content

Commit 4c74b72

Browse files
authored
Merge pull request #7 from LambdaLabsML/eole/fp16-v-autocast
fp16 v autocast
2 parents fe172cb + cc801b8 commit 4c74b72

File tree

3 files changed

+108
-58
lines changed

3 files changed

+108
-58
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,5 @@ cython_debug/
163163
# and can be added to the global gitignore or merged into this file. For a more nuclear
164164
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
165165
#.idea/
166+
167+
.vscode/

.vscode/settings.json

Lines changed: 0 additions & 3 deletions
This file was deleted.

scripts/benchmark.py

Lines changed: 106 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
prompt = "a photo of an astronaut riding a horse on mars"
1515

16+
1617
def get_inference_pipeline(precision, backend):
1718
"""
1819
returns HuggingFace diffuser pipeline
@@ -31,12 +32,14 @@ def get_inference_pipeline(precision, backend):
3132
torch_dtype=torch.float32 if precision == "single" else torch.float16,
3233
)
3334
pipe = pipe.to(device)
34-
else:
35+
else:
3536
pipe = StableDiffusionOnnxPipeline.from_pretrained(
3637
"CompVis/stable-diffusion-v1-4",
3738
use_auth_token=os.environ["ACCESS_TOKEN"],
3839
revision="onnx",
39-
provider="CPUExecutionProvider" if device.type=="cpu" else "CUDAExecutionProvider",
40+
provider="CPUExecutionProvider"
41+
if device.type == "cpu"
42+
else "CUDAExecutionProvider",
4043
torch_dtype=torch.float32 if precision == "single" else torch.float16,
4144
)
4245

@@ -51,43 +54,59 @@ def null_safety(images, **kwargs):
5154
return pipe
5255

5356

54-
def do_inference(pipe, n_samples, precision, num_inference_steps):
57+
def do_inference(pipe, n_samples, use_autocast, num_inference_steps):
5558
torch.cuda.empty_cache()
56-
context = autocast if (device.type == "cuda" and precision == 'half') else nullcontext
59+
context = (
60+
autocast if (device.type == "cuda" and use_autocast) else nullcontext
61+
)
5762
with context("cuda"):
58-
images = pipe(prompt=[prompt] * n_samples, num_inference_steps=num_inference_steps).images
63+
images = pipe(
64+
prompt=[prompt] * n_samples, num_inference_steps=num_inference_steps
65+
).images
5966

6067
return images
6168

6269

63-
def get_inference_time(pipe, n_samples, n_repeats, precision, num_inference_steps):
70+
def get_inference_time(
71+
pipe, n_samples, n_repeats, use_autocast, num_inference_steps
72+
):
6473
from torch.utils.benchmark import Timer
74+
6575
timer = Timer(
66-
stmt="do_inference(pipe, n_samples, precision, num_inference_steps)",
76+
stmt="do_inference(pipe, n_samples, use_autocast, num_inference_steps)",
6777
setup="from __main__ import do_inference",
68-
globals={"pipe": pipe, "n_samples": n_samples, "precision": precision, "num_inference_steps": num_inference_steps},
69-
num_threads=multiprocessing.cpu_count()
78+
globals={
79+
"pipe": pipe,
80+
"n_samples": n_samples,
81+
"use_autocast": use_autocast,
82+
"num_inference_steps": num_inference_steps,
83+
},
84+
num_threads=multiprocessing.cpu_count(),
7085
)
7186
profile_result = timer.timeit(
7287
n_repeats
7388
) # benchmark.Timer performs 2 iterations for warmup
7489
return round(profile_result.mean, 2)
7590

7691

77-
def get_inference_memory(pipe, n_samples, precision, num_inference_steps):
92+
def get_inference_memory(pipe, n_samples, use_autocast, num_inference_steps):
7893
if not torch.cuda.is_available():
7994
return 0
80-
95+
8196
torch.cuda.empty_cache()
82-
context = autocast if (device.type == "cuda" and precision == 'half') else nullcontext
97+
context = autocast if (device.type == "cuda" and use_autocast) else nullcontext
8398
with context("cuda"):
84-
images = pipe(prompt=[prompt] * n_samples, num_inference_steps=num_inference_steps).images
99+
images = pipe(
100+
prompt=[prompt] * n_samples, num_inference_steps=num_inference_steps
101+
).images
85102

86103
mem = torch.cuda.memory_reserved()
87104
return round(mem / 1e9, 2)
88105

89106

90-
def run_benchmark(n_repeats, n_samples, precision, backend, num_inference_steps):
107+
def run_benchmark(
108+
n_repeats, n_samples, precision, use_autocast, backend, num_inference_steps
109+
):
91110
"""
92111
* n_repeats: nb datapoints for inference latency benchmark
93112
* n_samples: number of samples to generate (~ batch size)
@@ -100,10 +119,16 @@ def run_benchmark(n_repeats, n_samples, precision, backend, num_inference_steps)
100119
pipe = get_inference_pipeline(precision, backend)
101120

102121
logs = {
103-
"memory": 0.00 if device.type=="cpu" else get_inference_memory(pipe, n_samples, precision, num_inference_steps),
104-
"latency": get_inference_time(pipe, n_samples, n_repeats, precision, num_inference_steps),
122+
"memory": 0.00
123+
if device.type == "cpu"
124+
else get_inference_memory(
125+
pipe, n_samples, use_autocast, num_inference_steps
126+
),
127+
"latency": get_inference_time(
128+
pipe, n_samples, n_repeats, use_autocast, num_inference_steps
129+
),
105130
}
106-
print(f"n_samples: {n_samples}\tprecision: {precision}\tbackend: {backend}")
131+
print(f"n_samples: {n_samples}\tprecision: {precision}\tautocast: {use_autocast}\tbackend: {backend}")
107132
print(logs, "\n")
108133
return logs
109134

@@ -115,9 +140,8 @@ def get_device_description():
115140
"""
116141
if device.type == "cpu":
117142
name = subprocess.check_output(
118-
"grep -m 1 'model name' /proc/cpuinfo",
119-
shell=True
120-
).decode("utf-8")
143+
"grep -m 1 'model name' /proc/cpuinfo", shell=True
144+
).decode("utf-8")
121145
name = " ".join(name.split(" ")[2:]).strip()
122146
return name
123147
else:
@@ -130,14 +154,23 @@ def run_benchmark_grid(grid, n_repeats, num_inference_steps):
130154
{
131155
"n_samples": (1, 2),
132156
"precision": ("single", "half"),
157+
"autocast" : ("yes", "no")
133158
}
134159
* n_repeats: nb datapoints for inference latency benchmark
135160
"""
136161

137162
csv_fpath = pathlib.Path(__file__).parent.parent / "benchmark_tmp.csv"
138163
# create benchmark.csv if not exists
139164
if not os.path.isfile(csv_fpath):
140-
header = ["device", "precision", "runtime", "n_samples", "latency", "memory"]
165+
header = [
166+
"device",
167+
"precision",
168+
"autocast",
169+
"runtime",
170+
"n_samples",
171+
"latency",
172+
"memory",
173+
]
141174
with open(csv_fpath, "w") as f:
142175
writer = csv.writer(f)
143176
writer.writerow(header)
@@ -148,48 +181,58 @@ def run_benchmark_grid(grid, n_repeats, num_inference_steps):
148181
device_desc = get_device_description()
149182
for n_samples in grid["n_samples"]:
150183
for precision in grid["precision"]:
151-
for backend in grid["backend"]:
152-
try:
153-
new_log = run_benchmark(
154-
n_repeats=n_repeats,
155-
n_samples=n_samples,
156-
precision=precision,
157-
backend=backend,
158-
num_inference_steps=num_inference_steps
159-
)
160-
except Exception as e:
161-
if "CUDA out of memory" in str(e) or "Failed to allocate memory" in str(e):
162-
print(str(e))
163-
torch.cuda.empty_cache()
164-
new_log = {
165-
"latency": -1.00,
166-
"memory": -1.00
167-
}
168-
else:
169-
raise e
170-
171-
latency = new_log["latency"]
172-
memory = new_log["memory"]
173-
new_row = [device_desc, precision, backend, n_samples, latency, memory]
174-
writer.writerow(new_row)
184+
use_autocast = False
185+
if precision == "half":
186+
for autocast in grid["autocast"]:
187+
if autocast == "yes":
188+
use_autocast = True
189+
for backend in grid["backend"]:
190+
try:
191+
new_log = run_benchmark(
192+
n_repeats=n_repeats,
193+
n_samples=n_samples,
194+
precision=precision,
195+
use_autocast=use_autocast,
196+
backend=backend,
197+
num_inference_steps=num_inference_steps,
198+
)
199+
except Exception as e:
200+
if "CUDA out of memory" in str(
201+
e
202+
) or "Failed to allocate memory" in str(e):
203+
print(str(e))
204+
torch.cuda.empty_cache()
205+
new_log = {"latency": -1.00, "memory": -1.00}
206+
else:
207+
raise e
208+
209+
latency = new_log["latency"]
210+
memory = new_log["memory"]
211+
new_row = [
212+
device_desc,
213+
precision,
214+
autocast,
215+
backend,
216+
n_samples,
217+
latency,
218+
memory,
219+
]
220+
writer.writerow(new_row)
175221

176222

177223
if __name__ == "__main__":
178224

179225
parser = argparse.ArgumentParser()
180226

181227
parser.add_argument(
182-
"--samples",
228+
"--samples",
183229
default="1",
184-
type=str,
185-
help="Comma sepearated list of batch sizes (number of samples)"
230+
type=str,
231+
help="Comma sepearated list of batch sizes (number of samples)",
186232
)
187233

188234
parser.add_argument(
189-
"--steps",
190-
default=50,
191-
type=int,
192-
help="Number of diffusion steps."
235+
"--steps", default=50, type=int, help="Number of diffusion steps."
193236
)
194237

195238
parser.add_argument(
@@ -199,17 +242,25 @@ def run_benchmark_grid(grid, n_repeats, num_inference_steps):
199242
help="Number of repeats.",
200243
)
201244

245+
parser.add_argument(
246+
"--autocast",
247+
default="no",
248+
type=str,
249+
help="If 'yes', will perform additional runs with autocast activated for half precision inferences",
250+
)
251+
202252
args = parser.parse_args()
203253

204254
grid = {
205-
"n_samples": tuple(map(int, args.samples.split(","))),
206-
# Only use single-precision for cpu because "LayerNormKernelImpl" not implemented for 'Half' on cpu,
255+
"n_samples": tuple(map(int, args.samples.split(","))),
256+
# Only use single-precision for cpu because "LayerNormKernelImpl" not implemented for 'Half' on cpu,
207257
# Remove autocast won't help. Ref:
208258
# https://github.com/CompVis/stable-diffusion/issues/307
209259
"precision": ("single",) if device.type == "cpu" else ("single", "half"),
260+
"autocast": ("no",) if args.autocast == "no" else ("yes", "no"),
210261
# Only use onnx for cpu, until issues are fixed by upstreams. Ref:
211262
# https://github.com/huggingface/diffusers/issues/489#issuecomment-1261577250
212263
# https://github.com/huggingface/diffusers/pull/440
213-
"backend": ("pytorch", "onnx") if device.type == "cpu" else ("pytorch",)
264+
"backend": ("pytorch", "onnx") if device.type == "cpu" else ("pytorch",),
214265
}
215266
run_benchmark_grid(grid, n_repeats=args.repeats, num_inference_steps=args.steps)

0 commit comments

Comments
 (0)