Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
WT-MM committed Dec 30, 2024
1 parent b41d3df commit f627fb3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 33 deletions.
50 changes: 21 additions & 29 deletions sim/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import logging
import math
import os

# Add local kinfer to the path because kinfer requires 3.11 and we are using 3.8
import sys
import time
import uuid
from datetime import datetime
Expand All @@ -17,29 +20,25 @@
import cv2
import h5py
import numpy as np
import onnx
from isaacgym import gymapi
from tqdm import tqdm
import onnx

# Add local kinfer to the path because kinfer requires 3.11 and we are using 3.8
import sys

# Get absolute path relative to this script
kinfer_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'third_party', 'kinfer'))
kinfer_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "third_party", "kinfer"))
sys.path.append(kinfer_path)

from kinfer import proto as P
from kinfer import proto as P
from kinfer.export.python import export_model

from sim.utils.resources import load_embodiment

from sim.env import run_dir # noqa: E402
from sim.envs import task_registry # noqa: E402

# Local imports third
from sim.model_export import ActorCfg, get_actor_policy
from sim.utils.helpers import get_args # noqa: E402
from sim.utils.logger import Logger # noqa: E402
from sim.utils.resources import load_embodiment

import torch # special case with isort: skip comment

Expand All @@ -52,11 +51,13 @@ def export_policy_as_jit(actor_critic: Any, path: Union[str, os.PathLike]) -> No
model = get_actor_jit(actor_critic)
model.save(path)


def get_actor_jit(actor_critic: Any) -> Any:
model = copy.deepcopy(actor_critic.actor).to("cpu")
traced_script_module = torch.jit.script(model)
return traced_script_module


def play(args: argparse.Namespace) -> None:
logger.info("Configuring environment and training settings...")
env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
Expand Down Expand Up @@ -114,42 +115,36 @@ def play(args: argparse.Namespace) -> None:
# 3 comes from the number of times num_actions is repeated in the observation (dof_pos, dof_vel, prev_actions)
# -1 comes from the fact that the latest observation is the current state and not part of the buffer
shape=[(env_cfg.env.frame_stack - 1) * (11 + env.num_dof * 3)],
description="Buffer of previous observations"
)
description="Buffer of previous observations",
),
),
P.ValueSchema(
value_name="vector_command",
vector_command=P.VectorCommandSchema(
dimensions=3, # x_vel, y_vel, rot
)
),
),
P.ValueSchema(
value_name="timestamp",
timestamp=P.TimestampSchema(
unit=P.TimestampUnit.SECONDS,
description="Current policy time in seconds"
)
timestamp=P.TimestampSchema(unit=P.TimestampUnit.SECONDS, description="Current policy time in seconds"),
),
P.ValueSchema(
value_name="joint_positions",
joint_positions=P.JointPositionsSchema(
joint_names=JOINT_NAMES,
unit=P.JointPositionUnit.RADIANS,
)
),
),
P.ValueSchema(
value_name="joint_velocities",
value_name="joint_velocities",
joint_velocities=P.JointVelocitiesSchema(
joint_names=JOINT_NAMES,
unit=P.JointVelocityUnit.RADIANS_PER_SECOND,
)
),
),
P.ValueSchema(
value_name="previous_actions",
joint_positions=P.JointPositionsSchema(
joint_names=JOINT_NAMES,
unit=P.JointPositionUnit.RADIANS
)
joint_positions=P.JointPositionsSchema(joint_names=JOINT_NAMES, unit=P.JointPositionUnit.RADIANS),
),
# Abusing the IMU schema to pass in euler and angular velocity instead of raw sensor data
P.ValueSchema(
Expand All @@ -158,16 +153,16 @@ def play(args: argparse.Namespace) -> None:
use_accelerometer=False,
use_gyroscope=True,
use_magnetometer=False,
)
),
),
P.ValueSchema(
value_name="imu_orientation",
imu=P.ImuSchema(
use_accelerometer=True,
use_gyroscope=False,
use_magnetometer=False,
)
)
),
),
]
)

Expand All @@ -184,10 +179,7 @@ def play(args: argparse.Namespace) -> None:
)

# Create the full model schema
model_schema = P.ModelSchema(
input_schema=input_schema,
output_schema=output_schema
)
model_schema = P.ModelSchema(input_schema=input_schema, output_schema=output_schema)

if args.export_onnx:
jit_policy = get_actor_jit(policy)
Expand Down
8 changes: 4 additions & 4 deletions sim/sim2sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import argparse
import math
import os

# Add local kinfer to the path because kinfer requires 3.11 and we are using 3.8
import sys
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
Expand All @@ -17,15 +20,12 @@
import onnxruntime as ort
import pygame
import torch
# Add local kinfer to the path because kinfer requires 3.11 and we are using 3.8
import sys

kinfer_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'third_party', 'kinfer'))
kinfer_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "third_party", "kinfer"))
sys.path.append(kinfer_path)

from kinfer.export.pytorch import export_to_onnx
from kinfer.inference.python import ONNXModel

from scipy.spatial.transform import Rotation as R
from tqdm import tqdm

Expand Down
1 change: 1 addition & 0 deletions sim/utils/resources.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
from typing import Any


def load_embodiment(embodiment: str) -> Any: # noqa: ANN401
# Dynamically import embodiment
module_name = f"sim.resources.{embodiment}.joints"
Expand Down

0 comments on commit f627fb3

Please sign in to comment.