diff --git a/alf/algorithms/merlin_algorithm.py b/alf/algorithms/merlin_algorithm.py index 4f3492442..d121b11e2 100644 --- a/alf/algorithms/merlin_algorithm.py +++ b/alf/algorithms/merlin_algorithm.py @@ -543,6 +543,8 @@ def __init__(self, enc_layers.append(res_block) in_channels = 64 + if output_activation is None: + output_activation = alf.math.identity enc_layers.extend([ nn.Flatten(), alf.layers.FC( diff --git a/alf/config_util.py b/alf/config_util.py index 6eccaf5fe..267f58986 100644 --- a/alf/config_util.py +++ b/alf/config_util.py @@ -337,7 +337,11 @@ def pre_config(configs): try: config1(name, value, mutable=False) _HANDLED_PRE_CONFIGS.append((name, value)) - except ValueError: + except ValueError as e: + # Most of the times, for command line flags, this warning is a false alarm. + # This can be useful in other failures, e.g. when the Config has already been used, + # before configuring its value. + logging.warning("pre_config potential error: %s", e) _PRE_CONFIGS.append((name, value)) diff --git a/alf/data_structures.py b/alf/data_structures.py index 9cd820990..7f72b6c55 100644 --- a/alf/data_structures.py +++ b/alf/data_structures.py @@ -279,7 +279,8 @@ def _generate_time_step(batched, if env_id is None: env_id = md.arange(batch_size, dtype=md.int32) if reward is not None: - assert reward.shape[:1] == outer_dims + assert reward.shape[:1] == outer_dims, "%s, %s" % (reward.shape, + outer_dims) if prev_action is not None: flat_action = nest.flatten(prev_action) assert flat_action[0].shape[:1] == outer_dims diff --git a/alf/networks/critic_networks.py b/alf/networks/critic_networks.py index 72a910cc2..6978188a4 100644 --- a/alf/networks/critic_networks.py +++ b/alf/networks/critic_networks.py @@ -77,6 +77,7 @@ def __init__(self, joint_fc_layer_params=None, activation=torch.relu_, kernel_initializer=None, + last_bias_init_value=0.0, use_fc_bn=False, use_naive_parallel_network=False, name="CriticNetwork"): @@ -174,7 +175,8 @@ def __init__(self, last_activation=math_ops.identity, use_fc_bn=use_fc_bn, last_kernel_initializer=last_kernel_initializer, - name=name) + last_bias_init_value=last_bias_init_value, + name=name + ".joint_encoder") self._use_naive_parallel_network = use_naive_parallel_network def make_parallel(self, n): diff --git a/alf/networks/encoding_networks.py b/alf/networks/encoding_networks.py index 479401cc0..42329a7b4 100644 --- a/alf/networks/encoding_networks.py +++ b/alf/networks/encoding_networks.py @@ -405,6 +405,7 @@ def __init__(self, last_layer_size=None, last_activation=None, last_kernel_initializer=None, + last_bias_init_value=0.0, last_use_fc_bn=False, name="EncodingNetwork"): """ @@ -540,7 +541,8 @@ def __init__(self, last_layer_size, activation=last_activation, use_bn=last_use_fc_bn, - kernel_initializer=last_kernel_initializer)) + kernel_initializer=last_kernel_initializer, + bias_init_value=last_bias_init_value)) input_size = last_layer_size if output_tensor_spec is not None: diff --git a/alf/trainers/policy_trainer.py b/alf/trainers/policy_trainer.py index 8a9e078ac..2e9c27dc7 100644 --- a/alf/trainers/policy_trainer.py +++ b/alf/trainers/policy_trainer.py @@ -498,6 +498,7 @@ def __init__(self, config: TrainerConfig, ddp_rank: int = -1): logging.info( "observation_spec=%s" % pprint.pformat(env.observation_spec())) logging.info("action_spec=%s" % pprint.pformat(env.action_spec())) + logging.info("reward_spec=%s" % pprint.pformat(env.reward_spec())) # for offline buffer construction untransformed_observation_spec = env.observation_spec() diff --git a/alf/utils/external_configurables.py b/alf/utils/external_configurables.py index 5a00b7e2e..e9e6cda84 100644 --- a/alf/utils/external_configurables.py +++ b/alf/utils/external_configurables.py @@ -46,3 +46,5 @@ gin.external_configurable(torch.nn.init.xavier_normal_, 'torch.nn.init.xavier_normal_') +gin.external_configurable(torch.nn.Embedding, 'torch.nn.Embedding') +gin.external_configurable(torch.nn.Sequential, 'torch.nn.Sequential') diff --git a/alf/utils/normalizers.py b/alf/utils/normalizers.py index 9ef8e2ca0..9f01e7560 100644 --- a/alf/utils/normalizers.py +++ b/alf/utils/normalizers.py @@ -138,7 +138,11 @@ def _summary(name, val): def _summarize_all(path, t, m2, m): if path: path += "." - spec = TensorSpec.from_tensor(m if m2 is None else m2) + if m2 is not None: + spec = TensorSpec.from_tensor(m2) + else: + assert m is not None + spec = TensorSpec.from_tensor(m) _summary(path + "tensor.batch_min", _reduce_along_batch_dims(t, spec, torch.min)) _summary(path + "tensor.batch_max",