From d4e774b91b0cf8dd1cb775c8b6079a6a45914437 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 9 Dec 2025 13:19:07 +0530 Subject: [PATCH 1/7] url validation --- backend/app/api/routes/collections.py | 8 +- backend/app/api/routes/documents.py | 15 +- backend/app/api/routes/llm.py | 5 +- backend/app/core/config.py | 3 +- backend/app/core/util.py | 15 - backend/app/tests/utils/test_callback_ssrf.py | 383 ++++++++++++++++++ backend/app/utils.py | 117 +++++- 7 files changed, 522 insertions(+), 24 deletions(-) create mode 100644 backend/app/tests/utils/test_callback_ssrf.py diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index 78a40ae7e..ed66cd04d 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -26,7 +26,7 @@ DeletionRequest, CollectionPublic, ) -from app.utils import APIResponse, load_description +from app.utils import APIResponse, load_description, validate_callback_url from app.services.collections import ( create_collection as create_service, delete_collection as delete_service, @@ -81,6 +81,9 @@ def create_collection( current_user: CurrentUserOrgProject, request: CreationRequest, ): + if request.callback_url: + validate_callback_url(str(request.callback_url)) + collection_job_crud = CollectionJobCrud(session, current_user.project_id) collection_job = collection_job_crud.create( CollectionJobCreate( @@ -130,6 +133,9 @@ def delete_collection( collection_id: UUID = FastPath(description="Collection to delete"), request: CallbackRequest | None = Body(default=None), ): + if request and request.callback_url: + validate_callback_url(str(request.callback_url)) + _ = CollectionCrud(session, current_user.project_id).read_one(collection_id) deletion_request = DeletionRequest( diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index a1c2d89dd..baac1365f 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -10,6 +10,7 @@ Query, UploadFile, ) +from pydantic import HttpUrl from fastapi import Path as FastPath from app.api.deps import CurrentUserOrgProject, SessionDep @@ -32,7 +33,12 @@ build_document_schema, build_document_schemas, ) -from app.utils import APIResponse, get_openai_client, load_description +from app.utils import ( + APIResponse, + get_openai_client, + load_description, + validate_callback_url, +) logger = logging.getLogger(__name__) @@ -108,9 +114,12 @@ async def upload_doc( | None = Form( None, description="Name of the transformer to apply when converting." ), - callback_url: str + callback_url: HttpUrl | None = Form(None, description="URL to call to report endpoint status"), ): + if callback_url: + validate_callback_url(str(callback_url)) + source_format, actual_transformer = pre_transform_validation( src_filename=src.filename, target_format=target_format, @@ -137,7 +146,7 @@ async def upload_doc( target_format=target_format, actual_transformer=actual_transformer, source_document_id=source_document.id, - callback_url=callback_url, + callback_url=str(callback_url), ) document_schema = DocumentPublic.model_validate( diff --git a/backend/app/api/routes/llm.py b/backend/app/api/routes/llm.py index 26c9ee423..42a09c8b0 100644 --- a/backend/app/api/routes/llm.py +++ b/backend/app/api/routes/llm.py @@ -5,7 +5,7 @@ from app.api.deps import AuthContextDep, SessionDep from app.models import LLMCallRequest, LLMCallResponse, Message from app.services.llm.jobs import start_job -from app.utils import APIResponse +from app.utils import APIResponse, validate_callback_url logger = logging.getLogger(__name__) @@ -44,6 +44,9 @@ def llm_call( project_id = _current_user.project.id organization_id = _current_user.organization.id + if request.callback_url: + validate_callback_url(str(request.callback_url)) + start_job( db=_session, request=request, diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 515874af5..f46d6f40f 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -118,9 +118,10 @@ def AWS_S3_BUCKET(self) -> str: CELERY_ENABLE_UTC: bool = True CELERY_TIMEZONE: str = "UTC" - # callback timeouts + # callback timeouts and limits CALLBACK_CONNECT_TIMEOUT: int = 3 CALLBACK_READ_TIMEOUT: int = 10 + CALLBACK_MAX_RESPONSE_SIZE: int = 1048576 # (1*1024*1024) @computed_field # type: ignore[prop-decorator] @property diff --git a/backend/app/core/util.py b/backend/app/core/util.py index 41f2eb376..56a05dbc1 100644 --- a/backend/app/core/util.py +++ b/backend/app/core/util.py @@ -2,8 +2,6 @@ from datetime import datetime, timezone from fastapi import HTTPException -from requests import Session, RequestException -from pydantic import BaseModel, HttpUrl from openai import OpenAI @@ -24,19 +22,6 @@ def raise_from_unknown(error: Exception, status_code=500): raise HTTPException(status_code=status_code, detail=str(error)) -def post_callback(url: HttpUrl, payload: BaseModel): - errno = 0 - with Session() as session: - response = session.post(str(url), json=payload.model_dump()) - try: - response.raise_for_status() - except RequestException as err: - logger.warning(f"Callback failure: {err}") - errno += 1 - - return not errno - - def configure_openai(credentials: dict) -> tuple[OpenAI, bool]: """ Configure OpenAI client with the provided credentials. diff --git a/backend/app/tests/utils/test_callback_ssrf.py b/backend/app/tests/utils/test_callback_ssrf.py new file mode 100644 index 000000000..e39152436 --- /dev/null +++ b/backend/app/tests/utils/test_callback_ssrf.py @@ -0,0 +1,383 @@ +"""Tests for callback SSRF protection in utils.py""" + +import pytest +from unittest.mock import patch, MagicMock +import socket +import requests + +from app.utils import _is_private_ip, validate_callback_url, send_callback + + +class TestIsPrivateIP: + """Test suite for _is_private_ip function.""" + + def test_private_ipv4_addresses(self): + """Test that private IPv4 addresses are correctly identified.""" + private_ips = [ + "10.0.0.1", + "10.255.255.255", + "172.16.0.1", + "172.31.255.255", + "192.168.0.1", + "192.168.255.255", + ] + for ip in private_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is True, f"{ip} should be identified as private" + assert reason == "private", f"{ip} should have reason 'private'" + + def test_localhost_addresses(self): + """Test that localhost/loopback addresses are blocked.""" + localhost_ips = [ + "127.0.0.1", + "127.0.0.2", + "127.255.255.255", + "::1", + ] + for ip in localhost_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is True, f"{ip} should be identified as loopback" + assert ( + reason == "loopback/localhost" + ), f"{ip} should have reason 'loopback/localhost'" + + def test_link_local_addresses(self): + """Test that link-local addresses are blocked.""" + link_local_ips = [ + "169.254.0.1", + "169.254.169.254", + "169.254.255.255", + ] + for ip in link_local_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is True, f"{ip} should be identified as link-local" + assert reason == "link-local", f"{ip} should have reason 'link-local'" + + def test_multicast_addresses(self): + """Test that multicast addresses are blocked.""" + multicast_ips = [ + "224.0.0.1", + "239.255.255.255", + ] + for ip in multicast_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is True, f"{ip} should be identified as multicast" + assert reason == "multicast", f"{ip} should have reason 'multicast'" + + def test_public_ipv4_addresses(self): + """Test that public IPv4 addresses are not blocked.""" + public_ips = [ + "8.8.8.8", + "1.1.1.1", + "93.184.216.34", + "151.101.1.140", + ] + for ip in public_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is False, f"{ip} should be identified as public" + assert reason == "", f"{ip} should have empty reason" + + def test_public_ipv6_addresses(self): + """Test that public IPv6 addresses are not blocked.""" + public_ipv6 = [ + "2001:4860:4860::8888", + "2606:4700:4700::1111", + ] + for ip in public_ipv6: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is False, f"{ip} should be identified as public" + assert reason == "", f"{ip} should have empty reason" + + def test_invalid_ip_addresses(self): + """Test that invalid IP addresses return False.""" + invalid_ips = [ + "not_an_ip", + "999.999.999.999", + "example.com", + ] + for ip in invalid_ips: + is_blocked, reason = _is_private_ip(ip) + assert is_blocked is False, f"{ip} should return False" + assert reason == "", f"{ip} should have empty reason" + + +class TestValidateCallbackURL: + """Test suite for validate_callback_url function.""" + + def test_reject_non_https_schemes(self): + """Test that non-HTTPS URL schemes are rejected.""" + non_https_urls = [ + "http://example.com/callback", + "ftp://example.com/callback", + "file:///etc/passwd", + ] + for url in non_https_urls: + with pytest.raises(ValueError, match="Only HTTPS URLs are allowed"): + validate_callback_url(url) + + @patch("socket.getaddrinfo") + def test_reject_localhost_by_name(self, mock_getaddrinfo): + """Test that localhost is rejected.""" + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("127.0.0.1", 443)) + ] + + with pytest.raises(ValueError, match="loopback/localhost IP address"): + validate_callback_url("https://localhost/callback") + + @patch("socket.getaddrinfo") + def test_reject_private_ip_addresses(self, mock_getaddrinfo): + """Test that private IPs in all RFC 1918 ranges are rejected.""" + private_ips = [ + ("10.0.0.1", "https://internal.company.com/callback"), + ("192.168.1.1", "https://router.local/callback"), + ("172.16.0.1", "https://internal-api.local/callback"), + ] + + for ip, url in private_ips: + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, 443)) + ] + + with pytest.raises(ValueError, match="private IP address"): + validate_callback_url(url) + + @patch("socket.getaddrinfo") + def test_reject_link_local_addresses(self, mock_getaddrinfo): + """Test that link-local addresses are rejected (including cloud metadata endpoints).""" + link_local_ips = [ + ( + "169.254.169.254", + "https://metadata.aws/callback", + ), # AWS metadata endpoint + ("169.254.0.1", "https://link-local.example/callback"), + ] + + for ip, url in link_local_ips: + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, 443)) + ] + + with pytest.raises(ValueError, match="link-local IP address"): + validate_callback_url(url) + + @patch("socket.getaddrinfo") + def test_accept_public_ip_addresses(self, mock_getaddrinfo): + """Test that valid HTTPS URLs with public IP addresses are accepted.""" + public_ips = [ + ("8.8.8.8", "https://api.example.com/callback"), + ("151.101.1.140", "https://webhook.site/unique-id"), + ] + + for ip, url in public_ips: + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, 443)) + ] + + validate_callback_url(url) + + def test_reject_url_without_hostname(self): + """Test that URLs without hostname are rejected.""" + with pytest.raises(ValueError, match="URL must have a valid hostname"): + validate_callback_url("https:///callback") + + def test_reject_invalid_url_format(self): + """Test that invalid URL formats are rejected.""" + with pytest.raises(ValueError, match="Only HTTPS URLs are allowed"): + validate_callback_url("not a url at all") + + @patch("socket.getaddrinfo") + def test_check_all_resolved_ips(self, mock_getaddrinfo): + """Test that all resolved IPs are checked (DNS round-robin).""" + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("8.8.8.8", 443)), + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.168.1.1", 443)), + ] + + with pytest.raises(ValueError, match="private IP address"): + validate_callback_url("https://malicious-dns.example/callback") + + @patch("socket.getaddrinfo") + def test_ipv6_public_address_accepted(self, mock_getaddrinfo): + """Test that public IPv6 addresses are accepted.""" + mock_getaddrinfo.return_value = [ + (socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("2001:4860:4860::8888", 443)) + ] + + validate_callback_url("https://ipv6.example.com/callback") + + @patch("socket.getaddrinfo") + def test_ipv6_localhost_rejected(self, mock_getaddrinfo): + """Test that IPv6 localhost is rejected.""" + mock_getaddrinfo.return_value = [ + (socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("::1", 443)) + ] + + with pytest.raises(ValueError, match="loopback/localhost IP address"): + validate_callback_url("https://localhost6/callback") + + +class TestSendCallback: + """Test suite for send_callback function.""" + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_successful_callback(self, mock_session_class, mock_validate): + """Test successful callback execution.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [b"test"] + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + result = send_callback( + "https://api.example.com/callback", {"status": "success"} + ) + + assert result is True + mock_session.post.assert_called_once() + assert mock_session.post.call_args[1]["allow_redirects"] is False + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_network_error(self, mock_session_class, mock_validate): + """Test that callback returns False on network errors.""" + mock_session = MagicMock() + mock_session.post.side_effect = requests.RequestException("Connection refused") + mock_session_class.return_value.__enter__.return_value = mock_session + + result = send_callback("https://api.example.com/callback", {"data": "test"}) + + assert result is False + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_http_error(self, mock_session_class, mock_validate): + """Test that callback returns False on HTTP errors.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + result = send_callback("https://api.example.com/callback", {"data": "test"}) + + assert result is False + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_disables_redirects(self, mock_session_class, mock_validate): + """Test that redirects are disabled to prevent redirect-based SSRF.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [b"test"] + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + send_callback("https://api.example.com/callback", {"data": "test"}) + + call_kwargs = mock_session.post.call_args[1] + assert call_kwargs["allow_redirects"] is False + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_uses_timeout(self, mock_session_class, mock_validate): + """Test that callback uses configured timeouts.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [b"test"] + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + send_callback("https://api.example.com/callback", {"data": "test"}) + + call_kwargs = mock_session.post.call_args[1] + assert "timeout" in call_kwargs + assert isinstance(call_kwargs["timeout"], tuple) + assert len(call_kwargs["timeout"]) == 2 + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_sends_json_data(self, mock_session_class, mock_validate): + """Test that callback sends data as JSON.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [b"test"] + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + test_data = {"status": "completed", "result": 42} + send_callback("https://api.example.com/callback", test_data) + + call_kwargs = mock_session.post.call_args[1] + assert "json" in call_kwargs + assert call_kwargs["json"] == test_data + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_response_size_limit_exceeded( + self, mock_session_class, mock_validate + ): + """Test that callback rejects responses exceeding size limit.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + + chunk_size = 8192 + num_chunks = 130 + mock_response.iter_content.return_value = [ + b"x" * chunk_size for _ in range(num_chunks) + ] + + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + result = send_callback("https://api.example.com/callback", {"data": "test"}) + + assert result is False + mock_response.close.assert_called_once() + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_response_size_within_limit( + self, mock_session_class, mock_validate + ): + """Test that callback accepts responses within size limit.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + + chunk_size = 8192 + num_chunks = 10 + mock_response.iter_content.return_value = [ + b"x" * chunk_size for _ in range(num_chunks) + ] + + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + result = send_callback("https://api.example.com/callback", {"data": "test"}) + + assert result is True + mock_response.close.assert_not_called() + + @patch("app.utils.validate_callback_url") + @patch("requests.Session") + def test_callback_uses_streaming(self, mock_session_class, mock_validate): + """Test that callback uses streaming to prevent loading entire response into memory.""" + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [b"test"] + + mock_session.post.return_value = mock_response + mock_session_class.return_value.__enter__.return_value = mock_session + + send_callback("https://api.example.com/callback", {"data": "test"}) + + call_kwargs = mock_session.post.call_args[1] + assert call_kwargs["stream"] is True diff --git a/backend/app/utils.py b/backend/app/utils.py index 094c36829..0dbba5f7c 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -1,10 +1,13 @@ import functools as ft +import ipaddress import logging from dataclasses import dataclass from datetime import datetime, timedelta, timezone from pathlib import Path import requests +import socket from typing import Any, Dict, Generic, Optional, TypeVar +from urllib.parse import urlparse import jwt import emails @@ -262,12 +265,105 @@ def handle_openai_error(e: openai.OpenAIError) -> str: return str(e) +def _is_private_ip(ip: str) -> bool: + """Check if an IP address is private, localhost, or reserved.""" + try: + ip_obj = ipaddress.ip_address(ip) + + checks = [ + (ip_obj.is_loopback, "loopback/localhost"), + (ip_obj.is_link_local, "link-local"), + (ip_obj.is_multicast, "multicast"), + (ip_obj.is_private, "private"), + (ip_obj.is_reserved, "reserved"), + ] + + for is_blocked, reason in checks: + if is_blocked: + return (True, reason) + + return (False, "") + + except ValueError: + return (False, "") + + +def validate_callback_url(url: str) -> None: + """ + Validate callback URL to prevent SSRF attacks. + + Blocks: + - Non-HTTPS URLs + - Private IP addresses (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16) + - Localhost/loopback addresses (127.0.0.0/8, ::1) + - Link-local addresses (169.254.0.0/16) + - Cloud metadata endpoints (169.254.169.254) + - Reserved IP ranges + + Args: + url: The callback URL to validate + + Raises: + ValueError: If URL is not allowed + """ + try: + parsed = urlparse(url) + + if parsed.scheme != "https": + raise ValueError( + f"Only HTTPS URLs are allowed for callbacks. Got: {parsed.scheme}" + ) + + if not parsed.hostname: + raise ValueError("URL must have a valid hostname") + + addr_info = socket.getaddrinfo( + parsed.hostname, + parsed.port or 443, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + ) + + for info in addr_info: + ip_address = info[4][0] + is_blocked, reason = _is_private_ip(ip_address) + if is_blocked: + raise ValueError( + f"Callback URL resolves to {reason} IP address: {ip_address}. " + f"This IP type is not allowed for callbacks." + ) + + except ValueError: + raise + except Exception as e: + raise ValueError(f"Error validating callback URL: {str(e)}") from e + + def send_callback(callback_url: str, data: dict): - """Send results to the callback URL (synchronously).""" + """ + Send results to the callback URL (synchronously) with SSRF protection. + + Security features: + - HTTPS-only enforcement + - Private IP blocking (RFC 1918) + - Localhost/loopback blocking + - Cloud metadata endpoint blocking + - DNS rebinding protection + - Redirect following disabled + - Strict timeouts + - Response size limits (prevents DoS via large responses) + + Args: + callback_url: The HTTPS URL to send the callback to + data: The JSON data to send in the POST request + + Returns: + bool: True if callback succeeded, False otherwise + """ try: + validate_callback_url(str(callback_url)) + with requests.Session() as session: - # uncomment this to run locally without SSL - # session.verify = False response = session.post( callback_url, json=data, @@ -275,10 +371,25 @@ def send_callback(callback_url: str, data: dict): settings.CALLBACK_CONNECT_TIMEOUT, settings.CALLBACK_READ_TIMEOUT, ), + allow_redirects=False, + stream=True, ) + response.raise_for_status() + + total_size = 0 + for chunk in response.iter_content(chunk_size=8192): + total_size += len(chunk) + if total_size > settings.CALLBACK_MAX_RESPONSE_SIZE: + response.close() + logger.error( + f"[send_callback] Response size exceeds {settings.CALLBACK_MAX_RESPONSE_SIZE} bytes while reading" + ) + return False + logger.info(f"[send_callback] Callback sent successfully to {callback_url}") return True + except requests.RequestException as e: logger.error(f"[send_callback] Callback failed: {str(e)}", exc_info=True) return False From c95f14821687b27eb1113d3b1fbd90d6836a46da Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 9 Dec 2025 13:52:37 +0530 Subject: [PATCH 2/7] coderabbit pr --- .env.example | 3 ++- backend/app/api/routes/documents.py | 6 +++--- backend/app/utils.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.env.example b/.env.example index 7c2cd9f02..14f1385a5 100644 --- a/.env.example +++ b/.env.example @@ -80,9 +80,10 @@ CELERY_ENABLE_UTC=true CELERY_TIMEZONE=Asia/Kolkata -# Callback Timeouts (in seconds) +# Callback Timeouts and size limit(in seconds and MB respectively) CALLBACK_CONNECT_TIMEOUT = 3 CALLBACK_READ_TIMEOUT = 10 +CALLBACK_MAX_RESPONSE_SIZE = 1048576 #(1*1024*1024) # require as a env if you want to use doc transformation OPENAI_API_KEY="" diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index baac1365f..13aa75fc8 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -114,11 +114,11 @@ async def upload_doc( | None = Form( None, description="Name of the transformer to apply when converting." ), - callback_url: HttpUrl + callback_url: str | None = Form(None, description="URL to call to report endpoint status"), ): if callback_url: - validate_callback_url(str(callback_url)) + validate_callback_url(callback_url) source_format, actual_transformer = pre_transform_validation( src_filename=src.filename, @@ -146,7 +146,7 @@ async def upload_doc( target_format=target_format, actual_transformer=actual_transformer, source_document_id=source_document.id, - callback_url=str(callback_url), + callback_url=callback_url, ) document_schema = DocumentPublic.model_validate( diff --git a/backend/app/utils.py b/backend/app/utils.py index 0dbba5f7c..48398e675 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -265,7 +265,7 @@ def handle_openai_error(e: openai.OpenAIError) -> str: return str(e) -def _is_private_ip(ip: str) -> bool: +def _is_private_ip(ip: str) -> tuple[bool, str]: """Check if an IP address is private, localhost, or reserved.""" try: ip_obj = ipaddress.ip_address(ip) From 83a8e99922ce8a4a019ab5c4b4e1dfbdd49a06c6 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 9 Dec 2025 14:08:25 +0530 Subject: [PATCH 3/7] coderabbit pr review --- backend/app/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/utils.py b/backend/app/utils.py index 48398e675..f80c072b2 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -339,7 +339,7 @@ def validate_callback_url(url: str) -> None: raise ValueError(f"Error validating callback URL: {str(e)}") from e -def send_callback(callback_url: str, data: dict): +def send_callback(callback_url: str, data: dict[str, Any]) -> bool: """ Send results to the callback URL (synchronously) with SSRF protection. From 1ba51504b1b4e3adebea52aa7c60bcb2180e1d4c Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 9 Dec 2025 14:35:49 +0530 Subject: [PATCH 4/7] coderabbit pr review --- backend/app/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backend/app/utils.py b/backend/app/utils.py index f80c072b2..4a38aa107 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -362,7 +362,11 @@ def send_callback(callback_url: str, data: dict[str, Any]) -> bool: """ try: validate_callback_url(str(callback_url)) + except ValueError as ve: + logger.error(f"[send_callback] Invalid callback URL: {ve}", exc_info=True) + return False + try: with requests.Session() as session: response = session.post( callback_url, @@ -387,7 +391,7 @@ def send_callback(callback_url: str, data: dict[str, Any]) -> bool: ) return False - logger.info(f"[send_callback] Callback sent successfully to {callback_url}") + logger.info(f"[send_callback] Callback sent successfully") return True except requests.RequestException as e: From 29cbaea3ff613d27b07f740895ad6ade1d05823f Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 9 Dec 2025 14:51:36 +0530 Subject: [PATCH 5/7] extra protection by stting trust_env = false --- backend/app/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backend/app/utils.py b/backend/app/utils.py index 4a38aa107..7d10c29fe 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -368,6 +368,8 @@ def send_callback(callback_url: str, data: dict[str, Any]) -> bool: try: with requests.Session() as session: + session.trust_env = False # Ignores environment proxies and other implicit settings for SSRF safety + response = session.post( callback_url, json=data, @@ -391,7 +393,7 @@ def send_callback(callback_url: str, data: dict[str, Any]) -> bool: ) return False - logger.info(f"[send_callback] Callback sent successfully") + logger.info("[send_callback] Callback sent successfully") return True except requests.RequestException as e: From 1b5c8d97602fb56c1b86fbac62fd4cecd9adfaab Mon Sep 17 00:00:00 2001 From: nishika26 Date: Thu, 11 Dec 2025 17:46:09 +0530 Subject: [PATCH 6/7] formatting issue --- backend/app/api/routes/llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/app/api/routes/llm.py b/backend/app/api/routes/llm.py index f27389894..e244b2258 100644 --- a/backend/app/api/routes/llm.py +++ b/backend/app/api/routes/llm.py @@ -8,7 +8,6 @@ from app.utils import APIResponse, validate_callback_url, load_description - logger = logging.getLogger(__name__) router = APIRouter(tags=["LLM"]) From af5a3c8ce12fce742e77be1e6d3d5b86f3b2e25f Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 12 Dec 2025 14:21:37 +0530 Subject: [PATCH 7/7] removing response size limit --- .env.example | 1 - backend/app/core/config.py | 1 - backend/app/tests/utils/test_callback_ssrf.py | 65 ------------------- backend/app/utils.py | 12 ---- 4 files changed, 79 deletions(-) diff --git a/.env.example b/.env.example index 14f1385a5..ed36af591 100644 --- a/.env.example +++ b/.env.example @@ -83,7 +83,6 @@ CELERY_TIMEZONE=Asia/Kolkata # Callback Timeouts and size limit(in seconds and MB respectively) CALLBACK_CONNECT_TIMEOUT = 3 CALLBACK_READ_TIMEOUT = 10 -CALLBACK_MAX_RESPONSE_SIZE = 1048576 #(1*1024*1024) # require as a env if you want to use doc transformation OPENAI_API_KEY="" diff --git a/backend/app/core/config.py b/backend/app/core/config.py index f46d6f40f..d318ce98f 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -121,7 +121,6 @@ def AWS_S3_BUCKET(self) -> str: # callback timeouts and limits CALLBACK_CONNECT_TIMEOUT: int = 3 CALLBACK_READ_TIMEOUT: int = 10 - CALLBACK_MAX_RESPONSE_SIZE: int = 1048576 # (1*1024*1024) @computed_field # type: ignore[prop-decorator] @property diff --git a/backend/app/tests/utils/test_callback_ssrf.py b/backend/app/tests/utils/test_callback_ssrf.py index e39152436..3f8219ae2 100644 --- a/backend/app/tests/utils/test_callback_ssrf.py +++ b/backend/app/tests/utils/test_callback_ssrf.py @@ -316,68 +316,3 @@ def test_callback_sends_json_data(self, mock_session_class, mock_validate): call_kwargs = mock_session.post.call_args[1] assert "json" in call_kwargs assert call_kwargs["json"] == test_data - - @patch("app.utils.validate_callback_url") - @patch("requests.Session") - def test_callback_response_size_limit_exceeded( - self, mock_session_class, mock_validate - ): - """Test that callback rejects responses exceeding size limit.""" - mock_session = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - - chunk_size = 8192 - num_chunks = 130 - mock_response.iter_content.return_value = [ - b"x" * chunk_size for _ in range(num_chunks) - ] - - mock_session.post.return_value = mock_response - mock_session_class.return_value.__enter__.return_value = mock_session - - result = send_callback("https://api.example.com/callback", {"data": "test"}) - - assert result is False - mock_response.close.assert_called_once() - - @patch("app.utils.validate_callback_url") - @patch("requests.Session") - def test_callback_response_size_within_limit( - self, mock_session_class, mock_validate - ): - """Test that callback accepts responses within size limit.""" - mock_session = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - - chunk_size = 8192 - num_chunks = 10 - mock_response.iter_content.return_value = [ - b"x" * chunk_size for _ in range(num_chunks) - ] - - mock_session.post.return_value = mock_response - mock_session_class.return_value.__enter__.return_value = mock_session - - result = send_callback("https://api.example.com/callback", {"data": "test"}) - - assert result is True - mock_response.close.assert_not_called() - - @patch("app.utils.validate_callback_url") - @patch("requests.Session") - def test_callback_uses_streaming(self, mock_session_class, mock_validate): - """Test that callback uses streaming to prevent loading entire response into memory.""" - mock_session = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.iter_content.return_value = [b"test"] - - mock_session.post.return_value = mock_response - mock_session_class.return_value.__enter__.return_value = mock_session - - send_callback("https://api.example.com/callback", {"data": "test"}) - - call_kwargs = mock_session.post.call_args[1] - assert call_kwargs["stream"] is True diff --git a/backend/app/utils.py b/backend/app/utils.py index 7d10c29fe..78877d351 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -351,7 +351,6 @@ def send_callback(callback_url: str, data: dict[str, Any]) -> bool: - DNS rebinding protection - Redirect following disabled - Strict timeouts - - Response size limits (prevents DoS via large responses) Args: callback_url: The HTTPS URL to send the callback to @@ -378,21 +377,10 @@ def send_callback(callback_url: str, data: dict[str, Any]) -> bool: settings.CALLBACK_READ_TIMEOUT, ), allow_redirects=False, - stream=True, ) response.raise_for_status() - total_size = 0 - for chunk in response.iter_content(chunk_size=8192): - total_size += len(chunk) - if total_size > settings.CALLBACK_MAX_RESPONSE_SIZE: - response.close() - logger.error( - f"[send_callback] Response size exceeds {settings.CALLBACK_MAX_RESPONSE_SIZE} bytes while reading" - ) - return False - logger.info("[send_callback] Callback sent successfully") return True