diff --git a/httpx/_transports/mock.py b/httpx/_transports/mock.py index 8c418f59e0..d521838977 100644 --- a/httpx/_transports/mock.py +++ b/httpx/_transports/mock.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +from datetime import timedelta from .._models import Request, Response from .base import AsyncBaseTransport, BaseTransport @@ -13,8 +14,11 @@ class MockTransport(AsyncBaseTransport, BaseTransport): - def __init__(self, handler: SyncHandler | AsyncHandler) -> None: + def __init__( + self, handler: SyncHandler | AsyncHandler, delay: timedelta = timedelta(0) + ) -> None: self.handler = handler + self.delay = delay def handle_request( self, @@ -24,6 +28,8 @@ def handle_request( response = self.handler(request) if not isinstance(response, Response): # pragma: no cover raise TypeError("Cannot use an async handler in a sync Client") + + self._apply_elapsed(response) return response async def handle_async_request( @@ -40,4 +46,14 @@ async def handle_async_request( if not isinstance(response, Response): response = await response + self._apply_elapsed(response) return response + + def _apply_elapsed(self, response: Response) -> None: + # If the handler already set `response._elapsed`, it is preserved. + # If a delay was provided to MockTransport, `.elapsed` is set to that duration. + # If no delay is provided, `.elapsed` is explicitly set to None. + if hasattr(response, "_elapsed"): + return + + response._elapsed = self.delay diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 657839018a..a7fd670a08 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -460,3 +460,41 @@ def cp1252_but_no_content_type(request): assert response.reason_phrase == "OK" assert response.encoding == "ISO-8859-1" assert response.text == text + + +def test_mocktransport_preserves_handler_elapsed(): + def handler(request): + r = httpx.Response(200) + r.elapsed = timedelta(seconds=1) + return r + + transport = httpx.MockTransport(handler, delay=timedelta(seconds=0.5)) + client = httpx.Client(transport=transport) + + response = client.get("https://example.com") + + assert response.elapsed == timedelta(seconds=1) + + +def test_mocktransport_sets_elapsed_to_delay(): + def handler(request): + return httpx.Response(200) + + transport = httpx.MockTransport(handler, delay=timedelta(seconds=0.5)) + client = httpx.Client(transport=transport) + + response = client.get("https://example.com") + + assert response.elapsed == timedelta(seconds=0.5) + + +def test_mocktransport_sets_elapsed_none_when_no_delay(): + def handler(request): + return httpx.Response(200) + + transport = httpx.MockTransport(handler) + client = httpx.Client(transport=transport) + + response = client.get("https://example.com") + + assert response.elapsed == timedelta(0)