Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Adding lifespan support for FastAPI & migrated dynamic-scheduler to use it #7149

Merged
merged 27 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager
from typing import AsyncContextManager, TypeAlias

from fastapi import FastAPI

LifespanContextManager: TypeAlias = Callable[[FastAPI], AsyncContextManager[None]]


def combine_lfiespans(*lifespans: LifespanContextManager) -> LifespanContextManager:
GitHK marked this conversation as resolved.
Show resolved Hide resolved
"""Applyes the `setup` part of the context mangers in the order they are provided.
GitHK marked this conversation as resolved.
Show resolved Hide resolved
The `teardown` is in revere order with regarad to the `seutp`.
GitHK marked this conversation as resolved.
Show resolved Hide resolved
"""

@asynccontextmanager
async def _(app: FastAPI) -> AsyncGenerator[None, None]:
async with AsyncExitStack() as stack:
for lifespan in lifespans:
await stack.enter_async_context(lifespan(app))
yield

GitHK marked this conversation as resolved.
Show resolved Hide resolved
return _
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Final

from fastapi import FastAPI
from servicelib.aiohttp import status
from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON
from starlette.requests import Request
Expand All @@ -13,7 +14,7 @@
)


def is_last_response(response_headers: dict[bytes, bytes], message: dict[str, Any]):
def _is_last_response(response_headers: dict[bytes, bytes], message: dict[str, Any]):
if (
content_type := response_headers.get(b"content-type")
) and content_type == MIMETYPE_APPLICATION_JSON.encode():
Expand Down Expand Up @@ -79,7 +80,7 @@ async def _send_wrapper(message):
response_headers = dict(message.get("headers"))
message["headers"] = check_response_headers(response_headers)
elif message["type"] == "http.response.body":
if is_last_response(response_headers, message):
if _is_last_response(response_headers, message):
_profiler.stop()
profile_text = _profiler.output_text(
unicode=True, color=True, show_all=True
Expand All @@ -96,3 +97,8 @@ async def _send_wrapper(message):

finally:
_profiler.reset()


def initialize_profiler(app: FastAPI) -> None:
# NOTE: this cannot be ran once the application is started
app.add_middleware(ProfilerMiddleware)
Original file line number Diff line number Diff line change
@@ -1,28 +1,63 @@
# pylint: disable=protected-access


from collections.abc import AsyncIterator
from contextlib import asynccontextmanager

from fastapi import FastAPI
from prometheus_client import CollectorRegistry
from prometheus_fastapi_instrumentator import Instrumentator


def setup_prometheus_instrumentation(app: FastAPI) -> Instrumentator:
def initialize_prometheus_instrumentation(app: FastAPI) -> None:
# NOTE: this cannot be ran once the application is started
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THOUGHT: ideally i would enforce this comment with code (e.g. with a decorator that if the app is started, it raises)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not know how to do that.

I think this is also fine since an error will be raised pointing you to this function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about checking that the state that you initialize in this initializer is not initialized? :-)


# NOTE: use that registry to prevent having a global one
app.state.prometheus_registry = registry = CollectorRegistry(auto_describe=True)
instrumentator = Instrumentator(
app.state.prometheus_instrumentator = Instrumentator(
should_instrument_requests_inprogress=False, # bug in https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/317
inprogress_labels=False,
registry=registry,
).instrument(app)
)
app.state.prometheus_instrumentator.instrument(app)


def _setup(app: FastAPI) -> None:
GitHK marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(app.state.prometheus_instrumentator, Instrumentator) # nosec
app.state.prometheus_instrumentator.expose(app, include_in_schema=False)


def _shutdown(app: FastAPI) -> None:
assert isinstance(app.state.prometheus_registry, CollectorRegistry) # nosec
registry = app.state.prometheus_registry
for collector in list(registry._collector_to_names.keys()): # noqa: SLF001
registry.unregister(collector)


def get_prometheus_instrumentator(app: FastAPI) -> Instrumentator:
assert isinstance(app.state.prometheus_instrumentator, Instrumentator) # nosec
return app.state.prometheus_instrumentator


def setup_prometheus_instrumentation(app: FastAPI) -> Instrumentator:
initialize_prometheus_instrumentation(app)

async def _on_startup() -> None:
instrumentator.expose(app, include_in_schema=False)
_setup(app)

def _unregister() -> None:
# NOTE: avoid registering collectors multiple times when running unittests consecutively (https://stackoverflow.com/a/62489287)
for collector in list(registry._collector_to_names.keys()): # noqa: SLF001
registry.unregister(collector)
_shutdown(app)

app.add_event_handler("startup", _on_startup)
app.add_event_handler("shutdown", _unregister)
return instrumentator

return get_prometheus_instrumentator(app)


@asynccontextmanager
async def lifespan_prometheus_instrumentation(app: FastAPI) -> AsyncIterator[None]:
# NOTE: requires ``initialize_prometheus_instrumentation`` to be called before the
GitHK marked this conversation as resolved.
Show resolved Hide resolved
# lifespan of the applicaiton runs, usually rigth after the ``FastAPI`` instance is created
_setup(app)
yield
_shutdown(app)
14 changes: 14 additions & 0 deletions packages/service-library/src/servicelib/fastapi/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
"""

import logging
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from typing import AsyncContextManager

from fastapi import FastAPI
from httpx import AsyncClient, Client
Expand Down Expand Up @@ -131,5 +134,16 @@ def setup_tracing(
RequestsInstrumentor().instrument()


def get_lifespan_tracing(
tracing_settings: TracingSettings, service_name: str
) -> Callable[[FastAPI], AsyncContextManager[None]]:
@asynccontextmanager
async def _(app: FastAPI) -> AsyncIterator[None]:
setup_tracing(app, tracing_settings, service_name)
yield

return _


def setup_httpx_client_tracing(client: AsyncClient | Client):
HTTPXClientInstrumentor.instrument_client(client)
40 changes: 40 additions & 0 deletions packages/service-library/tests/fastapi/test_lifespan_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager

import pytest
from asgi_lifespan import LifespanManager
from fastapi import FastAPI
from servicelib.fastapi.lifespan_utils import combine_lfiespans
GitHK marked this conversation as resolved.
Show resolved Hide resolved


async def test_multiple_lifespan_managers(capsys: pytest.CaptureFixture):
GitHK marked this conversation as resolved.
Show resolved Hide resolved
@asynccontextmanager
async def database_lifespan(_: FastAPI) -> AsyncIterator[None]:
print("setup DB")
yield
print("shutdown DB")

@asynccontextmanager
async def cache_lifespan(_: FastAPI) -> AsyncIterator[None]:
print("setup CACHE")
yield
print("shutdown CACHE")

app = FastAPI(lifespan=combine_lfiespans(database_lifespan, cache_lifespan))
GitHK marked this conversation as resolved.
Show resolved Hide resolved

capsys.readouterr()

async with LifespanManager(app):
messages = capsys.readouterr().out

assert "setup DB" in messages
assert "setup CACHE" in messages
assert "shutdown DB" not in messages
assert "shutdown CACHE" not in messages

messages = capsys.readouterr().out

assert "setup DB" not in messages
assert "setup CACHE" not in messages
assert "shutdown DB" in messages
assert "shutdown CACHE" in messages
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi_pagination import add_pagination
from models_library.basic_types import BootModeEnum
from packaging.version import Version
from servicelib.fastapi.profiler_middleware import ProfilerMiddleware
from servicelib.fastapi.profiler import initialize_profiler
from servicelib.fastapi.tracing import setup_tracing
from servicelib.logging_utils import config_all_loggers

Expand Down Expand Up @@ -123,7 +123,7 @@ def init_app(settings: ApplicationSettings | None = None) -> FastAPI:
)

if settings.API_SERVER_PROFILING:
app.add_middleware(ProfilerMiddleware)
initialize_profiler(app)

if app.state.settings.API_SERVER_PROMETHEUS_INSTRUMENTATION_ENABLED:
setup_prometheus_instrumentation(app)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from models_library.basic_types import BootModeEnum
from servicelib.fastapi import timing_middleware
from servicelib.fastapi.openapi import override_fastapi_openapi_method
from servicelib.fastapi.profiler_middleware import ProfilerMiddleware
from servicelib.fastapi.profiler import initialize_profiler
from servicelib.fastapi.prometheus_instrumentation import (
setup_prometheus_instrumentation,
)
Expand Down Expand Up @@ -61,7 +61,7 @@ def create_app(settings: ApplicationSettings | None = None) -> FastAPI:

# MIDDLEWARES
if app.state.settings.CATALOG_PROFILING:
app.add_middleware(ProfilerMiddleware)
initialize_profiler(app)

if settings.SC_BOOT_MODE != BootModeEnum.PRODUCTION:
# middleware to time requests (ONLY for development)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
get_common_oas_options,
override_fastapi_openapi_method,
)
from servicelib.fastapi.profiler_middleware import ProfilerMiddleware
from servicelib.fastapi.profiler import initialize_profiler
from servicelib.fastapi.tracing import setup_tracing
from servicelib.logging_utils import config_all_loggers

Expand Down Expand Up @@ -204,7 +204,7 @@ def init_app(settings: AppSettings | None = None) -> FastAPI:
instrumentation.setup(app)

if settings.DIRECTOR_V2_PROFILING:
app.add_middleware(ProfilerMiddleware)
initialize_profiler(app)

# setup app --
app.add_event_handler("startup", on_startup)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ._setup import setup_frontend
from ._setup import initialize_frontend

__all__: tuple[str, ...] = ("setup_frontend",)
__all__: tuple[str, ...] = ("initialize_frontend",)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .routes import router


def setup_frontend(app: FastAPI) -> None:
def initialize_frontend(app: FastAPI) -> None:
settings: ApplicationSettings = app.state.settings

nicegui.app.include_router(router)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from . import _health, _meta


def setup_rest_api(app: FastAPI):
def initialize_rest_api(app: FastAPI) -> None:
app.include_router(_health.router)

api_router = APIRouter(prefix=f"/{API_VTAG}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager

from fastapi import FastAPI
from models_library.api_schemas_dynamic_scheduler import DYNAMIC_SCHEDULER_RPC_NAMESPACE
from servicelib.rabbitmq import RPCRouter
Expand All @@ -10,12 +13,10 @@
]


def setup_rpc_api_routes(app: FastAPI) -> None:
async def startup() -> None:
rpc_server = get_rabbitmq_rpc_server(app)
for router in ROUTERS:
await rpc_server.register_router(
router, DYNAMIC_SCHEDULER_RPC_NAMESPACE, app
)
@asynccontextmanager
async def lifespan_rpc_api_routes(app: FastAPI) -> AsyncIterator[None]:
rpc_server = get_rabbitmq_rpc_server(app)
for router in ROUTERS:
await rpc_server.register_router(router, DYNAMIC_SCHEDULER_RPC_NAMESPACE, app)

app.add_event_handler("startup", startup)
yield
Loading
Loading