diff --git a/alf/algorithms/curl_encoder.py b/alf/algorithms/curl_encoder.py new file mode 100644 index 000000000..c74488f3d --- /dev/null +++ b/alf/algorithms/curl_encoder.py @@ -0,0 +1,216 @@ +# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contrastive Unsupervised Representations for Reinforcement Learning.""" + +import torch +import torch.nn as nn +import numpy as np + +import alf +from alf.algorithms.algorithm import Algorithm +from alf.data_structures import AlgStep, LossInfo, TimeStep +from alf.utils import common +from skimage.util.shape import view_as_windows +from skimage.io import imsave +import torchvision + + +def creat_encoder(input_spec, feature_dim, num_layers=2, num_filters=32): + """ + Creats encoder for CURL Alogrithm. + + Args: + + input_spec (TensorSpec): Describing the input tensor. + feature_dim (int): The dimension of feature at the output. + num_layers (int): Number of hidden layers. + num_filters (int): Number of filters in each layer + + Returns: + + (Alf Sequential): A module that perofrms the described operations. + + """ + + stacks = [ + alf.layers.Conv2D(input_spec.shape[0], num_filters, 3, strides=2) + ] + + for i in range(num_layers - 1): + stacks.append( + alf.layers.Conv2D(num_filters, num_filters, 3, strides=1)) + + before_fc = alf.nn.Sequential( + *stacks, alf.layers.Reshape((-1, )), input_tensor_spec=input_spec) + return alf.nn.Sequential( + before_fc, + alf.layers.FC( + input_size=before_fc.output_spec.shape[0], + output_size=feature_dim, + use_ln=True)) + + +@alf.configurable +class curl_encoder(Algorithm): + """ + The encoder part of contrastive unsupervised representations + for reinforcement learning. Can be used with most reinforcement + learning like SAC. + """ + + def __init__(self, + observation_spec, + feature_dim, + crop_size=84, + action_spec=None, + encoder_tau=0.05, + debug_summaries=False, + optimizer=None, + output_tanh=False, + save_image=False, + use_pytorch_randcrop=False, + detach_encoder=False, + name='curl_encoder'): + """ + Args: + + observation_spec (TensorSpec): The shape of input tensor + (B x C x W x H) assume W = H. + feature_dim (int): The dimension of the output vector, the + dim of W is (feature_dim x feature_dim). + crop_size (int): Dim of cropped image. After crop, the image + look like (B x C x crop_size x crop_size). + encoder_tau (float): Factor for soft update of target encoder + output_tanh (boolean): Determin if attach a layer of tanh at + the end of encoder. + + Retrun: + A CURL model. + """ + super().__init__( + train_state_spec=observation_spec, + optimizer=optimizer, + debug_summaries=debug_summaries, + name=name) + self.observation_spec = observation_spec + self.channels = observation_spec.shape[0] + self.after_crop_spec = alf.BoundedTensorSpec((self.channels, crop_size, + crop_size)) + self.feature_dim = feature_dim + self.output_spec = alf.BoundedTensorSpec((feature_dim, )) + self.crop_size = crop_size + self._encoding_net = creat_encoder(self.after_crop_spec, feature_dim) + self._target_encoding_net = self._encoding_net.copy( + name='target_encoding_net_ctor') + self.CrossEntropyLoss = nn.CrossEntropyLoss(reduction='none') + self.W = nn.Parameter(torch.rand(feature_dim, feature_dim)) + self.output_tanh = output_tanh + self.save_image = save_image + self.use_pytorch_randcrop = use_pytorch_randcrop + self.detach_encoder = detach_encoder + self._update_target = common.get_target_updater( + models=[self._encoding_net], + target_models=[self._target_encoding_net], + tau=encoder_tau) + if use_pytorch_randcrop: + self.pytorch_randcrop = torchvision.transforms.RandomCrop( + self.crop_size) + + def random_crop(self, obs, output_size, save_image=False): + """ + Random crop the input images. On each image, the crop position is + identical across the channels. + + Args: + obs (Tensor): Batch images with shape (B,C,H,W). + output_size (int): The hight and width of output image. + save_image (boolean): Save the origin image and cropped image if True. + + Return: + (Tensor): Cropped images. + """ + if self.use_pytorch_randcrop: + return self.pytorch_randcrop(obs) + else: + obs_cpu = obs.cpu() + imgs = obs_cpu.numpy() + n = imgs.shape[0] + img_size = imgs.shape[-1] + crop_max = img_size - output_size + imgs = np.transpose(imgs, (0, 2, 3, 1)) + w1 = np.random.randint(0, crop_max, n) + h1 = np.random.randint(0, crop_max, n) + windows = view_as_windows( + imgs, (1, output_size, output_size, 1))[..., 0, :, :, 0] + cropped_imgs = windows[np.arange(n), w1, h1] + if save_image: + for i in range(n): + breakpoint() + imsave("~/image_test/origin" + str(i) + ".PNG", + imgs[i, :, :, 0]) + imsave("~/image_test/cropped" + str(i) + ".PNG", + cropped_imgs[i, 0, :, :]) + + return_torch = torch.from_numpy(cropped_imgs) + return return_torch.to(torch.device("cuda:0")) + + def predict_step(self, inputs: TimeStep, state): + #random crop + crop_obs = self.random_crop( + inputs.observation, self.crop_size, save_image=self.save_image) + latent = self._encoding_net(crop_obs)[0] + if self.output_tanh: + output = torch.tanh(latent) + else: + output = latent + + return AlgStep(output=output, state=state) + + def rollout_step(self, inputs: TimeStep, state): + #random crop + crop_obs = self.random_crop( + inputs.observation, self.crop_size, save_image=self.save_image) + latent = self._encoding_net(crop_obs)[0] + if self.output_tanh: + output = torch.tanh(latent) + else: + output = latent + + return AlgStep(output=latent, state=state) + + def train_step(self, inputs: TimeStep, state, rollout_info=None): + #random crop obs + rc_obs_1 = self.random_crop( + inputs.observation, self.crop_size, save_image=self.save_image) + rc_obs_2 = self.random_crop( + inputs.observation, self.crop_size, save_image=self.save_image) + + #generate encoded observation + latent_q = self._encoding_net(rc_obs_1)[0] + latent_k = self._target_encoding_net(rc_obs_2)[0].detach() + + W_z = torch.matmul(self.W, latent_k.T) + logits = torch.matmul(latent_q, W_z) + logits = logits - torch.max(logits, 1)[0][:, None] + labels = torch.arange(logits.shape[0]).long() + loss = self.CrossEntropyLoss(logits, labels) + if self.detach_encoder: + latent_q = latent_q.detach() + return AlgStep(output=latent_q, state=state, info=LossInfo(loss=loss)) + + def after_update(self, root_inputs=None, train_info=None): + self._update_target() + + def _trainable_attributes_to_ignore(self): + return ['_target_encoding_net'] diff --git a/alf/environments/suite_dmc2gym.py b/alf/environments/suite_dmc2gym.py new file mode 100644 index 000000000..b533bc5f6 --- /dev/null +++ b/alf/environments/suite_dmc2gym.py @@ -0,0 +1,245 @@ +# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import alf +import dmc2gym +from alf.environments.alf_environment import AlfEnvironment +from alf.environments import gym_wrappers, alf_wrappers, alf_gym_wrapper + + +class AlfEnvironmentDMC2GYMWrapper(AlfEnvironment): + """AlfEnvironment wrapper forwards calls to the given environment.""" + + def __init__(self, env): + """Create an ALF environment base wrapper. + + Args: + env (AlfEnvironment): An AlfEnvironment instance to wrap. + + Returns: + A wrapped AlfEnvironment + """ + super(AlfEnvironmentDMC2GYMWrapper, self).__init__() + self._env = env + + def __getattr__(self, name): + """Forward all other calls to the base environment.""" + if name.startswith('_'): + raise AttributeError( + "attempted to get missing private attribute '{}'".format(name)) + return getattr(self._env, name) + + @property + def batched(self): + return self._env.batched + + @property + def batch_size(self): + return self._env.batch_size + + @property + def num_tasks(self): + return self._env.num_tasks + + @property + def task_names(self): + return self._env.task_names + + def _reset(self): + reset_result = self._env.reset() + reset_result = reset_result._replace(env_info={}) + return reset_result + + def _step(self, action): + step_result = self._env.step(action) + step_result = step_result._replace(env_info={}) + return step_result + + def get_info(self): + return self._env.get_info() + + def env_info_spec(self): + return self._env.env_info_spec() + + def time_step_spec(self): + return self._env.time_step_spec() + + def observation_spec(self): + return self._env.observation_spec() + + def action_spec(self): + return self._env.action_spec() + + def reward_spec(self): + return self._env.reward_spec() + + def close(self): + return self._env.close() + + def render(self, mode='rgb_array'): + return self._env.render(mode) + + def seed(self, seed): + return self._env.seed(seed) + + def wrapped_env(self): + return self._env + + +@alf.configurable +def dmc2gym_loader(environment_name, + domain_name='cheetah', + task_name='run', + seed=1, + pre_transform_image_size=100, + action_repeat=4, + env_id=None, + discount=1.0, + max_episode_steps=1000, + gym_env_wrappers=(), + alf_env_wrappers=[AlfEnvironmentDMC2GYMWrapper], + image_channel_first=False): + """ Load the MuJoCo environment through dmc2gym + + This loader will not take environment_name, instead please use domain_name and tesk_name. + For installation of dmc2gym, see https://github.com/denisyarats/dmc2gym. + For installation of DMControl, see https://github.com/deepmind/mujoco. + For installation of MuJoCo200, see https://roboti.us. + Args: + environment_name (str): Do not use this arg, this arg is here to + metch up with create_environment. + domain_name (str): The name of MuJoCo domain that is used. + task_name (str): The name of task we want the agent to do in the + current MuJoCo domain. + seed (int): Random seed for the environment. + pre_transform_inage_size (int): The height and width of the output + image from the environment. + action_repeat (int): Action repeat of gym environment. + env_id (str): The environment id generated form domain_name, task_name + and seed. + discount (float): Discount to use for the environment. + max_episode_steps (int): If None the max_episode_steps will be set to the + default step limit defined in the environment's spec. No limit is applied + if set to 0 or if there is no max_episode_steps set in the environment's + spec. + gym_env_wrappers (Iterable): Iterable with references to gym_wrappers + classes to use directly on the gym environment. + alf_env_wrappers (Iterable): Iterable with references to alf_wrappers + classes to use on the ALF environment. + image_channel_first (bool): whether transpose image channels to first dimension. + + + + Returns: + A wrapped AlfEnvironment + """ + + gym_env = dmc2gym.make( + domain_name=domain_name, + task_name=task_name, + seed=seed, + visualize_reward=False, + from_pixels=True, + episode_length=max_episode_steps, + height=pre_transform_image_size, + width=pre_transform_image_size, + frame_skip=action_repeat) + return wrap_env( + gym_env, + env_id=env_id, + discount=discount, + max_episode_steps=max_episode_steps, + gym_env_wrappers=gym_env_wrappers, + alf_env_wrappers=alf_env_wrappers, + image_channel_first=image_channel_first) + + +@alf.configurable +def wrap_env(gym_env, + env_id=None, + discount=1.0, + max_episode_steps=0, + gym_env_wrappers=(), + time_limit_wrapper=alf_wrappers.TimeLimit, + normalize_action=True, + clip_action=True, + alf_env_wrappers=(), + image_channel_first=True, + auto_reset=True): + """Wraps given gym environment with AlfGymWrapper. + + Note that by default a TimeLimit wrapper is used to limit episode lengths + to the default benchmarks defined by the registered environments. + + Also note that all gym wrappers assume images are 'channel_last' by default, + while PyTorch only supports 'channel_first' image inputs. To enable this + transpose, 'image_channel_first' is set as True by default. ``gym_wrappers.ImageChannelFirst`` + is applied after all gym_env_wrappers and before the AlfGymWrapper. + + Args: + gym_env (gym.Env): An instance of OpenAI gym environment. + env_id (int): (optional) ID of the environment. + discount (float): Discount to use for the environment. + max_episode_steps (int): Used to create a TimeLimitWrapper. No limit is applied + if set to 0. Usually set to `gym_spec.max_episode_steps` as done in `load. + gym_env_wrappers (Iterable): Iterable with references to gym_wrappers, + classes to use directly on the gym environment. + time_limit_wrapper (AlfEnvironmentBaseWrapper): Wrapper that accepts + (env, max_episode_steps) params to enforce a TimeLimit. Usually this + should be left as the default, alf_wrappers.TimeLimit. + normalize_action (bool): if True, will scale continuous actions to + ``[-1, 1]`` to be better used by algorithms that compute entropies. + clip_action (bool): If True, will clip continuous action to its bound specified + by ``action_spec``. If ``normalize_action`` is also ``True``, this + clipping happens after the normalization (i.e., clips to ``[-1, 1]``). + alf_env_wrappers (Iterable): Iterable with references to alf_wrappers + classes to use on the ALF environment. + image_channel_first (bool): whether transpose image channels to first dimension. + PyTorch only supports channgel_first image inputs. + auto_reset (bool): If True (default), reset the environment automatically after a + terminal state is reached. + + Returns: + An AlfEnvironment instance. + """ + + for wrapper in gym_env_wrappers: + gym_env = wrapper(gym_env) + + # To apply channel_first transpose on gym (py) env + if image_channel_first: + gym_env = gym_wrappers.ImageChannelFirst(gym_env) + + if normalize_action: + # normalize continuous actions to [-1, 1] + gym_env = gym_wrappers.NormalizedAction(gym_env) + + if clip_action: + # clip continuous actions according to gym_env.action_space + gym_env = gym_wrappers.ContinuousActionClip(gym_env) + + env = alf_gym_wrapper.AlfGymWrapper( + gym_env=gym_env, + env_id=env_id, + discount=discount, + auto_reset=auto_reset, + ) + + if max_episode_steps > 0: + env = time_limit_wrapper(env, max_episode_steps) + + for wrapper in alf_env_wrappers: + env = wrapper(env) + + return env diff --git a/alf/examples/curl_cheetah_run_conf.py b/alf/examples/curl_cheetah_run_conf.py new file mode 100644 index 000000000..d7253d855 --- /dev/null +++ b/alf/examples/curl_cheetah_run_conf.py @@ -0,0 +1,103 @@ +# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import alf +from functools import partial +from alf.algorithms.agent import Agent +from alf.algorithms.sac_algorithm import SacAlgorithm +from alf.algorithms.curl_encoder import curl_encoder +from alf.networks import (NormalProjectionNetwork, ActorDistributionNetwork, + CriticNetwork) +from alf.algorithms.data_transformer import FrameStacker, ImageScaleTransformer, RewardClipping +from alf.utils.math_ops import clipped_exp +from alf.optimizers import Adam +from alf.utils.dist_utils import calc_default_target_entropy +from alf.utils.losses import element_wise_squared_loss +from alf.environments.suite_dmc2gym import dmc2gym_loader +import math + +seed = 1 + +alf.config( + 'create_environment', + env_name='dmc2gym', + env_load_fn=dmc2gym_loader, + num_parallel_environments=1, + nonparallel=True, + seed=seed) + +actor_network_cls = partial( + ActorDistributionNetwork, + fc_layer_params=(256, 256), + continuous_projection_net_ctor=partial( + NormalProjectionNetwork, + state_dependent_std=True, + squash_mean=False, + scale_distribution=True, + std_transform=clipped_exp)) + +critic_network_cls = partial( + CriticNetwork, + use_naive_parallel_network=True, + joint_fc_layer_params=(256, 256)) + +alf.config( + 'SacAlgorithm', + actor_network_cls=actor_network_cls, + critic_network_cls=critic_network_cls, + actor_optimizer=Adam(lr=2e-4), + critic_optimizer=Adam(lr=2e-4), + alpha_optimizer=Adam(lr=1e-4, betas=(0.5, 0.999)), ##change as test 3 + target_entropy=partial(calc_default_target_entropy, min_prob=0.1), + target_update_tau=0.01, + initial_log_alpha=math.log(0.1), + target_update_period=2) +alf.config( + 'OneStepTDLoss', gamma=0.99, td_error_loss_fn=element_wise_squared_loss) + +alf.config( + 'curl_encoder', + feature_dim=50, + crop_size=84, + optimizer=Adam(lr=2e-4), + save_image=False, + use_pytorch_randcrop=True) + +alf.config( + 'Agent', + rl_algorithm_cls=SacAlgorithm, + representation_learner_cls=curl_encoder) + +alf.config('FrameStacker', stack_size=3) + +alf.config( + 'TrainerConfig', + initial_collect_steps=1000, + mini_batch_length=2, + unroll_length=1, + mini_batch_size=512, + num_updates_per_train_iter=1, + whole_replay_buffer_training=False, + clear_replay_buffer=False, + num_iterations=1000000, + num_checkpoints=20, + evaluate=False, + debug_summaries=True, + summarize_grads_and_vars=True, + summary_interval=100, + replay_buffer_length=100000, + algorithm_ctor=Agent, + profiling=False, + data_transformer_ctor=[FrameStacker, ImageScaleTransformer], +)