Skip to content

Commit 37a71d6

Browse files
Jake VanderPlascopybara-github
Jake VanderPlas
authored andcommitted
Fix issues related to new behavior of JAX DeviceArray.copy()
In jax-ml/jax#10069, JAX changes the behavior of DeviceArray.copy() so that it returns a DeviceArray rather than returning a numpy array. For converting a DeviceArray to numpy, the preferred method is now np.asarray(device_array). PiperOrigin-RevId: 438711926
1 parent 83252f9 commit 37a71d6

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

trax/rl/task.py

+1
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def play(env, policy, dm_suite=False, max_steps=None, last_observation=None):
271271
cur_trajectory = Trajectory(last_observation)
272272
while not done and (max_steps is None or cur_step < max_steps):
273273
action, dist_inputs = policy(cur_trajectory)
274+
action = np.asarray(action)
274275
step = env.step(action)
275276
if dm_suite:
276277
(observation, reward, done) = (

0 commit comments

Comments
 (0)