diff --git a/nemoguardrails/actions_server/actions_server.py b/nemoguardrails/actions_server/actions_server.py index 58d49437b..e45131a46 100644 --- a/nemoguardrails/actions_server/actions_server.py +++ b/nemoguardrails/actions_server/actions_server.py @@ -16,7 +16,7 @@ import logging from typing import Dict, Optional -from fastapi import FastAPI +from fastapi import Depends, FastAPI from pydantic import BaseModel, Field from nemoguardrails.actions.action_dispatcher import ActionDispatcher @@ -34,7 +34,12 @@ # Create action dispatcher object to communicate with actions -app.action_dispatcher = ActionDispatcher(load_all_actions=True) +_action_dispatcher = ActionDispatcher(load_all_actions=True) + + +def get_action_dispatcher() -> ActionDispatcher: + """Dependency to provide the action dispatcher instance.""" + return _action_dispatcher class RequestBody(BaseModel): @@ -58,22 +63,26 @@ class ResponseBody(BaseModel): summary="Execute action", response_model=ResponseBody, ) -async def run_action(body: RequestBody): +async def run_action( + body: RequestBody, + action_dispatcher: ActionDispatcher = Depends(get_action_dispatcher), +): """Execute the specified action and return the result. Args: body (RequestBody): The request body containing action_name and action_parameters. + action_dispatcher (ActionDispatcher): The action dispatcher dependency. Returns: dict: The response containing the execution status and result. """ - log.info(f"Request body: {body}") - result, status = await app.action_dispatcher.execute_action( + log.info("Request body: %s", body) + result, status = await action_dispatcher.execute_action( body.action_name, body.action_parameters ) resp = {"status": status, "result": result} - log.info(f"Response: {resp}") + log.info("Response: %s", resp) return resp @@ -81,7 +90,9 @@ async def run_action(body: RequestBody): "/v1/actions/list", summary="List available actions", ) -async def get_actions_list(): +async def get_actions_list( + action_dispatcher: ActionDispatcher = Depends(get_action_dispatcher), +): """Returns the list of available actions.""" - return app.action_dispatcher.get_registered_actions() + return action_dispatcher.get_registered_actions()