Skip to content

Commit 3abb55c

Browse files
committed
add support test script for jax install
1 parent b1b1826 commit 3abb55c

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

scripts/test-jax-install.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
devices = jax.devices()
5+
print(f"The available devices are: {devices}")
6+
7+
@jax.jit
8+
def matrix_multiply(a, b):
9+
return jnp.dot(a, b)
10+
11+
# Example usage:
12+
key = jax.random.PRNGKey(0)
13+
x = jax.random.normal(key, (1000, 1000))
14+
y = jax.random.normal(key, (1000, 1000))
15+
z = matrix_multiply(x, y)
16+
17+
# Now the function is JIT compiled and will likely run on GPU (if available)
18+
print(z)
19+
20+
devices = jax.devices()
21+
print(f"The available devices are: {devices}")

0 commit comments

Comments
 (0)