Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705003784
Change-Id: I311a441e696fff40348979dd77f83dd79efe5060
  • Loading branch information
Brax Team authored and btaba committed Dec 11, 2024
1 parent c59da3f commit fa68be1
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 18 deletions.
9 changes: 6 additions & 3 deletions brax/training/agents/apg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def train(
episode_length: int,
policy_updates: int,
wrap_env: bool = True,
wrap_env_fn: Optional[Callable[[Any], Any]] = None,
horizon_length: int = 32,
num_envs: int = 1,
num_evals: int = 1,
Expand Down Expand Up @@ -117,7 +118,9 @@ def train(

env = environment
if wrap_env:
if isinstance(env, envs.Env):
if wrap_env_fn is not None:
wrap_for_training = wrap_env_fn
elif isinstance(env, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training
Expand All @@ -132,7 +135,7 @@ def train(
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
) # pytype: disable=wrong-keyword-args

reset_fn = jax.jit(jax.vmap(env.reset))
step_fn = jax.jit(jax.vmap(env.step))
Expand Down Expand Up @@ -326,7 +329,7 @@ def training_epoch_with_timing(
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
) # pytype: disable=wrong-keyword-args

evaluator = acting.Evaluator(
eval_env,
Expand Down
9 changes: 6 additions & 3 deletions brax/training/agents/ars/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TrainingState:
def train(
environment: Union[envs_v1.Env, envs.Env],
wrap_env: bool = True,
wrap_env_fn: Optional[Callable[[Any], Any]] = None,
num_timesteps: int = 100,
episode_length: int = 1000,
action_repeat: int = 1,
Expand Down Expand Up @@ -105,7 +106,9 @@ def train(
assert num_envs % local_devices_to_use == 0
env = environment
if wrap_env:
if isinstance(env, envs.Env):
if wrap_env_fn is not None:
wrap_for_training = wrap_env_fn
elif isinstance(env, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training
Expand All @@ -121,7 +124,7 @@ def train(
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
) # pytype: disable=wrong-keyword-args

obs_size = env.observation_size
if isinstance(obs_size, Dict):
Expand Down Expand Up @@ -335,7 +338,7 @@ def training_epoch_with_timing(
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
) # pytype: disable=wrong-keyword-args

# Evaluator function
evaluator = acting.Evaluator(
Expand Down
9 changes: 6 additions & 3 deletions brax/training/agents/es/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class FitnessShaping(enum.Enum):
def train(
environment: Union[envs_v1.Env, envs.Env],
wrap_env: bool = True,
wrap_env_fn: Optional[Callable[[Any], Any]] = None,
num_timesteps: int = 100,
episode_length: int = 1000,
action_repeat: int = 1,
Expand Down Expand Up @@ -133,7 +134,9 @@ def train(
assert num_envs % local_devices_to_use == 0
env = environment
if wrap_env:
if isinstance(env, envs.Env):
if wrap_env_fn is not None:
wrap_for_training = wrap_env_fn
elif isinstance(env, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training
Expand All @@ -149,7 +152,7 @@ def train(
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
) # pytype: disable=wrong-keyword-args

obs_size = env.observation_size
if isinstance(obs_size, Dict):
Expand Down Expand Up @@ -398,7 +401,7 @@ def training_epoch_with_timing(
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
) # pytype: disable=wrong-keyword-args

# Evaluator function
evaluator = acting.Evaluator(
Expand Down
14 changes: 10 additions & 4 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import functools
import time
from typing import Callable, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Mapping, Optional, Tuple, Union

from absl import logging
from brax import base
Expand Down Expand Up @@ -135,6 +135,7 @@ def train(
num_timesteps: int,
episode_length: int,
wrap_env: bool = True,
wrap_env_fn: Optional[Callable[[Any], Any]] = None,
action_repeat: int = 1,
num_envs: int = 1,
max_devices_per_host: Optional[int] = None,
Expand Down Expand Up @@ -177,6 +178,9 @@ def train(
episode_length: the length of an environment episode
wrap_env: If True, wrap the environment for training. Otherwise use the
environment as is.
wrap_env_fn: a custom function that wraps the environment for training.
If not specified, the environment is wrapped with the default training
wrapper.
action_repeat: the number of timesteps to repeat an action
num_envs: the number of parallel environments to use for rollouts
NOTE: `num_envs` must be divisible by the total number of chips since each
Expand Down Expand Up @@ -296,7 +300,9 @@ def train(
v_randomization_fn = functools.partial(
randomization_fn, rng=randomization_rng
)
if isinstance(environment, envs.Env):
if wrap_env_fn is not None:
wrap_for_training = wrap_env_fn
elif isinstance(environment, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training
Expand All @@ -305,7 +311,7 @@ def train(
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
) # pytype: disable=wrong-keyword-args

reset_fn = jax.jit(jax.vmap(env.reset))
key_envs = jax.random.split(key_env, num_envs // process_count)
Expand Down Expand Up @@ -561,7 +567,7 @@ def training_epoch_with_timing(
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
) # pytype: disable=wrong-keyword-args

evaluator = acting.Evaluator(
eval_env,
Expand Down
9 changes: 6 additions & 3 deletions brax/training/agents/sac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def train(
num_timesteps,
episode_length: int,
wrap_env: bool = True,
wrap_env_fn: Optional[Callable[[Any], Any]] = None,
action_repeat: int = 1,
num_envs: int = 1,
num_eval_envs: int = 128,
Expand Down Expand Up @@ -181,7 +182,9 @@ def train(
assert num_envs % device_count == 0
env = environment
if wrap_env:
if isinstance(env, envs.Env):
if wrap_env_fn is not None:
wrap_for_training = wrap_env_fn
elif isinstance(env, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training
Expand All @@ -201,7 +204,7 @@ def train(
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
) # pytype: disable=wrong-keyword-args

obs_size = env.observation_size
if isinstance(obs_size, Dict):
Expand Down Expand Up @@ -508,7 +511,7 @@ def training_epoch_with_timing(
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
) # pytype: disable=wrong-keyword-args

evaluator = acting.Evaluator(
eval_env,
Expand Down
10 changes: 8 additions & 2 deletions brax/training/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,17 @@ def get_env_factory():
if _CUSTOM_WRAP_ENV.value:
pass
else:
wrap_env_fn = None
get_environment = functools.partial(
envs.get_environment, backend=_BACKEND.value
)
return get_environment
return get_environment, wrap_env_fn


def main(unused_argv):
logdir = _LOGDIR.value

get_environment = get_env_factory()
get_environment, wrap_env_fn = get_env_factory()
with metrics.Writer(logdir) as writer:
writer.write_hparams({
'num_evals': _NUM_EVALS.value,
Expand All @@ -207,6 +208,7 @@ def main(unused_argv):
)
make_policy, params, _ = sac.train(
environment=get_environment(_ENV.value),
wrap_env_fn=wrap_env_fn,
num_envs=_NUM_ENVS.value,
action_repeat=_ACTION_REPEAT.value,
normalize_observations=_NORMALIZE_OBSERVATIONS.value,
Expand All @@ -228,6 +230,7 @@ def main(unused_argv):
elif _LEARNER.value == 'es':
make_policy, params, _ = es.train(
environment=get_environment(_ENV.value),
wrap_env_fn=wrap_env_fn,
num_timesteps=_TOTAL_ENV_STEPS.value,
fitness_shaping=es.FitnessShaping[_FITNESS_SHAPING.value.upper()],
population_size=_POPULATION_SIZE.value,
Expand Down Expand Up @@ -255,6 +258,7 @@ def main(unused_argv):
)
make_policy, params, _ = ppo.train(
environment=get_environment(_ENV.value),
wrap_env_fn=wrap_env_fn,
num_timesteps=_TOTAL_ENV_STEPS.value,
episode_length=_EPISODE_LENGTH.value,
network_factory=network_factory,
Expand All @@ -280,6 +284,7 @@ def main(unused_argv):
elif _LEARNER.value == 'apg':
make_policy, params, _ = apg.train(
environment=get_environment(_ENV.value),
wrap_env_fn=wrap_env_fn,
policy_updates=_POLICY_UPDATES.value,
num_envs=_NUM_ENVS.value,
action_repeat=_ACTION_REPEAT.value,
Expand All @@ -295,6 +300,7 @@ def main(unused_argv):
elif _LEARNER.value == 'ars':
make_policy, params, _ = ars.train(
environment=get_environment(_ENV.value),
wrap_env_fn=wrap_env_fn,
number_of_directions=_NUMBER_OF_DIRECTIONS.value,
max_devices_per_host=_MAX_DEVICES_PER_HOST.value,
action_repeat=_ACTION_REPEAT.value,
Expand Down
15 changes: 15 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
[tool.isort]
force_single_line = true
force_sort_within_sections = true
lexicographical = true
single_line_exclusions = ["typing"]
order_by_type = false
group_by_package = true
line_length = 120
use_parentheses = true
multi_line_output = 3
skip_glob = ["**/*.ipynb"]

[tool.pyink]
line-length = 80
unstable = true
pyink-indentation = 2
pyink-use-majority-quotes = true
extend-exclude = '''(
.ipynb$
)'''

0 comments on commit fa68be1

Please sign in to comment.