Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ addr = "https://api.optimatist.com"
# in seconds
heartbeat_interval = 1
instance_id_file = "/etc/psh/instance.id"
max_retries = 3
# in seconds
base_delay = 1

[remote.rpc.data_export]
buf_size = 4096
Expand Down
12 changes: 12 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::{fs, path::Path};

use anyhow::Result;
use serde::Deserialize;
use std::time::Duration;

const TEMPLATE: &str = include_str!("../doc/config.toml");

Expand Down Expand Up @@ -56,6 +57,17 @@ pub struct RpcConfig {
pub heartbeat_interval: u64,
pub instance_id_file: String,
pub data_export: DataExportConfig,
pub max_retries: Option<u32>,
#[serde(deserialize_with = "deserialize_duration")]
pub base_delay: Option<Duration>,
}

fn deserialize_duration<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: serde::Deserializer<'de>,
{
let seconds: u64 = serde::Deserialize::deserialize(deserializer)?;
Ok(Some(Duration::from_secs(seconds)))
}

#[derive(Deserialize)]
Expand Down
89 changes: 82 additions & 7 deletions src/services/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ use psh_proto::{
ExportDataReq, GetTaskReq, HeartbeatReq, TaskDoneReq, Unit,
psh_service_client::PshServiceClient,
};
use std::time::Duration;
use tokio::time::sleep;
use tonic::Code;
use tonic::{
Request,
transport::{Channel, ClientTlsConfig, Endpoint},
Expand All @@ -29,6 +32,8 @@ use crate::{config::RpcConfig, runtime::Task, services::host_info::new_info_req}
pub struct RpcClient {
token: String,
client: PshServiceClient<Channel>,
max_retries: u32,
base_delay: Duration,
}

fn into_req<T>(message: T, token: &str) -> Result<Request<T>> {
Expand All @@ -38,12 +43,69 @@ fn into_req<T>(message: T, token: &str) -> Result<Request<T>> {
Ok(req)
}

async fn retry_with_backoff<F, T>(
max_retries: u32,
base_delay: Duration,
mut operation: F,
) -> Result<T, tonic::Status>
where
F: AsyncFnMut() -> Result<T, tonic::Status>,
{
let mut attempts = 0;
loop {
match operation().await {
Ok(resp) => break Ok(resp),
Err(status) => {
attempts += 1;
if attempts >= max_retries {
tracing::error!("RpcClient max retries reached after {} attempts", attempts);
break Err(status);
}

let retry_delay = base_delay * (2_u32.pow(attempts - 1));

if status.code() == Code::Unknown && status.message().contains("transport error") {
tracing::warn!(
"RpcClient transport error detected (attempt {}/{}), retrying in {:?}...",
attempts,
max_retries,
retry_delay
);
sleep(retry_delay).await;
continue;
}
break Err(status);
}
}
}
}

impl RpcClient {
pub async fn new(config: &RpcConfig, token: String) -> Result<Self> {
let ep = Endpoint::from_shared(config.addr.clone())?
// 连接相关设置
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(30))
// TCP 相关设置
.tcp_keepalive(Some(Duration::from_secs(60)))
.tcp_nodelay(true)
// HTTP/2相关设置
.http2_keep_alive_interval(Duration::from_secs(30))
.keep_alive_while_idle(true)
// 并发和限流
.concurrency_limit(256)
.rate_limit(5, Duration::from_secs(1))
// TLS 配置
.tls_config(ClientTlsConfig::new().with_native_roots())?;

let client: PshServiceClient<Channel> = PshServiceClient::connect(ep).await?;
Ok(Self { token, client })

Ok(Self {
token,
client,
max_retries: config.max_retries.unwrap_or(3),
base_delay: config.base_delay.unwrap_or(Duration::from_secs(1)),
})
}

pub async fn send_host_info(&mut self, instance_id: String) -> Result<()> {
Expand All @@ -58,18 +120,31 @@ impl RpcClient {
self.client.export_data(req).await?;
Ok(())
}

pub async fn heartbeat(&mut self, message: HeartbeatReq) -> Result<()> {
let req = into_req(message, &self.token)?;
self.client.heartbeat(req).await?;
let token = &self.token;

retry_with_backoff(self.max_retries, self.base_delay, async || {
let req = into_req(message.clone(), token)
.map_err(|e| tonic::Status::invalid_argument(e.to_string()))?;
self.client.heartbeat(req).await
})
.await?;
Ok(())
}

pub async fn get_task(&mut self, instance_id: String) -> Result<Option<Task>> {
let req = into_req(GetTaskReq { instance_id }, &self.token)?;
let get_task_req = GetTaskReq { instance_id };
let token = &self.token;

let Some(task) = self.client.get_task(req).await?.into_inner().task else {
return Ok(None);
let response = retry_with_backoff(self.max_retries, self.base_delay, async || {
let req = into_req(get_task_req.clone(), token)
.map_err(|e| tonic::Status::invalid_argument(e.to_string()))?;
self.client.get_task(req).await
})
.await?;
let task = match response.into_inner().task {
Some(task) => task,
None => return Ok(None),
};

let end_time = match Utc.timestamp_millis_opt(task.end_time as _) {
Expand Down
Loading