Skip to content

Commit d9859d1

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4c71f8c commit d9859d1

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

docs/examples/quickstart_jax_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,13 @@ def speedometer(
3838
key = dropout_key
3939
for _ in range(warmup_iters):
4040
key, step_key = jax.random.split(key)
41-
loss, (param_grads, other_grads) = train_step_fn(
42-
variables, input, output_grad, step_key
43-
)
41+
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
4442

4543
# Timing runs
4644
start = time.time()
4745
for _ in range(timing_iters):
4846
key, step_key = jax.random.split(key)
49-
loss, (param_grads, other_grads) = train_step_fn(
50-
variables, input, output_grad, step_key
51-
)
47+
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
5248
end = time.time()
5349

5450
print(f"Mean time: {(end - start) * 1000 / timing_iters} ms")

0 commit comments

Comments
 (0)