Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700718610
Change-Id: I9044c43a9d02974269a1cb84d70869b529dd13f5
  • Loading branch information
Brax Team authored and btaba committed Nov 27, 2024
1 parent 300b107 commit 9ede872
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 181 deletions.
2 changes: 1 addition & 1 deletion brax/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def dumps(sys: System, states: List[State]) -> Text:
for id_ in range(sys.ngeom):
link_idx = sys.geom_bodyid[id_] - 1

rgba = sys.geom_rgba[id_]
rgba = sys.mj_model.geom_rgba[id_]
if (rgba == [0.5, 0.5, 0.5, 1.0]).all():
# convert the default mjcf color to brax default color
rgba = np.array([0.4, 0.33, 0.26, 1.0])
Expand Down
5 changes: 2 additions & 3 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,13 @@ def train(
preprocess_observations_fn=normalize)
make_policy = ppo_networks.make_inference_fn(ppo_network)

optimizer = optax.adam(learning_rate=learning_rate)
if max_grad_norm is not None:
# TODO(btaba): Move gradient clipping to `training/gradients.py`.
# TODO: Move gradient clipping to `training/gradients.py`.
optimizer = optax.chain(
optax.clip_by_global_norm(max_grad_norm),
optax.adam(learning_rate=learning_rate)
)
else:
optimizer = optax.adam(learning_rate=learning_rate)

loss_fn = functools.partial(
ppo_losses.compute_ppo_loss,
Expand Down
Loading

0 comments on commit 9ede872

Please sign in to comment.