diff --git a/sim/algo/ppo/on_policy_runner.py b/sim/algo/ppo/on_policy_runner.py index e3fb6e1c..227f0fd3 100755 --- a/sim/algo/ppo/on_policy_runner.py +++ b/sim/algo/ppo/on_policy_runner.py @@ -30,6 +30,7 @@ # Copyright (c) 2024 Beijing RobotEra TECHNOLOGY CO.,LTD. All rights reserved. # type: ignore +import json import os import statistics import time @@ -46,6 +47,51 @@ from sim.algo.vec_env import VecEnv +def make_serializable(obj): + """Recursively convert an object and its components to JSON-serializable types.""" + if isinstance(obj, dict): + return {key: make_serializable(value) for key, value in obj.items()} + elif isinstance(obj, (list, tuple)): + return [make_serializable(item) for item in obj] + try: + # Try standard JSON serialization first + json.dumps(obj) + return obj + except (TypeError, OverflowError): + try: + # for numpy types + return float(obj) + except: + try: + return int(obj) + except: + # If all else fails, convert to string representation + return f"<{type(obj).__name__}:{str(obj)}>" + + +def write_config_file(config: dict, log_dir: str) -> None: + """Writes the configuration to a JSON file in the log directory.""" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + metadata = { + "creation_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "experiment_name": config["runner"]["experiment_name"], + "run_name": config["runner"]["run_name"], + } + + serializable_config = make_serializable(config) + + full_config = { + "metadata": metadata, + "configuration": serializable_config + } + + json_path = os.path.join(log_dir, "experiment_config.json") + with open(json_path, 'w') as f: + json.dump(full_config, f, indent=2, sort_keys=True) + + class OnPolicyRunner: def __init__(self, env: VecEnv, train_cfg: dict, log_dir: Optional[str] = None, device: str = "cpu"): self.cfg = train_cfg["runner"] @@ -102,6 +148,8 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals config=self.all_cfg, ) self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) + if self.current_learning_iteration == 0: + write_config_file(self.all_cfg, self.log_dir) if init_at_random_ep_len: self.env.episode_length_buf = torch.randint_like( self.env.episode_length_buf, high=int(self.env.max_episode_length)