Skip to content

Commit 5018186

Browse files
committed
add converter
1 parent 699b54b commit 5018186

File tree

8 files changed

+101
-3
lines changed

8 files changed

+101
-3
lines changed

.vscode/settings.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"marscode.chatLanguage": "cn"
3+
}

examples/model_100.onnx

1.44 MB
Binary file not shown.

examples/model_100.pt

4.34 MB
Binary file not shown.

sim/genesis/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ cd rsl_rl && git checkout v1.0.2 && pip install -e .
2222

2323
`python zeroth_eval.py`
2424

25+
模型转换。
26+
27+
`cd genesis`
28+
29+
`python3 utils/convert_to_onnx.py --cfg logs/zeroth-walking/cfgs.pkl --model ../../examples/model_100.pt --output ../../examples/model_100.onnx`
30+
2531
注意⚠️
2632
Mac M系列芯片,使用micromamba替换conda使用,例如
2733
`micromamba activate genesis`

sim/genesis/README_EN.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ To monitor the training process, start TensorBoard:
3232
tensorboard --logdir logs
3333
```
3434

35+
## Model Conversion to ONNX
36+
37+
To convert trained models to ONNX format:
38+
39+
```bash
40+
python utils/convert_to_onnx.py \
41+
--cfg logs/zeroth-walking/cfgs.pkl \
42+
--model ../../examples/model_100.pt \
43+
--output ../../examples/model_100.onnx
44+
```
45+
46+
Arguments:
47+
- `--cfg`: Path to config file (.pkl)
48+
- `--model`: Path to model file (.pt)
49+
- `--output`: Output ONNX file path
50+
3551
## Evaluation
3652

3753
To view training results:
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
import torch.nn as nn
3+
import pickle
4+
import argparse
5+
6+
def parse_args():
7+
parser = argparse.ArgumentParser(description='Convert model to ONNX format')
8+
parser.add_argument('--cfg', type=str, required=True,
9+
help='Path to config file (.pkl)')
10+
parser.add_argument('--model', type=str, required=True,
11+
help='Path to model file (.pt)')
12+
parser.add_argument('--output', type=str, required=True,
13+
help='Output ONNX file path')
14+
return parser.parse_args()
15+
16+
args = parse_args()
17+
18+
# 加载配置
19+
with open(args.cfg, 'rb') as f:
20+
cfgs = pickle.load(f)
21+
22+
# 加载模型
23+
model_dict = torch.load(args.model, weights_only=True)
24+
25+
# 创建继承自ActorCritic的模型类
26+
from rsl_rl.modules import ActorCritic
27+
28+
class ExportModel(ActorCritic):
29+
def forward(self, obs):
30+
# 使用actor网络生成动作
31+
actions = self.actor(obs)
32+
# 使用critic网络评估状态值
33+
values = self.critic(obs)
34+
return actions, values
35+
36+
# 根据配置创建模型实例
37+
model = ExportModel(
38+
num_actor_obs=cfgs[1]['num_obs'],
39+
num_critic_obs=cfgs[1]['num_obs'],
40+
num_actions=cfgs[0]['num_actions'],
41+
actor_hidden_dims=cfgs[4]['policy']['actor_hidden_dims'],
42+
critic_hidden_dims=cfgs[4]['policy']['critic_hidden_dims'],
43+
activation='elu',
44+
init_noise_std=cfgs[4]['policy']['init_noise_std']
45+
)
46+
47+
# 加载模型参数
48+
model.load_state_dict(model_dict['model_state_dict'])
49+
model.eval()
50+
51+
# 创建示例输入
52+
obs_dim = cfgs[1]['num_obs'] # 从obs_cfg获取观测维度
53+
dummy_input = torch.randn(1, obs_dim)
54+
55+
# 转换为ONNX
56+
torch.onnx.export(
57+
model,
58+
dummy_input,
59+
args.output,
60+
input_names=['obs'],
61+
output_names=['actions'],
62+
dynamic_axes={
63+
'obs': {0: 'batch_size'},
64+
'actions': {0: 'batch_size'}
65+
}
66+
)
67+
68+
print("Model successfully converted to ONNX format")

sim/genesis/utils/temp_read_pkl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import pickle
2+
3+
with open('./logs/zeroth-walking/cfgs.pkl', 'rb') as f:
4+
data = pickle.load(f)
5+
print(data)

sim/genesis/zeroth_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ def get_train_cfg(exp_name, max_iterations):
3939
"load_run": -1,
4040
"log_interval": 1,
4141
"max_iterations": max_iterations,
42-
"num_steps_per_env": 24,
42+
"num_steps_per_env": 48,
4343
"policy_class_name": "ActorCritic",
4444
"record_interval": -1,
4545
"resume": False,
4646
"resume_path": None,
4747
"run_name": "",
48-
"runner_class_name": "runner_class_name",
49-
"save_interval": 100,
48+
"runner_class_name": "OnPolicyRunner",
49+
"save_interval": 10,
5050
},
5151
"runner_class_name": "OnPolicyRunner",
5252
"seed": 1,

0 commit comments

Comments
 (0)