-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Open
Labels
torch.compileTorch compile and other relevant tutorialsTorch compile and other relevant tutorials
Description
In the “Demonstrating Speedups” section of the following docs:
https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#demonstrating-speedups
the timing helper is described as returning the elapsed time in seconds, but the implementation divides by 1024 instead of 1000:
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1024However, torch.cuda.Event.elapsed_time returns the elapsed time in milliseconds, so the last line should be:
return result, start.elapsed_time(end) / 1000Metadata
Metadata
Assignees
Labels
torch.compileTorch compile and other relevant tutorialsTorch compile and other relevant tutorials