Skip to content

Commit

Permalink
Merge pull request #7 from LambdaLabsML/eole/fp16-v-autocast
Browse files Browse the repository at this point in the history
fp16 v autocast
  • Loading branch information
eolecvk authored Oct 11, 2022
2 parents fe172cb + cc801b8 commit 4c74b72
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 58 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.vscode/
3 changes: 0 additions & 3 deletions .vscode/settings.json

This file was deleted.

161 changes: 106 additions & 55 deletions scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

prompt = "a photo of an astronaut riding a horse on mars"


def get_inference_pipeline(precision, backend):
"""
returns HuggingFace diffuser pipeline
Expand All @@ -31,12 +32,14 @@ def get_inference_pipeline(precision, backend):
torch_dtype=torch.float32 if precision == "single" else torch.float16,
)
pipe = pipe.to(device)
else:
else:
pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=os.environ["ACCESS_TOKEN"],
revision="onnx",
provider="CPUExecutionProvider" if device.type=="cpu" else "CUDAExecutionProvider",
provider="CPUExecutionProvider"
if device.type == "cpu"
else "CUDAExecutionProvider",
torch_dtype=torch.float32 if precision == "single" else torch.float16,
)

Expand All @@ -51,43 +54,59 @@ def null_safety(images, **kwargs):
return pipe


def do_inference(pipe, n_samples, precision, num_inference_steps):
def do_inference(pipe, n_samples, use_autocast, num_inference_steps):
torch.cuda.empty_cache()
context = autocast if (device.type == "cuda" and precision == 'half') else nullcontext
context = (
autocast if (device.type == "cuda" and use_autocast) else nullcontext
)
with context("cuda"):
images = pipe(prompt=[prompt] * n_samples, num_inference_steps=num_inference_steps).images
images = pipe(
prompt=[prompt] * n_samples, num_inference_steps=num_inference_steps
).images

return images


def get_inference_time(pipe, n_samples, n_repeats, precision, num_inference_steps):
def get_inference_time(
pipe, n_samples, n_repeats, use_autocast, num_inference_steps
):
from torch.utils.benchmark import Timer

timer = Timer(
stmt="do_inference(pipe, n_samples, precision, num_inference_steps)",
stmt="do_inference(pipe, n_samples, use_autocast, num_inference_steps)",
setup="from __main__ import do_inference",
globals={"pipe": pipe, "n_samples": n_samples, "precision": precision, "num_inference_steps": num_inference_steps},
num_threads=multiprocessing.cpu_count()
globals={
"pipe": pipe,
"n_samples": n_samples,
"use_autocast": use_autocast,
"num_inference_steps": num_inference_steps,
},
num_threads=multiprocessing.cpu_count(),
)
profile_result = timer.timeit(
n_repeats
) # benchmark.Timer performs 2 iterations for warmup
return round(profile_result.mean, 2)


def get_inference_memory(pipe, n_samples, precision, num_inference_steps):
def get_inference_memory(pipe, n_samples, use_autocast, num_inference_steps):
if not torch.cuda.is_available():
return 0

torch.cuda.empty_cache()
context = autocast if (device.type == "cuda" and precision == 'half') else nullcontext
context = autocast if (device.type == "cuda" and use_autocast) else nullcontext
with context("cuda"):
images = pipe(prompt=[prompt] * n_samples, num_inference_steps=num_inference_steps).images
images = pipe(
prompt=[prompt] * n_samples, num_inference_steps=num_inference_steps
).images

mem = torch.cuda.memory_reserved()
return round(mem / 1e9, 2)


def run_benchmark(n_repeats, n_samples, precision, backend, num_inference_steps):
def run_benchmark(
n_repeats, n_samples, precision, use_autocast, backend, num_inference_steps
):
"""
* n_repeats: nb datapoints for inference latency benchmark
* n_samples: number of samples to generate (~ batch size)
Expand All @@ -100,10 +119,16 @@ def run_benchmark(n_repeats, n_samples, precision, backend, num_inference_steps)
pipe = get_inference_pipeline(precision, backend)

logs = {
"memory": 0.00 if device.type=="cpu" else get_inference_memory(pipe, n_samples, precision, num_inference_steps),
"latency": get_inference_time(pipe, n_samples, n_repeats, precision, num_inference_steps),
"memory": 0.00
if device.type == "cpu"
else get_inference_memory(
pipe, n_samples, use_autocast, num_inference_steps
),
"latency": get_inference_time(
pipe, n_samples, n_repeats, use_autocast, num_inference_steps
),
}
print(f"n_samples: {n_samples}\tprecision: {precision}\tbackend: {backend}")
print(f"n_samples: {n_samples}\tprecision: {precision}\tautocast: {use_autocast}\tbackend: {backend}")
print(logs, "\n")
return logs

Expand All @@ -115,9 +140,8 @@ def get_device_description():
"""
if device.type == "cpu":
name = subprocess.check_output(
"grep -m 1 'model name' /proc/cpuinfo",
shell=True
).decode("utf-8")
"grep -m 1 'model name' /proc/cpuinfo", shell=True
).decode("utf-8")
name = " ".join(name.split(" ")[2:]).strip()
return name
else:
Expand All @@ -130,14 +154,23 @@ def run_benchmark_grid(grid, n_repeats, num_inference_steps):
{
"n_samples": (1, 2),
"precision": ("single", "half"),
"autocast" : ("yes", "no")
}
* n_repeats: nb datapoints for inference latency benchmark
"""

csv_fpath = pathlib.Path(__file__).parent.parent / "benchmark_tmp.csv"
# create benchmark.csv if not exists
if not os.path.isfile(csv_fpath):
header = ["device", "precision", "runtime", "n_samples", "latency", "memory"]
header = [
"device",
"precision",
"autocast",
"runtime",
"n_samples",
"latency",
"memory",
]
with open(csv_fpath, "w") as f:
writer = csv.writer(f)
writer.writerow(header)
Expand All @@ -148,48 +181,58 @@ def run_benchmark_grid(grid, n_repeats, num_inference_steps):
device_desc = get_device_description()
for n_samples in grid["n_samples"]:
for precision in grid["precision"]:
for backend in grid["backend"]:
try:
new_log = run_benchmark(
n_repeats=n_repeats,
n_samples=n_samples,
precision=precision,
backend=backend,
num_inference_steps=num_inference_steps
)
except Exception as e:
if "CUDA out of memory" in str(e) or "Failed to allocate memory" in str(e):
print(str(e))
torch.cuda.empty_cache()
new_log = {
"latency": -1.00,
"memory": -1.00
}
else:
raise e

latency = new_log["latency"]
memory = new_log["memory"]
new_row = [device_desc, precision, backend, n_samples, latency, memory]
writer.writerow(new_row)
use_autocast = False
if precision == "half":
for autocast in grid["autocast"]:
if autocast == "yes":
use_autocast = True
for backend in grid["backend"]:
try:
new_log = run_benchmark(
n_repeats=n_repeats,
n_samples=n_samples,
precision=precision,
use_autocast=use_autocast,
backend=backend,
num_inference_steps=num_inference_steps,
)
except Exception as e:
if "CUDA out of memory" in str(
e
) or "Failed to allocate memory" in str(e):
print(str(e))
torch.cuda.empty_cache()
new_log = {"latency": -1.00, "memory": -1.00}
else:
raise e

latency = new_log["latency"]
memory = new_log["memory"]
new_row = [
device_desc,
precision,
autocast,
backend,
n_samples,
latency,
memory,
]
writer.writerow(new_row)


if __name__ == "__main__":

parser = argparse.ArgumentParser()

parser.add_argument(
"--samples",
"--samples",
default="1",
type=str,
help="Comma sepearated list of batch sizes (number of samples)"
type=str,
help="Comma sepearated list of batch sizes (number of samples)",
)

parser.add_argument(
"--steps",
default=50,
type=int,
help="Number of diffusion steps."
"--steps", default=50, type=int, help="Number of diffusion steps."
)

parser.add_argument(
Expand All @@ -199,17 +242,25 @@ def run_benchmark_grid(grid, n_repeats, num_inference_steps):
help="Number of repeats.",
)

parser.add_argument(
"--autocast",
default="no",
type=str,
help="If 'yes', will perform additional runs with autocast activated for half precision inferences",
)

args = parser.parse_args()

grid = {
"n_samples": tuple(map(int, args.samples.split(","))),
# Only use single-precision for cpu because "LayerNormKernelImpl" not implemented for 'Half' on cpu,
"n_samples": tuple(map(int, args.samples.split(","))),
# Only use single-precision for cpu because "LayerNormKernelImpl" not implemented for 'Half' on cpu,
# Remove autocast won't help. Ref:
# https://github.com/CompVis/stable-diffusion/issues/307
"precision": ("single",) if device.type == "cpu" else ("single", "half"),
"autocast": ("no",) if args.autocast == "no" else ("yes", "no"),
# Only use onnx for cpu, until issues are fixed by upstreams. Ref:
# https://github.com/huggingface/diffusers/issues/489#issuecomment-1261577250
# https://github.com/huggingface/diffusers/pull/440
"backend": ("pytorch", "onnx") if device.type == "cpu" else ("pytorch",)
"backend": ("pytorch", "onnx") if device.type == "cpu" else ("pytorch",),
}
run_benchmark_grid(grid, n_repeats=args.repeats, num_inference_steps=args.steps)

0 comments on commit 4c74b72

Please sign in to comment.