Skip to content

Commit

Permalink
add converter
Browse files Browse the repository at this point in the history
  • Loading branch information
BigJohnn committed Jan 16, 2025
1 parent 699b54b commit 5018186
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"marscode.chatLanguage": "cn"
}
Binary file added examples/model_100.onnx
Binary file not shown.
Binary file added examples/model_100.pt
Binary file not shown.
6 changes: 6 additions & 0 deletions sim/genesis/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ cd rsl_rl && git checkout v1.0.2 && pip install -e .

`python zeroth_eval.py`

模型转换。

`cd genesis`

`python3 utils/convert_to_onnx.py --cfg logs/zeroth-walking/cfgs.pkl --model ../../examples/model_100.pt --output ../../examples/model_100.onnx`

注意⚠️
Mac M系列芯片,使用micromamba替换conda使用,例如
`micromamba activate genesis`
Expand Down
16 changes: 16 additions & 0 deletions sim/genesis/README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ To monitor the training process, start TensorBoard:
tensorboard --logdir logs
```

## Model Conversion to ONNX

To convert trained models to ONNX format:

```bash
python utils/convert_to_onnx.py \
--cfg logs/zeroth-walking/cfgs.pkl \
--model ../../examples/model_100.pt \
--output ../../examples/model_100.onnx
```

Arguments:
- `--cfg`: Path to config file (.pkl)
- `--model`: Path to model file (.pt)
- `--output`: Output ONNX file path

## Evaluation

To view training results:
Expand Down
68 changes: 68 additions & 0 deletions sim/genesis/utils/convert_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import torch.nn as nn
import pickle
import argparse

def parse_args():
parser = argparse.ArgumentParser(description='Convert model to ONNX format')
parser.add_argument('--cfg', type=str, required=True,
help='Path to config file (.pkl)')
parser.add_argument('--model', type=str, required=True,
help='Path to model file (.pt)')
parser.add_argument('--output', type=str, required=True,
help='Output ONNX file path')
return parser.parse_args()

args = parse_args()

# 加载配置
with open(args.cfg, 'rb') as f:
cfgs = pickle.load(f)

# 加载模型
model_dict = torch.load(args.model, weights_only=True)

# 创建继承自ActorCritic的模型类
from rsl_rl.modules import ActorCritic

class ExportModel(ActorCritic):
def forward(self, obs):
# 使用actor网络生成动作
actions = self.actor(obs)
# 使用critic网络评估状态值
values = self.critic(obs)
return actions, values

# 根据配置创建模型实例
model = ExportModel(
num_actor_obs=cfgs[1]['num_obs'],
num_critic_obs=cfgs[1]['num_obs'],
num_actions=cfgs[0]['num_actions'],
actor_hidden_dims=cfgs[4]['policy']['actor_hidden_dims'],
critic_hidden_dims=cfgs[4]['policy']['critic_hidden_dims'],
activation='elu',
init_noise_std=cfgs[4]['policy']['init_noise_std']
)

# 加载模型参数
model.load_state_dict(model_dict['model_state_dict'])
model.eval()

# 创建示例输入
obs_dim = cfgs[1]['num_obs'] # 从obs_cfg获取观测维度
dummy_input = torch.randn(1, obs_dim)

# 转换为ONNX
torch.onnx.export(
model,
dummy_input,
args.output,
input_names=['obs'],
output_names=['actions'],
dynamic_axes={
'obs': {0: 'batch_size'},
'actions': {0: 'batch_size'}
}
)

print("Model successfully converted to ONNX format")
5 changes: 5 additions & 0 deletions sim/genesis/utils/temp_read_pkl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import pickle

with open('./logs/zeroth-walking/cfgs.pkl', 'rb') as f:
data = pickle.load(f)
print(data)
6 changes: 3 additions & 3 deletions sim/genesis/zeroth_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def get_train_cfg(exp_name, max_iterations):
"load_run": -1,
"log_interval": 1,
"max_iterations": max_iterations,
"num_steps_per_env": 24,
"num_steps_per_env": 48,
"policy_class_name": "ActorCritic",
"record_interval": -1,
"resume": False,
"resume_path": None,
"run_name": "",
"runner_class_name": "runner_class_name",
"save_interval": 100,
"runner_class_name": "OnPolicyRunner",
"save_interval": 10,
},
"runner_class_name": "OnPolicyRunner",
"seed": 1,
Expand Down

0 comments on commit 5018186

Please sign in to comment.