diff --git a/src/main.rs b/src/main.rs index daed748..f9447a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -176,6 +176,97 @@ fn is_header_allowed(header: &str) -> bool { ) } +struct RangeRequest { + start: u64, + end: u64, + total_size: u64, +} + +fn parse_range(range_str: &str, total_size: u64) -> Option { + let range_parts: Vec<&str> = range_str.split('-').collect(); + + if range_parts.len() != 2 { + return None; + } + + let start_result = range_parts[0].parse::(); + let start = start_result.unwrap_or(0); + + // Parse end position - if empty, use total_size-1 (open-ended range) + let end = if range_parts[1].is_empty() { + total_size.saturating_sub(1) // Avoid underflow + } else { + let end_result = range_parts[1].parse::(); + let parsed_end = end_result.unwrap_or(total_size.saturating_sub(1)); + parsed_end.min(total_size.saturating_sub(1)) + }; + + Some(RangeRequest { + start, + end, + total_size, + }) +} + +fn handle_range_response_correction( + response: &mut HttpResponseBuilder, + range_str: Option<&String>, + resp: &reqwest::Response, +) -> Option<()> { + // Check if this is a range request (either in headers or query string) + let has_range_request = resp.headers().contains_key("range") || range_str.is_some(); + + // Only apply correction if we have a range request and response is 200 (should be 206) + if !has_range_request + || !resp.status().is_success() + || resp.status() == reqwest::StatusCode::PARTIAL_CONTENT + { + return None; + } + + // Get content length from response headers + let total_size = resp + .headers() + .get("content-length")? + .to_str() + .ok()? + .parse::() + .ok()?; + + if total_size == 0 { + return None; + } + + // Parse the range request into start/end positions + let range_request = parse_range(range_str?, total_size)?; + + // Apply proper HTTP range headers + apply_range_headers(response, &range_request); + + Some(()) +} + +fn apply_range_headers(response: &mut HttpResponseBuilder, range_request: &RangeRequest) { + let content_range_value = format!( + "bytes {}-{}/{}", + range_request.start, range_request.end, range_request.total_size + ); + + // Use total_size when range was truncated by YouTube (start > end) + // Otherwise calculate normally + let actual_length = if range_request.start > range_request.end { + range_request.total_size + } else { + range_request.end - range_request.start + 1 + }; + + // Set proper partial content response + response.status(actix_web::http::StatusCode::PARTIAL_CONTENT); + // Set required headers for partial content responses + response.insert_header(("Content-Range", content_range_value)); + response.insert_header(("Content-Length", actual_length.to_string())); +} + async fn index(req: HttpRequest) -> Result> { if req.method() == actix_web::http::Method::OPTIONS { let mut response = HttpResponse::Ok(); @@ -287,7 +378,8 @@ async fn index(req: HttpRequest) -> Result> { if let Some(expiry) = query.get("expire") { let expiry = expiry.parse::()?; let now = SystemTime::now(); - let now = now.duration_since(UNIX_EPOCH) + let now = now + .duration_since(UNIX_EPOCH) .expect("Time went backwards") .as_secs() as i64; if now > expiry { @@ -389,6 +481,10 @@ async fn index(req: HttpRequest) -> Result> { } } + // Fix range request handling - convert 200 to 206 if we have a range request + // and ensure Content-Range header is present + handle_range_response_correction(&mut response, range.as_ref(), &resp); + if rewrite { if let Some(content_type) = resp.headers().get("content-type") { #[cfg(feature = "avif")] @@ -523,6 +619,10 @@ async fn index(req: HttpRequest) -> Result> { if let Some(clen) = clen { if range != &format!("0-{}", clen - 1) { response.status(StatusCode::PARTIAL_CONTENT); + + // Add proper Content-Range header for UMP streams + let content_range_value = format!("bytes {}/{}", range, clen); + response.insert_header(("Content-Range", content_range_value.clone())); } } }