Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update sim reset function #53

Merged
merged 2 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions kos-py/pykos/services/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,29 @@
from pykos.services import AsyncClientBase


class DefaultPosition(TypedDict):
qpos: list[float]
class StartingPosition(TypedDict):
x: float
y: float
z: float


class StartingQuaternion(TypedDict):
x: float
y: float
z: float
w: float


class JointPosition(TypedDict):
name: str
pos: NotRequired[float]
vel: NotRequired[float]


class ResetRequest(TypedDict):
initial_state: NotRequired[DefaultPosition]
randomize: NotRequired[bool]
pos: NotRequired[StartingPosition]
quat: NotRequired[StartingQuaternion]
joints: NotRequired[list[JointPosition]]


class StepRequest(TypedDict):
Expand All @@ -27,7 +43,6 @@ class StepRequest(TypedDict):
class SimulationParameters(TypedDict):
time_scale: NotRequired[float]
gravity: NotRequired[float]
initial_state: NotRequired[DefaultPosition]


class SimServiceClient(AsyncClientBase):
Expand Down Expand Up @@ -55,12 +70,28 @@ async def reset(self, **kwargs: Unpack[ResetRequest]) -> common_pb2.ActionRespon
Returns:
ActionResponse indicating success/failure
"""
initial_state = None
if "initial_state" in kwargs:
pos = kwargs["initial_state"]
initial_state = sim_pb2.DefaultPosition(qpos=pos["qpos"])

request = sim_pb2.ResetRequest(initial_state=initial_state, randomize=kwargs.get("randomize"))
pos = None
if (pos_dict := kwargs.get("pos")) is not None:
pos = sim_pb2.StartingPosition(
x=pos_dict["x"],
y=pos_dict["y"],
z=pos_dict["z"],
)

quat = None
if (quat_dict := kwargs.get("quat")) is not None:
quat = sim_pb2.StartingQuaternion(
x=quat_dict["x"],
y=quat_dict["y"],
z=quat_dict["z"],
w=quat_dict["w"],
)

joints_values = None
if (joints_dict := kwargs.get("joints")) is not None:
joints_values = sim_pb2.JointValues(values=[sim_pb2.JointValue(**joint) for joint in joints_dict])

request = sim_pb2.ResetRequest(pos=pos, quat=quat, joints=joints_values)
return await self.stub.Reset(request)

async def set_paused(self, paused: bool) -> common_pb2.ActionResponse:
Expand Down Expand Up @@ -95,7 +126,6 @@ async def set_parameters(self, **kwargs: Unpack[SimulationParameters]) -> common
>>> client.set_parameters(
... time_scale=1.0,
... gravity=9.81,
... initial_state={"qpos": [0.0, 0.0, 0.0]}
... )

Args:
Expand All @@ -107,13 +137,9 @@ async def set_parameters(self, **kwargs: Unpack[SimulationParameters]) -> common
Returns:
ActionResponse indicating success/failure
"""
initial_state = None
if "initial_state" in kwargs:
pos = kwargs["initial_state"]
initial_state = sim_pb2.DefaultPosition(qpos=pos["qpos"])

params = sim_pb2.SimulationParameters(
time_scale=kwargs.get("time_scale"), gravity=kwargs.get("gravity"), initial_state=initial_state
time_scale=kwargs.get("time_scale"),
gravity=kwargs.get("gravity"),
)
request = sim_pb2.SetParametersRequest(parameters=params)
return await self.stub.SetParameters(request)
Expand Down
41 changes: 24 additions & 17 deletions kos/proto/kos/sim.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,41 @@ service SimulationService {
rpc GetParameters(google.protobuf.Empty) returns (GetParametersResponse);
}

// Default position for the simulation (initial state)
message DefaultPosition {
repeated float qpos = 1;
message StartingPosition {
float x = 1;
float y = 2;
float z = 3;
}

message StartingQuaternion {
float x = 1;
float y = 2;
float z = 3;
float w = 4;
}

message JointValue {
string name = 1;
optional float pos = 2;
optional float vel = 3;
}

message JointValues {
repeated JointValue values = 1;
}

// Request to reset the simulation to initial state
message ResetRequest {
// If provided, reset to this specific state, otherwise use default
optional DefaultPosition initial_state = 1;
// If true, randomize the initial state within pre-set bounds
optional bool randomize = 2;
optional StartingPosition pos = 1;
optional StartingQuaternion quat = 2;
optional JointValues joints = 3;
}

// Request to pause or resume the simulation
message SetPausedRequest {
bool paused = 1;
}

// Request to step the simulation forward
message StepRequest {
// Number of simulation steps to take
uint32 num_steps = 1;
// Time per step in seconds
optional float step_size = 2;
}

Expand All @@ -62,12 +74,7 @@ message GetParametersResponse {
kos.common.Error error = 2; // Error details if any
}

// Controllable parameters for the simulation
message SimulationParameters {
// Time scale for the simulation
optional float time_scale = 1;
// Strength of gravity for the simulation
optional float gravity = 2;
// Initial state for the simulation
optional DefaultPosition initial_state = 3;
}