Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(rjy): add crowd md env new, and multi-head policy #230

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c27ae92
env(rjy): add crowdsim env
Jun 8, 2023
c6acd7d
config(rjy): add mz/ez config for crowdsim
Nov 13, 2023
0d235ab
Merge branch 'main' into rjy-crowd-2
Apr 6, 2024
ad0cd02
env(rjy): add crowdsim env
Apr 7, 2024
14542a1
feature(rjy): add RGCN for represent net
Apr 8, 2024
dc4a774
feature(rjy): add obs/action env mode. fix rgcn pipeline.
May 1, 2024
c99db40
feature(rjy): add multi-head policy(combine logits)
May 3, 2024
61831f1
feature(rjy): modify new env with transmitted data
May 3, 2024
9599faa
feature(rjy): add rough vis of crowdsim
May 5, 2024
15d9a44
polish(rjy): fix new env info in collecter
May 5, 2024
3c8804d
feature(rjy): add sez mlp_multi-head
May 6, 2024
c6723a0
feature(rjy): set the environment to two modes
May 6, 2024
e100fe4
Merge branch 'rjy-crowd-md-com-sez' into rjy-crowd-md-env-new
May 6, 2024
c4e9d58
feature(rjy): add ez multi-head model
May 7, 2024
f677af1
Merge branch 'rjy-crowd-md-com-ez' into rjy-crowd-md-env-new
May 7, 2024
cb044af
polish(rjy): add v_trans in config
May 7, 2024
715b5b8
fix(rjy): fix env bug
Jun 11, 2024
fecf5d3
feature(rjy): add entropy info/set margin
Jun 14, 2024
c4da015
Merge pull request #1 from nighood/rjy-crowd-md-env-scale
nighood Jun 14, 2024
63d37a8
polish(rjy): polish code according to comments
Jun 20, 2024
745f0a8
Merge branch 'rjy-crowd-md-env-new' of github.com:nighood/LightZero i…
Jun 20, 2024
f41a41b
polish(pu): reformat zoo/crowd_sim/
puyuan1996 Feb 14, 2025
0a7af42
Merge tag 'main' of https://github.com/opendilab/LightZero into rjy-c…
puyuan1996 Feb 14, 2025
62c38a2
Merge branch 'main' into rjy-crowd-md-env-new
puyuan1996 Feb 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions lzero/agent/efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def __init__(
elif self.cfg.policy.model.model_type == 'conv':
from lzero.model.efficientzero_model import EfficientZeroModel
model = EfficientZeroModel(**self.cfg.policy.model)
elif self.cfg.policy.model.model_type == 'mlp_md':
from lzero.model.efficientzero_model_md import EfficientZeroModelMD
model = EfficientZeroModelMD(**self.cfg.policy.model)
else:
raise NotImplementedError
if self.cfg.policy.cuda and torch.cuda.is_available():
Expand All @@ -124,8 +127,8 @@ def __init__(
self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env)

def train(
self,
step: int = int(1e7),
self,
step: int = int(1e7),
) -> TrainingReturn:
"""
Overview:
Expand Down Expand Up @@ -356,8 +359,8 @@ def deploy(
return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list))

def batch_evaluate(
self,
n_evaluator_episode: int = None,
self,
n_evaluator_episode: int = None,
) -> EvalReturn:
"""
Overview:
Expand Down
14 changes: 10 additions & 4 deletions lzero/agent/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ def __init__(
elif self.cfg.policy.model.model_type == 'conv':
from lzero.model.muzero_model import MuZeroModel
model = MuZeroModel(**self.cfg.policy.model)
elif self.cfg.policy.model.model_type == 'rgcn':
from lzero.model.muzero_model_gcn import MuZeroModelGCN
model = MuZeroModelGCN(**self.cfg.policy.model)
elif self.cfg.policy.model.model_type == 'mlp_md':
from lzero.model.muzero_model_md import MuZeroModelMD
model = MuZeroModelMD(**self.cfg.policy.model)
else:
raise NotImplementedError
if self.cfg.policy.cuda and torch.cuda.is_available():
Expand All @@ -124,8 +130,8 @@ def __init__(
self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env)

def train(
self,
step: int = int(1e7),
self,
step: int = int(1e7),
) -> TrainingReturn:
"""
Overview:
Expand Down Expand Up @@ -356,8 +362,8 @@ def deploy(
return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list))

def batch_evaluate(
self,
n_evaluator_episode: int = None,
self,
n_evaluator_episode: int = None,
) -> EvalReturn:
"""
Overview:
Expand Down
18 changes: 13 additions & 5 deletions lzero/agent/sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def __init__(
cfg.main_config.exp_name = exp_name
self.origin_cfg = cfg
self.cfg = compile_config(
cfg.main_config, seed=seed, env=None, auto=True, policy=SampledEfficientZeroPolicy, create_cfg=cfg.create_config
cfg.main_config,
seed=seed,
env=None,
auto=True,
policy=SampledEfficientZeroPolicy,
create_cfg=cfg.create_config
)
self.exp_name = self.cfg.exp_name

Expand All @@ -110,6 +115,9 @@ def __init__(
elif self.cfg.policy.model.model_type == 'conv':
from lzero.model.sampled_efficientzero_model import SampledEfficientZeroModel
model = SampledEfficientZeroModel(**self.cfg.policy.model)
elif self.cfg.policy.model.model_type == 'mlp_md':
from lzero.model.sampled_efficientzero_model_md import SampledEfficientZeroModelMD
model = SampledEfficientZeroModelMD(**self.cfg.policy.model)
else:
raise NotImplementedError
if self.cfg.policy.cuda and torch.cuda.is_available():
Expand All @@ -124,8 +132,8 @@ def __init__(
self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env)

def train(
self,
step: int = int(1e7),
self,
step: int = int(1e7),
) -> TrainingReturn:
"""
Overview:
Expand Down Expand Up @@ -356,8 +364,8 @@ def deploy(
return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list))

def batch_evaluate(
self,
n_evaluator_episode: int = None,
self,
n_evaluator_episode: int = None,
) -> EvalReturn:
"""
Overview:
Expand Down
28 changes: 19 additions & 9 deletions lzero/mcts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from graphviz import Digraph


def generate_random_actions_discrete(num_actions: int, action_space_size: int, num_of_sampled_actions: int,
reshape=False):
def generate_random_actions_discrete(
num_actions: int, action_space_size: int, num_of_sampled_actions: int, reshape=False
):
"""
Overview:
Generate a list of random actions.
Expand All @@ -19,10 +20,7 @@ def generate_random_actions_discrete(num_actions: int, action_space_size: int, n
Returns:
A list of random actions.
"""
actions = [
np.random.randint(0, action_space_size, num_of_sampled_actions).reshape(-1)
for _ in range(num_actions)
]
actions = [np.random.randint(0, action_space_size, num_of_sampled_actions).reshape(-1) for _ in range(num_actions)]

# If num_of_sampled_actions == 1, flatten the actions to a list of numbers
if num_of_sampled_actions == 1:
Expand Down Expand Up @@ -97,7 +95,7 @@ def prepare_observation(observation_list, model_type='conv'):
Returns:
- np.ndarray: Reshaped array of observations.
"""
assert model_type in ['conv', 'mlp', 'conv_context', 'mlp_context'], "model_type must be either 'conv' or 'mlp'"
assert model_type in ['conv', 'mlp', 'conv_context', 'mlp_context', 'rgcn', 'mlp_md'], "model_type must be either 'conv' or 'mlp'"
observation_array = np.array(observation_list)
batch_size = observation_array.shape[0]

Expand All @@ -109,13 +107,25 @@ def prepare_observation(observation_list, model_type='conv'):
# Reshape to [B, S*C, W, H]
_, stack_num, channels, width, height = observation_array.shape
observation_array = observation_array.reshape(batch_size, stack_num * channels, width, height)

elif model_type in ['mlp', 'mlp_context']:
elif model_type in ['mlp', 'mlp_md', 'mlp_context']:
if observation_array.ndim == 3:
# Flatten the last two dimensions
observation_array = observation_array.reshape(batch_size, -1)
else:
raise ValueError("For 'mlp' model_type, the observation must have 3 dimensions [B, S, O]")
elif model_type == 'rgcn':
if observation_array.ndim == 4:
# TODO(rjy): strage process
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strage process是什么意思?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

解释一下'rgcn'下面各种情况的含义吧

# observation_array should be reshaped to [B, S*M, O], where M is the agent number
# now observation_array.shape = [B, S, M, O]
observation_array = observation_array.reshape(batch_size, -1, observation_array.shape[-1])
elif observation_array.ndim == 3:
# Flatten the last two dimensions
observation_array = observation_array.reshape(batch_size, -1)
else:
raise ValueError(
"For 'rgcn' model_type, the observation must have 3 dimensions [B, S, O] or 4 dimensions [B, S, M, O]"
)

return observation_array

Expand Down
8 changes: 3 additions & 5 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"""
import math
from dataclasses import dataclass

import itertools
from typing import Callable, List, Optional
from typing import Tuple

Expand Down Expand Up @@ -250,6 +252,7 @@ def remove_hooks(self):

class DownSample(nn.Module):


def __init__(self, observation_shape: SequenceType, out_channels: int,
activation: nn.Module = nn.ReLU(inplace=True),
norm_type: Optional[str] = 'BN',
Expand Down Expand Up @@ -903,11 +906,6 @@ def __init__(
self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1)
self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1)

if observation_shape[1] == 96:
latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16)
elif observation_shape[1] == 64:
latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8)

if norm_type == 'BN':
self.norm_value = nn.BatchNorm2d(value_head_channels)
self.norm_policy = nn.BatchNorm2d(policy_head_channels)
Expand Down
Loading