Skip to content

[test] Foriloop #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
119 changes: 100 additions & 19 deletions examples/text_to_image/inference_tpu_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def parser(args):
parser.add_argument(
'--batch-size',
type=int,
default=8,
default=2, # 8,
help='Number of images to generate'
)

Expand All @@ -26,7 +26,7 @@ def parser(args):
parser.add_argument(
'--inf-steps',
type=int,
default=30,
default=2, # 30,
help='Number of itterations to run the benchmark.'
)

Expand All @@ -35,33 +35,114 @@ def parser(args):

def main(args):
server = xp.start_server(9012)
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-0.9",
use_safetensors=True,
)
# pipe = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-0.9",
# use_safetensors=True,
# )
device = xm.xla_device()
pipe.to(device)
# pipe.to(device)

bs = args.batch_size
inference_steps = args.inf_steps
height = width = args.width
bs = args.batch_size # 1
inference_steps = args.inf_steps # 2
height = width = args.width # 512

prompts = ["a photo of an astronaut riding a horse on mars"] * bs
print(f'batch size = {bs}, inference steps = {inference_steps}',
f'height = width = {width}',
flush=True
)

pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-0.9",
use_safetensors=True,
)
pipe.to(device)

iters = 15
print('starting inference', flush=True)
start2 = time()
iters = 3
for i in range(iters):
start = time()
image = pipe(prompts,
num_inference_steps=inference_steps,
height=height,
width=width,
).images[0]
print(f'Step {i} inference time {time()-start} sec', flush=True)
# pipe2 = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-0.9",
# use_safetensors=True,
# )
# pipe2.to(device)
image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts,
num_inference_steps=2, # inference_steps,
height=512, # height,
width=512, # width,
).images[0]
print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True)


import torch
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
def cond_fn(init, limit_value):
return limit_value[0] <= init[0]

def body_fn(init, limit_value):
# one_value = torch.ones(1, dtype=torch.int32, device=device)
# two_value = limit_value.clone()
# start = time()
# pipe = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-0.9",
# use_safetensors=True,
# )
# # device = xm.xla_device()
# pipe.to(device)
image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts,
num_inference_steps=2, # inference_steps,
height=512, # height,
width=512, # width,
).images[0]
# image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts,
# num_inference_steps=2, # inference_steps,
# height=512, # height,
# width=512, # width,
# ).images[0]
# print("type of image: ", type(image))
# print(f'Step {i} inference time {time()-start} sec', flush=True)
one_value = torch.ones(1, dtype=torch.int32, device=device)
two_value = limit_value.clone()
return (torch.sub(init, one_value), two_value)

start = time()
# iters = 3
init = torch.tensor([3], dtype=torch.int32, device=device)
limit_value = torch.tensor([0], dtype=torch.int32, device=device)
# res = while_loop(cond_fn, body_fn, (init, limit_value))
from torch_xla.experimental.fori_loop import _xla_while_loop
res = _xla_while_loop(cond_fn, body_fn, (init, limit_value))
print(f'Call pipeline with _xla_while_loop for three times used {time()-start} sec', flush=True)
print("result of while_loop: ", res)
# expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
# self.assertEqual(expected, res)

# start2 = time()
# iters = 3
# for i in range(iters):
# # pipe2 = DiffusionPipeline.from_pretrained(
# # "stabilityai/stable-diffusion-xl-base-0.9",
# # use_safetensors=True,
# # )
# pipe2.to(device)
# image2 = pipe2(["a photo of an astronaut riding a horse on mars"], # prompts,
# num_inference_steps=2, # inference_steps,
# height=512, # height,
# width=512, # width,
# ).images[0]
# print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True)

# iters = 1 # 15
# print('starting inference', flush=True)
# for i in range(iters):
# start = time()
# image = pipe(prompts,
# num_inference_steps=inference_steps,
# height=height,
# width=width,
# ).images[0]
# print(f'Step {i} inference time {time()-start} sec', flush=True)


if __name__ == '__main__':
Expand Down