Skip to content

Commit 4a61e69

Browse files
nfreqalik-git
andauthored
Implement PolicyService (#61)
* Implement PolicyService * lint --------- Co-authored-by: Ali Kuwajerwala <[email protected]>
1 parent 0fba226 commit 4a61e69

File tree

14 files changed

+309
-6
lines changed

14 files changed

+309
-6
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ members = [
77
]
88

99
[workspace.package]
10-
version = "0.7.8"
10+
version = "0.7.9"
1111
authors = [
1212
"Benjamin Bolte <[email protected]>",
1313
"Denys Bezmenov <[email protected]>",

kos-py/pykos/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""KOS Python client."""
22

3-
__version__ = "0.7.8"
3+
__version__ = "0.7.9"
44

55
from . import services
66
from .client import KOS

kos-py/pykos/client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pykos.services.imu import IMUServiceClient
1313
from pykos.services.inference import InferenceServiceClient
1414
from pykos.services.led_matrix import LEDMatrixServiceClient
15+
from pykos.services.policy import PolicyServiceClient
1516
from pykos.services.process_manager import ProcessManagerServiceClient
1617
from pykos.services.sim import SimServiceClient
1718
from pykos.services.sound import SoundServiceClient
@@ -40,6 +41,7 @@ def __init__(self, ip: str = "localhost", port: int = 50051) -> None:
4041
self._process_manager: ProcessManagerServiceClient | None = None
4142
self._inference: InferenceServiceClient | None = None
4243
self._sim: SimServiceClient | None = None
44+
self._policy: PolicyServiceClient | None = None
4345

4446
@property
4547
def imu(self) -> IMUServiceClient:
@@ -81,6 +83,14 @@ def process_manager(self) -> ProcessManagerServiceClient:
8183
raise RuntimeError("Process Manager client not initialized! Must call `connect()` manually.")
8284
return self._process_manager
8385

86+
@property
87+
def policy(self) -> PolicyServiceClient:
88+
if self._policy is None:
89+
self.connect()
90+
if self._policy is None:
91+
raise RuntimeError("Policy client not initialized! Must call `connect()` manually.")
92+
return self._policy
93+
8494
@property
8595
def inference(self) -> InferenceServiceClient:
8696
if self._inference is None:
@@ -108,6 +118,7 @@ def connect(self) -> None:
108118
self._led_matrix = LEDMatrixServiceClient(self._channel)
109119
self._sound = SoundServiceClient(self._channel)
110120
self._process_manager = ProcessManagerServiceClient(self._channel)
121+
self._policy = PolicyServiceClient(self._channel)
111122
self._inference = InferenceServiceClient(self._channel)
112123
self._sim = SimServiceClient(self._channel)
113124

kos-py/pykos/services/policy.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Policy service client."""
2+
3+
import grpc.aio
4+
from google.protobuf.empty_pb2 import Empty
5+
6+
from kos_protos import policy_pb2, policy_pb2_grpc
7+
from kos_protos.policy_pb2 import StartPolicyRequest
8+
from pykos.services import AsyncClientBase
9+
10+
11+
class PolicyServiceClient(AsyncClientBase):
12+
"""Client for the PolicyService."""
13+
14+
def __init__(self, channel: grpc.aio.Channel) -> None:
15+
super().__init__()
16+
self.stub = policy_pb2_grpc.PolicyServiceStub(channel)
17+
18+
async def start_policy(
19+
self, action: str, action_scale: float, episode_length: int, dry_run: bool
20+
) -> policy_pb2.StartPolicyResponse:
21+
"""Start policy execution.
22+
23+
Args:
24+
action: The action string for the policy
25+
action_scale: Scale factor for actions
26+
episode_length: Length of the episode
27+
dry_run: Whether to perform a dry run
28+
29+
Returns:
30+
The response from the server.
31+
"""
32+
request = StartPolicyRequest(
33+
action=action,
34+
action_scale=action_scale,
35+
episode_length=episode_length,
36+
dry_run=dry_run,
37+
)
38+
return await self.stub.StartPolicy(request)
39+
40+
async def stop_policy(self, request: Empty = Empty()) -> policy_pb2.StopPolicyResponse:
41+
"""Stop policy execution.
42+
43+
Returns:
44+
The response from the server.
45+
"""
46+
return await self.stub.StopPolicy(request)
47+
48+
async def get_state(self, request: Empty = Empty()) -> policy_pb2.GetStateResponse:
49+
"""Get the current policy state.
50+
51+
Returns:
52+
The response from the server containing the policy state.
53+
"""
54+
return await self.stub.GetState(request)

kos-stub/src/lib.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
mod actuator;
22
mod imu;
3+
mod policy;
34
mod process_manager;
45
use crate::actuator::StubActuator;
56
use crate::imu::StubIMU;
7+
use crate::policy::StubPolicy;
68
use crate::process_manager::StubProcessManager;
79
use async_trait::async_trait;
810
use kos::hal::Operation;
911
use kos::kos_proto::actuator::actuator_service_server::ActuatorServiceServer;
1012
use kos::kos_proto::imu::imu_service_server::ImuServiceServer;
13+
use kos::kos_proto::policy::policy_service_server::PolicyServiceServer;
1114
use kos::kos_proto::process_manager::process_manager_service_server::ProcessManagerServiceServer;
12-
use kos::services::{ActuatorServiceImpl, IMUServiceImpl, ProcessManagerServiceImpl};
15+
use kos::services::{
16+
ActuatorServiceImpl, IMUServiceImpl, PolicyServiceImpl, ProcessManagerServiceImpl,
17+
};
1318
use kos::{services::OperationsServiceImpl, Platform, ServiceEnum};
19+
1420
use std::future::Future;
1521
use std::pin::Pin;
1622
use std::sync::Arc;
@@ -52,6 +58,7 @@ impl Platform for StubPlatform {
5258
let actuator = StubActuator::new(operations_service.clone());
5359
let imu = StubIMU::new(operations_service.clone());
5460
let process_manager = StubProcessManager::new();
61+
let policy = StubPolicy::new();
5562

5663
Ok(vec![
5764
ServiceEnum::Actuator(ActuatorServiceServer::new(ActuatorServiceImpl::new(
@@ -61,6 +68,10 @@ impl Platform for StubPlatform {
6168
ProcessManagerServiceImpl::new(Arc::new(process_manager)),
6269
)),
6370
ServiceEnum::Imu(ImuServiceServer::new(IMUServiceImpl::new(Arc::new(imu)))),
71+
ServiceEnum::Policy(PolicyServiceServer::new(
72+
// Add this block
73+
PolicyServiceImpl::new(Arc::new(policy)),
74+
)),
6475
])
6576
})
6677
}

kos-stub/src/policy.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use async_trait::async_trait;
2+
use eyre::Result;
3+
use kos::hal::{GetStateResponse, Policy, StartPolicyResponse, StopPolicyResponse};
4+
use kos::kos_proto::common::{Error, ErrorCode};
5+
use std::collections::HashMap;
6+
use std::sync::Mutex;
7+
use uuid::Uuid;
8+
9+
pub struct StubPolicy {
10+
policy_uuid: Mutex<Option<String>>,
11+
state: Mutex<HashMap<String, String>>,
12+
}
13+
14+
impl Default for StubPolicy {
15+
fn default() -> Self {
16+
Self::new()
17+
}
18+
}
19+
20+
impl StubPolicy {
21+
pub fn new() -> Self {
22+
StubPolicy {
23+
policy_uuid: Mutex::new(None),
24+
state: Mutex::new(HashMap::new()),
25+
}
26+
}
27+
}
28+
29+
#[async_trait]
30+
impl Policy for StubPolicy {
31+
async fn start_policy(
32+
&self,
33+
action: String,
34+
action_scale: f32,
35+
episode_length: i32,
36+
dry_run: bool,
37+
) -> Result<StartPolicyResponse> {
38+
let mut policy_uuid = self.policy_uuid.lock().unwrap();
39+
if policy_uuid.is_some() {
40+
return Ok(StartPolicyResponse {
41+
policy_uuid: None,
42+
error: Some(Error {
43+
code: ErrorCode::InvalidArgument as i32,
44+
message: "Policy is already running".to_string(),
45+
}),
46+
});
47+
}
48+
49+
let new_uuid = Uuid::new_v4().to_string();
50+
*policy_uuid = Some(new_uuid.clone());
51+
52+
// Update state with policy parameters
53+
let mut state = self.state.lock().unwrap();
54+
state.insert("action".to_string(), action);
55+
state.insert("action_scale".to_string(), action_scale.to_string());
56+
state.insert("episode_length".to_string(), episode_length.to_string());
57+
state.insert("dry_run".to_string(), dry_run.to_string());
58+
59+
Ok(StartPolicyResponse {
60+
policy_uuid: Some(new_uuid),
61+
error: None,
62+
})
63+
}
64+
65+
async fn stop_policy(&self) -> Result<StopPolicyResponse> {
66+
let mut policy_uuid = self.policy_uuid.lock().unwrap();
67+
if policy_uuid.is_none() {
68+
return Ok(StopPolicyResponse {
69+
policy_uuid: None,
70+
error: Some(Error {
71+
code: ErrorCode::InvalidArgument as i32,
72+
message: "Policy is not running".to_string(),
73+
}),
74+
});
75+
}
76+
77+
let stopped_uuid = policy_uuid.take().unwrap();
78+
79+
// Clear the state when stopping
80+
let mut state = self.state.lock().unwrap();
81+
state.clear();
82+
83+
Ok(StopPolicyResponse {
84+
policy_uuid: Some(stopped_uuid),
85+
error: None,
86+
})
87+
}
88+
89+
async fn get_state(&self) -> Result<GetStateResponse> {
90+
let state = self.state.lock().unwrap();
91+
Ok(GetStateResponse {
92+
state: state.clone(),
93+
error: None,
94+
})
95+
}
96+
}

kos/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ fn main() {
1616
"kos/sim.proto",
1717
"kos/inference.proto",
1818
"kos/process_manager.proto",
19+
"kos/policy.proto",
1920
"kos/system.proto",
2021
"kos/led_matrix.proto",
2122
"kos/sound.proto",

kos/proto/kos/policy.proto

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
syntax = "proto3";
2+
3+
package kos.policy;
4+
5+
import "google/protobuf/empty.proto";
6+
import "kos/common.proto";
7+
8+
option go_package = "kos/policy;policy";
9+
option java_package = "com.kos.policy";
10+
option csharp_namespace = "KOS.Policy";
11+
12+
// The PolicyService manages policy execution.
13+
service PolicyService {
14+
// Starts policy execution.
15+
rpc StartPolicy(StartPolicyRequest) returns (StartPolicyResponse);
16+
17+
// Stops policy execution.
18+
rpc StopPolicy(google.protobuf.Empty) returns (StopPolicyResponse);
19+
20+
// Gets the current policy state.
21+
rpc GetState(google.protobuf.Empty) returns (GetStateResponse);
22+
}
23+
24+
message StartPolicyRequest {
25+
string action = 1;
26+
float action_scale = 2;
27+
int32 episode_length = 3;
28+
bool dry_run = 4;
29+
}
30+
31+
message StartPolicyResponse {
32+
optional string policy_uuid = 1;
33+
kos.common.Error error = 2;
34+
}
35+
36+
message StopPolicyResponse {
37+
optional string policy_uuid = 1;
38+
kos.common.Error error = 2;
39+
}
40+
41+
message GetStateResponse {
42+
map<string, string> state = 1;
43+
kos.common.Error error = 2;
44+
}

kos/src/daemon.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ fn add_service_to_router(
4040
ServiceEnum::Inference(svc) => router.add_service(svc),
4141
ServiceEnum::LEDMatrix(svc) => router.add_service(svc),
4242
ServiceEnum::Sound(svc) => router.add_service(svc),
43+
ServiceEnum::Policy(svc) => router.add_service(svc),
4344
}
4445
}
4546

kos/src/grpc_interface.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ pub mod kos {
1919
tonic::include_proto!("kos/kos.processmanager");
2020
}
2121

22+
pub mod policy {
23+
tonic::include_proto!("kos/kos.policy");
24+
}
25+
2226
pub mod system {
2327
tonic::include_proto!("kos/kos.system");
2428
}

kos/src/hal.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ pub use crate::grpc_interface::google::longrunning::*;
22
pub use crate::grpc_interface::kos;
33
pub use crate::grpc_interface::kos::common::ActionResponse;
44
pub use crate::kos_proto::{
5-
actuator::*, common::ActionResult, imu::*, inference::*, led_matrix::*, process_manager::*,
6-
sound::*,
5+
actuator::*, common::ActionResult, imu::*, inference::*, led_matrix::*, policy::*,
6+
process_manager::*, sound::*,
77
};
88
use async_trait::async_trait;
99
use bytes::Bytes;
@@ -55,6 +55,19 @@ pub trait ProcessManager: Send + Sync {
5555
async fn stop_kclip(&self) -> Result<KClipStopResponse>;
5656
}
5757

58+
#[async_trait]
59+
pub trait Policy: Send + Sync {
60+
async fn start_policy(
61+
&self,
62+
action: String,
63+
action_scale: f32,
64+
episode_length: i32,
65+
dry_run: bool,
66+
) -> Result<StartPolicyResponse>;
67+
async fn stop_policy(&self) -> Result<StopPolicyResponse>;
68+
async fn get_state(&self) -> Result<GetStateResponse>;
69+
}
70+
5871
#[async_trait]
5972
pub trait Inference: Send + Sync {
6073
async fn upload_model(

kos/src/lib.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ use hal::imu_service_server::ImuServiceServer;
1919
use hal::inference_service_server::InferenceServiceServer;
2020
use hal::led_matrix_service_server::LedMatrixServiceServer;
2121
use hal::process_manager_service_server::ProcessManagerServiceServer;
22+
use hal::policy_service_server::PolicyServiceServer;
2223
use hal::sound_service_server::SoundServiceServer;
2324
use services::OperationsServiceImpl;
2425
use services::{
2526
ActuatorServiceImpl, IMUServiceImpl, InferenceServiceImpl, LEDMatrixServiceImpl,
26-
ProcessManagerServiceImpl, SoundServiceImpl,
27+
ProcessManagerServiceImpl, SoundServiceImpl, PolicyServiceImpl,
2728
};
2829
use std::fmt::Debug;
2930
use std::future::Future;
@@ -48,6 +49,12 @@ impl Debug for ProcessManagerServiceImpl {
4849
}
4950
}
5051

52+
impl Debug for PolicyServiceImpl {
53+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54+
write!(f, "PolicyServiceImpl")
55+
}
56+
}
57+
5158
impl Debug for InferenceServiceImpl {
5259
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5360
write!(f, "InferenceServiceImpl")
@@ -74,6 +81,7 @@ pub enum ServiceEnum {
7481
Inference(InferenceServiceServer<InferenceServiceImpl>),
7582
LEDMatrix(LedMatrixServiceServer<LEDMatrixServiceImpl>),
7683
Sound(SoundServiceServer<SoundServiceImpl>),
84+
Policy(PolicyServiceServer<PolicyServiceImpl>),
7785
}
7886

7987
#[async_trait]

0 commit comments

Comments
 (0)