Skip to content

Commit 855bbaa

Browse files
committed
fix(guard, runner): enforce http body size and ws msg size
1 parent ce7323c commit 855bbaa

File tree

11 files changed

+125
-47
lines changed

11 files changed

+125
-47
lines changed

engine/artifacts/errors/guard.invalid_request_body.json

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/artifacts/errors/guard.invalid_response_body.json

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/packages/guard-core/src/errors.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
11
use rivet_error::*;
22
use serde::{Deserialize, Serialize};
33

4+
#[derive(RivetError, Serialize, Deserialize)]
5+
#[error(
6+
"guard",
7+
"invalid_request_body",
8+
"Unable to parse request body.",
9+
"Unable to parse request body: {0}."
10+
)]
11+
pub struct InvalidRequestBody(pub String);
12+
13+
#[derive(RivetError, Serialize, Deserialize)]
14+
#[error(
15+
"guard",
16+
"invalid_response_body",
17+
"Unable to parse response body.",
18+
"Unable to parse response body: {0}."
19+
)]
20+
pub struct InvalidResponseBody(pub String);
21+
422
#[derive(RivetError, Serialize, Deserialize)]
523
#[error(
624
"guard",
@@ -42,7 +60,7 @@ pub struct UriParseError(pub String);
4260
pub struct RequestBuildError(pub String);
4361

4462
#[derive(RivetError)]
45-
#[error("guard", "upstream_error", "Upstream error.", "Upstream error: {0}")]
63+
#[error("guard", "upstream_error", "Upstream error.", "Upstream error: {0}.")]
4664
pub struct UpstreamError(pub String);
4765

4866
#[derive(RivetError, Serialize, Deserialize)]

engine/packages/guard-core/src/proxy_service.rs

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use anyhow::{Context, Result, bail, ensure};
22
use bytes::Bytes;
33
use futures_util::{SinkExt, StreamExt};
4-
use http_body_util::{BodyExt, Full};
4+
use http_body_util::{BodyExt, Full, Limited};
55
use hyper::{
66
Request, Response, StatusCode,
77
body::Incoming as BodyIncoming,
@@ -40,6 +40,7 @@ use crate::{
4040

4141
pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
4242
pub const X_RIVET_ERROR: HeaderName = HeaderName::from_static("x-rivet-error");
43+
pub const MAX_BODY_SIZE: usize = rivet_util::size::mebibytes(20) as usize;
4344

4445
const PROXY_STATE_CACHE_TTL: Duration = Duration::from_secs(60 * 60); // 1 hour
4546
const WEBSOCKET_CLOSE_LINGER: Duration = Duration::from_millis(100); // Keep TCP connection open briefly after WebSocket close
@@ -723,13 +724,11 @@ impl ProxyService {
723724
ResolveRouteOutput::Target(mut target) => {
724725
// Read the request body before proceeding with retries
725726
let (req_parts, body) = req.into_parts();
726-
let req_body = match http_body_util::BodyExt::collect(body).await {
727-
Ok(collected) => collected.to_bytes(),
728-
Err(err) => {
729-
tracing::debug!(?err, "Failed to read request body");
730-
Bytes::new()
731-
}
732-
};
727+
let req_body = Limited::new(body, MAX_BODY_SIZE)
728+
.collect()
729+
.await
730+
.map_err(|err| errors::InvalidRequestBody(err.to_string()).build())?
731+
.to_bytes();
733732

734733
// Use a value-returning loop to handle both errors and successful responses
735734
let mut attempts = 0;
@@ -742,7 +741,8 @@ impl ProxyService {
742741

743742
// Create the final request with body
744743
let proxied_req = builder
745-
.body(Full::<Bytes>::new(req_body.clone()))
744+
// NOTE: the `Bytes` type is cheaply cloneable, this is not resource intensive
745+
.body(Full::new(req_body.clone()))
746746
.map_err(|err| errors::RequestBuildError(err.to_string()).build())?;
747747

748748
// Send the request with timeout
@@ -800,10 +800,13 @@ impl ProxyService {
800800
return Ok(Response::from_parts(parts, streaming_body));
801801
} else {
802802
// For non-streaming responses, buffer as before
803-
let body_bytes = match BodyExt::collect(body).await {
804-
Ok(collected) => collected.to_bytes(),
805-
Err(_) => Bytes::new(),
806-
};
803+
let body_bytes = Limited::new(body, MAX_BODY_SIZE)
804+
.collect()
805+
.await
806+
.map_err(|err| {
807+
errors::InvalidResponseBody(err.to_string()).build()
808+
})?
809+
.to_bytes();
807810

808811
let full_body = ResponseBody::Full(Full::new(body_bytes));
809812
return Ok(Response::from_parts(parts, full_body));
@@ -857,15 +860,13 @@ impl ProxyService {
857860
ResolveRouteOutput::CustomServe(mut handler) => {
858861
// Collect request body
859862
let (req_parts, body) = req.into_parts();
860-
let collected_body = match http_body_util::BodyExt::collect(body).await {
861-
Ok(collected) => collected.to_bytes(),
862-
Err(err) => {
863-
tracing::debug!(?err, "Failed to read request body");
864-
Bytes::new()
865-
}
866-
};
863+
let req_body = Limited::new(body, MAX_BODY_SIZE)
864+
.collect()
865+
.await
866+
.map_err(|err| errors::InvalidRequestBody(err.to_string()).build())?
867+
.to_bytes();
867868
let req_collected =
868-
hyper::Request::from_parts(req_parts, Full::<Bytes>::new(collected_body));
869+
hyper::Request::from_parts(req_parts, Full::<Bytes>::new(req_body));
869870

870871
// Attempt request
871872
let mut attempts = 0;

engine/packages/guard-core/src/utils.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ pub(crate) fn err_into_response(err: anyhow::Error) -> Result<Response<ResponseB
181181
("guard", "service_unavailable") => StatusCode::SERVICE_UNAVAILABLE,
182182
("guard", "actor_ready_timeout") => StatusCode::SERVICE_UNAVAILABLE,
183183
("guard", "no_route") => StatusCode::NOT_FOUND,
184+
("guard", "invalid_request_body") => StatusCode::PAYLOAD_TOO_LARGE,
185+
("guard", "invalid_response_body") => StatusCode::PAYLOAD_TOO_LARGE,
184186
_ => StatusCode::BAD_REQUEST,
185187
};
186188

engine/packages/pegboard-gateway/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ impl PegboardGateway {
157157
max_age: None,
158158
});
159159

160+
// NOTE: Size constraints have already been applied by guard
160161
let body_bytes = req
161162
.into_body()
162163
.collect()

engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use gas::prelude::*;
66
use hyper_tungstenite::tungstenite::Message;
77
use pegboard::actor_kv;
88
use pegboard::pubsub_subjects::GatewayReceiverSubject;
9+
use rivet_guard_core::proxy_service::MAX_BODY_SIZE;
910
use rivet_guard_core::websocket_handle::WebSocketReceiver;
1011
use rivet_runner_protocol::{self as protocol, PROTOCOL_MK2_VERSION, versioned};
1112
use std::sync::{Arc, atomic::Ordering};
@@ -783,12 +784,15 @@ async fn handle_tunnel_message_mk2(
783784
ctx: &StandaloneCtx,
784785
msg: protocol::mk2::ToServerTunnelMessage,
785786
) -> Result<()> {
786-
// Publish message to UPS
787-
let gateway_reply_to = GatewayReceiverSubject::new(msg.message_id.gateway_id).to_string();
788-
789787
// Extract inner data length before consuming msg
790-
let inner_data_len = tunnel_message_inner_data_len(&msg.message_kind);
788+
let inner_data_len = tunnel_message_inner_data_len_mk2(&msg.message_kind);
791789

790+
// Enforce incoming payload size
791+
if inner_data_len > MAX_BODY_SIZE {
792+
return Err(errors::WsError::InvalidPacket(format!("payload too large")).build());
793+
}
794+
795+
let gateway_reply_to = GatewayReceiverSubject::new(msg.message_id.gateway_id).to_string();
792796
let msg_serialized =
793797
versioned::ToGateway::wrap_latest(protocol::mk2::ToGateway::ToServerTunnelMessage(msg))
794798
.serialize_with_embedded_version(PROTOCOL_MK2_VERSION)
@@ -800,6 +804,7 @@ async fn handle_tunnel_message_mk2(
800804
"publishing tunnel message to gateway"
801805
);
802806

807+
// Publish message to UPS
803808
ctx.ups()
804809
.context("failed to get UPS instance for tunnel message")?
805810
.publish(&gateway_reply_to, &msg_serialized, PublishOpts::one())
@@ -814,22 +819,6 @@ async fn handle_tunnel_message_mk2(
814819
Ok(())
815820
}
816821

817-
/// Returns the length of the inner data payload for a tunnel message kind.
818-
fn tunnel_message_inner_data_len(kind: &protocol::mk2::ToServerTunnelMessageKind) -> usize {
819-
use protocol::mk2::ToServerTunnelMessageKind;
820-
match kind {
821-
ToServerTunnelMessageKind::ToServerResponseStart(resp) => {
822-
resp.body.as_ref().map_or(0, |b| b.len())
823-
}
824-
ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.body.len(),
825-
ToServerTunnelMessageKind::ToServerWebSocketMessage(msg) => msg.data.len(),
826-
ToServerTunnelMessageKind::ToServerResponseAbort
827-
| ToServerTunnelMessageKind::ToServerWebSocketOpen(_)
828-
| ToServerTunnelMessageKind::ToServerWebSocketMessageAck(_)
829-
| ToServerTunnelMessageKind::ToServerWebSocketClose(_) => 0,
830-
}
831-
}
832-
833822
#[tracing::instrument(skip_all)]
834823
async fn handle_tunnel_message_mk1(
835824
ctx: &StandaloneCtx,
@@ -843,6 +832,14 @@ async fn handle_tunnel_message_mk1(
843832
return Ok(());
844833
}
845834

835+
// Extract inner data length before consuming msg
836+
let inner_data_len = tunnel_message_inner_data_len_mk1(&msg.message_kind);
837+
838+
// Enforce incoming payload size
839+
if inner_data_len > MAX_PAYLOAD_SIZE {
840+
return Err(errors::WsError::InvalidPacket(format!("payload too large")).build());
841+
}
842+
846843
// Publish message to UPS
847844
let gateway_reply_to = GatewayReceiverSubject::new(msg.message_id.gateway_id).to_string();
848845
let msg_serialized = versioned::ToGateway::v3_to_v6(versioned::ToGateway::V3(
@@ -864,6 +861,39 @@ async fn handle_tunnel_message_mk1(
864861
Ok(())
865862
}
866863

864+
/// Returns the length of the inner data payload for a tunnel message kind.
865+
fn tunnel_message_inner_data_len_mk2(kind: &protocol::mk2::ToServerTunnelMessageKind) -> usize {
866+
use protocol::mk2::ToServerTunnelMessageKind;
867+
match kind {
868+
ToServerTunnelMessageKind::ToServerResponseStart(resp) => {
869+
resp.body.as_ref().map_or(0, |b| b.len())
870+
}
871+
ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.body.len(),
872+
ToServerTunnelMessageKind::ToServerWebSocketMessage(msg) => msg.data.len(),
873+
ToServerTunnelMessageKind::ToServerResponseAbort
874+
| ToServerTunnelMessageKind::ToServerWebSocketOpen(_)
875+
| ToServerTunnelMessageKind::ToServerWebSocketMessageAck(_)
876+
| ToServerTunnelMessageKind::ToServerWebSocketClose(_) => 0,
877+
}
878+
}
879+
880+
/// Returns the length of the inner data payload for a tunnel message kind.
881+
fn tunnel_message_inner_data_len_mk1(kind: &protocol::ToServerTunnelMessageKind) -> usize {
882+
use protocol::ToServerTunnelMessageKind;
883+
match kind {
884+
ToServerTunnelMessageKind::ToServerResponseStart(resp) => {
885+
resp.body.as_ref().map_or(0, |b| b.len())
886+
}
887+
ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.body.len(),
888+
ToServerTunnelMessageKind::ToServerWebSocketMessage(msg) => msg.data.len(),
889+
ToServerTunnelMessageKind::ToServerResponseAbort
890+
| ToServerTunnelMessageKind::ToServerWebSocketOpen(_)
891+
| ToServerTunnelMessageKind::ToServerWebSocketMessageAck(_)
892+
| ToServerTunnelMessageKind::ToServerWebSocketClose(_)
893+
| ToServerTunnelMessageKind::DeprecatedTunnelAck => 0,
894+
}
895+
}
896+
867897
/// Send ack message for deprecated tunnel versions.
868898
///
869899
/// We have to parse as specifically a v2 message since we need the exact request & message ID

engine/sdks/typescript/runner/src/tunnel.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import {
1515
stringifyToClientTunnelMessageKind,
1616
stringifyToServerTunnelMessageKind,
1717
} from "./stringify";
18-
import { arraysEqual, idToStr, stringifyError, unreachable } from "./utils";
18+
import { arraysEqual, idToStr, MAX_BODY_SIZE, stringifyError, unreachable } from "./utils";
1919
import {
2020
HIBERNATABLE_SYMBOL,
2121
WebSocketTunnelAdapter,
@@ -855,6 +855,10 @@ export class Tunnel {
855855
// Read the body first to get the actual content
856856
const body = response.body ? await response.arrayBuffer() : null;
857857

858+
if (body && body.byteLength > MAX_BODY_SIZE) {
859+
throw new Error("Response body too large");
860+
}
861+
858862
// Convert headers to map and add Content-Length if not present
859863
const headers = new Map<string, string>();
860864
response.headers.forEach((value, key) => {
@@ -1079,7 +1083,7 @@ export class Tunnel {
10791083
});
10801084

10811085
if (clientMessageIndex < 0 || clientMessageIndex > 65535)
1082-
throw new Error("invalid websocket ack index");
1086+
throw new Error("Invalid websocket ack index");
10831087

10841088
// Get the actor to find the gatewayId
10851089
//
@@ -1157,7 +1161,7 @@ function buildRequestForWebSocket(
11571161
};
11581162

11591163
if (!path.startsWith("/")) {
1160-
throw new Error("path must start with leading slash");
1164+
throw new Error("Path must start with leading slash");
11611165
}
11621166

11631167
const request = new Request(`http://actor${path}`, {

engine/sdks/typescript/runner/src/utils.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import { logger } from "./log";
22

3+
// 20MiB. Keep in sync with MAX_BODY_SIZE from engine/packages/guard-core/src/proxy_service.rs
4+
export const MAX_BODY_SIZE = 20 * 1024 * 1024;
5+
36
export function unreachable(x: never): never {
47
throw `Unreachable: ${x}`;
58
}

engine/sdks/typescript/runner/src/websocket-tunnel-adapter.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import type { Logger } from "pino";
22
import { VirtualWebSocket, type UniversalWebSocket, type RivetMessageEvent } from "@rivetkit/virtual-websocket";
33
import type { Tunnel } from "./tunnel";
4-
import { wrappingAddU16, wrappingLteU16, wrappingSubU16 } from "./utils";
4+
import { MAX_BODY_SIZE, wrappingAddU16, wrappingLteU16, wrappingSubU16 } from "./utils";
55

66
export const HIBERNATABLE_SYMBOL = Symbol("hibernatable");
77

@@ -70,11 +70,20 @@ export class WebSocketTunnelAdapter {
7070
let messageData: string | ArrayBuffer;
7171

7272
if (typeof data === "string") {
73+
let encoder = new TextEncoder();
74+
if (encoder.encode(data).byteLength > MAX_BODY_SIZE) {
75+
throw new Error("WebSocket message too large");
76+
}
77+
7378
messageData = data;
7479
} else if (data instanceof ArrayBuffer) {
80+
if (data.byteLength > MAX_BODY_SIZE) throw new Error("WebSocket message too large");
81+
7582
isBinary = true;
7683
messageData = data;
7784
} else if (ArrayBuffer.isView(data)) {
85+
if (data.byteLength > MAX_BODY_SIZE) throw new Error("WebSocket message too large");
86+
7887
isBinary = true;
7988
const view = data;
8089
const buffer = view.buffer instanceof SharedArrayBuffer

0 commit comments

Comments
 (0)