diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index d36f6bc717..c6e537883d 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -651,7 +651,9 @@ def test_discrete_dt(): from ding.data import create_dataset from ding.config import compile_config from ding.model import DecisionTransformer + from ding.model.template.elastic_decision_transformer import ElasticDecisionTransformer from ding.policy import DTPolicy + from ding.policy.edt import EDTPolicy from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \ OfflineMemoryDataFetcher, offline_logger, termination_checker ding_init(config[0]) diff --git a/ding/model/template/elastic_decision_transformer.py b/ding/model/template/elastic_decision_transformer.py new file mode 100644 index 0000000000..0d6cbb0ff2 --- /dev/null +++ b/ding/model/template/elastic_decision_transformer.py @@ -0,0 +1,556 @@ +""" +This is the implementation of elastic decision transformer + +Reference: https://github.com/kristery/Elastic-DT/blob/master/decision_transformer/model.py +""" +import math +from typing import Union, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from ding.utils import SequenceType + +class MaskedCausalAttention(nn.Module): + def __init__( + self, + h_dim: int, + max_T: int, + n_heads: int, + drop_p: float, + mgdt: bool = False, + dt_mask: bool = False, + att_mask: Optional[torch.Tensor] = None, + num_inputs: int = 4, + real_rtg: bool = False # currently not used to change the attention mask since it will make sampling more complicated + ) -> None: + """ + Overview: + The implementation of masked causal attention in decision transformer. + + Arguments: + - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. + - max_T (:obj:`int`): The max context length of the attention, such as 6. + - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. + - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. + - mgdt (:obj:`bool`): If use multi-game decision transformer. + - dt_mask (:obj:`bool`): If use decision transformer mask. + - att_mask (:obj:`Optional[torch.Tensor]`): Define attention mask manually of default. + - num_inputs (:obj:`int`): The number of inputs when mgdt mode is used. + - real_rth (:obj:`bool`): + """ + super().__init__() + + self.n_heads = n_heads + self.max_T = max_T + self.num_inputs=num_inputs + self.real_rtg=real_rtg + + self.q_net = nn.Linear(h_dim, h_dim) + self.k_net = nn.Linear(h_dim, h_dim) + self.v_net = nn.Linear(h_dim, h_dim) + + self.proj_net = nn.Linear(h_dim, h_dim) + + self.att_drop = nn.Dropout(drop_p) + self.proj_drop = nn.Dropout(drop_p) + + if att_mask is not None: + mask = att_mask + else: + ones = torch.ones((max_T, max_T)) + mask = torch.tril(ones).view(1, 1, max_T, max_T) + if (mgdt and not dt_mask): + # need to mask the return except for the first return entry + # this is the default practice used by their notebook + # for every inference, we first estimate the return value for the first return + # then we estimate the action for at timestamp t + # it is actually not mentioned in the paper. (ref: ret_sample_fn, single_return_token) + # mask other ret entries (s, R, a, s, R, a) + period = num_inputs + ret_order = 2 + ret_masked_rows = torch.arange(period + ret_order-1, max_T, period).long() + # print(ret_masked_rows) + # print(max_T, ret_masked_rows, mask.shape) + mask[:, :, :, ret_masked_rows] = 0 + + # register buffer makes sure mask does not get updated + # during backpropagation + self.register_buffer("mask", mask) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + MaskedCausalAttention forward computation graph, input a sequence tensor \ + and return a tensor with the same shape. + + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + + Returns: + - out (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. + + Examples: + >>> inputs = torch.randn(2, 4, 64) + >>> model = MaskedCausalAttention(64, 5, 4, 0.1) + >>> outputs = model(inputs) + >>> assert outputs.shape == torch.Size([2, 4, 64]) + """ + B, T, C = x.shape # batch size, seq length, h_dim * n_heads + + N, D = self.n_heads, C // self.n_heads, + # N = num heads, D = attention dim + + # rearrange q, k, v as (B, N, T, D) + q = self.q_net(x).view(B, T, N, D).transpose(1, 2) + k = self.k_net(x).view(B, T, N, D).transpose(1, 2) + v = self.v_net(x).view(B, T, N, D).transpose(1, 2) + + # weights (B, N, T, T) + weights = q @ k.transpose(2, 3) / math.sqrt(D) + # causal mask applied to weights + #print(f"shape of weights: {weights.shape}, shape of mask: {self.mask.shape}, T: {T}") + weights = weights.masked_fill( + self.mask[..., :T, :T] == 0, float("-inf") + ) + # normalize weights, all -inf -> 0 after softmax + normalized_weights = F.softmax(weights, dim=-1) + + # attention (B, N, T, D) + attention = self.att_drop(normalized_weights @ v) + + # gather heads and project (B, N, T, D) -> (B, T, N*D) + attention = attention.transpose(1, 2).contiguous().view(B, T, N * D) + + out = self.proj_drop(self.proj_net(attention)) + return out + +class Block(nn.Module): + def __init__(self, + h_dim: int, + max_T: int, + n_heads: int, + drop_p: float, + mgdt: bool=False, + dt_mask: bool=False, + att_mask: Optional[torch.Tensor]=None, + num_inputs: int=4, + real_rtg: bool=False + ) -> None: + """ + Overview: + The decision transformer block based on MaskedCasualAttention. + + Arguments: + - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. + - max_T (:obj:`int`): The max context length of the attention, such as 6. + - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. + - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. + - mgdt (:obj:`bool`): If use multi-game decision transformer. + - dt_mask (:obj:`bool`): If use decision transformer mask. + - att_mask (:obj:`Optional[torch.Tensor]`): Define attention mask manually of default. + - num_inputs (:obj:`int`): The number of inputs when mgdt mode is used. + - real_rth (:obj:`bool`): + """ + super().__init__() + self.num_inputs = num_inputs + self.attention = MaskedCausalAttention( + h_dim, + max_T, + n_heads, + drop_p, + mgdt=mgdt, + dt_mask=dt_mask, + att_mask=att_mask, + num_inputs=num_inputs, + real_rtg=real_rtg + ) + self.mlp = nn.Sequential( + nn.Linear(h_dim, 4 * h_dim), + nn.GELU(), + nn.Linear(4 * h_dim, h_dim), + nn.Dropout(drop_p), + ) + self.ln1 = nn.LayerNorm(h_dim) + self.ln2 = nn.LayerNorm(h_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward computation graph of the decision transformer block, input a sequence tensor + and return a tensor with the same shape. + + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + + Returns: + - output (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. + + Examples: + >>> inputs = torch.randn(2, 4, 64) + >>> model = Block(64, 5, 4, 0.1) + >>> outputs = model(inputs) + >>> outputs.shape == torch.Size([2, 4, 64]) + """ + # Attention -> LayerNorm -> MLP -> LayerNorm + # print(f"shape of x: {x.shape}, shape of attention: {self.attention(x).shape}") + x = x + self.attention(x) # residual + x = self.ln1(x) + x = x + self.mlp(x) # residual + x = self.ln2(x) + return x + + +class DecisionTransformer(nn.Module): + """ + Overview: + The implementation of decision transformer. + Interfaces: + ``__init__``, ``forward``, ``configure_optimizers`` + """ + + def __init__( + self, + state_dim: int, + act_dim: int, + n_blocks: int, + h_dim: int, + context_len: int, + n_heads: int, + drop_p: float, + env_name: str, + max_timestep: int = 4096, + num_bin: int = 120, + dt_mask: bool = False, + rtg_scale: int =1000, + ) -> None: + """ + Overview: + Initialize the DecisionTransformer Model according to input arguments. + + Arguments: + - state_dim (:obj:`int`): Dimension of state, such as 17. + - act_dim (:obj:`int`): The dimension of actions, such as 6. + - n_blocks (:obj:`int`): The number of transformer blocks in the decision transformer, such as 3. + - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. + - context_len (:obj:`int`): The max context length of the attention, such as 6. + - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. + - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. + - env_name (:obj:`str`): The name of environment. + - max_timestep (:obj:`int`): The max length of the total sequence, defaults to be 4096. + - num_bin (:obj:`int`): Number of return output bins, such as 60. + - dt_mask (:obj:`bool`): Whether use mask in the blocks of Decision Transformer. + - rtg_scale (:obj:`int`): The scale factor for normalizing the return-to-go values during training. + """ + super().__init__() + + self.state_dim = state_dim + self.act_dim = act_dim + self.h_dim = h_dim + self.num_bin = num_bin + # for return scaling + self.env_name = env_name + self.rtg_scale = rtg_scale + + ### transformer blocks + input_seq_len = 4 * context_len + blocks = [Block( + h_dim, + input_seq_len, + n_heads, + drop_p, + mgdt=True, + dt_mask=dt_mask, + ) for _ in range(n_blocks) + ] + self.transformer = nn.Sequential(*blocks) + + ### projection heads (project to embedding) + self.embed_ln = nn.LayerNorm(h_dim) + self.embed_timestep = nn.Embedding(max_timestep, h_dim) + self.embed_rtg = torch.nn.Linear(1, h_dim) + self.embed_state = torch.nn.Linear(state_dim, h_dim) + self.embed_reward = torch.nn.Linear(1, h_dim) + + # continuous actions + self.embed_action = torch.nn.Linear(act_dim, h_dim) + use_action_tanh = True # True for continuous actions + + + ### prediction heads + self.predict_rtg = torch.nn.Linear(h_dim, int(num_bin)) + self.predict_state = torch.nn.Linear(h_dim, state_dim) + self.predict_action = nn.Sequential( + *( + [nn.Linear(h_dim, act_dim)] + + ([nn.Tanh()] if use_action_tanh else []) + ) + ) + self.predict_reward = torch.nn.Linear(h_dim, 1) + + def forward( + self, + timesteps: torch.Tensor, + states: torch.Tensor, + actions: torch.Tensor, + returns_to_go: torch.Tensor, + rewards: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + B, T, _ = states.shape + + returns_to_go = returns_to_go.float() + # returns_to_go = ( + # encode_return( + # self.env_name, returns_to_go, num_bin=self.num_bin, rtg_scale=self.rtg_scale + # ) + # - self.num_bin / 2 + # ) / (self.num_bin / 2) + time_embeddings = self.embed_timestep(timesteps) + + # time embeddings are treated similar to positional embeddings + state_embeddings = self.embed_state(states) + time_embeddings + action_embeddings = self.embed_action(actions) + time_embeddings + returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings + rewards_embeddings = self.embed_reward(rewards) + time_embeddings + + + # stack rtg, states and actions and reshape sequence as + # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) + h = ( + torch.stack( + ( + state_embeddings, + returns_embeddings, + action_embeddings, + rewards_embeddings, + ), + dim=1, + ) + .permute(0, 2, 1, 3) + .reshape(B, 4 * T, self.h_dim) + ) + + h = self.embed_ln(h) + + # transformer and prediction + h = self.transformer(h) + + h = h.reshape(B, T, 4, self.h_dim).permute(0, 2, 1, 3) + + # get predictions + return_preds = self.predict_rtg(h[:, 0]) # predict next rtg given s + state_preds = self.predict_state( + h[:, 3] + ) # predict next state given s, R, a, r + action_preds = self.predict_action( + h[:, 1] + ) # predict action given s, R + reward_preds = self.predict_reward( + h[:, 2] + ) # predict reward given s, R, a + + return state_preds, action_preds, return_preds, reward_preds + + +# a version that does not use reward at all +class ElasticDecisionTransformer(DecisionTransformer): + """ + Overview: + The implementation of elsatic decision transformer. + Interfaces: + ``__init__``, ``forward`` + """ + def __init__( + self, + state_dim: int, + act_dim: int, + n_blocks: int, + h_dim: int, + context_len: int, + n_heads: int, + drop_p: float, + env_name: str, + max_timestep: int = 4096, + num_bin: int = 120, + dt_mask: bool = False, + rtg_scale: int = 1000, + num_inputs: int = 3, + real_rtg: bool = False, + is_continuous: bool = True, # True for continuous action + ) -> None: + """ + Overview: + Initialize the Elastic Decision Transformer Model. The definition of Elastic Decision Transformer \ + is defined based on Decision Transformer. + + Arguments: + - state_dim (:obj:`int`): Dimension of state, such as 17. + - act_dim (:obj:`int`): The dimension of actions, such as 6. + - n_blocks (:obj:`int`): The number of transformer blocks in the decision transformer, such as 3. + - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. + - context_len (:obj:`int`): The max context length of the attention, such as 6. + - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. + - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. + - max_timestep (:obj:`int`): The max length of the total sequence, defaults to be 4096. + - num_bin (:obj:`int`): Number of return output bins, such as 60. + - dt_mask (:obj:`bool`): Whether use mask in the blocks of Decision Transformer. + - rtg_scale (:obj:`int`): The scale factor for normalizing the return-to-go values during training. + - num_inputs (:obj:`int`): The input arguments of EDT. 3 for state, return, action while 4 for state, return, action, reward. + - real_rtg (:obj:`bool`): Realized return-to-go, which represents the actual cumulative return from the current state to the end of the episode. + - is_continuous (:obj:`bool`): True for continuous action, while False for discrete action. + """ + super().__init__(state_dim, act_dim, n_blocks, h_dim, context_len, n_heads, drop_p, \ + env_name, max_timestep=max_timestep, num_bin=num_bin, dt_mask=dt_mask, rtg_scale=rtg_scale, + ) + # return, state, action + self.num_inputs = num_inputs + self.is_continuous = is_continuous + input_seq_len = num_inputs * context_len + blocks = [ + Block( + h_dim, + input_seq_len, + n_heads, + drop_p, + mgdt=True, + dt_mask=dt_mask, + num_inputs=num_inputs, + real_rtg=real_rtg, + ) + for _ in range(n_blocks) + ] + self.transformer = nn.Sequential(*blocks) + + ### projection heads (project to embedding) + self.embed_ln = nn.LayerNorm(h_dim) + self.embed_timestep = nn.Embedding(max_timestep, h_dim) + self.embed_rtg = torch.nn.Linear(1, h_dim) + self.embed_state = torch.nn.Linear(state_dim, h_dim) + self.embed_reward = torch.nn.Linear(1, h_dim) + # # discrete actions + if not self.is_continuous: + self.embed_action = torch.nn.Embedding(18, h_dim) + else: + self.embed_action = torch.nn.Linear(act_dim, h_dim) + + ### prediction heads + self.predict_rtg = torch.nn.Linear(h_dim, int(num_bin)) + self.predict_rtg2 = torch.nn.Linear(h_dim, 1) + self.predict_state = torch.nn.Linear(h_dim + act_dim, state_dim) + self.predict_action = nn.Sequential( + *([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if is_continuous else [])) + ) + self.predict_reward = torch.nn.Linear(h_dim, 1) + + def forward( + self, + timesteps: torch.Tensor, + states: torch.Tensor, + actions: torch.Tensor, + returns_to_go: torch.Tensor, + *args, + **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation graph of the decision transformer, input a sequence tensor \ + and return a tensor with the same shape. Suppose B is batch size and T is context length. + + Arguments: + - timesteps (:obj:`torch.Tensor`): The timestep for input sequence with shape (B, T). + - states (:obj:`torch.Tensor`): The sequence of states with shape (B, T, S) where S is state size. + - actions (:obj:`torch.Tensor`): The sequence of actions with shape (B, T, A) where A is action size. + - returns_to_go (:obj:`torch.Tensor`): The sequence of return-to-go with shape (B, T, 1). + - rewards (:obj:`Optional[torch.Tensor]`): The sequence of rewards obtained at each timestep with shape (B, T, 1). \ + If provided and `num_inputs` is 4, it will be used in the computation. + + Returns: + - output (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]`): Output contains 5 tensors, \ + they are correspondingly `state_preds`, `action_preds`, `return_preds`, `return_preds2`, `reward_preds`. + + Examples: + >>> B, T, S, A, H, N = 5, 23, 17, 7, 64, 121 + >>> # B: batch_size + >>> # T: length + >>> # S: state_dim + >>> # A: action_dim + >>> # H: hidden_din + >>> # N: num_bin + >>> model = ElasticDecisionTransformer( + ... state_dim=S, + ... act_dim=A, + ... h_dim=H, + ... context_len=T, + ... num_bin=N, + ... n_blocks=5, + ... n_heads=8, # H must be divisible by n_heads + ... drop_p=0.1, + ... env_name="example_env", + ... ) + >>> timesteps = torch.randint(0, 4096, (B, T)) + >>> states = torch.randn(B, T, S) + >>> actions = torch.randn(B, T, A) + >>> returns_to_go = torch.randn(B, T, 1) + >>> rewards = torch.randn(B, T, 1) + >>> state_preds, action_preds, return_preds, return_preds2, reward_preds = model( + ... timesteps, states, actions, returns_to_go + ... ) + >>> assert state_preds.shape == torch.Size([B, T, S]) + >>> assert action_preds.shape == torch.Size([B, T, A]) + >>> assert return_preds.shape == torch.Size([B, T, N]) + >>> assert return_preds2.shape == torch.Size([B, T, 1]) + >>> assert reward_preds.shape == torch.Size([B, T, 1]) + + """ + B, T, _ = states.shape + returns_to_go = returns_to_go.float() + # returns_to_go = ( + # encode_return( + # self.env_name, returns_to_go, num_bin=self.num_bin, rtg_scale=self.rtg_scale + # ) + # - self.num_bin / 2 + # ) / (self.num_bin / 2) + rewards = kwargs.get("rewards", None) + time_embeddings = self.embed_timestep(timesteps) + state_embeddings = self.embed_state(states) + time_embeddings + action_embeddings = self.embed_action(actions) + time_embeddings + returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings + if rewards is not None and self.num_inputs == 4: + rewards_embeddings = self.embed_reward(rewards) + time_embeddings + assert self.num_inputs == 3 or 4 + + # stack rtg, states and actions and reshape sequence as + # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) + if self.num_inputs == 3: + h = ( + torch.stack((state_embeddings, returns_embeddings, action_embeddings), dim=1,) + .permute(0, 2, 1, 3).reshape(B, self.num_inputs * T, self.h_dim) + ) + elif self.num_inputs == 4: + h = ( + torch.stack((state_embeddings, returns_embeddings, action_embeddings, rewards_embeddings), dim=1,) + .permute(0, 2, 1, 3).reshape(B, self.num_inputs * T, self.h_dim) + ) + h = self.embed_ln(h) + + # transformer and prediction + h = self.transformer(h) + h = h.reshape(B, T, self.num_inputs, self.h_dim).permute(0, 2, 1, 3) + + # get predictions + return_preds = self.predict_rtg(h[:, 0]) # predict next rtg given s + return_preds2 = self.predict_rtg2(h[:, 0]) # predict next rtg with implicit loss + action_preds = self.predict_action(h[:, 1]) # predict action given s, R + state_preds = self.predict_state(torch.cat((h[:, 1], action_preds), 2)) + reward_preds = self.predict_reward(h[:, 2]) # predict reward given s, R, a + + return ( + state_preds, + action_preds, + return_preds, + return_preds2, + reward_preds, + ) + + \ No newline at end of file diff --git a/ding/model/template/tests/test_elastic_decision_transformer.py b/ding/model/template/tests/test_elastic_decision_transformer.py new file mode 100644 index 0000000000..651aadc8ac --- /dev/null +++ b/ding/model/template/tests/test_elastic_decision_transformer.py @@ -0,0 +1,66 @@ +import pytest +from itertools import product +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ding.model.template.elastic_decision_transformer import ElasticDecisionTransformer +from ding.torch_utils import is_differentiable + +@pytest.mark.unittest +def test_elastic_decision_transformer(): + B, T = 4, 6 + state_dim = 3 + act_dim = 2 + num_bin = 120 + model = ElasticDecisionTransformer( + state_dim=state_dim, + act_dim=act_dim, + h_dim=8, + context_len=T, + num_bin=num_bin, + n_blocks=3, + n_heads=2, #! H must be divisible by n_heads + drop_p=0.1, + env_name="example_env", + num_inputs=3, #! must be 3 or 4 + is_continuous=True + ) + + timesteps = torch.randint(0, 100, (B, T)) + + states = torch.randn(B, T, state_dim) + + actions = torch.randn(B, T, act_dim) + action_target = torch.randn([B, T, act_dim]) + + returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]) + returns_to_go = returns_to_go_sample.repeat([B, 1]).unsqueeze(-1) # B x T x 1 + + rewards = torch.randn(B, T, 1) + + traj_mask = torch.ones([B, T], dtype=torch.long) + + assert action_target.shape == (B, T, act_dim) + returns_to_go = returns_to_go.float() + # Forward + state_preds, action_preds, return_preds, return_preds2, reward_preds = model( + timesteps, states, actions, returns_to_go, rewards = rewards + ) + assert state_preds.shape == torch.Size([B, T, state_dim]) + assert action_preds.shape == torch.Size([B, T, act_dim]) + assert return_preds.shape == torch.Size([B, T, num_bin]) + assert return_preds2.shape == torch.Size([B, T, 1]) + assert reward_preds.shape == torch.Size([B, T, 1]) + + action_preds = action_preds.view(-1, act_dim)[traj_mask.view(-1, ) > 0] + action_target = action_target.view(-1, act_dim)[traj_mask.view(-1, ) > 0] + + action_loss = F.mse_loss(action_preds, action_target) + + is_differentiable( + action_loss, [ + model.transformer, model.embed_action, model.predict_action, model.embed_rtg, + model.embed_state + ] + ) diff --git a/ding/policy/edt.py b/ding/policy/edt.py new file mode 100644 index 0000000000..cebe1d65c5 --- /dev/null +++ b/ding/policy/edt.py @@ -0,0 +1,634 @@ +from typing import List, Dict, Any, Tuple, Optional +import math +from collections import namedtuple +import torch.nn.functional as F +import torch +import numpy as np +from ding.torch_utils import to_device +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_decollate +from .base_policy import Policy + +REF_MIN_SCORE = { + 'halfcheetah' : -280.178953, + 'walker2d' : 1.629008, + 'hopper' : -20.272305, + 'ant' : -325.6, + 'antmaze' : 0.0, +} + +REF_MAX_SCORE = { + 'halfcheetah' : 12135.0, + 'walker2d' : 4592.3, + 'hopper' : 3234.3, + 'ant' : 3879.7, + 'antmaze' : 700, +} + +@POLICY_REGISTRY.register('edt') +class EDTPolicy(Policy): + """ + Overview: + Policy class of Decision Transformer algorithm in discrete environments. + Paper link: https://arxiv.org/abs/2106.01345. + """ + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='edt', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool) Whether the RL algorithm is on-policy or off-policy. + on_policy=False, + # (bool) Whether use priority(priority sample, IS weight, update priority) + priority=False, + # (int) N-step reward for target q_value estimation + obs_shape=4, + action_shape=2, + rtg_scale=1000, # normalize returns to go + max_eval_ep_len=1000, # max len of one episode + batch_size=64, # training batch size + wt_decay=1e-4, # decay weight in optimizer + warmup_steps=10000, # steps for learning rate warmup + context_len=20, # length of transformer input + learning_rate=1e-4, + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + + .. note:: + The user can define and use customized network model but must obey the same inferface definition indicated \ + by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \ + ``ding.model.template.q_learning``. + """ + return 'edt', ['ding.model.template.edt'] + + def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For Decision Transformer, \ + it mainly contains the optimizer, algorithm-specific arguments such as rtg_scale and lr scheduler. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ + # rtg_scale: scale of `return to go` + # rtg_target: max target of `return to go` + # Our goal is normalize `return to go` to (0, 1), which will favour the covergence. + # As a result, we usually set rtg_scale == rtg_target. + self.env_name = self._cfg.env_id + + self.rtg_scale = self._cfg.rtg_scale # normalize returns to go + self.rtg_target = self._cfg.rtg_target # max target reward_to_go + self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode + + self.expectile = self._cfg.weights.expectile + self.top_percentile = self._cfg.weights.top_percentile + self.expert_weight = self._cfg.weights.expert_weight + self.exp_loss_weight = self._cfg.weights.exp_loss_weight + self.state_loss_weight = self._cfg.weights.state_loss_weight + self.cross_entropy_weight = self._cfg.weights.cross_entropy_weight + + + + lr = self._cfg.learning_rate # learning rate + wt_decay = self._cfg.wt_decay # weight decay + warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler + + self.clip_grad_norm_p = self._cfg.clip_grad_norm_p + + self.context_len = self._cfg.model.context_len # K in decision transformer + self.state_dim = self._cfg.model.state_dim + self.act_dim = self._cfg.model.act_dim + self.num_bin = self._cfg.model.num_bin # num of bin + + self._learn_model = self._model + self._atari_env = 'state_mean' not in self._cfg + self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg + + if self._atari_env: + self._optimizer = self._learn_model.configure_optimizers(wt_decay, lr) + else: + self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) + + self._scheduler = torch.optim.lr_scheduler.LambdaLR( + self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1) + ) + + self.max_env_score = -1.0 + + def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the offline dataset and then returns the output \ + result, including various training information such as loss, current learning rate. + Arguments: + - data (:obj:`List[torch.Tensor]`): The input data used for policy forward, including a series of \ + processed torch.Tensor data, i.e., timesteps, states, actions, returns_to_go, traj_mask. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + """ + self._learn_model.train() + + timesteps, states, next_states, actions, returns_to_go, rewards, traj_mask = data + + # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), + # and we need a 3-dim tensor + if len(returns_to_go.shape) == 2: + returns_to_go = returns_to_go.unsqueeze(-1) + if len(rewards.shape) == 2: + rewards = rewards.unsqueeze(-1) + # Guarantee return and reward has shape [B, T, 1] + + if self._basic_discrete_env: + actions = actions.to(torch.long) + actions = actions.squeeze(-1) + action_target = torch.clone(actions).detach().to(self._device) # [B, T, A] + state_target = torch.clone(states).detach().to(self._device) # [B, T, S] + return_to_go_target = torch.clone(returns_to_go).detach().to(self._device) + + if self._atari_env: + state_preds, action_preds, return_preds, imp_return_preds, reward_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1 + ) + else: + state_preds, action_preds, return_preds, imp_return_preds, reward_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go + ) + def expectile_loss(diff: torch.Tensor, expectile: float=0.8) -> torch.Tensor: + weight = torch.where(diff > 0, expectile, (1 - expectile)) + return weight * (diff**2) + + def cross_entropy(logits, labels): + # labels = F.one_hot(labels.long(), num_classes=int(num_bin)).squeeze(2) + labels = F.one_hot( + labels.long(), num_classes=int(self.num_bin) + ).squeeze() + criterion = torch.nn.CrossEntropyLoss() + return criterion(logits, labels.float()) + + def encode_return(env_name, ret, scale=1.0, num_bin=120, rtg_scale=1000): + env_key = env_name.split("-")[0].lower() + if env_key not in REF_MAX_SCORE: + ret_max = 100 + else: + ret_max = REF_MAX_SCORE[env_key] + if env_key not in REF_MIN_SCORE: + ret_min = -20 + else: + ret_min = REF_MIN_SCORE[env_key] + ret_max /= rtg_scale + ret_min /= rtg_scale + interval = (ret_max - ret_min) / (num_bin-1) + ret = torch.clip(ret, ret_min, ret_max) + return ((ret - ret_min) // interval).float() + + + + if self._atari_env: + action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) + else: + traj_mask = traj_mask.view(-1, ) + + # only consider non padded elements + action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0] + state_preds = state_preds.view(-1, self.state_dim)[traj_mask > 0] + imp_return_preds = imp_return_preds.reshape(-1, 1)[traj_mask > 0] + return_preds = return_preds.reshape(-1, int(self.num_bin))[traj_mask > 0] + + + if self._cfg.model.continuous: + action_target = action_target.view(-1, self.act_dim)[traj_mask > 0] + action_loss = F.mse_loss(action_preds, action_target) + state_target = next_states.view(-1, self.state_dim)[traj_mask > 0] + state_loss = F.mse_loss(state_preds, state_target) + imp_return_target = returns_to_go.reshape(-1, 1)[traj_mask > 0] + imp_loss = expectile_loss((imp_return_target - imp_return_preds), self.expectile).mean() + return_target = ( + encode_return( + self.env_name, + returns_to_go, + num_bin=self.num_bin, + rtg_scale=self.rtg_scale, + ).float().reshape(-1, 1)[traj_mask > 0] + ) + return_cross_entropy_loss = cross_entropy(return_preds, return_target) + + else: + action_target = action_target.view(-1)[traj_mask > 0] + action_loss = F.cross_entropy(action_preds, action_target) + state_target = next_states.view(-1)[traj_mask > 0] + state_loss = F.cross_entropy(state_preds, state_target) + imp_return_target = returns_to_go.reshape(-1, 1)[traj_mask > 0] + imp_loss = expectile_loss((imp_return_target - imp_return_preds), self.expectile).mean() + + edt_loss = action_loss \ + + state_loss * self.state_loss_weight \ + + imp_loss * self.exp_loss_weight \ + + if self._cfg.model.continuous: + edt_loss += return_cross_entropy_loss * self.cross_entropy_weight + + self._optimizer.zero_grad() + edt_loss.backward() + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self.clip_grad_norm_p) + self._optimizer.step() + self._scheduler.step() + + return { + 'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'], + 'action_loss': action_loss.detach().cpu().item(), + 'state_loss': state_loss.detach().cpu().item(), + 'implict_loss': imp_loss.detach().cpu().item(), + 'total_loss': edt_loss.detach().cpu().item(), + } + + def _init_eval(self) -> None: + """ + Overview: + Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \ + eval model, some algorithm-specific parameters such as context_len, max_eval_ep_len, etc. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. tip:: + For the evaluation of complete episodes, we need to maintain some historical information for transformer \ + inference. These variables need to be initialized in ``_init_eval`` and reset in ``_reset_eval`` when \ + necessary. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + """ + self._eval_model = self._model + # init data + self._device = torch.device(self._device) + + self.real_rtg = self._cfg.real_rtg + self.rtg_scale = self._cfg.rtg_scale # normalize returns to go + self.rtg_target = self._cfg.rtg_target # max target reward_to_go + self.state_dim = self._cfg.model.state_dim + self.act_dim = self._cfg.model.act_dim + self.eval_batch_size = self._cfg.evaluator_env_num + self.max_eval_ep_len = self._cfg.max_eval_ep_len + self.context_len = self._cfg.model.context_len # K in decision transformer + self.expectile = self._cfg.weights.expectile + + self.rs_steps = self._cfg.eval.rs_steps + self.rs_ratio = self._cfg.weights.rs_ratio + self.heuristic = self._cfg.eval.heuristic + self.heuristic_delta = self._cfg.eval.heuristic_delta + + + + self.t = [0 for _ in range(self.eval_batch_size)] + if self._cfg.model.continuous: + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, self.act_dim), + dtype=torch.float32, device=self._device) + else: + # (B, eval_len + 2 * context_len, A) for actions + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, 1), + dtype=torch.long, device=self._device) + + self._atari_env = 'state_mean' not in self._cfg + self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg + + if self._atari_env: + self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len,) + tuple(self.state_dim), + dtype=torch.float32, device=self._device) + else: + # (B, eval_len + 2 * context_len, S) for states + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, self.state_dim), + dtype=torch.float32, device=self._device) + self.state_mean = torch.from_numpy(np.array(self._cfg.state_mean)).to(self._device) + self.state_std = torch.from_numpy(np.array(self._cfg.state_std)).to(self._device) + + + self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len + 2 * self.context_len, step=1) + self.timesteps = self.timesteps.repeat(self.eval_batch_size, 1).to(self._device) + + # (B, eval_len + 2 * context_len, 1) for rtg & rewards + self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, 1), + dtype=torch.float32, device=self._device) + self.rewards = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, 1), + dtype=torch.float32, device=self._device) + + def decode_return(env_name: str, ret, scale: float=1.0, num_bin: int=120, rtg_scale: int=1000): + env_key = env_name.split("-")[0].lower() + if env_key not in REF_MAX_SCORE: + ret_max = 100 + else: + ret_max = REF_MAX_SCORE[env_key] + if env_key not in REF_MIN_SCORE: + ret_min = -20 + else: + ret_min = REF_MIN_SCORE[env_key] + ret_max /= rtg_scale + ret_min /= rtg_scale + interval = (ret_max - ret_min) / num_bin + return ret * interval + ret_min + + def _return_heuristic(self, + model: torch.nn.Module, + timesteps: torch.Tensor, + states: torch.Tensor, + actions: torch.Tensor, + rewards_to_go: torch.Tensor, + rewards: torch.Tensor, + context_len: int, + t: int, + # top_percentile: float, + # num_bin: int, + # rtg_scale: int, + # expert_weight: float, + # mgdt_sampling: bool = False, + rs_steps: int = 2, + rs_ratio: int = 1, + real_rtg: bool = False, + use_heuristic: bool = False, + heuristic_delta: int = 1, + previous_index: Optional[int] = None, + ) -> Tuple[torch.Tensor, int]: + highest_ret = -9999 + estimated_rtg = None + best_i = 0 + best_act = None + if t < context_len: + for i in range(0, math.ceil((t + 1) / rs_ratio), rs_steps): + _, act_preds, ret_preds, imp_ret_preds, _ = model.forward( + timesteps[:, i : context_len + i], + states[:, i : context_len + i], + actions[:, i : context_len + i], + rewards_to_go[:, i : context_len + i], + rewards[:, i : context_len + i], + ) + _, act_preds, ret_preds, imp_ret_preds_pure, _ = model.forward( + timesteps[:, i : context_len + i], + states[:, i : context_len + i], + actions[:, i : context_len + i], + imp_ret_preds, + rewards[:, i : context_len + i], + ) + if not real_rtg: + imp_ret_preds = imp_ret_preds_pure + ret_i = imp_ret_preds[:, t - i].detach().item() + if ret_i > highest_ret: + highest_ret = ret_i + best_i = i + estimated_rtg = imp_ret_preds.detach() + best_act = act_preds[0, t - i].detach() + else: + if use_heuristic: + prev_best_index = context_len - previous_index + loop = (prev_best_index-heuristic_delta, prev_best_index+1+heuristic_delta) + else: + loop = (0, math.ceil(context_len/rs_ratio), rs_steps) + for i in range(*loop): + if use_heuristic and (i < 0 or i >= context_len): + continue + _, act_preds, ret_preds, imp_ret_preds, _ = model.forward( + timesteps[:, t - context_len + 1 + i : t + 1 + i], + states[:, t - context_len + 1 + i : t + 1 + i], + actions[:, t - context_len + 1 + i : t + 1 + i], + rewards_to_go[:, t - context_len + 1 + i : t + 1 + i], + rewards[:, t - context_len + 1 + i : t + 1 + i], + ) + _, act_preds, ret_preds, imp_ret_preds_pure, _ = model.forward( + timesteps[:, t - context_len + 1 + i : t + 1 + i], + states[:, t - context_len + 1 + i : t + 1 + i], + actions[:, t - context_len + 1 + i : t + 1 + i], + imp_ret_preds, + rewards[:, t - context_len + 1 + i : t + 1 + i], + ) + if not real_rtg: + imp_ret_preds = imp_ret_preds_pure + + ret_i = imp_ret_preds[:, -1 - i].detach().item() + if ret_i > highest_ret: + highest_ret = ret_i + best_i = i + # estimated_rtg = imp_ret_preds.detach() + best_act = act_preds[0, -1 - i].detach() + return best_act, context_len - best_i + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance, such as interacting with envs. \ + Forward means that the policy gets some input data (current obs/return-to-go and historical information) \ + from the envs and then returns the output data, such as the action to interact with the envs. \ + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs and \ + reward to calculate running return-to-go. The key of the dict is environment id and the value is the \ + corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + Decision Transformer will do different operations for different types of envs in evaluation. + """ + # save and forward + + + data_id = list(data.keys()) + + self._eval_model.eval() + with torch.no_grad(): + + print(self.t) + best_acts = [] + for i in data_id: + curr_states = self.states[i].unsqueeze(0) + curr_runninng_rtg = self.running_rtg[i] + curr_rewards_to_go = self.rewards_to_go[i].unsqueeze(0) + curr_rewards = self.rewards[i].unsqueeze(0) + curr_actions = self.actions[i].unsqueeze(0) + previous_index = None + for t in range(self.max_eval_ep_len): + if self._atari_env: + curr_states[0, t] = data[i]['obs'].to(self._device) + else: + curr_states[0, t] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std + # print(f"curr_states[0, t] 的 shape 是 {curr_states[0, t].shape}, 而 state的shape是{curr_states.shape}") + curr_runninng_rtg = curr_runninng_rtg - (data[i]['reward'] / self.rtg_scale).to(self._device) + curr_rewards_to_go[0, t] = curr_runninng_rtg + curr_rewards[0, t] = data[i]['reward'] + act, best_index = self._return_heuristic( + model=self._eval_model, + timesteps=self.timesteps[i].unsqueeze(0), + states=curr_states, + actions=curr_actions, + rewards_to_go=curr_rewards_to_go, + rewards=curr_rewards, + context_len=self.context_len, + t=t, + rs_steps=self.rs_steps, + rs_ratio=self.rs_ratio, + real_rtg=self.real_rtg, + use_heuristic=self.heuristic, + heuristic_delta=self.heuristic_delta, + previous_index=previous_index + ) + previous_index = best_index + best_acts.append(act) + acts = torch.stack(best_acts, dim=0) + print(f"acts has shape {acts.shape}") + # previous_index = None + # for t in range(self.max_eval_ep_len): + # if self._atari_env: + # self.states[0, t] = data[0]['obs'].to(self._device) + # else: + # self.states[0, t] = (data[0]['obs'].to(self._device) - self.state_mean) / self.state_std + # print(f"self.states[0, t] 的 shape 是 {self.states[0, t].shape}, 而 state的shape是{self.states.shape}") + # self.running_rtg[0] = self.running_rtg[0] - (data[0]['reward'] / self.rtg_scale).to(self._device) + # self.rewards_to_go[0, t] = self.running_rtg[0] + # self.rewards[0, t] = data[0]['reward'] + # act, best_index = self._return_heuristic( + # model=self._eval_model, + # timesteps=self.timesteps, + # states=self.states, + # actions=self.actions, + # rewards_to_go=self.rewards_to_go, + # rewards=self.rewards, + # context_len=self.context_len, + # t=t, + # rs_steps=self.rs_steps, + # rs_ratio=self.rs_ratio, + # real_rtg=self.real_rtg, + # use_heuristic=self.heuristic, + # heuristic_delta=self.heuristic_delta, + # previous_index=previous_index + # ) + # previous_index = best_index + # act = act.unsqueeze(0) + # # print(f"{t} ended! act has shape {act.shape}") + for i in data_id: + self.actions[i, self.t[i]] = acts[i] # TODO: self.actions[i] should be a queue when exceed max_t + self.t[i] += 1 + + if self._cuda: + acts = to_device(acts, 'cpu') + output = {'action': acts} + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset some statvaeful riables for eval mode when necessary, such as the historical info of transformer \ + for decision transformer. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different environments/episodes in evaluation in ``data_id`` will have different history. + Arguments: + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + specified by ``data_id``. + """ + # clean data + if data_id is None: + self.t = [0 for _ in range(self.eval_batch_size)] + self.timesteps = torch.arange( + start=0, end=self.max_eval_ep_len + 2 * self.context_len, step=1 + ).repeat(self.eval_batch_size, 1).to(self._device) + if not self._cfg.model.continuous: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, 1), dtype=torch.long, device=self._device + ) + else: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, self.act_dim), + dtype=torch.float32, + device=self._device + ) + if self._atari_env: + self.states = torch.zeros( + ( + self.eval_batch_size, + self.max_eval_ep_len + 2 * self.context_len, + ) + tuple(self.state_dim), + dtype=torch.float32, + device=self._device + ) + self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] + else: + self.states = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, self.state_dim), + dtype=torch.float32, + device=self._device + ) + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] + + self.rewards_to_go = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, 1), dtype=torch.float32, device=self._device + ) + else: + for i in data_id: + self.t[i] = 0 + if not self._cfg.model.continuous: + self.actions[i] = torch.zeros((self.max_eval_ep_len + 2 * self.context_len, 1), dtype=torch.long, device=self._device) + else: + self.actions[i] = torch.zeros( + (self.max_eval_ep_len + 2 * self.context_len, self.act_dim), dtype=torch.float32, device=self._device + ) + if self._atari_env: + self.states[i] = torch.zeros( + (self.max_eval_ep_len + 2 * self.context_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device + ) + self.running_rtg[i] = self.rtg_target + else: + self.states[i] = torch.zeros( + (self.max_eval_ep_len + 2 * self.context_len, self.state_dim), dtype=torch.float32, device=self._device + ) + self.running_rtg[i] = self.rtg_target / self.rtg_scale + self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len + 2 * self.context_len, step=1).to(self._device) + self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len + 2 * self.context_len, 1), dtype=torch.float32, device=self._device) + self.rewards[i] = torch.zeros((self.max_eval_ep_len + 2 * self.context_len, 1), dtype=torch.float32, device=self._device) + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ + return ['cur_lr', 'action_loss'] + + def _init_collect(self) -> None: + pass + + def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: + pass + + def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + pass + + def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: + pass diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index f29dd3335a..cb0bc04046 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -843,7 +843,291 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso traj_mask = torch.ones(self.context_len, dtype=torch.long) return timesteps, states, actions, rtgs, traj_mask +@DATASET_REGISTRY.register('edt_d4rl_trajectory') +class EDTTrajectoryDataset(D4RLTrajectoryDataset): + """ + Overview: + D4RL trajectory dataset for EDT, which is used for offline RL algorithms. + Interfaces: + ``__init__``, ``__len__``, ``__getitem__`` + """ + def __init__(self, cfg: dict) -> None: + """ + Overview: + Initialization method. + Arguments: + - cfg (:obj:`dict`): Config dict. + """ + dataset_path = cfg.dataset.data_dir_prefix + rtg_scale = cfg.dataset.rtg_scale + self.context_len = cfg.dataset.context_len + self.env_type = cfg.dataset.env_type + + if 'hdf5' in dataset_path: # for mujoco env + try: + import h5py + import collections + except ImportError: + import sys + logging.warning("not found h5py package, please install it trough `pip install h5py ") + sys.exit(1) + dataset = h5py.File(dataset_path, 'r') + + N = dataset['rewards'].shape[0] + data_ = collections.defaultdict(list) + + use_timeouts = False + if 'timeouts' in dataset: + use_timeouts = True + + episode_step = 0 + paths = [] + for i in range(N): + done_bool = bool(dataset['terminals'][i]) + if use_timeouts: + final_timestep = dataset['timeouts'][i] + else: + final_timestep = (episode_step == 1000 - 1) + for k in ['observations', 'actions', 'rewards', 'terminals']: + data_[k].append(dataset[k][i]) + if done_bool or final_timestep: + episode_step = 0 + episode_data = {} + for k in data_: + episode_data[k] = np.array(data_[k]) + paths.append(episode_data) + data_ = collections.defaultdict(list) + episode_step += 1 + + self.trajectories = paths + + + # calculate state mean and variance and returns_to_go for all traj + states = [] + for traj in self.trajectories: + traj_len = traj['observations'].shape[0] + states.append(traj['observations']) + # calculate returns to go and rescale them + traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale + # used for input normalization + states = np.concatenate(states, axis=0) + self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + # normalize states + for traj in self.trajectories: + traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std + traj['next_observations'] = (traj['next_observations'] - self.state_mean) / self.state_std + + elif 'pkl' in dataset_path: + if 'dqn' in dataset_path: + # load dataset + with open(dataset_path, 'rb') as f: + self.trajectories = pickle.load(f) + + if isinstance(self.trajectories[0], list): + # for our collected dataset, e.g. cartpole/lunarlander case + trajectories_tmp = [] + + original_keys = ['obs', 'next_obs', 'action', 'reward'] + keys = ['observations', 'next_observations', 'actions', 'rewards'] + trajectories_tmp = [ + { + key: np.stack( + [ + self.trajectories[eps_index][transition_index][o_key] + for transition_index in range(len(self.trajectories[eps_index])) + ], + axis=0 + ) + for key, o_key in zip(keys, original_keys) + } for eps_index in range(len(self.trajectories)) + ] + self.trajectories = trajectories_tmp + + states = [] + for traj in self.trajectories: + # traj_len = traj['observations'].shape[0] + states.append(traj['observations']) + # calculate returns to go and rescale them + traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale + + # used for input normalization + states = np.concatenate(states, axis=0) + self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + # normalize states + for traj in self.trajectories: + traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std + traj['next_observations'] = (traj['next_observations'] - self.state_mean) / self.state_std + else: + # load dataset + with open(dataset_path, 'rb') as f: + self.trajectories = pickle.load(f) + + states = [] + for traj in self.trajectories: + states.append(traj['observations']) + # calculate returns to go and rescale them + traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale + + # used for input normalization + states = np.concatenate(states, axis=0) + self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + # normalize states + for traj in self.trajectories: + traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std + traj['next_observations'] = (traj['next_observations'] - self.state_mean) / self.state_std + else: + # -- load data from memory (make more efficient) + obss = [] + actions = [] + returns = [0] + done_idxs = [] + stepwise_returns = [] + + transitions_per_buffer = np.zeros(50, dtype=int) + num_trajectories = 0 + while len(obss) < cfg.dataset.num_steps: + buffer_num = np.random.choice(np.arange(50 - cfg.dataset.num_buffers, 50), 1)[0] + i = transitions_per_buffer[buffer_num] + frb = FixedReplayBuffer( + data_dir=cfg.dataset.data_dir_prefix + '/1/replay_logs', + replay_suffix=buffer_num, + observation_shape=(84, 84), + stack_size=4, + update_horizon=1, + gamma=0.99, + observation_dtype=np.uint8, + batch_size=32, + replay_capacity=100000 + ) + if frb._loaded_buffers: + done = False + curr_num_transitions = len(obss) + trajectories_to_load = cfg.dataset.trajectories_per_buffer + while not done: + states, ac, ret, next_states, next_action, next_reward, terminal, indices = \ + frb.sample_transition_batch(batch_size=1, indices=[i]) + states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) + obss.append(states) + actions.append(ac[0]) + stepwise_returns.append(ret[0]) + if terminal[0]: + done_idxs.append(len(obss)) + returns.append(0) + if trajectories_to_load == 0: + done = True + else: + trajectories_to_load -= 1 + returns[-1] += ret[0] + i += 1 + if i >= 100000: + obss = obss[:curr_num_transitions] + actions = actions[:curr_num_transitions] + stepwise_returns = stepwise_returns[:curr_num_transitions] + returns[-1] = 0 + i = transitions_per_buffer[buffer_num] + done = True + num_trajectories += (cfg.dataset.trajectories_per_buffer - trajectories_to_load) + transitions_per_buffer[buffer_num] = i + + actions = np.array(actions) + returns = np.array(returns) + stepwise_returns = np.array(stepwise_returns) + done_idxs = np.array(done_idxs) + + # -- create reward-to-go dataset + start_index = 0 + rtg = np.zeros_like(stepwise_returns) + for i in done_idxs: + i = int(i) + curr_traj_returns = stepwise_returns[start_index:i] + for j in range(i - 1, start_index - 1, -1): # start from i-1 + rtg_j = curr_traj_returns[j - start_index:i - start_index] + rtg[j] = sum(rtg_j) + start_index = i + + # -- create timestep dataset + start_index = 0 + timesteps = np.zeros(len(actions) + 1, dtype=int) + for i in done_idxs: + i = int(i) + timesteps[start_index:i + 1] = np.arange(i + 1 - start_index) + start_index = i + 1 + + self.obss = obss + self.actions = actions + self.done_idxs = done_idxs + self.rtgs = rtg + self.timesteps = timesteps + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Overview: + Get the item of the dataset. + Arguments: + - idx (:obj:`int`): The index of the dataset. + """ + if self.env_type != 'atari': + traj = self.trajectories[idx] + traj_len = traj['observations'].shape[0] + + if traj_len > self.context_len: + si = np.random.randint(0, traj_len - self.context_len) + states = torch.from_numpy(traj['observations'][si:si + self.context_len]) + next_states = torch.from_numpy(traj["next_observations"][si:si + self.context_len]) + actions = torch.from_numpy(traj['actions'][si:si + self.context_len]) + returns_to_go = torch.from_numpy(traj['returns_to_go'][si:si + self.context_len]) + rewards = torch.from_numpy(traj["rewards"][si : si + self.context_len]) + timesteps = torch.arange(start=si, end=si + self.context_len, step=1) + + # all ones since no padding + traj_mask = torch.ones(self.context_len, dtype=torch.long) + else: + padding_len = self.context_len - traj_len + + # padding with zeros + states = torch.from_numpy(traj['observations']) + states = torch.cat( + [states, torch.zeros(([padding_len] + list(states.shape[1:])), dtype=states.dtype)], dim=0 + ) + + next_states = torch.from_numpy(traj['next_observations']) + next_states = torch.cat( + [next_states, torch.zeros(([padding_len] + list(next_states.shape[1:])), dtype=states.dtype)], dim=0 + ) + + actions = torch.from_numpy(traj['actions']) + actions = torch.cat( + [actions, torch.zeros(([padding_len] + list(actions.shape[1:])), dtype=actions.dtype)], dim=0 + ) + + returns_to_go = torch.from_numpy(traj['returns_to_go']) + returns_to_go = torch.cat( + [ + returns_to_go, + torch.zeros(([padding_len] + list(returns_to_go.shape[1:])), dtype=returns_to_go.dtype) + ], + dim=0 + ) + + rewards = torch.from_numpy(traj["rewards"]) + rewards = torch.cat( + [ + rewards, + torch.zeros(([padding_len] + list(rewards.shape[1:])), dtype=rewards.dtype,), + ], + dim=0 + ) + timesteps = torch.arange(start=0, end=self.context_len, step=1) + + traj_mask = torch.cat( + [torch.ones(traj_len, dtype=torch.long), + torch.zeros(padding_len, dtype=torch.long)], dim=0 + ) + return timesteps, states, next_states, actions, returns_to_go, rewards, traj_mask @DATASET_REGISTRY.register('d4rl_diffuser') class D4RLDiffuserDataset(Dataset): """ diff --git a/dizoo/d4rl/config/halfcheetah_medium_edt_config.py b/dizoo/d4rl/config/halfcheetah_medium_edt_config.py new file mode 100644 index 0000000000..526d6922ce --- /dev/null +++ b/dizoo/d4rl/config/halfcheetah_medium_edt_config.py @@ -0,0 +1,92 @@ +from easydict import EasyDict +from copy import deepcopy + +halfcheetah_edt_config = dict( + exp_name='edt_log/d4rl/halfcheetah/halfcheetah_medium_edt_seed0', + env=dict( + env_id='HalfCheetah-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/halfcheetah-medium-v2.pkl', + ), + policy=dict( + env_id='HalfCheetah-v3', + cuda=True, + real_rtg=False, + stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='HalfCheetah-v3', + rtg_target=6000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + weights=dict( + top_percentile=0.15, + expectile=0.99, + expert_weight=10, + exp_loss_weight=0.5, + state_loss_weight=1.0, + cross_entropy_weight=0.001, + rs_ratio=1, # between 1 and 2 + + ), + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict( + evaluator=dict(eval_freq=1000, ), + rs_steps=2, + heuristic=False, + heuristic_delta=2), + ), +) + +halfcheetah_edt_config = EasyDict(halfcheetah_edt_config) +main_config = halfcheetah_edt_config +halfcheetah_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +halfcheetah_edt_create_config = EasyDict(halfcheetah_edt_create_config) +create_config = halfcheetah_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_edt_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_edt_config.py new file mode 100644 index 0000000000..80904d1d7d --- /dev/null +++ b/dizoo/d4rl/config/halfcheetah_medium_expert_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +halfcheetah_dt_config = dict( + exp_name='edt_log/d4rl/halfcheetah/halfcheetah_medium_expert_edt_seed0', + env=dict( + env_id='HalfCheetah-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/halfcheetah-medium-expert-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='HalfCheetah-v3', + rtg_target=6000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +halfcheetah_dt_config = EasyDict(halfcheetah_dt_config) +main_config = halfcheetah_dt_config +halfcheetah_dt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config) +create_config = halfcheetah_dt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_dt + config = deepcopy([main_config, create_config]) + serial_pipeline_dt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/d4rl/config/halfcheetah_medium_replay_edt_config.py b/dizoo/d4rl/config/halfcheetah_medium_replay_edt_config.py new file mode 100644 index 0000000000..eab048f855 --- /dev/null +++ b/dizoo/d4rl/config/halfcheetah_medium_replay_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +halfcheetah_dt_config = dict( + exp_name='edt_log/d4rl/halfcheetah/halfcheetah_medium_replay_edt_seed0', + env=dict( + env_id='HalfCheetah-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/halfcheetah-medium-replay-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='HalfCheetah-v3', + rtg_target=6000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +halfcheetah_dt_config = EasyDict(halfcheetah_dt_config) +main_config = halfcheetah_dt_config +halfcheetah_dt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config) +create_config = halfcheetah_dt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_dt + config = deepcopy([main_config, create_config]) + serial_pipeline_dt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/d4rl/config/hopper_medium_edt_config.py b/dizoo/d4rl/config/hopper_medium_edt_config.py new file mode 100644 index 0000000000..c00ebe44c1 --- /dev/null +++ b/dizoo/d4rl/config/hopper_medium_edt_config.py @@ -0,0 +1,92 @@ +from easydict import EasyDict +from copy import deepcopy + +hopper_dt_config = dict( + exp_name='edt_log/d4rl/hopper/hopper_medium_edt_seed0', + env=dict( + env_id='Hopper-v3', + collector_env_num=1, + evaluator_env_num=2, + use_act_scale=True, + n_evaluator_episode=2, + stop_value=3600, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/hopper-medium-v2.pkl', #! This points out the directory of dataset + ), + policy=dict( + env_id='Hopper-v3', + real_rtg=False, + cuda=True, + stop_value=3600, + state_mean=None, + state_std=None, + evaluator_env_num=2, #! the evaluator env num in policy should be equal to env + env_name='Hopper-v3', + rtg_target=3600, # max target return to go + max_eval_ep_len=20, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=11, + act_dim=3, + n_blocks=3, + h_dim=512, + context_len=20, + n_heads=1, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128,), + learning_rate=1e-4, + weights=dict( + top_percentile=0.15, + expectile=0.99, + expert_weight=10, + exp_loss_weight=0.5, + state_loss_weight=1.0, + cross_entropy_weight=0.001, + rs_ratio=1, # between 1 and 2 + + ), + collect=dict( + data_type='edt_d4rl_trajectory', + unroll_len=1, + ), + eval=dict( + evaluator=dict(eval_freq=1000, ), + rs_steps=2, + heuristic=False, + heuristic_delta=2), + ), +) + +hopper_dt_config = EasyDict(hopper_dt_config) +main_config = hopper_dt_config +hopper_dt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +hopper_dt_create_config = EasyDict(hopper_dt_create_config) +create_config = hopper_dt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) \ No newline at end of file diff --git a/dizoo/d4rl/config/hopper_medium_expert_edt_config.py b/dizoo/d4rl/config/hopper_medium_expert_edt_config.py new file mode 100644 index 0000000000..3410cbab7d --- /dev/null +++ b/dizoo/d4rl/config/hopper_medium_expert_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +hopper_edt_config = dict( + exp_name='edt_log/d4rl/hopper/hopper_medium_expert_edt_seed0', + env=dict( + env_id='Hopper-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=3600, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/hopper-medium-expert-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=3600, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Hopper-v3', + rtg_target=3600, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=11, + act_dim=3, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128,), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +hopper_edt_config = EasyDict(hopper_edt_config) +main_config = hopper_edt_config +hopper_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +hopper_edt_create_config = EasyDict(hopper_edt_create_config) +create_config = hopper_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) \ No newline at end of file diff --git a/dizoo/d4rl/config/hopper_medium_replay_edt_config.py b/dizoo/d4rl/config/hopper_medium_replay_edt_config.py new file mode 100644 index 0000000000..b316fb5694 --- /dev/null +++ b/dizoo/d4rl/config/hopper_medium_replay_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +hopper_edt_config = dict( + exp_name='edt_log/d4rl/hopper/hopper_medium_replay_edt_seed0', + env=dict( + env_id='Hopper-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=3600, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/hopper-medium-replay-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=3600, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Hopper-v3', + rtg_target=3600, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=11, + act_dim=3, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128,), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +hopper_edt_config = EasyDict(hopper_edt_config) +main_config = hopper_edt_config +hopper_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +hopper_edt_create_config = EasyDict(hopper_edt_create_config) +create_config = hopper_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) \ No newline at end of file diff --git a/dizoo/d4rl/config/walker2d_medium_edt_config.py b/dizoo/d4rl/config/walker2d_medium_edt_config.py new file mode 100644 index 0000000000..3d2f31123b --- /dev/null +++ b/dizoo/d4rl/config/walker2d_medium_edt_config.py @@ -0,0 +1,92 @@ +from easydict import EasyDict +from copy import deepcopy + +walk2d_edt_config = dict( + exp_name='edt_log/d4rl/walker2d/walker2d_medium_edt_seed0', + env=dict( + env_id='Walker2d-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=5000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/walker2d-medium-v2.pkl', + ), + policy=dict( + env_id='Hopper-v3', + real_rtg=False, + cuda=True, + stop_value=5000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Walker2d-v3', + rtg_target=5000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + weights=dict( + top_percentile=0.15, + expectile=0.99, + expert_weight=10, + exp_loss_weight=0.5, + state_loss_weight=1.0, + cross_entropy_weight=0.001, + rs_ratio=1, # between 1 and 2 + + ), + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict( + evaluator=dict(eval_freq=1000, ), + rs_steps=2, + heuristic=False, + heuristic_delta=2), + ), +) + +walk2d_edt_config = EasyDict(walk2d_edt_config) +main_config = walk2d_edt_config +walk2d_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +walk2d_edt_create_config = EasyDict(walk2d_edt_create_config) +create_config = walk2d_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) \ No newline at end of file diff --git a/dizoo/d4rl/config/walker2d_medium_expert_edt_config.py b/dizoo/d4rl/config/walker2d_medium_expert_edt_config.py new file mode 100644 index 0000000000..b00f59a54d --- /dev/null +++ b/dizoo/d4rl/config/walker2d_medium_expert_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +walk2d_edt_config = dict( + exp_name='edt_log/d4rl/walker2d/walker2d_medium_expert_edt_seed0', + env=dict( + env_id='Walker2d-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=5000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/walker2d-medium-expert-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=5000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Walker2d-v3', + rtg_target=5000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +walk2d_edt_config = EasyDict(walk2d_edt_config) +main_config = walk2d_edt_config +walk2d_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +walk2d_edt_create_config = EasyDict(walk2d_edt_create_config) +create_config = walk2d_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/d4rl/config/walker2d_medium_replay_edt_config.py b/dizoo/d4rl/config/walker2d_medium_replay_edt_config.py new file mode 100644 index 0000000000..95afd4b5b2 --- /dev/null +++ b/dizoo/d4rl/config/walker2d_medium_replay_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +walk2d_edt_config = dict( + exp_name='edt_log/d4rl/walker2d/walker2d_medium_replay_edt_seed0', + env=dict( + env_id='Walker2d-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=5000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/walker2d-medium-replay-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=5000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Walker2d-v3', + rtg_target=5000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +walk2d_edt_config = EasyDict(walk2d_edt_config) +main_config = walk2d_edt_config +walk2d_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +walk2d_edt_create_config = EasyDict(walk2d_edt_create_config) +create_config = walk2d_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/d4rl/entry/d4rl_edt_mujoco.py b/dizoo/d4rl/entry/d4rl_edt_mujoco.py new file mode 100644 index 0000000000..d56a9dd9ae --- /dev/null +++ b/dizoo/d4rl/entry/d4rl_edt_mujoco.py @@ -0,0 +1,48 @@ +import gym +import torch +import numpy as np +from ditk import logging +from ding.model.template.elastic_decision_transformer import ElasticDecisionTransformer +from ding.policy.edt import EDTPolicy +from ding.envs import BaseEnvManagerV2 +from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper +from ding.data import create_dataset +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OfflineRLContext +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher_from_mem, offline_logger, termination_checker +from ding.utils import set_pkg_seed +from dizoo.d4rl.envs import D4RLEnv +from dizoo.d4rl.config.hopper_medium_edt_config import main_config, create_config + + +def main(): + # If you don't have offline data, you need to prepare if first and set the data_path in config + # For demostration, we also can train a RL policy (e.g. SAC) and collect some data + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OfflineRLContext()): + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: AllinObsWrapper(D4RLEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], + cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + dataset = create_dataset(cfg) + # env_data_stats = dataset.get_d4rl_dataset_stats(cfg.policy.dataset_name) + cfg.policy.state_mean, cfg.policy.state_std = dataset.get_state_stats() + model = ElasticDecisionTransformer(**cfg.policy.model) + policy = EDTPolicy(cfg.policy, model=model) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(offline_data_fetcher_from_mem(cfg, dataset)) + task.use(trainer(cfg, policy.learn_mode)) + task.use(termination_checker(max_train_iter=5e4)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000)) + task.use(offline_logger()) + task.run() + + +if __name__ == "__main__": + main()