diff --git a/alf/algorithms/ddpg_algorithm.py b/alf/algorithms/ddpg_algorithm.py index 929246efc..681d59b93 100644 --- a/alf/algorithms/ddpg_algorithm.py +++ b/alf/algorithms/ddpg_algorithm.py @@ -42,11 +42,18 @@ DdpgActorState = namedtuple( "DdpgActorState", ['actor', 'critics'], default_value=()) DdpgState = namedtuple( - "DdpgState", ['actor', 'critics', 'noise'], default_value=()) + "DdpgState", ['actor', 'critics', 'noise', 'ensemble_ids'], + default_value=()) DdpgInfo = namedtuple( "DdpgInfo", [ - "reward", "step_type", "discount", "action", "action_distribution", - "actor_loss", "critic", "discounted_return" + "reward", + "step_type", + "discount", + "action", + "action_distribution", + "actor_loss", + "critic", + "discounted_return", ], default_value=()) DdpgLossInfo = namedtuple('DdpgLossInfo', ('actor', 'critic')) @@ -74,13 +81,18 @@ def __init__(self, config: TrainerConfig = None, ou_stddev=0.2, ou_damping=0.15, + noise_clipping=None, critic_loss_ctor=None, num_critic_replicas=1, target_update_tau=0.05, target_update_period=1, + actor_update_period=1, rollout_random_action=0., dqda_clipping=None, action_l2=0, + use_batch_ensemble=False, + ensemble_size=10, + input_with_ensemble_ids=False, actor_optimizer=None, critic_optimizer=None, checkpoint=None, @@ -124,12 +136,16 @@ def __init__(self, (OU) noise added in the default collect policy. ou_damping (float): Damping factor for the OU noise added in the default collect policy. + noise_clipping (float): when computing the action noise, clips the + noise element-wise between ``[-noise_clipping, noise_clipping]``. + Does not perform clipping if ``noise_clipping == 0``. critic_loss_ctor (None|OneStepTDLoss|MultiStepLoss): a critic loss constructor. If ``None``, a default ``OneStepTDLoss`` will be used. target_update_tau (float): Factor for soft update of the target networks. target_update_period (int): Period for soft update of the target networks. + actor_update_period (int): Period for update of the actor_network. rollout_random_action (float): the probability of taking a uniform random action during a ``rollout_step()``. 0 means always directly taking actions added with OU noises and 1 means always sample @@ -139,6 +155,17 @@ def __init__(self, gradient dqda element-wise between ``[-dqda_clipping, dqda_clipping]``. Does not perform clipping if ``dqda_clipping == 0``. action_l2 (float): weight of squared action l2-norm on actor loss. + use_batch_ensemble (bool): whether to use BatchEnsemble FC and Conv2D + layers. If True, both BatchEnsemble layers will always be created + with ``output_ensemble_ids=True``, and as a result, the output of + the network is a tuple with ensemble_ids. + ensemble_size (int): ensemble size, only effective if use_batch_ensemble + is True. + input_with_ensemble_ids (bool): whether handle inputs with ensemble_ids, + if True, input to the network should be a tuple of two tensors, the + first one is the input data tensor and the second one is the + ensemble_ids. This option is only effective if use_batch_ensemble + is True. actor_optimizer (torch.optim.optimizer): The optimizer for actor. critic_optimizer (torch.optim.optimizer): The optimizer for critic. checkpoint (None|str): a string in the format of "prefix@path", @@ -148,6 +175,10 @@ def __init__(self, debug_summaries (bool): True if debug summaries should be created. name (str): The name of this algorithm. """ + if use_batch_ensemble: + assert config.use_rollout_state, ( + 'use_rollout_state needs to be True when use_batch_ensemble.') + self._calculate_priority = calculate_priority if epsilon_greedy is None: epsilon_greedy = alf.utils.common.get_epsilon_greedy(config) @@ -155,9 +186,16 @@ def __init__(self, critic_network = critic_network_ctor( input_tensor_spec=(observation_spec, action_spec), - output_tensor_spec=reward_spec) + output_tensor_spec=reward_spec, + use_batch_ensemble=use_batch_ensemble, + ensemble_size=ensemble_size, + input_with_ensemble_ids=input_with_ensemble_ids) actor_network = actor_network_ctor( - input_tensor_spec=observation_spec, action_spec=action_spec) + input_tensor_spec=observation_spec, + action_spec=action_spec, + use_batch_ensemble=use_batch_ensemble, + ensemble_size=ensemble_size, + input_with_ensemble_ids=input_with_ensemble_ids) critic_networks = critic_network.make_parallel(num_critic_replicas) @@ -176,6 +214,8 @@ def __init__(self, train_state_spec = DdpgState( noise=noise_state, + ensemble_ids=TensorSpec( + (), dtype=torch.int64) if use_batch_ensemble else (), actor=DdpgActorState( actor=actor_network.state_spec, critics=critic_networks.state_spec), @@ -221,7 +261,11 @@ def __init__(self, self._critic_losses[i] = critic_loss_ctor( name=("critic_loss" + str(i))) + self._use_batch_ensemble = use_batch_ensemble self._noise_process = noise_process + self._noise_clipping = noise_clipping + self._actor_update_period = actor_update_period + self._train_step_count = 0 self._update_target = common.TargetUpdater( models=[self._actor_network, self._critic_networks], @@ -239,6 +283,9 @@ def predict_step(self, inputs: TimeStep, state): def _predict_step(self, time_step: TimeStep, state, epsilon_greedy=1.): action, actor_state = self._actor_network( time_step.observation, state=state.actor.actor) + if self._use_batch_ensemble: + ensemble_ids = action[1] + action = action[0] empty_state = nest.map_structure(lambda x: (), self.rollout_state_spec) def _sample(a, noise): @@ -252,11 +299,15 @@ def _sample(a, noise): return a noise, noise_state = self._noise_process(state.noise) + if self._noise_clipping: + noise = torch.clamp(noise, -self._noise_clipping, + self._noise_clipping) noisy_action = nest.map_structure(_sample, action, noise) noisy_action = nest.map_structure(spec_utils.clip_to_spec, noisy_action, self._action_spec) state = empty_state._replace( noise=noise_state, + ensemble_ids=ensemble_ids if self._use_batch_ensemble else (), actor=DdpgActorState(actor=actor_state, critics=())) return AlgStep( @@ -265,9 +316,12 @@ def _sample(a, noise): info=DdpgInfo(action=noisy_action, action_distribution=action)) def rollout_step(self, time_step: TimeStep, state: DdpgState = None): - if self.need_full_rollout_state(): - raise NotImplementedError("Storing RNN state to replay buffer " - "is not supported by DdpgAlgorithm") + """``rollout_step()`` basically predicts actions like what is done by + ``predict_step()``. Additionally, if states are to be stored a in replay + buffer, then this function also call ``_critic_networks``, + ``_target_critic_networks``, and ``_target_actor_network`` to maintain + their states. + """ def _update_random_action(spec, noisy_action): random_action = spec_utils.scale_to_spec( @@ -277,18 +331,52 @@ def _update_random_action(spec, noisy_action): _rollout_random_action) noisy_action[ind[0], :] = random_action[ind[0], :] + observation = time_step.observation + if self._use_batch_ensemble and torch.count_nonzero( + state.ensemble_ids) > 0: + # If use_batch_ensemble, we want to use the same ensemble_ids + # to forward the actor_network during the rollout of an episode, + # except for the initial rollout_step, where the ensemble_ids + # in the initial rollout_state are all zeros. + time_step = time_step._replace( + observation=(observation, state.ensemble_ids)) pred_step = self._predict_step(time_step, state, epsilon_greedy=1.0) if self._rollout_random_action > 0: nest.map_structure(_update_random_action, self._action_spec, pred_step.output) - return pred_step + + if self.need_full_rollout_state(): + _, critics_state = self._critic_networks( + (observation, pred_step.output), state.critics.critics) + _, target_critics_state = self._target_critic_networks( + (observation, pred_step.output), state.critics.target_critics) + _, target_actor_state = self._target_actor_network( + observation, state=state.critics.target_actor) + critic_state = DdpgCriticState( + critics=critics_state, + target_actor=target_actor_state, + target_critics=target_critics_state) + else: + critics_state = state.critics.critics + critic_state = state.critics + + actor_state = pred_step.state.actor._replace(critics=critics_state) + + new_state = pred_step.state._replace( + actor=actor_state, critics=critic_state) + + return pred_step._replace(state=new_state) def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState, rollout_info: DdpgInfo): target_action, target_actor_state = self._target_actor_network( inputs.observation, state=state.target_actor) + if self._use_batch_ensemble: + target_action = target_action[0] target_q_values, target_critic_states = self._target_critic_networks( (inputs.observation, target_action), state=state.target_critics) + if self._use_batch_ensemble: + target_q_values = target_q_values[0] if self.has_multidim_reward(): sign = self.reward_weights.sign() @@ -298,6 +386,8 @@ def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState, q_values, critic_states = self._critic_networks( (inputs.observation, rollout_info.action), state=state.critics) + if self._use_batch_ensemble: + q_values = q_values[0] state = DdpgCriticState( critics=critic_states, @@ -312,9 +402,13 @@ def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState, def _actor_train_step(self, inputs: TimeStep, state: DdpgActorState): action, actor_state = self._actor_network( inputs.observation, state=state.actor) + if self._use_batch_ensemble: + action = action[0] q_values, critic_states = self._critic_networks( (inputs.observation, action), state=state.critics) + if self._use_batch_ensemble: + q_values = q_values[0] if self.has_multidim_reward(): # Multidimensional reward: [B, replicas, reward_dim] q_values = q_values * self.reward_weights @@ -343,9 +437,22 @@ def actor_loss_fn(dqda, action): def train_step(self, inputs: TimeStep, state: DdpgState, rollout_info: DdpgInfo): + self._train_step_count += 1 critic_states, critic_info = self._critic_train_step( inputs=inputs, state=state.critics, rollout_info=rollout_info) - policy_step = self._actor_train_step(inputs=inputs, state=state.actor) + if self._train_step_count % self._actor_update_period == 0: + policy_step = self._actor_train_step( + inputs=inputs, state=state.actor) + critic_states = critic_states._replace( + critics=policy_step.state.critics) + else: + batch_dims = nest_utils.get_outer_rank(inputs.prev_action, + self._action_spec) + loss = torch.zeros(*inputs.prev_action.shape[:batch_dims]) + policy_step = AlgStep( + output=torch.zeros_like(inputs.prev_action), + state=state.actor, + info=LossInfo(loss=loss, extra=loss)) return policy_step._replace( state=state._replace( actor=policy_step.state, critics=critic_states), diff --git a/alf/algorithms/ddpg_algorithm_test.py b/alf/algorithms/ddpg_algorithm_test.py index 05c2dcd76..48f39f9b7 100644 --- a/alf/algorithms/ddpg_algorithm_test.py +++ b/alf/algorithms/ddpg_algorithm_test.py @@ -33,9 +33,14 @@ class DDPGAlgorithmTest(parameterized.TestCase, alf.test.TestCase): - @parameterized.parameters((1, 1, None), (2, 3, [1, 2, 3])) - def test_ddpg_algorithm(self, num_critic_replicas, reward_dim, - reward_weights): + @parameterized.parameters((1, 1, None, 2), (1, 1, None, 2, True), + (2, 3, [1, 2, 3])) + def test_ddpg_algorithm(self, + num_critic_replicas, + reward_dim, + reward_weights, + actor_update_period=1, + use_batch_ensemble=False): num_env = 128 num_eval_env = 100 steps_per_episode = 13 @@ -45,6 +50,7 @@ def test_ddpg_algorithm(self, num_critic_replicas, reward_dim, mini_batch_length=2, mini_batch_size=128, initial_collect_steps=steps_per_episode, + use_rollout_state=use_batch_ensemble, whole_replay_buffer_training=False, clear_replay_buffer=False, ) @@ -86,6 +92,9 @@ def test_ddpg_algorithm(self, num_critic_replicas, reward_dim, env=env, config=config, num_critic_replicas=num_critic_replicas, + actor_update_period=actor_update_period, + use_batch_ensemble=use_batch_ensemble, + ensemble_size=3, actor_optimizer=alf.optimizers.Adam(lr=1e-2), critic_optimizer=alf.optimizers.Adam(lr=1e-2), debug_summaries=False, diff --git a/alf/networks/actor_networks.py b/alf/networks/actor_networks.py index b7ec21e3e..9208bd01d 100644 --- a/alf/networks/actor_networks.py +++ b/alf/networks/actor_networks.py @@ -27,6 +27,7 @@ import alf.nest as nest from alf.initializers import variance_scaling_init from alf.networks import Network +from alf.networks.containers import Parallel from alf.tensor_specs import TensorSpec, BoundedTensorSpec from alf.utils import common, math_ops, spec_utils @@ -85,10 +86,21 @@ def __init__(self, a=-0.003, b=0.003) self._action_layers = nn.ModuleList() self._squashing_func = squashing_func + fc_layer_ctor = layers.FC + encoder_output_spec = self._encoding_net.output_spec + self._use_batch_ensemble = encoder_kwargs.get('use_batch_ensemble', + False) + if self._use_batch_ensemble: + encoder_output_spec = encoder_output_spec[0] + fc_layer_ctor = functools.partial( + layers.FCBatchEnsemble, + ensemble_size=encoder_kwargs.get('ensemble_size', 10), + output_ensemble_ids=False) + for single_action_spec in flat_action_spec: self._action_layers.append( - layers.FC( - self._encoding_net.output_spec.shape[0], + fc_layer_ctor( + encoder_output_spec.shape[0], single_action_spec.shape[0], kernel_initializer=last_kernel_initializer)) @@ -134,6 +146,11 @@ def forward(self, observation, state=()): i += 1 output_actions = nest.pack_sequence_as(self._action_spec, actions) + if self._use_batch_ensemble: + # note that when use_batch_ensemble, EncodingNetwork always + # outputs a tuple (output_tensor, ensemble_ids) + output_actions = (output_actions, encoded_obs[1]) + return output_actions, state @property @@ -149,12 +166,16 @@ def __init__(self, input_tensor_spec: TensorSpec, action_spec: BoundedTensorSpec, input_preprocessors=None, + input_preprocessors_ctor=None, preprocessing_combiner=None, conv_layer_params=None, fc_layer_params=None, activation=torch.relu_, squashing_func=torch.tanh, kernel_initializer=None, + use_batch_ensemble=False, + ensemble_size=10, + input_with_ensemble_ids=False, name="ActorNetwork"): """Creates an instance of ``ActorNetwork``, which maps the inputs to actions (single or nested) through a sequence of deterministic layers. @@ -189,6 +210,17 @@ def __init__(self, kernel_initializer (Callable): initializer for all the layers but the last layer. If none is provided a ``variance_scaling_initializer`` with uniform distribution will be used. + use_batch_ensemble (bool): whether to use BatchEnsemble FC and Conv2D + layers. If True, both BatchEnsemble layers will always be created + with ``output_ensemble_ids=True``, and as a result, the output of + the network is a tuple of (outputs, ensemble_ids). + ensemble_size (int): ensemble size, only effective if use_batch_ensemble + is True. + input_with_ensemble_ids (bool): whether handle inputs with ensemble_ids, + if True, input to the network should be a tuple of two tensors, the + first one is the input data tensor and the second one is the + ensemble_ids. This option is only effective if use_batch_ensemble + is True. name (str): name of the network """ super(ActorNetwork, self).__init__( @@ -202,7 +234,10 @@ def __init__(self, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params, activation=activation, - kernel_initializer=kernel_initializer) + kernel_initializer=kernel_initializer, + use_batch_ensemble=use_batch_ensemble, + ensemble_size=ensemble_size, + input_with_ensemble_ids=input_with_ensemble_ids) @alf.configurable diff --git a/alf/networks/actor_networks_test.py b/alf/networks/actor_networks_test.py index 519ed86ae..d60fdb06c 100644 --- a/alf/networks/actor_networks_test.py +++ b/alf/networks/actor_networks_test.py @@ -23,7 +23,7 @@ class ActorNetworkTest(alf.test.TestCase, parameterized.TestCase): - def _init(self, lstm_hidden_size): + def _init(self, lstm_hidden_size, use_batch_ensemble=False): if lstm_hidden_size is not None: actor_fc_layer_params = (6, 4) network_ctor = functools.partial( @@ -40,12 +40,14 @@ def _init(self, lstm_hidden_size): ), dtype=torch.float32), ) * 2) state.append(()) else: - network_ctor = actor_network.ActorNetwork + network_ctor = functools.partial( + actor_network.ActorNetwork, + use_batch_ensemble=use_batch_ensemble) state = () return network_ctor, state - @parameterized.parameters((100, ), (None, ), ((200, 100), )) - def test_actor_networks(self, lstm_hidden_size): + @parameterized.parameters((100, ), (None, ), (None, True), ((200, 100), )) + def test_actor_networks(self, lstm_hidden_size, use_batch_ensemble=False): obs_spec = TensorSpec((3, 20, 20), torch.float32) action_spec = BoundedTensorSpec((5, ), torch.float32, 2., 3.) conv_layer_params = ((8, 3, 1), (16, 3, 2, 1)) @@ -53,7 +55,7 @@ def test_actor_networks(self, lstm_hidden_size): image = obs_spec.zeros(outer_dims=(1, )) - network_ctor, state = self._init(lstm_hidden_size) + network_ctor, state = self._init(lstm_hidden_size, use_batch_ensemble) actor_net = network_ctor( obs_spec, @@ -62,6 +64,8 @@ def test_actor_networks(self, lstm_hidden_size): fc_layer_params=fc_layer_params) action, state = actor_net(image, state) + if use_batch_ensemble: + action = action[0] # (batch_size, num_actions) self.assertEqual(action.shape, (1, 5)) diff --git a/alf/networks/critic_networks.py b/alf/networks/critic_networks.py index eb050593b..19fb60957 100644 --- a/alf/networks/critic_networks.py +++ b/alf/networks/critic_networks.py @@ -21,6 +21,8 @@ import alf import alf.utils.math_ops as math_ops import alf.nest as nest +from alf.networks.containers import Parallel, Sequential +from alf.networks.network import NetworkWrapper from alf.initializers import variance_scaling_init from alf.tensor_specs import TensorSpec @@ -82,6 +84,9 @@ def __init__(self, kernel_initializer=None, use_fc_bn=False, use_fc_ln=False, + use_batch_ensemble=False, + ensemble_size=10, + input_with_ensemble_ids=False, last_use_fc_bn=False, last_use_fc_ln=False, last_layer_activation=math_ops.identity, @@ -132,6 +137,23 @@ def __init__(self, FC layers (i.e. FC layers beside the last one). use_fc_ln (bool): whether use Layer Normalization for the internal FC layers (i.e. FC layers beside the last one). + use_batch_ensemble (bool): whether to use BatchEnsemble FC and Conv2D + layers. If True, both BatchEnsemble layers will always be created + with ``output_ensemble_ids=True``, and as a result, the output of + the network is a tuple of (outputs, ensemble_ids). + ensemble_size (int): ensemble size, only effective if use_batch_ensemble + is True. + input_with_ensemble_ids (bool): whether handle inputs with ensemble_ids, + if True, input to the network should be a tuple of two tensors, the + first one is the input data tensor and the second one is the + ensemble_ids. This option is only effective if use_batch_ensemble + is True. + last_use_fc_bn (bool): whether use Batch Normalization for the last + fc layer. + last_use_fc_ln (bool): whether use Layer Normalization for the last + fc layer. + last_activation (nn.functional): activation function of the + additional layer specified by ``output_tensor_spec.numel``. use_naive_parallel_network (bool): if True, will use ``NaiveParallelNetwork`` when ``make_parallel`` is called. This might be useful in cases when the ``NaiveParallelNetwork`` @@ -160,6 +182,9 @@ def __init__(self, kernel_initializer=kernel_initializer, use_fc_bn=use_fc_bn, use_fc_ln=use_fc_ln, + use_batch_ensemble=use_batch_ensemble, + ensemble_size=ensemble_size, + input_with_ensemble_ids=input_with_ensemble_ids, name=name + ".obs_encoder") _check_action_specs_for_critic_networks(action_spec, @@ -175,13 +200,39 @@ def __init__(self, kernel_initializer=kernel_initializer, use_fc_bn=use_fc_bn, use_fc_ln=use_fc_ln, + use_batch_ensemble=use_batch_ensemble, + ensemble_size=ensemble_size, + input_with_ensemble_ids=input_with_ensemble_ids, name=name + ".action_encoder") last_kernel_initializer = functools.partial( torch.nn.init.uniform_, a=-0.003, b=0.003) if observation_action_combiner is None: - observation_action_combiner = alf.layers.NestConcat(dim=-1) + if use_batch_ensemble: + obs_spec = obs_encoder.output_spec + action_spec = action_encoder.output_spec + obs_action_spec = (obs_spec, action_spec) + + def _obs_action_combiner(inputs): + obs, action = inputs + ensemble_ids = None + if isinstance(obs_spec, tuple): + ensemble_ids = obs[1] + obs = obs[0] + if isinstance(action_spec, tuple): + if ensemble_ids is None: + ensemble_ids = action[1] + action = action[0] + outputs = alf.layers.NestConcat(dim=-1)((obs, action)) + if ensemble_ids is not None: + outputs = (outputs, ensemble_ids) + return outputs + + observation_action_combiner = NetworkWrapper( + _obs_action_combiner, obs_action_spec) + else: + observation_action_combiner = alf.layers.NestConcat(dim=-1) super().__init__( input_tensor_spec=input_tensor_spec, @@ -193,13 +244,18 @@ def __init__(self, kernel_initializer=kernel_initializer, use_fc_bn=use_fc_bn, use_fc_ln=use_fc_ln, + use_batch_ensemble=use_batch_ensemble, + ensemble_size=ensemble_size, + # when use_batch_ensemble, ensemble_ids of inputs should be handled + # already by input_preprocessors and preprocessing_combiner + input_with_ensemble_ids=False, last_layer_size=output_tensor_spec.numel, last_activation=last_layer_activation, last_kernel_initializer=last_kernel_initializer, last_use_fc_bn=last_use_fc_bn, last_use_fc_ln=last_use_fc_ln, name=name) - self._use_naive_parallel_network = use_naive_parallel_network + self._use_naive_parallel_network = use_naive_parallel_network or use_batch_ensemble def make_parallel(self, n): """Create a parallel critic network using ``n`` replicas of ``self``. diff --git a/alf/networks/critic_networks_test.py b/alf/networks/critic_networks_test.py index 5befc71f4..e3fcaa020 100644 --- a/alf/networks/critic_networks_test.py +++ b/alf/networks/critic_networks_test.py @@ -29,7 +29,7 @@ class CriticNetworksTest(parameterized.TestCase, alf.test.TestCase): - def _init(self, lstm_hidden_size): + def _init(self, lstm_hidden_size, use_batch_ensemble=False): if lstm_hidden_size is not None: post_rnn_fc_layer_params = (6, 4) network_ctor = functools.partial( @@ -45,12 +45,13 @@ def _init(self, lstm_hidden_size): size, ), dtype=torch.float32), ) * 2) else: - network_ctor = CriticNetwork + network_ctor = functools.partial( + CriticNetwork, use_batch_ensemble=use_batch_ensemble) state = () return network_ctor, state - @parameterized.parameters((100, ), (None, ), ((200, 100), )) - def test_critic(self, lstm_hidden_size): + @parameterized.parameters((100, ), (None, ), (None, True), ((200, 100), )) + def test_critic(self, lstm_hidden_size, use_batch_ensemble=False): obs_spec = TensorSpec((3, 20, 20), torch.float32) action_spec = TensorSpec((5, ), torch.float32) input_spec = (obs_spec, action_spec) @@ -64,7 +65,7 @@ def test_critic(self, lstm_hidden_size): network_input = (image, action) - network_ctor, state = self._init(lstm_hidden_size) + network_ctor, state = self._init(lstm_hidden_size, use_batch_ensemble) critic_net = network_ctor( input_spec, @@ -74,13 +75,19 @@ def test_critic(self, lstm_hidden_size): test_net_copy(critic_net) value, state = critic_net._test_forward() + if use_batch_ensemble: + value = value[0] self.assertEqual(value.shape, (2, )) if lstm_hidden_size is None: self.assertEqual(state, ()) value, state = critic_net(network_input, state) + output_spec = critic_net.output_spec + if use_batch_ensemble: + value = value[0] + output_spec = output_spec[0] - self.assertEqual(critic_net.output_spec, TensorSpec(())) + self.assertEqual(output_spec, TensorSpec(())) # (batch_size,) self.assertEqual(value.shape, (2, )) @@ -96,7 +103,11 @@ def test_critic(self, lstm_hidden_size): lambda x: x.unsqueeze(1).expand(x.shape[0], 6, x.shape[1]), state) value, state = pnet(network_input, state) - self.assertEqual(pnet.output_spec, TensorSpec((6, ))) + output_spec = pnet.output_spec + if use_batch_ensemble: + value = value[0] + output_spec = output_spec[0] + self.assertEqual(output_spec, TensorSpec((6, ))) self.assertEqual(value.shape, (2, 6)) def test_make_parallel(self): diff --git a/alf/networks/encoding_networks.py b/alf/networks/encoding_networks.py index 74b45519e..3bbcc0406 100644 --- a/alf/networks/encoding_networks.py +++ b/alf/networks/encoding_networks.py @@ -691,14 +691,14 @@ def __init__(self, use_batch_ensemble (bool): whether to use BatchEnsemble FC and Conv2D layers. If True, both BatchEnsemble layers will always be created with ``output_ensemble_ids=True``, and as a result, the output of - the network is a tuple with ensemble_ids. + the network is a tuple of (outputs, ensemble_ids). ensemble_size (int): ensemble size, only effective if use_batch_ensemble is True. input_with_ensemble_ids (bool): whether handle inputs with ensemble_ids, if True, input to the network should be a tuple of two tensors, the first one is the input data tensor and the second one is the ensemble_ids. This option is only effective if use_batch_ensemble - is True. + is True. last_layer_size (int): an optional size of an additional layer appended at the very end. Note that if ``last_activation`` is specified, ``last_layer_size`` has to be specified explicitly. @@ -734,13 +734,17 @@ def __init__(self, assert preprocessing_combiner is not None, \ ("When a nested input tensor spec is provided, an input " + "preprocessing combiner must also be provided!") - spec = preprocessing_combiner(spec) nets.append(preprocessing_combiner) + spec = preprocessing_combiner(spec) + if isinstance(preprocessing_combiner, (_Sequential, Network)): + spec = spec[0] + if isinstance(spec, tuple): + spec = spec[0] else: assert isinstance(spec, TensorSpec), \ "The spec must be an instance of TensorSpec!" - if input_with_ensemble_ids: + if nets and input_with_ensemble_ids: nets = [ Parallel( (Sequential(*nets, input_tensor_spec=input_tensor_spec),