diff --git a/README.md b/README.md index 2bd04b8e..55b33c1b 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,48 @@ -![figure](https://neuralmmo.github.io/_static/banner.png) +# The Winning Solution for the NeurIPS 2023 Neural MMO Challenge -# ![icon](https://neuralmmo.github.io/_build/html/_images/icon.png) Welcome to the Platform! +This solution is based on the [Neural MMO Baselines](https://github.com/NeuralMMO/baselines/tree/2.0?tab=readme-ov-file). For more information about the challenge, please refer to the [challenge homepage](https://www.aicrowd.com/challenges/neurips-2023-the-neural-mmo-challenge). -[![](https://dcbadge.vercel.app/api/server/BkMmFUC?style=plastic)](https://discord.gg/BkMmFUC) -[![Twitter](https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow%20%40jsuarez5341)](https://twitter.com/jsuarez5341) +## How to Use -Baselines repository for Neural MMO. New users should treat this as a starter project. The project is under active development with [Documentation](https://neuralmmo.github.io "Neural MMO Documentation") hosted by github.io. +### Environment Installation +Use `docker/build.sh` to build image for training. +```shell +cd docker +bash build.sh +``` -![figure](https://neuralmmo.github.io/_build/html/_images/poster.png) +### Run training +Run inside the training container: +```shell +export WANDB_API_KEY=xxx # Change it to yours +WANDB_PROJECT=xxx # Change it to yours +WANDB_ENTITY=xxx # Change it to yours + +export WANDB_DISABLE_GIT=true +export WANDB_DISABLE_CODE=true + +export OMP_NUM_THREADS=4 + +python train.py \ + --runs-dir runs \ + --use-ray-vecenv true \ + --wandb-project $WANDB_PROJECT \ + --wandb-entity $WANDB_ENTITY \ + --model ReducedModelV2 \ + --meander-bonus-weight 0.0 \ + --heal-bonus-weight 0.0 \ + --num-npcs 128 \ + --early-stop-agent-num 0 \ + --resilient-population 0.0 \ + --ppo-update-epochs 1 \ + --train-num-steps 40000000 \ + --num-maps 1280 \ +``` + +### Evaluation + +After training, copy the checkpoints into `policies` and run: +```shell +python evaluate.py -p policies +``` +`policies/submission.pkl` is the trained model we submitted. diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 00000000..d464cbde --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,10 @@ +ARG DOCKER_BASE_IMAGE +FROM ${DOCKER_BASE_IMAGE} + +COPY requirements.txt ./ + +RUN sudo rm /etc/apt/sources.list.d/* && sudo apt-get update && sudo apt-get install -y vim + +RUN $HOME/anaconda3/bin/pip --no-cache-dir install -r requirements.txt + +RUN sudo rm requirements.txt diff --git a/docker/build.sh b/docker/build.sh new file mode 100644 index 00000000..5b84910f --- /dev/null +++ b/docker/build.sh @@ -0,0 +1,5 @@ +DOCKER_BASE_IMAGE="${DOCKER_BASE_IMAGE:-rayproject/ray:1.6.0-py39-gpu}" + +docker build --no-cache \ +--build-arg DOCKER_BASE_IMAGE=${DOCKER_BASE_IMAGE} \ +-t ${DOCKER_BASE_IMAGE}-nmmo2023 . diff --git a/requirements.txt b/docker/requirements.txt similarity index 82% rename from requirements.txt rename to docker/requirements.txt index f38284b8..afb08d50 100644 --- a/requirements.txt +++ b/docker/requirements.txt @@ -9,6 +9,10 @@ scikit-learn==1.3.0 tensorboard==2.11.2 tiktoken==0.4.0 torch==1.13.1 +torchtyping==0.1.4 traitlets==5.9.0 transformers==4.31.0 wandb==0.13.7 + +nmmo==2.0.3 +pufferlib==0.4.5 diff --git a/environment.py b/environment.py index cff247c1..ab87f31a 100644 --- a/environment.py +++ b/environment.py @@ -1,12 +1,55 @@ +from typing import Dict, List, Optional +from nmmo.task.task_spec import TaskSpec +import numpy as np +import dill +import json +from types import SimpleNamespace + from argparse import Namespace import math +import copy import nmmo +from nmmo.lib.log import EventCode +from nmmo.core.observation import Observation +from nmmo.systems.skill import Skills +from nmmo.entity.entity import Entity + import pufferlib import pufferlib.emulation from leader_board import StatPostprocessor, calculate_entropy +_DEBUG_TASK_REWARD = False +_DEBUG_TASK_SETTING = {} # {"harvest": {"Fishing": 0.01}} + +_EVENTS = [ + "EAT_FOOD", + "DRINK_WATER", + "GO_FARTHEST", + "SCORE_HIT", + "PLAYER_KILL", + "CONSUME_ITEM", + "GIVE_ITEM", + "DESTROY_ITEM", + "HARVEST_ITEM", + "EQUIP_ITEM", + "LOOT_ITEM", + "GIVE_GOLD", + "LIST_ITEM", + "EARN_GOLD", + "BUY_ITEM", + "LEVEL_UP", +] +EVENTCODE_TO_EVENT = {getattr(EventCode, _): _ for _ in _EVENTS} +_COLS = [ + "type", + "level", + "number", + "gold", + "target_ent", +] + class Config(nmmo.config.Default): """Configuration for Neural MMO.""" @@ -45,7 +88,13 @@ def __init__( heal_bonus_weight=0, meander_bonus_weight=0, explore_bonus_weight=0, + task_learning_bonus_weight=0, + alive_bonus_weight=0, clip_unique_event=3, + adjust_ori_reward=False, + train_tasks_info=None, + task_reward_settings=None, + debug_print_events=False, ): super().__init__(env, agent_id, eval_mode) self.early_stop_agent_num = early_stop_agent_num @@ -55,10 +104,53 @@ def __init__( self.explore_bonus_weight = explore_bonus_weight self.clip_unique_event = clip_unique_event + self.adjust_ori_reward = adjust_ori_reward + + self.debug_print_events = debug_print_events + + self.alive_bonus_weight = alive_bonus_weight + + # Customized task reward + self.train_tasks_info = train_tasks_info + self._task_index: Optional[int] = ( + None # The index of the current task in `train_tasks_info` + ) + self.task_learning_bonus_weight = task_learning_bonus_weight + self.task_reward_settings = task_reward_settings + self.task_reward_setting: Optional[Dict] = None + + self.prev_done = False + + def _reset_task_reward_state(self) -> None: + self._seen_tiles = { + "co": set(), # collection of coordinates + "last_update_tick": 0, + } + + self._been_tiles = { # tiles visited + "co": set(), + "last_update_tick": 0, + } + + self._last_damage_inflicted = 0 + + self._last_harvest_skill_exp = 0 + + self._history_own = {} # The highest ownership record in history + def reset(self, obs): """Called at the start of each episode""" super().reset(obs) + self.prev_done = False + + if self.task_learning_bonus_weight: + self._update_task_index(obs["Task"]) + self._reset_task_reward_state() + + setting = self._get_task_reward_setting() + self.task_reward_setting = setting + @property def observation_space(self): """If you modify the shape of features, you need to specify the new obs space""" @@ -80,6 +172,10 @@ def action(self, action): def reward_done_info(self, reward, done, info): """Called on reward, done, and info before they are returned from the environment""" + self.env: nmmo.Env + + if self.adjust_ori_reward: + reward = self._adjust_ori_reward(reward, done, info) # Stop early if there are too few agents generating the training data if len(self.env.agents) <= self.early_stop_agent_num: @@ -92,39 +188,437 @@ def reward_done_info(self, reward, done, info): # Add "Healing" score based on health increase and decrease due to food and water healing_bonus = 0 - if self.agent_id in self.env.realm.players: + if self.heal_bonus_weight and self.agent_id in self.env.realm.players: if self.env.realm.players[self.agent_id].resources.health_restore > 0: healing_bonus = self.heal_bonus_weight # Add meandering bonus to encourage moving to various directions meander_bonus = 0 - if len(self._last_moves) > 5: + if self.meander_bonus_weight and len(self._last_moves) > 5: move_entropy = calculate_entropy(self._last_moves[-8:]) # of last 8 moves meander_bonus = self.meander_bonus_weight * (move_entropy - 1) # Unique event-based rewards, similar to exploration bonus # The number of unique events are available in self._curr_unique_count, self._prev_unique_count - if self.sqrt_achievement_rewards: - explore_bonus = math.sqrt(self._curr_unique_count) - math.sqrt( - self._prev_unique_count - ) - else: - explore_bonus = min( - self.clip_unique_event, - self._curr_unique_count - self._prev_unique_count, - ) - explore_bonus *= self.explore_bonus_weight + explore_bonus = 0 + if self.explore_bonus_weight: + if self.sqrt_achievement_rewards: + explore_bonus = math.sqrt(self._curr_unique_count) - math.sqrt( + self._prev_unique_count + ) + else: + explore_bonus = min( + self.clip_unique_event, + self._curr_unique_count - self._prev_unique_count, + ) + explore_bonus *= self.explore_bonus_weight + + alive_bonus = 0 + if self.alive_bonus_weight and not done: + alive_bonus = self._get_alive_bonus() + alive_bonus *= self.alive_bonus_weight + + task_learning_bonus = 0 + if self.task_learning_bonus_weight and not done: + task_learning_bonus = self._get_task_learning_bonus() + task_learning_bonus *= self.task_learning_bonus_weight + + if self.debug_print_events and done: + self._print_agent_all_events() reward = reward + explore_bonus + healing_bonus + meander_bonus + reward += alive_bonus + reward += task_learning_bonus + + self.prev_done = done return reward, done, info + def _adjust_ori_reward(self, reward, done, info) -> float: + if not reward: + return reward + + task_infos = list(info["task"].values()) + assert len(task_infos) == 1 + task_info = task_infos[0] + + if reward == -1: + assert done + if task_info["completed"]: + return -0.1 + else: + return -10.0 + + if reward == 1: + assert task_info["completed"] + return 10.0 + + return reward + + @property + def _eval_fn_name(self): + return self.train_tasks_info.eval_fn_name[self._task_index] + + @property + def _eval_fn_kwargs(self): + return self.train_tasks_info.eval_fn_kwargs[self._task_index] + + def _update_task_index(self, task_embedding: np.ndarray) -> None: + """Use task embedding to find the task index""" + if self.eval_mode: + self._task_index = None + return + + assert task_embedding.shape == (4096,) + + # diff = self.train_tasks_info.embedding_mat - task_embedding + # diff = np.sum(diff**2, axis=-1) + # (indexes,) = np.where(diff == 0) + + (indexes,) = np.where( + (self.train_tasks_info.embedding_mat == task_embedding).all(axis=1) + ) + + n_matched_task = len(indexes) + assert ( + n_matched_task == 1 + ), f"{n_matched_task} task match emb ({task_embedding})" + + self._task_index = int(indexes[0]) + + assert self._task_index < self.train_tasks_info.n + + # if _DEBUG_TASK_REWARD and self.agent_id <= 20: + # print( + # f"agent_id {self.agent_id}, task index {self._task_index}" + # f", {self.train_tasks_info.eval_fn_name[self._task_index]}" + # f", {self.train_tasks_info.eval_fn_kwargs[self._task_index]}" + # ) + + return + + def _get_task_reward_setting(self) -> Dict: + if _DEBUG_TASK_REWARD and _DEBUG_TASK_SETTING: + return _DEBUG_TASK_SETTING + + _eval_fn_name = self._eval_fn_name + _eval_fn_kwargs = self._eval_fn_kwargs + + if _eval_fn_name not in self.task_reward_settings: + # print(f"Reward of eval fn {_eval_fn_name} not set") + return {} + + eval_fn_setting: Dict = self.task_reward_settings[_eval_fn_name] + try: + ret: Dict = eval_fn_setting[_eval_fn_kwargs[eval_fn_setting["_key"]]] + except: + ret: Dict = eval_fn_setting["_default"] + + return ret + + def _get_alive_bonus(self) -> float: + ret = 0 + + cur_tick = self.env.realm.tick + entity: Entity = self.env.realm.players.entities[self.agent_id] + + health_lost = 100 - entity.health.val + if health_lost > 50: + ret += -(health_lost - 50) / 50 * 0.001 + + return ret + + def _get_task_learning_bonus(self) -> float: + ret = 0 + + setting = self.task_reward_setting + + for reward_type, args in setting.items(): + if reward_type == "log": + _reward = self._task_log_bonus(args) + elif reward_type == "log_value": + _reward = self._task_log_bonus(args, use_value=True) + elif reward_type == "wander": + _reward = self._task_wander_bonus(args) + elif reward_type == "wander_occupy": + _reward = self._task_wander_occupy_bonus(args) + elif reward_type == "attack": + _reward = self._task_attack_bonus(args) + elif reward_type == "harvest": + _reward = self._task_harvest_bonus(args) + elif reward_type == "own": + _reward = self._task_own_bonus(args) + else: + raise Exception(f"Invalid reward type {reward_type}") + + ret += _reward + + # if _DEBUG_TASK_REWARD and _reward and self.agent_id <= 20: + # print( + # f"agent_id {self.agent_id}, current_tick {self.env.realm.tick}" + # f", task learning bonus: type {reward_type}, setting {setting}, reward {_reward}" + # f", # players remain {len(self.env.realm.players.entities)}" + # ) + + return ret + + def _task_log_bonus(self, args: Dict, use_value: bool = False) -> float: + """[TASK REWARD] Reward the agent for a specific log""" + ret = 0 + + assert args + + cur_tick = self.env.realm.tick + cur_logs = self.env.realm.event_log.get_data( + agents=[self.agent_id], tick=cur_tick + ) + + attr_to_col = self.env.realm.event_log.attr_to_col + + for line in cur_logs: + event_name = EVENTCODE_TO_EVENT.get(line[attr_to_col["event"]], "") + + if event_name in args: + if use_value: + if event_name == "EARN_GOLD": + value = line[attr_to_col["gold"]] + else: + raise NotImplementedError(event_name) + ret += args[event_name] * value + else: + ret += args[event_name] + + return ret + + def _task_wander_bonus(self, args: Dict) -> float: + """[TASK REWARD]""" + ret = 0 + + per_tile = args["per_tile"] + + obs: Observation = self.env.obs[self.agent_id] + current_tick = obs.current_tick + visible_tiles = obs.tiles + + n_new_seen_tiles = 0 + for tile in visible_tiles: + x, y, t = tile + if (x, y) not in self._seen_tiles["co"]: + n_new_seen_tiles += 1 + self._seen_tiles["co"].add((x, y)) + self._seen_tiles["last_update_tick"] = current_tick + + if current_tick > 1: + ret += n_new_seen_tiles * per_tile + + # if _DEBUG_TASK_REWARD and self.agent_id == 1: + # print( + # f"agent_id {self.agent_id}, current_tick {current_tick}" + # f", n_new_seen_tiles {n_new_seen_tiles}" + # ) + + return ret + + def _task_wander_occupy_bonus(self, args: Dict) -> float: + """[TASK REWARD]""" + ret = 0 + + per_tile = args["per_tile"] + + entity: Entity = self.env.realm.players.entities[self.agent_id] + current_tick = self.env.realm.tick + + if entity.pos not in self._been_tiles["co"]: + self._been_tiles["co"].add(entity.pos) + if current_tick > 1: + ret += per_tile + + self._been_tiles["last_update_tick"] = current_tick + + # if _DEBUG_TASK_REWARD and self.agent_id == 1: + # print( + # f"agent_id {self.agent_id}, current_tick {current_tick}" + # f", self._been_tiles {self._been_tiles}, +reward {ret}" + # ) + + return ret + + def _task_attack_bonus(self, args: Dict) -> float: + """[TASK REWARD]""" + ret = 0 + + entity = self.env.realm.players.entities[self.agent_id] + current_tick = self.env.realm.tick + + if entity.history.damage_inflicted > self._last_damage_inflicted: + assert isinstance(entity.history.attack, dict) + attack_style = entity.history.attack["style"] + if attack_style in args: + ret += args[attack_style] + + self._last_damage_inflicted = entity.history.damage_inflicted + + # if _DEBUG_TASK_REWARD and self.agent_id == 1: + # print( + # f"agent_id {self.agent_id}, current_tick {current_tick}" + # f", entity.history.attack {entity.history.attack}" + # f", entity.history.damage_inflicted {entity.history.damage_inflicted}" + # ) + + return ret + + def _task_harvest_bonus(self, args: Dict) -> float: + """[TASK REWARD]""" + ret = 0 + + entity = self.env.realm.players.entities[self.agent_id] + skills: Skills = entity.skills + current_tick = self.env.realm.tick + + skill_names = list(args.keys()) + assert ( + len(skill_names) == 1 + ), f"harvest reward require 1 skill but get {len(skill_names)}" + skill_name = skill_names[0] + + if skill_name == "Fishing": + skill = skills.fishing + elif skill_name == "Herbalism": + skill = skills.herbalism + elif skill_name == "Prospecting": + skill = skills.prospecting + elif skill_name == "Carving": + skill = skills.carving + elif skill_name == "Alchemy": + skill = skills.alchemy + else: + raise Exception(f"Invalid skill {skill_name}") + + cur_skill_exp = skill.exp.val + exp_diff = cur_skill_exp - self._last_harvest_skill_exp + if exp_diff > 0: + ret += args[skill_name] * exp_diff + self._last_harvest_skill_exp = cur_skill_exp + + # if _DEBUG_TASK_REWARD and self.agent_id <= 20: + # print( + # f"agent_id {self.agent_id}, current_tick {current_tick}" + # f", skill {skill_name}, exp {cur_skill_exp}, exp_diff {exp_diff}" + # ) + + return ret + + def _task_own_bonus(self, args: Dict) -> float: + """[TASK REWARD]""" + ret = 0 + + entity: Entity = self.env.realm.players.entities[self.agent_id] + current_tick = self.env.realm.tick + + packet = entity.inventory.packet() + for item in packet["items"]: + item_type = item["item"] + level = item["level"] + quantity = item["quantity"] + + reward_coef = args.get(item_type, args.get("", 0.0)) + if not reward_coef: + continue + + if item_type not in self._history_own: + self._history_own[item_type] = {} + if level not in self._history_own[item_type]: + self._history_own[item_type][level] = 0 + + quantity_diff = quantity - self._history_own[item_type][level] + + if quantity_diff > 0: + self._history_own[item_type][level] = quantity + ret += quantity_diff * level * reward_coef + + # if _DEBUG_TASK_REWARD and ret: + # print( + # f"agent_id {self.agent_id}, current_tick {current_tick}" + # f", _history_own {self._history_own}, +reward {ret}" + # ) + + return ret + + def _print_agent_all_events(self): + print(f"== agent_id {self.agent_id}'s logs ==") + log = self.env.realm.event_log.get_data(agents=[self.agent_id]) + self._print_events_log(log, self.env.realm.event_log.attr_to_col) + + @staticmethod + def _print_events_log(log, attr_to_col): + for line in log: + event_name = EVENTCODE_TO_EVENT.get(line[attr_to_col["event"]], "") + tick = line[attr_to_col["tick"]] + print( + f"tick {tick}, event {event_name}: " + + ", ".join([f"{_} {line[attr_to_col[_]]}" for _ in _COLS]) + ) + + +def get_tasks_info_for_reward_setting(tasks_path: str) -> SimpleNamespace: + with open(tasks_path, "rb") as f: + curriculums: List[TaskSpec] = dill.load(f) + + print(f"Load {len(curriculums)} train curriculums") + + ret = SimpleNamespace( + embedding_mat=None, # The matrix formed by concatenating all task embeddings + eval_fn_name=[], + eval_fn_kwargs=[], + n=0, + ) + + _mat = [] + + for curriculum in curriculums: + eval_fn_kwargs = { + key: value if isinstance(value, (str, int, float)) else value.__name__ + for key, value in curriculum.eval_fn_kwargs.items() + } + + _mat.append(curriculum.embedding) + ret.eval_fn_name.append(curriculum.eval_fn.__name__) + ret.eval_fn_kwargs.append(eval_fn_kwargs) + ret.n += 1 + + ret.embedding_mat = np.vstack(_mat) + + return ret + + +def load_task_reward_settings(path: str) -> Dict: + print(f"Load task reward setting {path}") + with open(path, "r") as f: + ret = json.load(f) + return ret + def make_env_creator(args: Namespace): # TODO: Max episode length + + use_task_reward = ( + not args.eval_mode + and args.task_reward_setting_path + and args.task_learning_bonus_weight + ) + + train_tasks_info = ( + get_tasks_info_for_reward_setting(args.tasks_path) if use_task_reward else None + ) + task_reward_settings = ( + load_task_reward_settings(args.task_reward_setting_path) + if use_task_reward + else None + ) + def env_creator(): """Create an environment.""" - env = nmmo.Env(Config(args)) + env = nmmo.Env(Config(args), seed=args.seed) env = pufferlib.emulation.PettingZooPufferEnv( env, postprocessor_cls=Postprocessor, @@ -135,6 +629,12 @@ def env_creator(): "heal_bonus_weight": args.heal_bonus_weight, "meander_bonus_weight": args.meander_bonus_weight, "explore_bonus_weight": args.explore_bonus_weight, + "task_learning_bonus_weight": args.task_learning_bonus_weight, + "alive_bonus_weight": args.alive_bonus_weight, + "adjust_ori_reward": args.adjust_ori_reward, + "train_tasks_info": train_tasks_info, + "task_reward_settings": task_reward_settings, + "debug_print_events": args.debug_print_events, }, ) return env diff --git a/evaluate.py b/evaluate.py index 2aed6e85..c0472980 100644 --- a/evaluate.py +++ b/evaluate.py @@ -18,7 +18,7 @@ from nmmo.task.task_spec import make_task_from_spec import pufferlib -from pufferlib.vectorization import Serial, Multiprocessing +from pufferlib.vectorization import Serial, Multiprocessing, Ray from pufferlib.policy_store import DirectoryPolicyStore from pufferlib.frameworks import cleanrl import pufferlib.policy_ranker @@ -203,7 +203,7 @@ def select_policies(self, policies): return [next(loop) for _ in range(self._num)] -def rank_policies(policy_store_dir, eval_curriculum_file, device): +def rank_policies(policy_store_dir, eval_curriculum_file, device, debug=False): # CHECK ME: can be custom models with different architectures loaded here? policy_store = setup_policy_store(policy_store_dir) policy_ranker = create_policy_ranker(policy_store_dir) @@ -213,13 +213,15 @@ def rank_policies(policy_store_dir, eval_curriculum_file, device): args = SimpleNamespace(**config.Config.asdict()) args.data_dir = policy_store_dir args.eval_mode = True - args.num_envs = 5 # sample a bit longer in each env + args.num_envs = 1 if debug else 5 # sample a bit longer in each env args.num_buffers = 1 args.learner_weight = 0 # evaluate mode args.selfplay_num_policies = num_policies + 1 args.early_stop_agent_num = 0 # run the full episode args.resilient_population = 0 # no resilient agents args.tasks_path = eval_curriculum_file # task-conditioning + args.use_ray_vecenv = False if debug else True + args.debug_print_events = debug # NOTE: This creates a dummy learner agent. Is it necessary? from reinforcement_learning import policy # import your policy @@ -241,7 +243,11 @@ def make_policy(envs): env_creator_kwargs={}, agent_creator=make_policy, data_dir=policy_store_dir, - vectorization=Multiprocessing, + vectorization=( + Serial + if args.use_serial_vecenv + else (Ray if args.use_ray_vecenv else Multiprocessing) + ), num_envs=args.num_envs, num_cores=args.num_envs, num_buffers=args.num_buffers, @@ -367,6 +373,11 @@ def make_policy(envs): default=None, help="The index of the task to assign in the curriculum file", ) + parser.add_argument( + "--debug", + action="store_true", + help="Debug mode (Default: False). Print events.", + ) # Parse and check the arguments eval_args = parser.parse_args() @@ -387,4 +398,9 @@ def make_policy(envs): else: logging.info("Ranking checkpoints from %s", eval_args.policy_store_dir) logging.info("Replays will NOT be generated") - rank_policies(eval_args.policy_store_dir, eval_args.task_file, eval_args.device) + rank_policies( + eval_args.policy_store_dir, + eval_args.task_file, + eval_args.device, + eval_args.debug, + ) diff --git a/export_submission_file.py b/export_submission_file.py new file mode 100644 index 00000000..34219d9b --- /dev/null +++ b/export_submission_file.py @@ -0,0 +1,71 @@ +""" +Generate pkl file for submission +""" + +import os +import pickle +import torch + +POLICY_PY_NAME = "policy_reduce_v2.py" +POLICY_CLASS_NAME = "ReducedModelV2" +# .pth file +CHECKPOINT_TO_SUBMIT = "" +OUT_NAME = CHECKPOINT_TO_SUBMIT + ".pkl" + +# replace policy.py with your file +custom_policy_file = "reinforcement_learning/" + POLICY_PY_NAME +assert os.path.exists(custom_policy_file), "CANNOT find the policy file" +print(custom_policy_file) + +# replace checkpoint with +checkpoint_to_submit = CHECKPOINT_TO_SUBMIT +assert os.path.exists(checkpoint_to_submit), "CANNOT find the checkpoint file" +assert checkpoint_to_submit.endswith( + "_state.pth" +), "the checkpoint file must end with _state.pth" +print(checkpoint_to_submit) + + +def create_custom_policy_pt(policy_file, pth_file, out_name="my_submission.pkl"): + assert out_name.endswith(".pkl"), "The file name must end with .pkl" + with open(policy_file, "r") as f: + src_code = f.read() + + # add the make_policy() function + # YOU SHOULD CHECK the name of your policy (if not Baseline), + # and the args that go into the policy + src_code += f""" + +class Config(nmmo.config.Default): + PROVIDE_ACTION_TARGETS = True + PROVIDE_NOOP_ACTION_TARGET = True + MAP_FORCE_GENERATION = False + TASK_EMBED_DIM = 4096 + COMMUNICATION_SYSTEM_ENABLED = False + +def make_policy(): + from pufferlib.frameworks import cleanrl + env = pufferlib.emulation.PettingZooPufferEnv(nmmo.Env(Config())) + # Parameters to your model should match your configuration + learner_policy = {POLICY_CLASS_NAME}( + env, + input_size=256, + hidden_size=256, + task_size=4096 + ) + return cleanrl.Policy(learner_policy) + """ + state_dict = torch.load(pth_file, map_location="cpu") + checkpoint = { + "policy_src": src_code, + "state_dict": state_dict, + } + with open(out_name, "wb") as out_file: + pickle.dump(checkpoint, out_file) + + +create_custom_policy_pt( + custom_policy_file, + checkpoint_to_submit, + out_name=OUT_NAME, +) diff --git a/leader_board.py b/leader_board.py index 85fd4064..feddf464 100644 --- a/leader_board.py +++ b/leader_board.py @@ -11,6 +11,10 @@ from nmmo.core.realm import Realm from nmmo.lib.log import EventCode import nmmo.systems.item as Item +import nmmo +from nmmo.entity.entity import Entity + +DISABLE_SCORE_HIT = False @dataclass @@ -159,7 +163,7 @@ def _reset_episode_stats(self): self._last_moves = [] self._last_price = 0 - def _update_stats(self, agent): + def _update_stats(self, agent: Entity): task = self.env.agent_task_map[agent.ent_id][0] # For each task spec, record whether its max progress and reward count self._curriculum[task.spec_name].append( @@ -218,6 +222,7 @@ def action(self, action): def reward_done_info(self, reward, done, info): """Update stats + info and save replays.""" + self.env: nmmo.Env # Remove the task from info. Curriculum info is processed in _update_stats() info.pop("task", None) @@ -237,7 +242,7 @@ def reward_done_info(self, reward, done, info): if "stats" not in info: info["stats"] = {} - agent = self.env.realm.players.dead_this_tick.get( + agent: Entity = self.env.realm.players.dead_this_tick.get( self.agent_id, self.env.realm.players.get(self.agent_id) ) assert agent is not None @@ -432,6 +437,13 @@ def extract_unique_event(log, attr_to_col): idx, attr_to_col["tick"] ].copy() # this is a hack + if DISABLE_SCORE_HIT: + return set( + tuple(row) + for row in log[:, attr_to_col["event"] :] + if row[0] != EventCode.SCORE_HIT + ) + # return unique events after masking return set(tuple(row) for row in log[:, attr_to_col["event"] :]) diff --git a/policies/submission.pkl b/policies/submission.pkl new file mode 100644 index 00000000..80a138d8 Binary files /dev/null and b/policies/submission.pkl differ diff --git a/reinforcement_learning/clean_pufferl.py b/reinforcement_learning/clean_pufferl.py index 5012fc5f..23767ffc 100644 --- a/reinforcement_learning/clean_pufferl.py +++ b/reinforcement_learning/clean_pufferl.py @@ -26,6 +26,29 @@ import pufferlib.utils import pufferlib.vectorization +TASK_EVAL_FN_NAMES = [ + "CountEvent", + "CanSeeTile", + "AttainSkill", + "PracticeSkillWithTool", + "TickGE", + "OccupyTile", + "CanSeeAgent", + "CanSeeGroup", + "ScoreHit", + "HoardGold", + "EarnGold", + "SpendGold", + "MakeProfit", + "PracticeInventoryManagement", + "OwnItem", + "EquipItem", + "ConsumeItem", + "HarvestItem", + "ListItem", + "BuyItem", +] + def unroll_nested_dict(d): if not isinstance(d, dict): @@ -58,6 +81,7 @@ class CleanPuffeRL: device: str = torch.device("cuda") if torch.cuda.is_available() else "cpu" total_timesteps: int = 10_000_000 learning_rate: float = 2.5e-4 + weight_decay: float = 0.0 num_buffers: int = 1 num_envs: int = 8 num_cores: int = psutil.cpu_count(logical=False) @@ -70,6 +94,8 @@ class CleanPuffeRL: policy_pool: pufferlib.policy_pool.PolicyPool = None policy_selector: pufferlib.policy_ranker.PolicySelector = None + as_fine_tune: bool = False + # Wandb wandb_entity: str = None wandb_project: str = None @@ -94,12 +120,17 @@ def __post_init__(self, *args, **kwargs): f"with policy {resume_state['policy_checkpoint_name']}" ) - self.wandb_run_id = resume_state.get("wandb_run_id", None) - self.learning_rate = resume_state.get("learning_rate", self.learning_rate) - - self.global_step = resume_state.get("global_step", 0) - self.agent_step = resume_state.get("agent_step", 0) - self.update = resume_state.get("update", 0) + if self.as_fine_tune: + self.wandb_run_id = None + self.global_step = 0 + self.agent_step = 0 + self.update = 0 + else: + self.wandb_run_id = resume_state.get("wandb_run_id", None) + self.learning_rate = resume_state.get("learning_rate", self.learning_rate) + self.global_step = resume_state.get("global_step", 0) + self.agent_step = resume_state.get("agent_step", 0) + self.update = resume_state.get("update", 0) self.total_updates = self.total_timesteps // self.batch_size self.envs_per_worker = self.num_envs // self.num_cores @@ -183,9 +214,12 @@ def __post_init__(self, *args, **kwargs): # Setup optimizer self.optimizer = optim.Adam( - self.agent.parameters(), lr=self.learning_rate, eps=1e-5 + self.agent.parameters(), + lr=self.learning_rate, + eps=1e-5, + weight_decay=self.weight_decay, ) - if "optimizer_state_dict" in resume_state: + if not self.as_fine_tune and "optimizer_state_dict" in resume_state: self.optimizer.load_state_dict(resume_state["optimizer_state_dict"]) ### Allocate Storage @@ -287,6 +321,7 @@ def evaluate(self, show_progress=False): step = 0 infos = defaultdict(lambda: defaultdict(list)) stats = defaultdict(lambda: defaultdict(list)) + curriculum = defaultdict(lambda: defaultdict(list)) performance = defaultdict(list) progress_bar = tqdm(total=self.batch_size, disable=not show_progress) @@ -398,7 +433,17 @@ def evaluate(self, show_progress=False): stat = float(stat) stats[policy_name][name].append(stat) except: - continue + if name.startswith("curriculum/"): + # Task completion status grouped by eval_fn + eval_fn_name = name.split("_")[1] + assert isinstance(stat, list) and len(stat) == 1 + _max_progress, reward_signal_count = stat[-1] + curriculum[policy_name][eval_fn_name].append( + _max_progress + ) + + else: + continue if self.policy_pool.scores and self.policy_ranker is not None: self.policy_ranker.update_ranks( @@ -428,6 +473,13 @@ def evaluate(self, show_progress=False): self.global_step += self.batch_size if self.wandb_entity: + tasks_log = {} + for fn_name in TASK_EVAL_FN_NAMES: + v = curriculum["learner"][fn_name] + tasks_log[f"charts/tasks/max_progress/{fn_name}"] = ( + np.mean(v) if v else None + ) + wandb.log( { "performance/env_time": env_step_time, @@ -439,6 +491,7 @@ def evaluate(self, show_progress=False): for k, v in performance.items() }, **{f"charts/{k}": np.mean(v) for k, v in stats["learner"].items()}, + **tasks_log, "charts/reward": float(torch.mean(data.rewards)), "agent_steps": self.global_step, "global_step": self.global_step, diff --git a/reinforcement_learning/config.py b/reinforcement_learning/config.py index 168a164f..804a4fc9 100644 --- a/reinforcement_learning/config.py +++ b/reinforcement_learning/config.py @@ -26,19 +26,25 @@ class Config: runs_dir = "/tmp/runs" # Directory for runs policy_store_dir = None # Policy store directory use_serial_vecenv = False # Use serial vecenv implementation + use_ray_vecenv = False # Use ray vecenv implementation learner_weight = 1.0 # Weight of learner policy max_opponent_policies = 0 # Maximum number of opponent policies to train against eval_num_policies = 2 # Number of policies to use for evaluation eval_num_rounds = 1 # Number of rounds to use for evaluation wandb_project = None # WandB project name wandb_entity = None # WandB entity name + as_fine_tune = False # PPO Args bptt_horizon = 8 # Train on this number of steps of a rollout at a time. Used to reduce GPU memory. ppo_training_batch_size = 128 # Number of rows in a training batch ppo_update_epochs = 3 # Number of update epochs to use for training ppo_learning_rate = 0.00015 # Learning rate + weight_decay = 0.0 # Adam's weight_decay parameter clip_coef = 0.1 # PPO clip coefficient + no_clip_vloss = False # Whether to disable clip value loss + ent_coef = 0.01 # Policy entropy coefficient + vf_coef = 0.5 # Value function coefficient # Environment Args num_agents = 128 # Number of agents to use for training @@ -52,6 +58,7 @@ class Config: 0.2 # Percentage of agents to be resilient to starvation/dehydration ) tasks_path = None # Path to tasks to use for training + task_reward_setting_path = None # Customized reward settings for different tasks eval_mode = False # Run the postprocessor in the eval mode early_stop_agent_num = ( 8 # Stop the episode when the number of agents reaches this number @@ -60,9 +67,13 @@ class Config: heal_bonus_weight = 0.03 meander_bonus_weight = 0.02 explore_bonus_weight = 0.01 + task_learning_bonus_weight = 0.0 # weight for task_reward_setting_path's rewards + alive_bonus_weight = 0.0 spawn_immunity = 20 + adjust_ori_reward = False # Policy Args + model = "Baseline" # Name of model class input_size = 256 hidden_size = 256 num_lstm_layers = 0 # Number of LSTM layers to use @@ -72,6 +83,9 @@ class Config: attentional_decode = True # Use attentional action decoder extra_encoders = True # Use inventory and market encoders + # Debug + debug_print_events = False + @classmethod def asdict(cls): return { diff --git a/reinforcement_learning/policy_mix_encoders.py b/reinforcement_learning/policy_mix_encoders.py new file mode 100644 index 00000000..c4adf4e0 --- /dev/null +++ b/reinforcement_learning/policy_mix_encoders.py @@ -0,0 +1,184 @@ +import argparse +import torch +import torch.nn.functional as F +from torch.nn import ModuleList +from typing import Dict + +import pufferlib +import pufferlib.emulation +import pufferlib.models + +import nmmo +from nmmo.entity.entity import EntityState +from reinforcement_learning.policy import ( + TileEncoder, + PlayerEncoder, + ItemEncoder, + InventoryEncoder, + MarketEncoder, + TaskEncoder, + ActionDecoder, +) + +EntityId = EntityState.State.attr_name_to_col["id"] + + +class MixtureEncodersModel(pufferlib.models.Policy): + """Multi-Task Reinforcement Learning with Context-based Representations""" + def __init__(self, env, input_size=256, hidden_size=256, task_size=4096): + super().__init__(env) + + self.k = 4 + + self.flat_observation_space = env.flat_observation_space + self.flat_observation_structure = env.flat_observation_structure + + self.tile_encoders = ModuleList( + [TileEncoderV2(input_size) for _ in range(self.k)] + ) + self.player_encoders = ModuleList( + [PlayerEncoder(input_size, hidden_size) for _ in range(self.k)] + ) + self.item_encoders = ModuleList( + [ItemEncoder(input_size, hidden_size) for _ in range(self.k)] + ) + self.inventory_encoders = ModuleList( + [InventoryEncoder(input_size, hidden_size) for _ in range(self.k)] + ) + self.market_encoders = ModuleList( + [MarketEncoder(input_size, hidden_size) for _ in range(self.k)] + ) + self.proj_enc_fcs = ModuleList( + [torch.nn.Linear(4 * input_size, input_size) for _ in range(self.k)] + ) + + self.task_encoder = TaskEncoder(input_size, hidden_size, task_size) + self.proj_z_fc = torch.nn.Linear(input_size, input_size) + self.proj_fc = torch.nn.Linear(2 * input_size, input_size) + self.action_decoder = ActionDecoder(input_size, hidden_size) + self.value_head = torch.nn.Linear(hidden_size, 1) + + def encode_observations(self, flat_observations): + env_outputs = pufferlib.emulation.unpack_batched_obs( + flat_observations, + self.flat_observation_space, + self.flat_observation_structure, + ) + + encs = [] + player_embeddings_list = [] + item_embeddings_list = [] + market_embeddings_list = [] + + tile = env_outputs["Tile"] + tile[:, :, :2] -= tile[:, 112:113, :2].clone() + tile[:, :, :2] += 7 + + for i in range(self.k): + tile_encoder = self.tile_encoders[i] + player_encoder = self.player_encoders[i] + item_encoder = self.item_encoders[i] + inventory_encoder = self.inventory_encoders[i] + market_encoder = self.market_encoders[i] + proj_enc_fc = self.proj_enc_fcs[i] + + tile = tile_encoder(env_outputs["Tile"]) + # env_outputs["Entity"] shape (BS, agents, n_entity_states) + # player_embeddings shape (BS, agents, input_size) + player_embeddings, my_agent = player_encoder( + env_outputs["Entity"], env_outputs["AgentId"][:, 0] + ) + + item_embeddings = item_encoder(env_outputs["Inventory"]) + inventory = inventory_encoder(item_embeddings) + + market_embeddings = item_encoder(env_outputs["Market"]) + market = market_encoder(market_embeddings) + + enc = torch.cat([tile, my_agent, inventory, market], dim=-1) + # shape (BS, input_size) + enc = proj_enc_fc(enc) + + encs.append(enc.unsqueeze(-2)) + + player_embeddings_list.append(player_embeddings.unsqueeze(-1)) + item_embeddings_list.append(item_embeddings.unsqueeze(-1)) + market_embeddings_list.append(market_embeddings.unsqueeze(-1)) + + # shape (BS, k, input_size) + encs = torch.cat(encs, dim=-2) + # shape (BS, agents, input_size, k) + player_embeddings_list = torch.cat(player_embeddings_list, dim=-1) + item_embeddings_list = torch.cat(item_embeddings_list, dim=-1) + market_embeddings_list = torch.cat(market_embeddings_list, dim=-1) + + task = self.task_encoder(env_outputs["Task"]) + + with torch.no_grad(): + # shape (BS, k) + alpha = torch.matmul(encs, task.unsqueeze(-1)).squeeze(-1) + alpha = torch.softmax(alpha, dim=-1) + + # shape (BS, input_size) + z_context = torch.matmul( + encs.transpose(-1, -2), alpha.unsqueeze(-1) + ).squeeze(-1) + + # shape (BS, agents, input_size) + player_embeddings = torch.matmul( + player_embeddings_list, alpha.unsqueeze(1).unsqueeze(-1) + ).squeeze(-1) + item_embeddings = torch.matmul( + item_embeddings_list, alpha.unsqueeze(1).unsqueeze(-1) + ).squeeze(-1) + market_embeddings = torch.matmul( + market_embeddings_list, alpha.unsqueeze(1).unsqueeze(-1) + ).squeeze(-1) + + # shape (BS, input_size) + z_enc = self.proj_z_fc(z_context) + + state_enc = torch.cat([task, z_enc], dim=-1) + obs = self.proj_fc(state_enc) + + return obs, ( + player_embeddings, + item_embeddings, + market_embeddings, + env_outputs["ActionTargets"], + ) + + def decode_actions(self, hidden, lookup): + actions = self.action_decoder(hidden, lookup) + value = self.value_head(hidden) + return actions, value + + +class TileEncoderV2(torch.nn.Module): + def __init__(self, input_size): + super().__init__() + self.tile_offset = torch.tensor([i * 256 for i in range(3)]) + self.embedding = torch.nn.Embedding(3 * 256, 32) + + self.tile_conv_1 = torch.nn.Conv2d(96, 32, 3) + self.tile_conv_2 = torch.nn.Conv2d(32, 8, 3) + self.tile_fc = torch.nn.Linear(8 * 11 * 11, input_size) + + def forward(self, tile): + tile = self.embedding( + tile.long().clip(0, 255) + self.tile_offset.to(tile.device) + ) + + agents, tiles, features, embed = tile.shape + tile = ( + tile.view(agents, tiles, features * embed) + .transpose(1, 2) + .view(agents, features * embed, 15, 15) + ) + + tile = F.relu(self.tile_conv_1(tile)) + tile = F.relu(self.tile_conv_2(tile)) + tile = tile.contiguous().view(agents, -1) + tile = F.relu(self.tile_fc(tile)) + + return tile diff --git a/reinforcement_learning/policy_reduce.py b/reinforcement_learning/policy_reduce.py new file mode 100644 index 00000000..0ecff633 --- /dev/null +++ b/reinforcement_learning/policy_reduce.py @@ -0,0 +1,377 @@ +import argparse +import torch +import torch.nn.functional as F +from typing import Dict + +import pufferlib +import pufferlib.emulation +import pufferlib.models + +import nmmo +from nmmo.entity.entity import EntityState + +EVAL_MODE = False +# print(f"** EVAL_MODE {EVAL_MODE}") + +EntityId = EntityState.State.attr_name_to_col["id"] + + +class ReducedModel(pufferlib.models.Policy): + """Reduce observation space""" + + def __init__(self, env, input_size=256, hidden_size=256, task_size=4096): + super().__init__(env) + + self.flat_observation_space = env.flat_observation_space + self.flat_observation_structure = env.flat_observation_structure + + self.tile_encoder = ReducedTileEncoder(input_size) + self.player_encoder = ReducedPlayerEncoder(input_size, hidden_size) + self.item_encoder = ReducedItemEncoder(input_size, hidden_size) + self.inventory_encoder = InventoryEncoder(input_size, hidden_size) + self.market_encoder = MarketEncoder(input_size, hidden_size) + self.task_encoder = TaskEncoder(input_size, hidden_size, task_size) + self.proj_fc = torch.nn.Linear(5 * input_size, hidden_size) + self.action_decoder = ActionDecoder(input_size, hidden_size) + self.value_head = torch.nn.Linear(hidden_size, 1) + + def encode_observations(self, flat_observations): + env_outputs = pufferlib.emulation.unpack_batched_obs( + flat_observations, + self.flat_observation_space, + self.flat_observation_structure, + ) + tile = self.tile_encoder(env_outputs["Tile"]) + player_embeddings, my_agent = self.player_encoder( + env_outputs["Entity"], env_outputs["AgentId"][:, 0] + ) + + item_embeddings = self.item_encoder(env_outputs["Inventory"]) + inventory = self.inventory_encoder(item_embeddings) + + market_embeddings = self.item_encoder(env_outputs["Market"]) + market = self.market_encoder(market_embeddings) + + task = self.task_encoder(env_outputs["Task"]) + + obs = torch.cat([tile, my_agent, inventory, market, task], dim=-1) + obs = self.proj_fc(obs) + + return obs, ( + player_embeddings, + item_embeddings, + market_embeddings, + env_outputs["ActionTargets"], + ) + + def no_explore_post_processing(self, logits): + # logits shape (BS, n sub-action dim) + max_index = torch.argmax(logits, dim=-1) + ret = torch.full_like(logits, fill_value=-1e9) + ret[torch.arange(logits.shape[0]), max_index] = 0 + + return ret + + def decode_actions(self, hidden, lookup): + actions = self.action_decoder(hidden, lookup) + value = self.value_head(hidden) + + if EVAL_MODE: + actions = [self.no_explore_post_processing(logits) for logits in actions] + # TODO: skip value + + return actions, value + + +class ReducedTileEncoder(torch.nn.Module): + def __init__(self, input_size): + super().__init__() + self.embedding = torch.nn.Embedding(256, 32) + + self.tile_conv_1 = torch.nn.Conv2d(32, 16, 3) + self.tile_conv_2 = torch.nn.Conv2d(16, 8, 3) + self.tile_fc = torch.nn.Linear(8 * 11 * 11, input_size) + + def forward(self, tile): + # tile: row, col, material_id + tile = tile[:, :, 2:] + + tile = self.embedding(tile.long().clip(0, 255)) + + agents, tiles, features, embed = tile.shape + tile = ( + tile.view(agents, tiles, features * embed) + .transpose(1, 2) + .view(agents, features * embed, 15, 15) + ) + + tile = F.relu(self.tile_conv_1(tile)) + tile = F.relu(self.tile_conv_2(tile)) + tile = tile.contiguous().view(agents, -1) + tile = F.relu(self.tile_fc(tile)) + + return tile + + +class ReducedPlayerEncoder(torch.nn.Module): + """ """ + + def __init__(self, input_size, hidden_size): + super().__init__() + + discrete_attr = [ + "id", # pos player entity id & neg npc entity id + "npc_type", + "attacker_id", # just pos player entity id + "message", + ] + self.discrete_idxs = [ + EntityState.State.attr_name_to_col[key] for key in discrete_attr + ] + self.discrete_offset = torch.Tensor( + [i * 256 for i in range(len(discrete_attr))] + ) + + _max_exp = 100 + _max_level = 10 + + continuous_attr_and_scale = [ + ("row", 256), + ("col", 256), + ("damage", 100), + ("time_alive", 1024), + ("freeze", 3), + ("item_level", 50), + ("latest_combat_tick", 1024), + ("gold", 100), + ("health", 100), + ("food", 100), + ("water", 100), + ("melee_level", _max_level), + ("melee_exp", _max_exp), + ("range_level", _max_level), + ("range_exp", _max_exp), + ("mage_level", _max_level), + ("mage_exp", _max_exp), + ("fishing_level", _max_level), + ("fishing_exp", _max_exp), + ("herbalism_level", _max_level), + ("herbalism_exp", _max_exp), + ("prospecting_level", _max_level), + ("prospecting_exp", _max_exp), + ("carving_level", _max_level), + ("carving_exp", _max_exp), + ("alchemy_level", _max_exp), + ("alchemy_exp", _max_level), + ] + self.continuous_idxs = [ + EntityState.State.attr_name_to_col[key] + for key, _ in continuous_attr_and_scale + ] + self.continuous_scale = torch.Tensor( + [scale for _, scale in continuous_attr_and_scale] + ) + + self.embedding = torch.nn.Embedding(len(discrete_attr) * 256, 32) + + emb_dim = len(discrete_attr) * 32 + len(continuous_attr_and_scale) + self.agent_fc = torch.nn.Linear(emb_dim, hidden_size) + self.my_agent_fc = torch.nn.Linear(emb_dim, input_size) + + def forward(self, agents, my_id): + # self._debug(agents) + + # Pull out rows corresponding to the agent + agent_ids = agents[:, :, EntityId] + mask = (agent_ids == my_id.unsqueeze(1)) & (agent_ids != 0) + mask = mask.int() + row_indices = torch.where( + mask.any(dim=1), mask.argmax(dim=1), torch.zeros_like(mask.sum(dim=1)) + ) + + if self.discrete_offset.device != agents.device: + self.discrete_offset = self.discrete_offset.to(agents.device) + self.continuous_scale = self.continuous_scale.to(agents.device) + + # Embed each feature separately + # agents shape (BS, agents, n of states) + discrete = agents[:, :, self.discrete_idxs] + self.discrete_offset + discrete = self.embedding(discrete.long().clip(0, 255)) + batch, item, attrs, embed = discrete.shape + discrete = discrete.view(batch, item, attrs * embed) + + continuous = agents[:, :, self.continuous_idxs] / self.continuous_scale + + # shape (BS, agents, x) + agent_embeddings = torch.cat([discrete, continuous], dim=-1) + + my_agent_embeddings = agent_embeddings[ + torch.arange(agents.shape[0]), row_indices + ] + + # Project to input of recurrent size + agent_embeddings = self.agent_fc(agent_embeddings) + my_agent_embeddings = self.my_agent_fc(my_agent_embeddings) + my_agent_embeddings = F.relu(my_agent_embeddings) + + return agent_embeddings, my_agent_embeddings + + def _debug(self, agents): + agents_max, _ = torch.max(agents, dim=-2) + agents_max, _ = torch.max(agents_max, dim=-2) + print(f"agents_max {agents_max.tolist()}") + + +class ReducedItemEncoder(torch.nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.item_offset = torch.tensor([i * 256 for i in range(16)]) + self.embedding = torch.nn.Embedding(256, 32) + + self.fc = torch.nn.Linear(2 * 32 + 12, hidden_size) + + self.discrete_idxs = [1, 14] + self.discrete_offset = torch.Tensor([2, 0]) + self.continuous_idxs = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15] + self.continuous_scale = torch.Tensor( + [ + 10, + 10, + 10, + 100, + 100, + 100, + 40, + 40, + 40, + 100, + 100, + 100, + ] + ) + + def forward(self, items): + if self.discrete_offset.device != items.device: + self.discrete_offset = self.discrete_offset.to(items.device) + self.continuous_scale = self.continuous_scale.to(items.device) + + # Embed each feature separately + discrete = items[:, :, self.discrete_idxs] + self.discrete_offset + discrete = self.embedding(discrete.long().clip(0, 255)) + batch, item, attrs, embed = discrete.shape + discrete = discrete.view(batch, item, attrs * embed) + + continuous = items[:, :, self.continuous_idxs] / self.continuous_scale + + item_embeddings = torch.cat([discrete, continuous], dim=-1) + item_embeddings = self.fc(item_embeddings) + return item_embeddings + + +class InventoryEncoder(torch.nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.fc = torch.nn.Linear(12 * hidden_size, input_size) + + def forward(self, inventory): + agents, items, hidden = inventory.shape + inventory = inventory.view(agents, items * hidden) + return self.fc(inventory) + + +class MarketEncoder(torch.nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.fc = torch.nn.Linear(hidden_size, input_size) + + def forward(self, market): + return self.fc(market).mean(-2) + + +class TaskEncoder(torch.nn.Module): + def __init__(self, input_size, hidden_size, task_size): + super().__init__() + self.fc = torch.nn.Linear(task_size, input_size) + + def forward(self, task): + return self.fc(task.clone()) + + +class ActionDecoder(torch.nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.layers = torch.nn.ModuleDict( + { + "attack_style": torch.nn.Linear(hidden_size, 3), + "attack_target": torch.nn.Linear(hidden_size, hidden_size), + "market_buy": torch.nn.Linear(hidden_size, hidden_size), + "inventory_destroy": torch.nn.Linear(hidden_size, hidden_size), + "inventory_give_item": torch.nn.Linear(hidden_size, hidden_size), + "inventory_give_player": torch.nn.Linear(hidden_size, hidden_size), + "gold_quantity": torch.nn.Linear(hidden_size, 99), + "gold_target": torch.nn.Linear(hidden_size, hidden_size), + "move": torch.nn.Linear(hidden_size, 5), + "inventory_sell": torch.nn.Linear(hidden_size, hidden_size), + "inventory_price": torch.nn.Linear(hidden_size, 99), + "inventory_use": torch.nn.Linear(hidden_size, hidden_size), + } + ) + + def apply_layer(self, layer, embeddings, mask, hidden): + hidden = layer(hidden) + if hidden.dim() == 2 and embeddings is not None: + hidden = torch.matmul(embeddings, hidden.unsqueeze(-1)).squeeze(-1) + + if mask is not None: + hidden = hidden.masked_fill(mask == 0, -1e9) + + return hidden + + def forward(self, hidden, lookup): + ( + player_embeddings, + inventory_embeddings, + market_embeddings, + action_targets, + ) = lookup + + embeddings = { + "attack_target": player_embeddings, + "market_buy": market_embeddings, + "inventory_destroy": inventory_embeddings, + "inventory_give_item": inventory_embeddings, + "inventory_give_player": player_embeddings, + "gold_target": player_embeddings, + "inventory_sell": inventory_embeddings, + "inventory_use": inventory_embeddings, + } + + action_targets = { + "attack_style": action_targets["Attack"]["Style"], + "attack_target": action_targets["Attack"]["Target"], + "market_buy": action_targets["Buy"]["MarketItem"], + "inventory_destroy": action_targets["Destroy"]["InventoryItem"], + "inventory_give_item": action_targets["Give"]["InventoryItem"], + "inventory_give_player": action_targets["Give"]["Target"], + "gold_quantity": action_targets["GiveGold"]["Price"], + "gold_target": action_targets["GiveGold"]["Target"], + "move": action_targets["Move"]["Direction"], + "inventory_sell": action_targets["Sell"]["InventoryItem"], + "inventory_price": action_targets["Sell"]["Price"], + "inventory_use": action_targets["Use"]["InventoryItem"], + } + + actions = [] + for key, layer in self.layers.items(): + mask = None + mask = action_targets[key] + embs = embeddings.get(key) + if embs is not None and embs.shape[1] != mask.shape[1]: + b, _, f = embs.shape + zeros = torch.zeros([b, 1, f], dtype=embs.dtype, device=embs.device) + embs = torch.cat([embs, zeros], dim=1) + + action = self.apply_layer(layer, embs, mask, hidden) + actions.append(action) + + return actions diff --git a/reinforcement_learning/policy_reduce_v2.py b/reinforcement_learning/policy_reduce_v2.py new file mode 100644 index 00000000..21f9d2a9 --- /dev/null +++ b/reinforcement_learning/policy_reduce_v2.py @@ -0,0 +1,411 @@ +import argparse +import torch +import torch.nn.functional as F +from typing import Dict + +import pufferlib +import pufferlib.emulation +import pufferlib.models + +import nmmo +from nmmo.entity.entity import EntityState + +EVAL_MODE = False +# print(f"** EVAL_MODE {EVAL_MODE}") + +EntityId = EntityState.State.attr_name_to_col["id"] + + +class ReducedModelV2(pufferlib.models.Policy): + """Reduce observation space""" + + def __init__(self, env, input_size=256, hidden_size=256, task_size=4096): + super().__init__(env) + + self.flat_observation_space = env.flat_observation_space + self.flat_observation_structure = env.flat_observation_structure + + self.tile_encoder = ReducedTileEncoder(input_size) + self.player_encoder = ReducedPlayerEncoder(input_size, hidden_size) + self.item_encoder = ReducedItemEncoder(input_size, hidden_size) + self.inventory_encoder = InventoryEncoder(input_size, hidden_size) + self.market_encoder = MarketEncoder(input_size, hidden_size) + self.task_encoder = TaskEncoder(input_size, hidden_size, task_size) + self.proj_fc = torch.nn.Linear(5 * input_size, hidden_size) + self.action_decoder = ReducedActionDecoder(input_size, hidden_size) + self.value_head = torch.nn.Linear(hidden_size, 1) + + def encode_observations(self, flat_observations): + env_outputs = pufferlib.emulation.unpack_batched_obs( + flat_observations, + self.flat_observation_space, + self.flat_observation_structure, + ) + tile = self.tile_encoder(env_outputs["Tile"]) + player_embeddings, my_agent = self.player_encoder( + env_outputs["Entity"], env_outputs["AgentId"][:, 0] + ) + + item_embeddings = self.item_encoder(env_outputs["Inventory"]) + inventory = self.inventory_encoder(item_embeddings) + + market_embeddings = self.item_encoder(env_outputs["Market"]) + market = self.market_encoder(market_embeddings) + + task = self.task_encoder(env_outputs["Task"]) + + obs = torch.cat([tile, my_agent, inventory, market, task], dim=-1) + obs = self.proj_fc(obs) + + return obs, ( + player_embeddings, + item_embeddings, + market_embeddings, + env_outputs["ActionTargets"], + ) + + def no_explore_post_processing(self, logits): + # logits shape (BS, n sub-action dim) + max_index = torch.argmax(logits, dim=-1) + ret = torch.full_like(logits, fill_value=-1e9) + ret[torch.arange(logits.shape[0]), max_index] = 0 + + return ret + + def decode_actions(self, hidden, lookup): + actions = self.action_decoder(hidden, lookup) + value = self.value_head(hidden) + + if EVAL_MODE: + actions = [self.no_explore_post_processing(logits) for logits in actions] + # TODO: skip value + + return actions, value + + +class ReducedTileEncoder(torch.nn.Module): + def __init__(self, input_size): + super().__init__() + self.embedding = torch.nn.Embedding(256, 32) + + self.tile_conv_1 = torch.nn.Conv2d(32, 16, 3) + self.tile_conv_2 = torch.nn.Conv2d(16, 8, 3) + self.tile_fc = torch.nn.Linear(8 * 11 * 11, input_size) + + def forward(self, tile): + # tile: row, col, material_id + tile = tile[:, :, 2:] + + tile = self.embedding(tile.long().clip(0, 255)) + + agents, tiles, features, embed = tile.shape + tile = ( + tile.view(agents, tiles, features * embed) + .transpose(1, 2) + .view(agents, features * embed, 15, 15) + ) + + tile = F.relu(self.tile_conv_1(tile)) + tile = F.relu(self.tile_conv_2(tile)) + tile = tile.contiguous().view(agents, -1) + tile = F.relu(self.tile_fc(tile)) + + return tile + + +class ReducedPlayerEncoder(torch.nn.Module): + """ """ + + def __init__(self, input_size, hidden_size): + super().__init__() + + discrete_attr = [ + "id", # pos player entity id & neg npc entity id + "npc_type", + "attacker_id", # just pos player entity id + "message", + ] + self.discrete_idxs = [ + EntityState.State.attr_name_to_col[key] for key in discrete_attr + ] + self.discrete_offset = torch.Tensor( + [i * 256 for i in range(len(discrete_attr))] + ) + + _max_exp = 100 + _max_level = 10 + + continuous_attr_and_scale = [ + ("row", 256), + ("col", 256), + ("damage", 100), + ("time_alive", 1024), + ("freeze", 3), + ("item_level", 50), + ("latest_combat_tick", 1024), + ("gold", 100), + ("health", 100), + ("food", 100), + ("water", 100), + ("melee_level", _max_level), + ("melee_exp", _max_exp), + ("range_level", _max_level), + ("range_exp", _max_exp), + ("mage_level", _max_level), + ("mage_exp", _max_exp), + ("fishing_level", _max_level), + ("fishing_exp", _max_exp), + ("herbalism_level", _max_level), + ("herbalism_exp", _max_exp), + ("prospecting_level", _max_level), + ("prospecting_exp", _max_exp), + ("carving_level", _max_level), + ("carving_exp", _max_exp), + ("alchemy_level", _max_exp), + ("alchemy_exp", _max_level), + ] + self.continuous_idxs = [ + EntityState.State.attr_name_to_col[key] + for key, _ in continuous_attr_and_scale + ] + self.continuous_scale = torch.Tensor( + [scale for _, scale in continuous_attr_and_scale] + ) + + self.embedding = torch.nn.Embedding(len(discrete_attr) * 256, 32) + + emb_dim = len(discrete_attr) * 32 + len(continuous_attr_and_scale) + self.agent_fc = torch.nn.Linear(emb_dim, hidden_size) + self.my_agent_fc = torch.nn.Linear(emb_dim, input_size) + + def forward(self, agents, my_id): + # self._debug(agents) + + # Pull out rows corresponding to the agent + agent_ids = agents[:, :, EntityId] + mask = (agent_ids == my_id.unsqueeze(1)) & (agent_ids != 0) + mask = mask.int() + row_indices = torch.where( + mask.any(dim=1), mask.argmax(dim=1), torch.zeros_like(mask.sum(dim=1)) + ) + + if self.discrete_offset.device != agents.device: + self.discrete_offset = self.discrete_offset.to(agents.device) + self.continuous_scale = self.continuous_scale.to(agents.device) + + # Embed each feature separately + # agents shape (BS, agents, n of states) + discrete = agents[:, :, self.discrete_idxs] + self.discrete_offset + discrete = self.embedding(discrete.long().clip(0, 255)) + batch, item, attrs, embed = discrete.shape + discrete = discrete.view(batch, item, attrs * embed) + + continuous = agents[:, :, self.continuous_idxs] / self.continuous_scale + + # shape (BS, agents, x) + agent_embeddings = torch.cat([discrete, continuous], dim=-1) + + my_agent_embeddings = agent_embeddings[ + torch.arange(agents.shape[0]), row_indices + ] + + # Project to input of recurrent size + agent_embeddings = self.agent_fc(agent_embeddings) + my_agent_embeddings = self.my_agent_fc(my_agent_embeddings) + my_agent_embeddings = F.relu(my_agent_embeddings) + + return agent_embeddings, my_agent_embeddings + + def _debug(self, agents): + agents_max, _ = torch.max(agents, dim=-2) + agents_max, _ = torch.max(agents_max, dim=-2) + print(f"agents_max {agents_max.tolist()}") + + +class ReducedItemEncoder(torch.nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.item_offset = torch.tensor([i * 256 for i in range(16)]) + self.embedding = torch.nn.Embedding(256, 32) + + self.fc = torch.nn.Linear(2 * 32 + 12, hidden_size) + + self.discrete_idxs = [1, 14] + self.discrete_offset = torch.Tensor([2, 0]) + self.continuous_idxs = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15] + self.continuous_scale = torch.Tensor( + [ + 10, + 10, + 10, + 100, + 100, + 100, + 40, + 40, + 40, + 100, + 100, + 100, + ] + ) + + def forward(self, items): + if self.discrete_offset.device != items.device: + self.discrete_offset = self.discrete_offset.to(items.device) + self.continuous_scale = self.continuous_scale.to(items.device) + + # Embed each feature separately + discrete = items[:, :, self.discrete_idxs] + self.discrete_offset + discrete = self.embedding(discrete.long().clip(0, 255)) + batch, item, attrs, embed = discrete.shape + discrete = discrete.view(batch, item, attrs * embed) + + continuous = items[:, :, self.continuous_idxs] / self.continuous_scale + + item_embeddings = torch.cat([discrete, continuous], dim=-1) + item_embeddings = self.fc(item_embeddings) + return item_embeddings + + +class InventoryEncoder(torch.nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.fc = torch.nn.Linear(12 * hidden_size, input_size) + + def forward(self, inventory): + agents, items, hidden = inventory.shape + inventory = inventory.view(agents, items * hidden) + return self.fc(inventory) + + +class MarketEncoder(torch.nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.fc = torch.nn.Linear(hidden_size, input_size) + + def forward(self, market): + return self.fc(market).mean(-2) + + +class TaskEncoder(torch.nn.Module): + def __init__(self, input_size, hidden_size, task_size): + super().__init__() + self.fc = torch.nn.Linear(task_size, input_size) + + def forward(self, task): + return self.fc(task.clone()) + + +class ReducedActionDecoder(torch.nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + # order corresponding to action space + self.sub_action_keys = [ + "attack_style", + "attack_target", + "market_buy", + "inventory_destroy", + "inventory_give_item", + "inventory_give_player", + "gold_quantity", + "gold_target", + "move", + "inventory_sell", + "inventory_price", + "inventory_use", + ] + self.layers = torch.nn.ModuleDict( + { + "attack_style": torch.nn.Linear(hidden_size, 3), + "attack_target": torch.nn.Linear(hidden_size, hidden_size), + "market_buy": torch.nn.Linear(hidden_size, hidden_size), + "inventory_destroy": torch.nn.Linear(hidden_size, hidden_size), + # "inventory_give_item": torch.nn.Linear(hidden_size, hidden_size), # TODO: useful for Inventory Management? + # "inventory_give_player": torch.nn.Linear(hidden_size, hidden_size), + # "gold_quantity": torch.nn.Linear(hidden_size, 99), + # "gold_target": torch.nn.Linear(hidden_size, hidden_size), + "move": torch.nn.Linear(hidden_size, 5), + "inventory_sell": torch.nn.Linear(hidden_size, hidden_size), + "inventory_price": torch.nn.Linear(hidden_size, 99), + "inventory_use": torch.nn.Linear(hidden_size, hidden_size), + } + ) + + def apply_layer(self, layer, embeddings, mask, hidden): + hidden = layer(hidden) + if hidden.dim() == 2 and embeddings is not None: + hidden = torch.matmul(embeddings, hidden.unsqueeze(-1)).squeeze(-1) + + if mask is not None: + hidden = hidden.masked_fill(mask == 0, -1e9) + + return hidden + + def act_noob_action(self, key, mask): + if key in ("inventory_give_item", "inventory_give_player", "gold_target"): + noob_action_index = -1 + elif key in ("gold_quantity",): + noob_action_index = 0 + else: + raise NotImplementedError(key) + + logits = torch.full_like(mask, fill_value=-1e9) + logits[:, noob_action_index] = 0 + + return logits + + def forward(self, hidden, lookup): + ( + player_embeddings, + inventory_embeddings, + market_embeddings, + action_targets, + ) = lookup + + embeddings = { + "attack_target": player_embeddings, + "market_buy": market_embeddings, + "inventory_destroy": inventory_embeddings, + "inventory_give_item": inventory_embeddings, + "inventory_give_player": player_embeddings, + "gold_target": player_embeddings, + "inventory_sell": inventory_embeddings, + "inventory_use": inventory_embeddings, + } + + action_targets = { + "attack_style": action_targets["Attack"]["Style"], + "attack_target": action_targets["Attack"]["Target"], + "market_buy": action_targets["Buy"]["MarketItem"], + "inventory_destroy": action_targets["Destroy"]["InventoryItem"], + "inventory_give_item": action_targets["Give"]["InventoryItem"], + "inventory_give_player": action_targets["Give"]["Target"], + "gold_quantity": action_targets["GiveGold"]["Price"], + "gold_target": action_targets["GiveGold"]["Target"], + "move": action_targets["Move"]["Direction"], + "inventory_sell": action_targets["Sell"]["InventoryItem"], + "inventory_price": action_targets["Sell"]["Price"], + "inventory_use": action_targets["Use"]["InventoryItem"], + } + + actions = [] + for key in self.sub_action_keys: + mask = action_targets[key] + + if key in self.layers: + layer = self.layers[key] + embs = embeddings.get(key) + if embs is not None and embs.shape[1] != mask.shape[1]: + b, _, f = embs.shape + zeros = torch.zeros([b, 1, f], dtype=embs.dtype, device=embs.device) + embs = torch.cat([embs, zeros], dim=1) + + action = self.apply_layer(layer, embs, mask, hidden) + + else: + action = self.act_noob_action(key, mask) + + actions.append(action) + + return actions diff --git a/reinforcement_learning/policy_routing.py b/reinforcement_learning/policy_routing.py new file mode 100644 index 00000000..52a5d90e --- /dev/null +++ b/reinforcement_learning/policy_routing.py @@ -0,0 +1,187 @@ +import argparse +import math + +import torch +import torch.nn.functional as F +from torch.nn import ModuleList +from torch.nn.parameter import Parameter +from torch.nn import init + +from typing import Dict + +import pufferlib +import pufferlib.emulation +import pufferlib.models + +import nmmo +from nmmo.entity.entity import EntityState +from reinforcement_learning.policy import ( + TileEncoder, + PlayerEncoder, + ItemEncoder, + InventoryEncoder, + MarketEncoder, + TaskEncoder, + ActionDecoder, +) + +EntityId = EntityState.State.attr_name_to_col["id"] + + +class PolicyRoutingModel(pufferlib.models.Policy): + """Multi-Task Reinforcement Learning with Soft Modularization""" + + L = 2 + n = 2 + + def __init__(self, env, input_size=256, hidden_size=256, task_size=4096): + super().__init__(env) + + self.flat_observation_space = env.flat_observation_space + self.flat_observation_structure = env.flat_observation_structure + + self.tile_encoder = TileEncoder(input_size) + self.player_encoder = PlayerEncoder(input_size, hidden_size) + self.item_encoder = ItemEncoder(input_size, hidden_size) + self.inventory_encoder = InventoryEncoder(input_size, hidden_size) + self.market_encoder = MarketEncoder(input_size, hidden_size) + self.task_encoder = TaskEncoder(input_size, hidden_size, task_size) + self.proj_fc = torch.nn.Linear(4 * input_size, input_size) + self.action_decoder = ActionDecoder(input_size, hidden_size) + self.value_head = torch.nn.Linear(hidden_size, 1) + + self.input_size = input_size + + self.routing_layers = ModuleList( + [RoutingLayer(input_size, self.n, l) for l in range(self.L)] + ) + + self.base_policy_layers = ModuleList( + [BasePolicyLayer(input_size, self.n) for l in range(self.L)] + ) + + self.g_fc = torch.nn.Linear(input_size, self.n * input_size) + self.hidden_fc = torch.nn.Linear(self.n, 1) + + def encode_observations(self, flat_observations): + env_outputs = pufferlib.emulation.unpack_batched_obs( + flat_observations, + self.flat_observation_space, + self.flat_observation_structure, + ) + tile = self.tile_encoder(env_outputs["Tile"]) + player_embeddings, my_agent = self.player_encoder( + env_outputs["Entity"], env_outputs["AgentId"][:, 0] + ) + + item_embeddings = self.item_encoder(env_outputs["Inventory"]) + inventory = self.inventory_encoder(item_embeddings) + + market_embeddings = self.item_encoder(env_outputs["Market"]) + market = self.market_encoder(market_embeddings) + + # inputs of base policy network & routing network + # shape (BS, input_size) + state_enc = torch.cat([tile, my_agent, inventory, market], dim=-1) + state_enc = self.proj_fc(state_enc) + task_enc = self.task_encoder(env_outputs["Task"]) + + # shape (BS, input_size) + state_mul_task = state_enc * task_enc + + batch_size, _ = state_mul_task.shape + + p_l = None + gs = self.g_fc(state_enc) + gs = gs.view(batch_size, self.n, self.input_size) + for l in range(self.L): + # shape (BS, n^2) + p_l = self.routing_layers[l](state_mul_task, p_l) + + # shape (BS, n, n) + p_l_mat = p_l.view(batch_size, self.n, self.n) + p_l_softmax = torch.softmax(p_l_mat, dim=-1) + + # shape (BS, n, input_size) + gs = self.base_policy_layers[l](gs, p_l_softmax) + + # shape (BS, input_size) + hidden = self.hidden_fc(gs.transpose(-1, -2)).squeeze(-1) + + return hidden, ( + player_embeddings, + item_embeddings, + market_embeddings, + env_outputs["ActionTargets"], + ) + + def decode_actions(self, hidden, lookup): + actions = self.action_decoder(hidden, lookup) + value = self.value_head(hidden) + return actions, value + + +class RoutingLayer(torch.nn.Module): + def __init__(self, input_size, n, l): + """ + input_size: D + """ + super().__init__() + self.l = l + + n2 = n**2 + + if self.l > 0: + self.w_u_fc = torch.nn.Linear(n2, input_size, bias=False) + + self.w_d_fc = torch.nn.Linear(input_size, n2, bias=False) + + def forward(self, state_mul_task, p_l): + """ + state_mul_task: shape (BS, input_size) + p_l: shape (BS, n^2) + """ + x = state_mul_task + if self.l > 0: + assert p_l is not None + # shape (BS, input_size) + x = self.w_u_fc(p_l) * x + + x = F.relu(x) + x = self.w_d_fc(x) + + return x + + +class BasePolicyLayer(torch.nn.Module): + def __init__(self, input_size, n): + """ + input_size: D + """ + super().__init__() + + self.weight = Parameter(torch.empty((1, n, input_size, input_size))) + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, x, p_l_softmax): + """ + x: shape (BS, n, input_size) + p_l_softmax: shape (BS, n, n) + """ + # shape (BS, n, input_size) + x = torch.matmul( + x.unsqueeze(-2), # shape (BS, n, 1, input_size) + self.weight, # shape (1, n, input_size, input_size) + ).squeeze(-2) + + x = F.relu(x) + + # shape (BS, n, input_size) + x = torch.matmul(p_l_softmax, x) + + return x + + +class PolicyRoutingModelDeep(PolicyRoutingModel): + L = 4 + n = 4 diff --git a/train.py b/train.py index 81e8e1cc..02fb757f 100644 --- a/train.py +++ b/train.py @@ -2,19 +2,66 @@ import logging import torch -from pufferlib.vectorization import Serial, Multiprocessing +from pufferlib.vectorization import Serial, Multiprocessing, Ray from pufferlib.policy_store import DirectoryPolicyStore +from pufferlib.models import RecurrentWrapper from pufferlib.frameworks import cleanrl import environment -from reinforcement_learning import clean_pufferl, policy, config +from reinforcement_learning import clean_pufferl, config +from reinforcement_learning.policy import Baseline +from reinforcement_learning.policy_mix_encoders import MixtureEncodersModel +from reinforcement_learning.policy_routing import ( + PolicyRoutingModel, + PolicyRoutingModelDeep, +) +from reinforcement_learning.policy_reduce import ReducedModel +from reinforcement_learning.policy_reduce_v2 import ReducedModelV2 # NOTE: this file changes when running curriculum generation track # Run test_task_encoder.py to regenerate this file (or get it from the repo) BASELINE_CURRICULUM_FILE = "reinforcement_learning/curriculum_with_embedding.pkl" CUSTOM_CURRICULUM_FILE = "curriculum_generation/custom_curriculum_with_embedding.pkl" +TASK_REWARD_SETTING_PATH = "reinforcement_learning/task_reward_setting.json" + + +def get_make_policy_fn(args): + try: + model_cls = eval(args.model) + assert model_cls in ( + Baseline, + MixtureEncodersModel, + PolicyRoutingModel, + PolicyRoutingModelDeep, + ReducedModel, + ReducedModelV2, + ) + except: + raise Exception(f"Invalid model `{args.model}`") + + def make_policy(envs): + learner_policy = model_cls( + envs.driver_env, + input_size=args.input_size, + hidden_size=args.hidden_size, + task_size=args.task_size, + ) + if args.num_lstm_layers > 0: + learner_policy = RecurrentWrapper( + env=envs.driver_env, + policy=learner_policy, + input_size=args.input_size, + hidden_size=args.hidden_size, + num_layers=args.num_lstm_layers, + ) + return cleanrl.RecurrentPolicy(learner_policy) + else: + return cleanrl.Policy(learner_policy) + + return make_policy + def setup_env(args): run_dir = os.path.join(args.runs_dir, args.run_name) @@ -28,35 +75,32 @@ def setup_env(args): logging.info("Using policy store from %s", args.policy_store_dir) policy_store = DirectoryPolicyStore(args.policy_store_dir) - def make_policy(envs): - learner_policy = policy.Baseline( - envs.driver_env, - input_size=args.input_size, - hidden_size=args.hidden_size, - task_size=args.task_size, - ) - return cleanrl.Policy(learner_policy) - trainer = clean_pufferl.CleanPuffeRL( device=torch.device(args.device), seed=args.seed, env_creator=environment.make_env_creator(args), env_creator_kwargs={}, - agent_creator=make_policy, + agent_creator=get_make_policy_fn(args), data_dir=run_dir, exp_name=args.run_name, policy_store=policy_store, wandb_entity=args.wandb_entity, wandb_project=args.wandb_project, wandb_extra_data=args, + as_fine_tune=args.as_fine_tune, checkpoint_interval=args.checkpoint_interval, - vectorization=Serial if args.use_serial_vecenv else Multiprocessing, + vectorization=( + Serial + if args.use_serial_vecenv + else (Ray if args.use_ray_vecenv else Multiprocessing) + ), total_timesteps=args.train_num_steps, num_envs=args.num_envs, num_cores=args.num_cores or args.num_envs, num_buffers=args.num_buffers, batch_size=args.rollout_batch_size, learning_rate=args.ppo_learning_rate, + weight_decay=args.weight_decay, selfplay_learner_weight=args.learner_weight, selfplay_num_policies=args.max_opponent_policies + 1, # record_loss = args.record_loss, @@ -72,6 +116,9 @@ def reinforcement_learning_track(trainer, args): bptt_horizon=args.bptt_horizon, batch_rows=args.ppo_training_batch_size // args.bptt_horizon, clip_coef=args.clip_coef, + clip_vloss=not args.no_clip_vloss, + ent_coef=args.ent_coef, + vf_coef=args.vf_coef, ) @@ -142,6 +189,7 @@ def curriculum_generation_track(trainer, args, use_elm=True): if args.track == "rl": args.tasks_path = BASELINE_CURRICULUM_FILE + args.task_reward_setting_path = TASK_REWARD_SETTING_PATH trainer = setup_env(args) reinforcement_learning_track(trainer, args) elif args.track == "curriculum":