From a6595dea2ddd0af01885eb19b3ffcec70885c9d6 Mon Sep 17 00:00:00 2001 From: Robert Szefler Date: Thu, 6 Jun 2024 17:24:34 +0200 Subject: [PATCH 1/9] Support for calling Robusta AI --- holmes/common/env_vars.py | 5 ++++ holmes/core/supabase_dal.py | 59 +++++++++++++++++++++++++++++++++++-- holmes/utils/auth.py | 29 ++++++++++++++++++ requirements.txt | 1 + server.py | 34 ++++++++++++++------- 5 files changed, 114 insertions(+), 14 deletions(-) create mode 100644 holmes/utils/auth.py diff --git a/holmes/common/env_vars.py b/holmes/common/env_vars.py index 66d9d463..2990dd47 100644 --- a/holmes/common/env_vars.py +++ b/holmes/common/env_vars.py @@ -10,3 +10,8 @@ STORE_API_KEY = os.environ.get("STORE_API_KEY", "") STORE_EMAIL = os.environ.get("STORE_EMAIL", "") STORE_PASSWORD = os.environ.get("STORE_PASSWORD", "") + +# Currently supports BUILTIN and ROBUSTA_AI +AI_AGENT = os.environ.get("AI_AGENT", "BUILTIN") + +ROBUSTA_AI_URL = os.environ.get("ROBUSTA_AI_URL", "") diff --git a/holmes/core/supabase_dal.py b/holmes/core/supabase_dal.py index 0378da10..e8d88872 100644 --- a/holmes/core/supabase_dal.py +++ b/holmes/core/supabase_dal.py @@ -2,7 +2,9 @@ import json import logging import os +from datetime import datetime from typing import Dict, Optional, List +from uuid import uuid4 import yaml from supabase import create_client @@ -17,6 +19,8 @@ ISSUES_TABLE = "Issues" EVIDENCE_TABLE = "Evidence" +TOKENS_TABLE = "AuthTokens" +ACCOUNT_USERS_TABLE = "AccountUsers" class RobustaConfig(BaseModel): @@ -31,6 +35,15 @@ class RobustaToken(BaseModel): password: str +class AuthToken(BaseModel): + account_id: str + user_id: str + token: str + type: str + deleted: bool = False + created_at: datetime = None + + class SupabaseDal: def __init__(self): @@ -98,7 +111,7 @@ def get_issue_data(self, issue_id: str) -> Optional[Dict]: issue_response = ( self.client .table(ISSUES_TABLE) - .select(f"*") + .select("*") .filter("id", "eq", issue_id) .execute() ) @@ -113,9 +126,49 @@ def get_issue_data(self, issue_id: str) -> Optional[Dict]: evidence = ( self.client .table(EVIDENCE_TABLE) - .select(f"*") + .select("*") .filter("issue_id", "eq", issue_id) .execute() ) issue_data["evidence"] = evidence.data - return issue_data \ No newline at end of file + return issue_data + + def create_auth_token(self, token_type: str, user_id: str) -> AuthToken: + result = ( + self.client + .table(TOKENS_TABLE) + .insert( + { + "account_id": self.account_id, + "user_id": user_id, + "token": uuid4(), + "type": token_type, + } + ) + .execute() + ) + return AuthToken(**result.data[0]) + + def get_freshest_auth_token(self, token_type: str) -> AuthToken: + result = ( + self.client + .table(TOKENS_TABLE) + .select("*") + .filter("token_type", "eq", token_type) + .filter("deleted", "eq", False) + .order("created_at", desc=True) + .limit(1) + .execute() + ) + return AuthToken(**result.data[0]) + + def get_user_ids_for_account(self, account_id: str) -> List[str]: + return [ + row["user_id"] + for row in ( + self.client + .table(ACCOUNT_USERS_TABLE) + .select("user_id") + .filter("account_id", "eq", account_id) + ).data + ] diff --git a/holmes/utils/auth.py b/holmes/utils/auth.py new file mode 100644 index 00000000..6ce2582a --- /dev/null +++ b/holmes/utils/auth.py @@ -0,0 +1,29 @@ +# from cachetools import cached +from typing import Optional + +from holmes.core.supabase_dal import AuthToken, SupabaseDal + + +class SessionManager: + def __init__(self, dal: SupabaseDal, token_type: str): + self.dal = dal + self.token_type = token_type + self.cached_token: Optional[AuthToken] = None + # TODO should this part of initialization be moved to SupabaseDal? + user_ids = dal.get_user_ids_for_account(dal.account_id) + if not user_ids: + raise ValueError(f"No users found for account_id={dal.account_id}") + if len(user_ids) > 1: + raise ValueError(f"Multiple users found for account_id={dal.account_id}") + self.user_id = user_ids[0] + + def get_current_auth_token(self) -> AuthToken: + if self.cached_token: + return self.cached_token + else: + return self.dal.get_freshest_auth_token(self.token_type) + + def recreate_auth_token(self) -> AuthToken: + new_token = self.dal.create_auth_token(self.token_type, self.user_id) + self.cached_token = new_token + return new_token diff --git a/requirements.txt b/requirements.txt index c71c42ed..60e89a55 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ annotated-types==0.7.0 ; python_version >= "3.11" and python_version < "4.0" anyio==4.4.0 ; python_version >= "3.11" and python_version < "4.0" +cachetools==5.3.3 certifi==2024.2.2 ; python_version >= "3.11" and python_version < "4.0" charset-normalizer==3.3.2 ; python_version >= "3.11" and python_version < "4.0" click==8.1.7 ; python_version >= "3.11" and python_version < "4.0" diff --git a/server.py b/server.py index a107bb85..b8f9cfd7 100644 --- a/server.py +++ b/server.py @@ -1,4 +1,8 @@ import os + +import jinja2 + +from holmes.utils.auth import SessionManager from holmes.utils.cert_utils import add_custom_certificate ADDITIONAL_CERTIFICATE: str = os.environ.get("CERTIFICATE", "") @@ -9,19 +13,21 @@ # IMPORTING ABOVE MIGHT INITIALIZE AN HTTPS CLIENT THAT DOESN'T TRUST THE CUSTOM CERTIFICATEE import logging -import uvicorn -import colorlog - from typing import List, Union -from fastapi import FastAPI +import colorlog +import uvicorn +from fastapi import FastAPI, HTTPException from pydantic import BaseModel from rich.console import Console -from holmes.common.env_vars import HOLMES_HOST, HOLMES_PORT, ALLOWED_TOOLSETS -from holmes.core.supabase_dal import SupabaseDal +from holmes.common.env_vars import ( + HOLMES_HOST, + HOLMES_PORT, +) from holmes.config import Config from holmes.core.issue import Issue +from holmes.core.supabase_dal import AuthToken, SupabaseDal from holmes.plugins.prompts import load_prompt @@ -40,6 +46,7 @@ class InvestigateRequest(BaseModel): include_tool_calls: bool = False include_tool_call_results: bool = False prompt_template: str = "builtin://generic_investigation.jinja2" + model: str = "gpt-4o" # TODO in the future # response_handler: ... @@ -50,7 +57,9 @@ def init_logging(): logging_datefmt = "%Y-%m-%d %H:%M:%S" print("setting up colored logging") - colorlog.basicConfig(format=logging_format, level=logging_level, datefmt=logging_datefmt) + colorlog.basicConfig( + format=logging_format, level=logging_level, datefmt=logging_datefmt + ) logging.getLogger().setLevel(logging_level) httpx_logger = logging.getLogger("httpx") @@ -62,6 +71,7 @@ def init_logging(): init_logging() dal = SupabaseDal() +session_manager = SessionManager("RelayHolmes") app = FastAPI() console = Console() @@ -69,15 +79,16 @@ def init_logging(): @app.post("/api/investigate") -def investigate_issues(request: InvestigateRequest): +def investigate_issue(request: InvestigateRequest): context = fetch_context_data(request.context) raw_data = request.model_dump() + raw_data.pop("model") + raw_data.pop("system_prompt") if context: raw_data["extra_context"] = context - ai = config.create_issue_investigator(console, allowed_toolsets=ALLOWED_TOOLSETS) issue = Issue( - id=context['id'] if context else "", + id=context["id"] if context else "", name=request.title, source_type=request.source, source_instance_id=request.source_instance_id, @@ -112,5 +123,6 @@ def fetch_context_data(context: List[InvestigateContext]) -> dict: # makes sense to have several of them in the context structure. return dal.get_issue_data(context_item.value) + if __name__ == "__main__": - uvicorn.run(app, host=HOLMES_HOST, port=HOLMES_PORT) \ No newline at end of file + uvicorn.run(app, host=HOLMES_HOST, port=HOLMES_PORT) From 4e731b292419a727cf36cdbe2a95c8f613b01c76 Mon Sep 17 00:00:00 2001 From: Robert Szefler Date: Mon, 10 Jun 2024 15:59:56 +0200 Subject: [PATCH 2/9] auth: user_id --- holmes/common/env_vars.py | 1 + holmes/core/supabase_dal.py | 75 +++++++++++++++++++++++-------------- holmes/utils/auth.py | 9 +---- requirements.txt | 1 - server.py | 2 +- 5 files changed, 49 insertions(+), 39 deletions(-) diff --git a/holmes/common/env_vars.py b/holmes/common/env_vars.py index 2990dd47..bed9e4df 100644 --- a/holmes/common/env_vars.py +++ b/holmes/common/env_vars.py @@ -6,6 +6,7 @@ ROBUSTA_CONFIG_PATH = os.environ.get('ROBUSTA_CONFIG_PATH', "/etc/robusta/config/active_playbooks.yaml") ROBUSTA_ACCOUNT_ID = os.environ.get("ROBUSTA_ACCOUNT_ID", "") +ROBUSTA_USER_ID = os.environ.get("ROBUSTA_USER_ID", "") STORE_URL = os.environ.get("STORE_URL", "") STORE_API_KEY = os.environ.get("STORE_API_KEY", "") STORE_EMAIL = os.environ.get("STORE_EMAIL", "") diff --git a/holmes/core/supabase_dal.py b/holmes/core/supabase_dal.py index e8d88872..4212e6a6 100644 --- a/holmes/core/supabase_dal.py +++ b/holmes/core/supabase_dal.py @@ -11,8 +11,15 @@ from supabase.lib.client_options import ClientOptions from pydantic import BaseModel -from holmes.common.env_vars import (ROBUSTA_CONFIG_PATH, ROBUSTA_ACCOUNT_ID, STORE_URL, STORE_API_KEY, STORE_EMAIL, - STORE_PASSWORD) +from holmes.common.env_vars import ( + ROBUSTA_ACCOUNT_ID, + ROBUSTA_CONFIG_PATH, + ROBUSTA_USER_ID, + STORE_API_KEY, + STORE_EMAIL, + STORE_PASSWORD, + STORE_URL, +) SUPABASE_TIMEOUT_SECONDS = int(os.getenv("SUPABASE_TIMEOUT_SECONDS", 3600)) @@ -31,6 +38,7 @@ class RobustaToken(BaseModel): store_url: str api_key: str account_id: str + user_id: str email: str password: str @@ -80,24 +88,48 @@ def __init_config(self) -> bool: robusta_token = self.__load_robusta_config() if robusta_token: self.account_id = robusta_token.account_id + self.user_id = robusta_token.user_id self.url = robusta_token.store_url self.api_key = robusta_token.api_key self.email = robusta_token.email self.password = robusta_token.password else: self.account_id = ROBUSTA_ACCOUNT_ID + self.user_id = ROBUSTA_USER_ID self.url = STORE_URL self.api_key = STORE_API_KEY self.email = STORE_EMAIL self.password = STORE_PASSWORD # valid only if all store parameters are provided - return all([self.account_id, self.url, self.api_key, self.email, self.password]) + return self.check_settings() + + def check_settings(self): + unset_attrs = [] + for attr_name in [ + "account_id", + "user_id", + "url", + "api_key", + "email", + "password", + ]: + if not getattr(self, attr_name, None): + unset_attrs.append(attr_name) + if unset_attrs: + logging.warning(f"Unset store config variables: {', '.join(unset_attrs)}") + return False + else: + return True def sign_in(self): logging.info("Supabase DAL login") - res = self.client.auth.sign_in_with_password({"email": self.email, "password": self.password}) - self.client.auth.set_session(res.session.access_token, res.session.refresh_token) + res = self.client.auth.sign_in_with_password( + {"email": self.email, "password": self.password} + ) + self.client.auth.set_session( + res.session.access_token, res.session.refresh_token + ) self.client.postgrest.auth(res.session.access_token) def get_issue_data(self, issue_id: str) -> Optional[Dict]: @@ -109,11 +141,10 @@ def get_issue_data(self, issue_id: str) -> Optional[Dict]: issue_data = None try: issue_response = ( - self.client - .table(ISSUES_TABLE) - .select("*") - .filter("id", "eq", issue_id) - .execute() + self.client.table(ISSUES_TABLE) + .select("*") + .filter("id", "eq", issue_id) + .execute() ) if len(issue_response.data): issue_data = issue_response.data[0] @@ -124,8 +155,7 @@ def get_issue_data(self, issue_id: str) -> Optional[Dict]: if not issue_data: return None evidence = ( - self.client - .table(EVIDENCE_TABLE) + self.client.table(EVIDENCE_TABLE) .select("*") .filter("issue_id", "eq", issue_id) .execute() @@ -133,14 +163,13 @@ def get_issue_data(self, issue_id: str) -> Optional[Dict]: issue_data["evidence"] = evidence.data return issue_data - def create_auth_token(self, token_type: str, user_id: str) -> AuthToken: + def create_auth_token(self, token_type: str) -> AuthToken: result = ( - self.client - .table(TOKENS_TABLE) + self.client.table(TOKENS_TABLE) .insert( { "account_id": self.account_id, - "user_id": user_id, + "user_id": self.user_id, "token": uuid4(), "type": token_type, } @@ -151,8 +180,7 @@ def create_auth_token(self, token_type: str, user_id: str) -> AuthToken: def get_freshest_auth_token(self, token_type: str) -> AuthToken: result = ( - self.client - .table(TOKENS_TABLE) + self.client.table(TOKENS_TABLE) .select("*") .filter("token_type", "eq", token_type) .filter("deleted", "eq", False) @@ -161,14 +189,3 @@ def get_freshest_auth_token(self, token_type: str) -> AuthToken: .execute() ) return AuthToken(**result.data[0]) - - def get_user_ids_for_account(self, account_id: str) -> List[str]: - return [ - row["user_id"] - for row in ( - self.client - .table(ACCOUNT_USERS_TABLE) - .select("user_id") - .filter("account_id", "eq", account_id) - ).data - ] diff --git a/holmes/utils/auth.py b/holmes/utils/auth.py index 6ce2582a..4851021b 100644 --- a/holmes/utils/auth.py +++ b/holmes/utils/auth.py @@ -9,13 +9,6 @@ def __init__(self, dal: SupabaseDal, token_type: str): self.dal = dal self.token_type = token_type self.cached_token: Optional[AuthToken] = None - # TODO should this part of initialization be moved to SupabaseDal? - user_ids = dal.get_user_ids_for_account(dal.account_id) - if not user_ids: - raise ValueError(f"No users found for account_id={dal.account_id}") - if len(user_ids) > 1: - raise ValueError(f"Multiple users found for account_id={dal.account_id}") - self.user_id = user_ids[0] def get_current_auth_token(self) -> AuthToken: if self.cached_token: @@ -24,6 +17,6 @@ def get_current_auth_token(self) -> AuthToken: return self.dal.get_freshest_auth_token(self.token_type) def recreate_auth_token(self) -> AuthToken: - new_token = self.dal.create_auth_token(self.token_type, self.user_id) + new_token = self.dal.create_auth_token(self.token_type) self.cached_token = new_token return new_token diff --git a/requirements.txt b/requirements.txt index 60e89a55..c71c42ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ annotated-types==0.7.0 ; python_version >= "3.11" and python_version < "4.0" anyio==4.4.0 ; python_version >= "3.11" and python_version < "4.0" -cachetools==5.3.3 certifi==2024.2.2 ; python_version >= "3.11" and python_version < "4.0" charset-normalizer==3.3.2 ; python_version >= "3.11" and python_version < "4.0" click==8.1.7 ; python_version >= "3.11" and python_version < "4.0" diff --git a/server.py b/server.py index b8f9cfd7..6e4f7dcc 100644 --- a/server.py +++ b/server.py @@ -71,7 +71,7 @@ def init_logging(): init_logging() dal = SupabaseDal() -session_manager = SessionManager("RelayHolmes") +session_manager = SessionManager(dal, "RelayHolmes") app = FastAPI() console = Console() From 1b78dc194255117fc079e8ff70b70c2937fc6a4f Mon Sep 17 00:00:00 2001 From: Robert Szefler Date: Mon, 10 Jun 2024 16:09:45 +0200 Subject: [PATCH 3/9] more auth fixes --- holmes/core/supabase_dal.py | 13 ++++++++++++- holmes/utils/auth.py | 7 +++++-- server.py | 4 +--- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/holmes/core/supabase_dal.py b/holmes/core/supabase_dal.py index 4212e6a6..e2b7d06f 100644 --- a/holmes/core/supabase_dal.py +++ b/holmes/core/supabase_dal.py @@ -59,7 +59,7 @@ def __init__(self): if not self.enabled: logging.info("Robusta store initialization parameters not provided. skipping") return - logging.info(f"Initializing robusta store for account {self.account_id}") + logging.info(f"Initializing robusta store for account {self.account_id}, user {self.user_id}") options = ClientOptions(postgrest_client_timeout=SUPABASE_TIMEOUT_SECONDS) self.client = create_client(self.url, self.api_key, options) self.sign_in() @@ -189,3 +189,14 @@ def get_freshest_auth_token(self, token_type: str) -> AuthToken: .execute() ) return AuthToken(**result.data[0]) + + def invalidate_auth_token(self, token: AuthToken) -> None: + ( + self.client.table(TOKENS_TABLE) + .update({"deleted": True}) + .eq("account_id", token.account_id) + .eq("user_id", token.user_id) + .eq("token", token.token) + .eq("type", token.type) + ) + # TODO maybe handle errors such as non-existent tokens? diff --git a/holmes/utils/auth.py b/holmes/utils/auth.py index 4851021b..f0699e7a 100644 --- a/holmes/utils/auth.py +++ b/holmes/utils/auth.py @@ -10,13 +10,16 @@ def __init__(self, dal: SupabaseDal, token_type: str): self.token_type = token_type self.cached_token: Optional[AuthToken] = None - def get_current_auth_token(self) -> AuthToken: + def get_current_token(self) -> AuthToken: if self.cached_token: return self.cached_token else: return self.dal.get_freshest_auth_token(self.token_type) - def recreate_auth_token(self) -> AuthToken: + def create_token(self) -> AuthToken: new_token = self.dal.create_auth_token(self.token_type) self.cached_token = new_token return new_token + + def invalidate_token(self, token: AuthToken): + self.dal.invalidate_auth_token(token) diff --git a/server.py b/server.py index 6e4f7dcc..000a8621 100644 --- a/server.py +++ b/server.py @@ -57,9 +57,7 @@ def init_logging(): logging_datefmt = "%Y-%m-%d %H:%M:%S" print("setting up colored logging") - colorlog.basicConfig( - format=logging_format, level=logging_level, datefmt=logging_datefmt - ) + colorlog.basicConfig(format=logging_format, level=logging_level, datefmt=logging_datefmt) logging.getLogger().setLevel(logging_level) httpx_logger = logging.getLogger("httpx") From 4bd103d3c5271938f746a3f1169628732dbd0107 Mon Sep 17 00:00:00 2001 From: Robert Szefler Date: Mon, 10 Jun 2024 16:36:50 +0200 Subject: [PATCH 4/9] multi-AI-agent refactoring --- holmes.py | 46 +++++----- holmes/config.py | 126 +++++++++++++++++----------- holmes/core/tool_calling_llm.py | 3 + holmes/plugins/runbooks/__init__.py | 14 ++-- holmes/utils/pydantic_utils.py | 10 ++- server.py | 6 +- test-api.sh | 2 +- 7 files changed, 126 insertions(+), 81 deletions(-) diff --git a/holmes.py b/holmes.py index eb11bfb7..765de8f7 100644 --- a/holmes.py +++ b/holmes.py @@ -13,7 +13,7 @@ from rich.markdown import Markdown from rich.rule import Rule from holmes.utils.file_utils import write_json_file -from holmes.config import Config, LLMType +from holmes.config import LLMConfig, LLMProviderType from holmes.plugins.destinations import DestinationType from holmes.plugins.prompts import load_prompt from holmes.plugins.sources.opsgenie import OPSGENIE_TEAM_INTEGRATION_KEY_HELP @@ -28,22 +28,11 @@ ) app.add_typer(investigate_app, name="investigate") -def init_logging(verbose = False): - logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO, format="%(message)s", handlers=[RichHandler(show_level=False, show_time=False)]) - # disable INFO logs from OpenAI - logging.getLogger("httpx").setLevel(logging.WARNING) - # when running in --verbose mode we don't want to see DEBUG logs from these libraries - logging.getLogger("openai._base_client").setLevel(logging.INFO) - logging.getLogger("httpcore").setLevel(logging.INFO) - logging.getLogger("markdown_it").setLevel(logging.INFO) - # Suppress UserWarnings from the slack_sdk module - warnings.filterwarnings("ignore", category=UserWarning, module="slack_sdk.*") - return Console() # Common cli options -opt_llm: Optional[LLMType] = typer.Option( - LLMType.OPENAI, - help="Which LLM to use ('openai' or 'azure')", +opt_llm: Optional[LLMProviderType] = typer.Option( + LLMProviderType.OPENAI, + help="LLM provider ('openai' or 'azure')", # TODO list all ) opt_api_key: Optional[str] = typer.Option( None, @@ -111,6 +100,19 @@ def init_logging(verbose = False): system_prompt_help = "Advanced. System prompt for LLM. Values starting with builtin:// are loaded from holmes/plugins/prompts, values starting with file:// are loaded from the given path, other values are interpreted as a prompt string" +def init_logging(verbose = False): + logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO, format="%(message)s", handlers=[RichHandler(show_level=False, show_time=False)]) + # disable INFO logs from OpenAI + logging.getLogger("httpx").setLevel(logging.WARNING) + # when running in --verbose mode we don't want to see DEBUG logs from these libraries + logging.getLogger("openai._base_client").setLevel(logging.INFO) + logging.getLogger("httpcore").setLevel(logging.INFO) + logging.getLogger("markdown_it").setLevel(logging.INFO) + # Suppress UserWarnings from the slack_sdk module + warnings.filterwarnings("ignore", category=UserWarning, module="slack_sdk.*") + return Console() + + # TODO: add interactive interpreter mode # TODO: add streaming output @app.command() @@ -141,7 +143,7 @@ def ask( Ask any question and answer using available tools """ console = init_logging(verbose) - config = Config.load_from_file( + config = LLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -180,7 +182,7 @@ def alertmanager( None, help="Password to use for basic auth" ), # common options - llm: Optional[LLMType] = opt_llm, + llm: Optional[LLMProviderType] = opt_llm, api_key: Optional[str] = opt_api_key, azure_endpoint: Optional[str] = opt_azure_endpoint, model: Optional[str] = opt_model, @@ -203,7 +205,7 @@ def alertmanager( Investigate a Prometheus/Alertmanager alert """ console = init_logging(verbose) - config = Config.load_from_file( + config = LLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -282,7 +284,7 @@ def jira( False, help="Update Jira with AI results" ), # common options - llm: Optional[LLMType] = opt_llm, + llm: Optional[LLMProviderType] = opt_llm, api_key: Optional[str] = opt_api_key, azure_endpoint: Optional[str] = opt_azure_endpoint, model: Optional[str] = opt_model, @@ -301,7 +303,7 @@ def jira( Investigate a Jira ticket """ console = init_logging(verbose) - config = Config.load_from_file( + config = LLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -371,7 +373,7 @@ def github( help="Investigate tickets matching a GitHub query (e.g. 'is:issue is:open')", ), # common options - llm: Optional[LLMType] = opt_llm, + llm: Optional[LLMProviderType] = opt_llm, api_key: Optional[str] = opt_api_key, azure_endpoint: Optional[str] = opt_azure_endpoint, model: Optional[str] = opt_model, @@ -390,7 +392,7 @@ def github( Investigate a GitHub issue """ console = init_logging(verbose) - config = Config.load_from_file( + config = LLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, diff --git a/holmes/config.py b/holmes/config.py index 7a566d67..e6364370 100644 --- a/holmes/config.py +++ b/holmes/config.py @@ -2,7 +2,7 @@ import os import os.path from strenum import StrEnum -from typing import List, Optional +from typing import Annotated, Any, Dict, List, Optional, get_args, get_type_hints from openai import AzureOpenAI, OpenAI from pydantic import FilePath, SecretStr @@ -21,28 +21,21 @@ from holmes.plugins.sources.opsgenie import OpsGenieSource from holmes.plugins.sources.pagerduty import PagerDutySource from holmes.plugins.sources.prometheus.plugin import AlertManagerSource -from holmes.plugins.toolsets import (load_builtin_toolsets, - load_toolsets_from_file) -from holmes.utils.pydantic_utils import RobustaBaseConfig, load_model_from_file +from holmes.plugins.toolsets import load_builtin_toolsets, load_toolsets_from_file +from holmes.utils.pydantic_utils import BaseConfig, EnvVarName, load_model_from_file -class LLMType(StrEnum): +class LLMProviderType(StrEnum): OPENAI = "openai" AZURE = "azure" + ROBUSTA = "robusta_ai" -class Config(RobustaBaseConfig): - llm: Optional[LLMType] = LLMType.OPENAI - api_key: Optional[SecretStr] = ( - None # if None, read from OPENAI_API_KEY or AZURE_OPENAI_ENDPOINT env var - ) - azure_endpoint: Optional[str] = ( - None # if None, read from AZURE_OPENAI_ENDPOINT env var - ) - azure_api_version: Optional[str] = "2024-02-01" - model: Optional[str] = "gpt-4o" - max_steps: Optional[int] = 10 +class BaseLLMConfig(BaseConfig): + llm: LLMProviderType = LLMProviderType.OPENAI + # FIXME: the following settings do not belong here. They define the + # configuration of specific actions, and not of the LLM provider. alertmanager_url: Optional[str] = None alertmanager_username: Optional[str] = None alertmanager_password: Optional[str] = None @@ -74,42 +67,79 @@ class Config(RobustaBaseConfig): custom_toolsets: List[FilePath] = [] @classmethod - def load_from_env(cls): - kwargs = {"llm": LLMType(os.getenv("HOLMES_LLM", "OPENAI").lower())} - for field_name in [ - "model", - "api_key", - "azure_endpoint", - "max_steps", - "alertmanager_url", - "alertmanager_username", - "alertmanager_password", - "jira_url", - "jira_username", - "jira_api_key", - "jira_query", - "slack_token", - "slack_channel", - "github_url", - "github_owner", - "github_repository", - "github_pat", - "github_query", - # TODO - # custom_runbooks - # custom_toolsets - ]: - val = os.getenv(field_name.upper(), None) - if val is not None: - kwargs[field_name] = val - return cls(**kwargs) + def _collect_env_vars(cls) -> Dict[str, Any]: + """Collect the environment variables that this class might require for setup. + + Environment variable names are determined from model fields as follows: + - if the model field is not annotated with an EnvVarName, the env var name is + just the model field name in upper case + - if the model field is annotated with an EnvVarName, the env var name is + taken from the annotation. + """ + vars_dict = {} + hints = get_type_hints(cls, include_extras=True) + for field_name in cls.model_fields: + if field_name == "llm": + # Handled in load_from_env + continue + tp_obj = hints[field_name] + for arg in get_args(tp_obj): + if isinstance(arg, EnvVarName): + env_var_name = arg + break + else: + env_var_name = field_name.upper() + if env_var_name in os.environ: + vars_dict[field_name] = os.environ[env_var_name] + return vars_dict + + @classmethod + def load_from_env(cls) -> "BaseLLMConfig": + llm_name = os.getenv("LLM_PROVIDER", "OPENAI").lower() + llm_provider_type = LLMProviderType(llm_name) + if llm_provider_type == LLMProviderType.AZURE: + final_class = AzureLLMConfig + elif llm_provider_type == LLMProviderType.OPENAI: + final_class = OpenAILLMConfig + elif llm_provider_type == LLMProviderType.ROBUSTA: + final_class = RobustaLLMConfig + else: + raise NotImplementedError(f"Unknown LLM {llm_name}") + kwargs = final_class._collect_env_vars() + ret = final_class(**kwargs) + return ret + + +class BaseOpenAIConfig(BaseLLMConfig): + model: Optional[str] = "gpt-4o" + max_steps: Optional[int] = 10 + + +class OpenAILLMConfig(BaseOpenAIConfig): + api_key: Optional[SecretStr] + + +class AzureLLMConfig(BaseOpenAIConfig): + api_key: Optional[SecretStr] + endpoint: Optional[str] + azure_api_version: Optional[str] = "2024-02-01" + + +class RobustaLLMConfig(BaseLLMConfig): + url: Annotated[str, EnvVarName("ROBUSTA_AI_URL")] + + +# TODO refactor everything below + + +class LLMConfig(BaseLLMConfig): def create_llm(self) -> OpenAI: - if self.llm == LLMType.OPENAI: + if self.llm == LLMProviderType.OPENAI: return OpenAI( api_key=self.api_key.get_secret_value() if self.api_key else None, ) - elif self.llm == LLMType.AZURE: + elif self.llm == LLMProviderType.AZURE: return AzureOpenAI( api_key=self.api_key.get_secret_value() if self.api_key else None, azure_endpoint=self.azure_endpoint, @@ -155,7 +185,7 @@ def _create_tool_executor( def create_toolcalling_llm( self, console: Console, allowed_toolsets: ToolsetPattern - ) -> IssueInvestigator: + ) -> ToolCallingLLM: tool_executor = self._create_tool_executor(console, allowed_toolsets) return ToolCallingLLM( self.create_llm(), diff --git a/holmes/core/tool_calling_llm.py b/holmes/core/tool_calling_llm.py index 6e89984d..7f05b7fd 100644 --- a/holmes/core/tool_calling_llm.py +++ b/holmes/core/tool_calling_llm.py @@ -22,6 +22,7 @@ class ToolCallResult(BaseModel): description: str result: str + class LLMResult(BaseModel): tool_calls: Optional[List[ToolCallResult]] = None result: Optional[str] = None @@ -33,6 +34,7 @@ def get_tool_usage_summary(self): [f"`{tool_call.description}`" for tool_call in self.tool_calls] ) + class ToolCallingLLM: def __init__( @@ -127,6 +129,7 @@ def call(self, system_prompt, user_prompt) -> LLMResult: ) ) + # TODO: consider getting rid of this entirely and moving templating into the cmds in holmes.py class IssueInvestigator(ToolCallingLLM): """ diff --git a/holmes/plugins/runbooks/__init__.py b/holmes/plugins/runbooks/__init__.py index 242c34a6..7b5d51e9 100644 --- a/holmes/plugins/runbooks/__init__.py +++ b/holmes/plugins/runbooks/__init__.py @@ -1,20 +1,21 @@ import os import os.path -from typing import List, Literal, Optional, Pattern, Union +from typing import List, Optional, Pattern -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, PrivateAttr -from holmes.utils.pydantic_utils import RobustaBaseConfig, load_model_from_file +from holmes.utils.pydantic_utils import BaseConfig, load_model_from_file THIS_DIR = os.path.abspath(os.path.dirname(__file__)) -class IssueMatcher (RobustaBaseConfig): + +class IssueMatcher(BaseConfig): issue_id: Optional[Pattern] = None # unique id issue_name: Optional[Pattern] = None # not necessary unique source: Optional[Pattern] = None -class Runbook(RobustaBaseConfig): +class Runbook(BaseConfig): match: IssueMatcher instructions: str @@ -26,15 +27,18 @@ def set_path(self, path: str): def get_path(self) -> str: return self._path + class ListOfRunbooks(BaseModel): runbooks: List[Runbook] + def load_runbooks_from_file(path: str) -> List[Runbook]: data: ListOfRunbooks = load_model_from_file(ListOfRunbooks, file_path=path) for runbook in data.runbooks: runbook.set_path(path) return data.runbooks + def load_builtin_runbooks() -> List[Runbook]: all_runbooks = [] for filename in os.listdir(THIS_DIR): diff --git a/holmes/utils/pydantic_utils.py b/holmes/utils/pydantic_utils.py index 2ebe007e..cd7c0263 100644 --- a/holmes/utils/pydantic_utils.py +++ b/holmes/utils/pydantic_utils.py @@ -10,9 +10,15 @@ PromptField = Annotated[str, BeforeValidator(lambda v: load_prompt(v))] -class RobustaBaseConfig(BaseModel): + +class BaseConfig(BaseModel): model_config = ConfigDict(extra='forbid', validate_default=True) + +class EnvVarName(str): + pass + + def loc_to_dot_sep(loc: Tuple[Union[str, int], ...]) -> str: path = "" for i, x in enumerate(loc): @@ -43,7 +49,7 @@ def load_model_from_file( contents = contents[yaml_path] return model.model_validate(contents) except ValidationError as e: - print(e) + print(e) # FIXME bad_fields = [e["loc"] for e in convert_errors(e)] typer.secho( f"Invalid config file at {file_path}. Check the fields {bad_fields}.\nSee detailed errors above.", diff --git a/server.py b/server.py index 000a8621..1588c354 100644 --- a/server.py +++ b/server.py @@ -25,7 +25,7 @@ HOLMES_HOST, HOLMES_PORT, ) -from holmes.config import Config +from holmes.config import LLMConfig from holmes.core.issue import Issue from holmes.core.supabase_dal import AuthToken, SupabaseDal from holmes.plugins.prompts import load_prompt @@ -56,7 +56,6 @@ def init_logging(): logging_format = "%(log_color)s%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s" logging_datefmt = "%Y-%m-%d %H:%M:%S" - print("setting up colored logging") colorlog.basicConfig(format=logging_format, level=logging_level, datefmt=logging_datefmt) logging.getLogger().setLevel(logging_level) @@ -68,12 +67,13 @@ def init_logging(): init_logging() +logging.info(f"Starting AI server with {AI_AGENT=}, {ROBUSTA_AI_URL=}") dal = SupabaseDal() session_manager = SessionManager(dal, "RelayHolmes") app = FastAPI() console = Console() -config = Config.load_from_env() +config = LLMConfig.load_from_env() @app.post("/api/investigate") diff --git a/test-api.sh b/test-api.sh index 49d4cc58..caf1495d 100755 --- a/test-api.sh +++ b/test-api.sh @@ -1,4 +1,4 @@ -curl -XPOST 127.0.0.1:8000/api/investigate -H "Content-Type: application/json" --data "{ +curl -XPOST 127.0.0.1:5050/api/investigate -H "Content-Type: application/json" --data "{ \"source\": \"prometheus\", \"source_instance_id\": \"some-instance\", \"title\": \"Pod is crash looping.\", From 08489abf0b20ecf65bb5aba6af8edd22742cbadf Mon Sep 17 00:00:00 2001 From: Robert Szefler Date: Tue, 11 Jun 2024 13:43:36 +0200 Subject: [PATCH 5/9] More LLM provider config refactors --- holmes.py | 43 +++-- holmes/common/env_vars.py | 5 - holmes/config.py | 242 +++++------------------- holmes/core/provider.py | 182 ++++++++++++++++++ holmes/core/robusta_ai.py | 54 ++++++ holmes/core/server_models.py | 21 ++ holmes/core/tool_calling_llm.py | 17 +- holmes/plugins/sources/jira/__init__.py | 7 +- server.py | 59 ++---- 9 files changed, 362 insertions(+), 268 deletions(-) create mode 100644 holmes/core/provider.py create mode 100644 holmes/core/robusta_ai.py create mode 100644 holmes/core/server_models.py diff --git a/holmes.py b/holmes.py index 765de8f7..20d551f9 100644 --- a/holmes.py +++ b/holmes.py @@ -2,18 +2,19 @@ # add_custom_certificate("cert goes here as a string (not path to the cert rather the cert itself)") import logging -import re import warnings from pathlib import Path -from typing import List, Optional, Pattern -import json +from typing import List, Optional + import typer from rich.console import Console from rich.logging import RichHandler from rich.markdown import Markdown from rich.rule import Rule + from holmes.utils.file_utils import write_json_file -from holmes.config import LLMConfig, LLMProviderType +from holmes.config import BaseLLMConfig, LLMProviderType +from holmes.core.provider import LLMProviderFactory from holmes.plugins.destinations import DestinationType from holmes.plugins.prompts import load_prompt from holmes.plugins.sources.opsgenie import OPSGENIE_TEAM_INTEGRATION_KEY_HELP @@ -30,9 +31,10 @@ # Common cli options +llm_provider_names = ", ".join(str(tp) for tp in LLMProviderType) opt_llm: Optional[LLMProviderType] = typer.Option( LLMProviderType.OPENAI, - help="LLM provider ('openai' or 'azure')", # TODO list all + help="LLM provider (supported values: {llm_provider_names})" ) opt_api_key: Optional[str] = typer.Option( None, @@ -143,7 +145,7 @@ def ask( Ask any question and answer using available tools """ console = init_logging(verbose) - config = LLMConfig.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -152,8 +154,9 @@ def ask( max_steps=max_steps, custom_toolsets=custom_toolsets, ) + provider_factory = LLMProviderFactory(config) system_prompt = load_prompt(system_prompt) - ai = config.create_toolcalling_llm(console, allowed_toolsets) + ai = provider_factory.create_toolcalling_llm(console, allowed_toolsets) console.print("[bold yellow]User:[/bold yellow] " + prompt) response = ai.call(system_prompt, prompt) text_result = Markdown(response.result) @@ -162,7 +165,8 @@ def ask( if show_tool_output and response.tool_calls: for tool_call in response.tool_calls: console.print(f"[bold magenta]Used Tool:[/bold magenta]", end="") - # we need to print this separately with markup=False because it contains arbitrary text and we don't want console.print to interpret it + # we need to print this separately with markup=False because it contains arbitrary text + # and we don't want console.print to interpret it console.print(f"{tool_call.description}. Output=\n{tool_call.result}", markup=False) console.print(f"[bold green]AI:[/bold green]", end=" ") console.print(text_result, soft_wrap=True) @@ -205,7 +209,7 @@ def alertmanager( Investigate a Prometheus/Alertmanager alert """ console = init_logging(verbose) - config = LLMConfig.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -221,14 +225,15 @@ def alertmanager( custom_toolsets=custom_toolsets, custom_runbooks=custom_runbooks ) + provider_factory = LLMProviderFactory(config) system_prompt = load_prompt(system_prompt) - ai = config.create_issue_investigator(console, allowed_toolsets) + ai = provider_factory.create_issue_investigator(console, allowed_toolsets) - source = config.create_alertmanager_source() + source = provider_factory.create_alertmanager_source() if destination == DestinationType.SLACK: - slack = config.create_slack_destination() + slack = provider_factory.create_slack_destination() try: issues = source.fetch_issues() @@ -303,7 +308,7 @@ def jira( Investigate a Jira ticket """ console = init_logging(verbose) - config = LLMConfig.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -317,10 +322,11 @@ def jira( custom_toolsets=custom_toolsets, custom_runbooks=custom_runbooks ) + provider_factory = LLMProviderFactory(config) system_prompt = load_prompt(system_prompt) - ai = config.create_issue_investigator(console, allowed_toolsets) - source = config.create_jira_source() + ai = provider_factory.create_issue_investigator(console, allowed_toolsets) + source = provider_factory.create_jira_source() try: # TODO: allow passing issue ID issues = source.fetch_issues() @@ -392,7 +398,7 @@ def github( Investigate a GitHub issue """ console = init_logging(verbose) - config = LLMConfig.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -407,10 +413,11 @@ def github( custom_toolsets=custom_toolsets, custom_runbooks=custom_runbooks ) + provider_factory = LLMProviderFactory(config) system_prompt = load_prompt(system_prompt) - ai = config.create_issue_investigator(console, allowed_toolsets) - source = config.create_github_source() + ai = provider_factory.create_issue_investigator(console, allowed_toolsets) + source = provider_factory.create_github_source() try: issues = source.fetch_issues() except Exception as e: diff --git a/holmes/common/env_vars.py b/holmes/common/env_vars.py index bed9e4df..1487135b 100644 --- a/holmes/common/env_vars.py +++ b/holmes/common/env_vars.py @@ -11,8 +11,3 @@ STORE_API_KEY = os.environ.get("STORE_API_KEY", "") STORE_EMAIL = os.environ.get("STORE_EMAIL", "") STORE_PASSWORD = os.environ.get("STORE_PASSWORD", "") - -# Currently supports BUILTIN and ROBUSTA_AI -AI_AGENT = os.environ.get("AI_AGENT", "BUILTIN") - -ROBUSTA_AI_URL = os.environ.get("ROBUSTA_AI_URL", "") diff --git a/holmes/config.py b/holmes/config.py index e6364370..73fe5394 100644 --- a/holmes/config.py +++ b/holmes/config.py @@ -2,26 +2,9 @@ import os import os.path from strenum import StrEnum -from typing import Annotated, Any, Dict, List, Optional, get_args, get_type_hints - -from openai import AzureOpenAI, OpenAI -from pydantic import FilePath, SecretStr -from pydash.arrays import concat -from rich.console import Console - -from holmes.core.runbooks import RunbookManager -from holmes.core.tool_calling_llm import (IssueInvestigator, ToolCallingLLM, - YAMLToolExecutor) -from holmes.core.tools import ToolsetPattern, get_matching_toolsets -from holmes.plugins.destinations.slack import SlackDestination -from holmes.plugins.runbooks import (load_builtin_runbooks, - load_runbooks_from_file) -from holmes.plugins.sources.github import GitHubSource -from holmes.plugins.sources.jira import JiraSource -from holmes.plugins.sources.opsgenie import OpsGenieSource -from holmes.plugins.sources.pagerduty import PagerDutySource -from holmes.plugins.sources.prometheus.plugin import AlertManagerSource -from holmes.plugins.toolsets import load_builtin_toolsets, load_toolsets_from_file +from typing import Annotated, Any, Dict, List, Optional, get_args, get_origin, get_type_hints + +from pydantic import SecretStr, FilePath from holmes.utils.pydantic_utils import BaseConfig, EnvVarName, load_model_from_file @@ -32,7 +15,7 @@ class LLMProviderType(StrEnum): class BaseLLMConfig(BaseConfig): - llm: LLMProviderType = LLMProviderType.OPENAI + llm_provider: LLMProviderType = LLMProviderType.OPENAI # FIXME: the following settings do not belong here. They define the # configuration of specific actions, and not of the LLM provider. @@ -79,29 +62,40 @@ def _collect_env_vars(cls) -> Dict[str, Any]: vars_dict = {} hints = get_type_hints(cls, include_extras=True) for field_name in cls.model_fields: - if field_name == "llm": + if field_name == "llm_provider": # Handled in load_from_env continue tp_obj = hints[field_name] - for arg in get_args(tp_obj): - if isinstance(arg, EnvVarName): - env_var_name = arg - break + if get_origin(tp_obj) is Annotated: + tp_args = get_args(tp_obj) + base_type = tp_args[0] + for arg in tp_args[1:]: + if isinstance(arg, EnvVarName): + env_var_name = arg + break + else: # no EnvVarName(...) in annotations + env_var_name = field_name.upper() else: + base_type = tp_obj env_var_name = field_name.upper() if env_var_name in os.environ: - vars_dict[field_name] = os.environ[env_var_name] + env_value = os.environ[env_var_name] + if get_origin(base_type) == list: + value = [value.strip() for value in env_value.split(",")] + else: + value = env_value + vars_dict[field_name] = value return vars_dict @classmethod def load_from_env(cls) -> "BaseLLMConfig": - llm_name = os.getenv("LLM_PROVIDER", "OPENAI").lower() - llm_provider_type = LLMProviderType(llm_name) - if llm_provider_type == LLMProviderType.AZURE: + llm_name = os.environ.get("LLM_PROVIDER", "OPENAI").lower() + llm_provider = LLMProviderType(llm_name) + if llm_provider == LLMProviderType.AZURE: final_class = AzureLLMConfig - elif llm_provider_type == LLMProviderType.OPENAI: + elif llm_provider == LLMProviderType.OPENAI: final_class = OpenAILLMConfig - elif llm_provider_type == LLMProviderType.ROBUSTA: + elif llm_provider == LLMProviderType.ROBUSTA: final_class = RobustaLLMConfig else: raise NotImplementedError(f"Unknown LLM {llm_name}") @@ -134,168 +128,9 @@ class RobustaLLMConfig(BaseLLMConfig): class LLMConfig(BaseLLMConfig): - def create_llm(self) -> OpenAI: - if self.llm == LLMProviderType.OPENAI: - return OpenAI( - api_key=self.api_key.get_secret_value() if self.api_key else None, - ) - elif self.llm == LLMProviderType.AZURE: - return AzureOpenAI( - api_key=self.api_key.get_secret_value() if self.api_key else None, - azure_endpoint=self.azure_endpoint, - api_version=self.azure_api_version, - ) - else: - raise ValueError(f"Unknown LLM type: {self.llm}") - - def _create_tool_executor( - self, console: Console, allowed_toolsets: ToolsetPattern - ) -> YAMLToolExecutor: - all_toolsets = load_builtin_toolsets() - for ts_path in self.custom_toolsets: - all_toolsets.extend(load_toolsets_from_file(ts_path)) - - if allowed_toolsets == "*": - matching_toolsets = all_toolsets - else: - matching_toolsets = get_matching_toolsets( - all_toolsets, allowed_toolsets.split(",") - ) - - enabled_toolsets = [ts for ts in matching_toolsets if ts.is_enabled()] - for ts in all_toolsets: - if ts not in matching_toolsets: - console.print( - f"[yellow]Disabling toolset {ts.name} [/yellow] from {ts.get_path()}" - ) - elif ts not in enabled_toolsets: - console.print( - f"[yellow]Not loading toolset {ts.name}[/yellow] ({ts.get_disabled_reason()})" - ) - #console.print(f"[red]The following tools will be disabled: {[t.name for t in ts.tools]}[/red])") - else: - logging.debug(f"Loaded toolset {ts.name} from {ts.get_path()}") - # console.print(f"[green]Loaded toolset {ts.name}[/green] from {ts.get_path()}") - - enabled_tools = concat(*[ts.tools for ts in enabled_toolsets]) - logging.debug( - f"Starting AI session with tools: {[t.name for t in enabled_tools]}" - ) - return YAMLToolExecutor(enabled_toolsets) - - def create_toolcalling_llm( - self, console: Console, allowed_toolsets: ToolsetPattern - ) -> ToolCallingLLM: - tool_executor = self._create_tool_executor(console, allowed_toolsets) - return ToolCallingLLM( - self.create_llm(), - self.model, - tool_executor, - self.max_steps, - ) - - def create_issue_investigator( - self, console: Console, allowed_toolsets: ToolsetPattern - ) -> IssueInvestigator: - all_runbooks = load_builtin_runbooks() - for runbook_path in self.custom_runbooks: - all_runbooks.extend(load_runbooks_from_file(runbook_path)) - - runbook_manager = RunbookManager(all_runbooks) - tool_executor = self._create_tool_executor(console, allowed_toolsets) - return IssueInvestigator( - self.create_llm(), - self.model, - tool_executor, - runbook_manager, - self.max_steps, - ) - - def create_jira_source(self) -> JiraSource: - if self.jira_url is None: - raise ValueError("--jira-url must be specified") - if not ( - self.jira_url.startswith("http://") or self.jira_url.startswith("https://") - ): - raise ValueError("--jira-url must start with http:// or https://") - if self.jira_username is None: - raise ValueError("--jira-username must be specified") - if self.jira_api_key is None: - raise ValueError("--jira-api-key must be specified") - - return JiraSource( - url=self.jira_url, - username=self.jira_username, - api_key=self.jira_api_key.get_secret_value(), - jql_query=self.jira_query, - ) - - def create_github_source(self) -> GitHubSource: - if not ( - self.github_url.startswith( - "http://") or self.github_url.startswith("https://") - ): - raise ValueError("--github-url must start with http:// or https://") - if self.github_owner is None: - raise ValueError("--github-owner must be specified") - if self.github_repository is None: - raise ValueError("--github-repository must be specified") - if self.github_pat is None: - raise ValueError("--github-pat must be specified") - - return GitHubSource( - url=self.github_url, - owner=self.github_owner, - pat=self.github_pat.get_secret_value(), - repository=self.github_repository, - query=self.github_query, - ) - - def create_pagerduty_source(self) -> OpsGenieSource: - if self.pagerduty_api_key is None: - raise ValueError("--pagerduty-api-key must be specified") - - return PagerDutySource( - api_key=self.pagerduty_api_key.get_secret_value(), - user_email=self.pagerduty_user_email, - incident_key=self.pagerduty_incident_key, - ) - - def create_opsgenie_source(self) -> OpsGenieSource: - if self.opsgenie_api_key is None: - raise ValueError("--opsgenie-api-key must be specified") - - return OpsGenieSource( - api_key=self.opsgenie_api_key.get_secret_value(), - query=self.opsgenie_query, - team_integration_key=self.opsgenie_team_integration_key.get_secret_value() if self.opsgenie_team_integration_key else None, - ) - - def create_alertmanager_source(self) -> AlertManagerSource: - if self.alertmanager_url is None: - raise ValueError("--alertmanager-url must be specified") - if not ( - self.alertmanager_url.startswith("http://") - or self.alertmanager_url.startswith("https://") - ): - raise ValueError("--alertmanager-url must start with http:// or https://") - - return AlertManagerSource( - url=self.alertmanager_url, - username=self.alertmanager_username, - password=self.alertmanager_password, - alertname=self.alertmanager_alertname, - ) - - def create_slack_destination(self): - if self.slack_token is None: - raise ValueError("--slack-token must be specified") - if self.slack_channel is None: - raise ValueError("--slack-channel must be specified") - return SlackDestination(self.slack_token.get_secret_value(), self.slack_channel) - @classmethod - def load_from_file(cls, config_file: Optional[str], **kwargs) -> "Config": + def load_from_file(cls, config_file: Optional[str], **kwargs) -> "BaseLLMConfig": + # FIXME! if config_file is not None: logging.debug("Loading config from file %s", config_file) config_from_file = load_model_from_file(cls, config_file) @@ -313,7 +148,26 @@ def load_from_file(cls, config_file: Optional[str], **kwargs) -> "Config": merged_config = config_from_file.dict() # remove Nones to avoid overriding config file with empty cli args cli_overrides = { - k: v for k, v in config_from_cli.dict().items() if v is not None and v != [] + k: v for k, v in config_from_cli.model_dump().items() if v is not None and v != [] } merged_config.update(cli_overrides) return cls(**merged_config) + + +class BaseOpenAIConfig(BaseLLMConfig): + model: Annotated[Optional[str], EnvVarName("AI_MODEL")] = "gpt-4o" + max_steps: Optional[int] = 10 + + +class OpenAILLMConfig(BaseOpenAIConfig): + api_key: Annotated[Optional[SecretStr], EnvVarName("OPENAI_API_KEY")] + + +class AzureLLMConfig(BaseOpenAIConfig): + api_key: Annotated[Optional[SecretStr], EnvVarName("AZURE_API_KEY")] + endpoint: Annotated[Optional[str], EnvVarName("AZURE_ENDPOINT")] + azure_api_version: Optional[str] = "2024-02-01" + + +class RobustaLLMConfig(BaseOpenAIConfig): + url: Annotated[str, EnvVarName("ROBUSTA_AI_URL")] diff --git a/holmes/core/provider.py b/holmes/core/provider.py new file mode 100644 index 00000000..0d281ece --- /dev/null +++ b/holmes/core/provider.py @@ -0,0 +1,182 @@ +import logging + +from openai import AzureOpenAI, OpenAI +from pydash.arrays import concat +from rich.console import Console + +from holmes.config import BaseLLMConfig, LLMProviderType +from holmes.core.robusta_ai import RobustaAIToolCallingLLM +from holmes.core.runbooks import RunbookManager +from holmes.core.tool_calling_llm import ( + BaseToolCallingLLM, + IssueInvestigator, + OpenAIToolCallingLLM, + YAMLToolExecutor, +) +from holmes.core.tools import ToolsetPattern, get_matching_toolsets +from holmes.plugins.destinations.slack import SlackDestination +from holmes.plugins.runbooks import load_builtin_runbooks, load_runbooks_from_file +from holmes.plugins.sources.jira import JiraSource +from holmes.plugins.sources.github import GitHubSource +from holmes.plugins.sources.prometheus.plugin import AlertManagerSource +from holmes.plugins.toolsets import load_builtin_toolsets, load_toolsets_from_file +from holmes.utils.auth import SessionManager + + +class LLMProviderFactory: + def __init__(self, config: BaseLLMConfig, session_manager: SessionManager = None): + self.config = config + self.session_manager = session_manager + + def create_llm(self) -> OpenAI: + if self.config.llm_provider == LLMProviderType.OPENAI: + return OpenAI( + api_key=self.config.api_key.get_secret_value() if self.config.api_key else None, + ) + elif self.config.llm_provider == LLMProviderType.AZURE: + return AzureOpenAI( + api_key=self.config.api_key.get_secret_value() if self.config.api_key else None, + azure_endpoint=self.config.azure_endpoint, + api_version=self.config.azure_api_version, + ) + else: + raise ValueError(f"Unknown LLM type: {self.config.llm_provider}") + + def create_toolcalling_llm(self, console: Console, allowed_toolsets: ToolsetPattern) -> BaseToolCallingLLM: + if self.config.llm_provider in [LLMProviderType.OPENAI, LLMProviderType.AZURE]: + tool_executor = self._create_tool_executor(console, allowed_toolsets) + return OpenAIToolCallingLLM( + self.create_llm(), + self.config.model, + tool_executor, + self.config.max_steps, + ) + else: + return RobustaAIToolCallingLLM( + self.config, self.session_manager + ) + + def create_issue_investigator(self, console: Console, allowed_toolsets: ToolsetPattern) -> IssueInvestigator: + all_runbooks = load_builtin_runbooks() + for runbook_path in self.config.custom_runbooks: + all_runbooks.extend(load_runbooks_from_file(runbook_path)) + + runbook_manager = RunbookManager(all_runbooks) + tool_executor = self._create_tool_executor(console, allowed_toolsets) + return IssueInvestigator( + self.create_llm(), + self.config.model, + tool_executor, + runbook_manager, + self.config.max_steps, + ) + + def create_jira_source(self) -> JiraSource: + if self.config.jira_url is None: + raise ValueError("--jira-url must be specified") + if not (self.config.jira_url.startswith("http://") or self.config.jira_url.startswith("https://")): + raise ValueError("--jira-url must start with http:// or https://") + if self.config.jira_username is None: + raise ValueError("--jira-username must be specified") + if self.config.jira_api_key is None: + raise ValueError("--jira-api-key must be specified") + + return JiraSource( + url=self.config.jira_url, + username=self.config.jira_username, + api_key=self.config.jira_api_key.get_secret_value(), + jql_query=self.config.jira_query, + ) + + def create_github_source(self) -> GitHubSource: + if not (self.config.github_url.startswith("http://") or self.config.github_url.startswith("https://")): + raise ValueError("--github-url must start with http:// or https://") + if self.config.github_owner is None: + raise ValueError("--github-owner must be specified") + if self.config.github_repository is None: + raise ValueError("--github-repository must be specified") + if self.config.github_pat is None: + raise ValueError("--github-pat must be specified") + + return GitHubSource( + url=self.config.github_url, + owner=self.config.github_owner, + pat=self.config.github_pat.get_secret_value(), + repository=self.config.github_repository, + query=self.config.github_query, + ) + + def create_pagerduty_source(self) -> PagerDutySource: + if self.pagerduty_api_key is None: + raise ValueError("--pagerduty-api-key must be specified") + + return PagerDutySource( + api_key=self.pagerduty_api_key.get_secret_value(), + user_email=self.pagerduty_user_email, + incident_key=self.pagerduty_incident_key, + ) + + def create_opsgenie_source(self) -> OpsGenieSource: + if self.opsgenie_api_key is None: + raise ValueError("--opsgenie-api-key must be specified") + + return OpsGenieSource( + api_key=self.opsgenie_api_key.get_secret_value(), + query=self.opsgenie_query, + team_integration_key=self.opsgenie_team_integration_key.get_secret_value() if self.opsgenie_team_integration_key else None, + ) + + def create_alertmanager_source(self) -> AlertManagerSource: + if self.config.alertmanager_url is None: + raise ValueError("--alertmanager-url must be specified") + if not ( + self.config.alertmanager_url.startswith("http://") or self.config.alertmanager_url.startswith("https://") + ): + raise ValueError("--alertmanager-url must start with http:// or https://") + + return AlertManagerSource( + url=self.config.alertmanager_url, + username=self.config.alertmanager_username, + password=self.config.alertmanager_password, + alertname=self.alertmanager_alertname, + ) + + def create_slack_destination(self): + if self.config.slack_token is None: + raise ValueError("--slack-token must be specified") + if self.config.slack_channel is None: + raise ValueError("--slack-channel must be specified") + return SlackDestination(self.config.slack_token.get_secret_value(), self.config.slack_channel) + + def _create_tool_executor(self, console: Console, allowed_toolsets: ToolsetPattern) -> YAMLToolExecutor: + all_toolsets = load_builtin_toolsets() + for ts_path in self.custom_toolsets: + all_toolsets.extend(load_toolsets_from_file(ts_path)) + + if allowed_toolsets == "*": + matching_toolsets = all_toolsets + else: + matching_toolsets = get_matching_toolsets( + all_toolsets, allowed_toolsets.split(",") + ) + + enabled_toolsets = [ts for ts in matching_toolsets if ts.is_enabled()] + for ts in all_toolsets: + if ts not in matching_toolsets: + console.print( + f"[yellow]Disabling toolset {ts.name} [/yellow] from {ts.get_path()}" + ) + elif ts not in enabled_toolsets: + console.print( + f"[yellow]Not loading toolset {ts.name}[/yellow] ({ts.get_disabled_reason()})" + ) + #console.print(f"[red]The following tools will be disabled: {[t.name for t in ts.tools]}[/red])") + else: + logging.debug(f"Loaded toolset {ts.name} from {ts.get_path()}") + # console.print(f"[green]Loaded to olset {ts.name}[/green] from {ts.get_path()}") + + enabled_tools = concat(*[ts.tools for ts in enabled_toolsets]) + logging.debug( + f"Starting AI session with tools: {[t.name for t in enabled_tools]}" + ) + return YAMLToolExecutor(enabled_toolsets) diff --git a/holmes/core/robusta_ai.py b/holmes/core/robusta_ai.py new file mode 100644 index 00000000..1d4c06b4 --- /dev/null +++ b/holmes/core/robusta_ai.py @@ -0,0 +1,54 @@ +# TODO finish refactor +import logging + +import jinja2 +import requests +from fastapi import HTTPException + +from holmes.core.tool_calling_llm import BaseToolCallingLLM, LLMResult +from holmes.plugins.prompts import load_prompt +from holmes.utils.auth import SessionManager +from holmes.core.server_models import InvestigateRequest + + +class RobustaAIToolCallingLLM(BaseToolCallingLLM): + def __init__(self, base_url: str, session_manager: SessionManager): + self.base_url = base_url + self.session_manager = session_manager + + def call(self, system_prompt: str, user_prompt: str) -> LLMResult: + pass + + def run_analysis(self, request: InvestigateRequest, issue): + # TODO refactor + """Delegate the AI analysis to Robusta AI running as a + separate service.""" + environment = jinja2.Environment() + sys_prompt_template = environment.from_string(load_prompt(request.system_prompt)) + # TODO what about runbooks? + sys_prompt = sys_prompt_template.render(issue=issue, runbooks=[]) + # TODO do we want a new token each time? + auth_token = self.session_manager.get_current_token() + payload = { + "auth": { + "account_id": auth_token.account_id, + "token": auth_token.token + }, + "system_message": sys_prompt, + "user_message": str(request.model_dump()), + "model": request.model, + } + resp = requests.post(self.base_url + "/api/ai", json=payload) + if resp.status_code == 401: + self.session_manager.invalidate_token(auth_token) + # Attempt auth again using a fresh token + auth_token = self.session_manager.create_token() + payload["auth"]["account_id"] = auth_token.account_id + payload["auth"]["token"] = auth_token.token + resp = requests.post(self.base_url + "/api/ai", json=payload) + if resp.status_code != 200: + logging.error(f"Failed to reauth with Robusta AI. Response status {resp.status_code}, content: {resp.text}") + raise HTTPException(status_code=400, detail="Unable to auth with Robusta AI") + # TODO reformat Robusta AI response to conform to the expected Holmes response + # format. + return resp.json() diff --git a/holmes/core/server_models.py b/holmes/core/server_models.py new file mode 100644 index 00000000..dc5b332e --- /dev/null +++ b/holmes/core/server_models.py @@ -0,0 +1,21 @@ +from typing import Union, List + +from pydantic import BaseModel + + +class InvestigateContext(BaseModel): + type: str + value: Union[str, dict] + + +class InvestigateRequest(BaseModel): + source: str # "prometheus" etc + title: str + description: str + subject: dict + context: List[InvestigateContext] + source_instance_id: str + model: str = "gpt-4o" + system_prompt: str = "builtin://generic_investigation.jinja2" + # TODO in the future + # response_handler: ... diff --git a/holmes/core/tool_calling_llm.py b/holmes/core/tool_calling_llm.py index 7f05b7fd..eba7cc5f 100644 --- a/holmes/core/tool_calling_llm.py +++ b/holmes/core/tool_calling_llm.py @@ -1,14 +1,12 @@ -import datetime +import abc import json import logging import textwrap -from typing import Dict, Generator, List, Optional +from typing import List, Optional import jinja2 from openai import BadRequestError, OpenAI from openai._types import NOT_GIVEN -from openai.types.chat.chat_completion import Choice -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from pydantic import BaseModel from rich.console import Console @@ -35,8 +33,13 @@ def get_tool_usage_summary(self): ) -class ToolCallingLLM: +class BaseToolCallingLLM(abc.ABC): + @abc.abstractmethod + def call(self, system_prompt: str, user_prompt: str) -> LLMResult: + raise NotImplementedError + +class OpenAIToolCallingLLM(BaseToolCallingLLM): def __init__( self, client: OpenAI, @@ -49,7 +52,7 @@ def __init__( self.max_steps = max_steps self.model = model - def call(self, system_prompt, user_prompt) -> LLMResult: + def call(self, system_prompt: str, user_prompt: str) -> LLMResult: messages = [ { "role": "system", @@ -131,7 +134,7 @@ def call(self, system_prompt, user_prompt) -> LLMResult: # TODO: consider getting rid of this entirely and moving templating into the cmds in holmes.py -class IssueInvestigator(ToolCallingLLM): +class IssueInvestigator(OpenAIToolCallingLLM): """ Thin wrapper around ToolCallingLLM which: 1) Provides a default prompt for RCA diff --git a/holmes/plugins/sources/jira/__init__.py b/holmes/plugins/sources/jira/__init__.py index 764354cd..d25915ee 100644 --- a/holmes/plugins/sources/jira/__init__.py +++ b/holmes/plugins/sources/jira/__init__.py @@ -1,15 +1,12 @@ import logging -from typing import List, Literal, Optional, Pattern +from typing import List, Pattern -import humanize import requests -from pydantic import BaseModel, SecretStr, ValidationError, parse_obj_as, validator from requests.auth import HTTPBasicAuth from holmes.core.issue import Issue -from holmes.core.tool_calling_llm import LLMResult, ToolCallingLLM, ToolCallResult +from holmes.core.tool_calling_llm import LLMResult from holmes.plugins.interfaces import SourcePlugin -from holmes.plugins.utils import dict_to_markdown class JiraSource(SourcePlugin): diff --git a/server.py b/server.py index 1588c354..1dba2e7d 100644 --- a/server.py +++ b/server.py @@ -1,8 +1,5 @@ import os -import jinja2 - -from holmes.utils.auth import SessionManager from holmes.utils.cert_utils import add_custom_certificate ADDITIONAL_CERTIFICATE: str = os.environ.get("CERTIFICATE", "") @@ -17,38 +14,21 @@ import colorlog import uvicorn -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel +from fastapi import FastAPI from rich.console import Console from holmes.common.env_vars import ( + ALLOWED_TOOLSETS, HOLMES_HOST, HOLMES_PORT, ) -from holmes.config import LLMConfig +from holmes.config import BaseLLMConfig from holmes.core.issue import Issue -from holmes.core.supabase_dal import AuthToken, SupabaseDal +from holmes.core.provider import LLMProviderFactory +from holmes.core.server_models import InvestigateContext, InvestigateRequest +from holmes.core.supabase_dal import SupabaseDal from holmes.plugins.prompts import load_prompt - - -class InvestigateContext(BaseModel): - type: str - value: Union[str, dict] - - -class InvestigateRequest(BaseModel): - source: str # "prometheus" etc - title: str - description: str - subject: dict - context: List[InvestigateContext] - source_instance_id: str - include_tool_calls: bool = False - include_tool_call_results: bool = False - prompt_template: str = "builtin://generic_investigation.jinja2" - model: str = "gpt-4o" - # TODO in the future - # response_handler: ... +from holmes.utils.auth import SessionManager def init_logging(): @@ -67,13 +47,22 @@ def init_logging(): init_logging() -logging.info(f"Starting AI server with {AI_AGENT=}, {ROBUSTA_AI_URL=}") +config = BaseLLMConfig.load_from_env() +logging.info(f"Starting AI server with config: {config}") dal = SupabaseDal() session_manager = SessionManager(dal, "RelayHolmes") +provider_factory = LLMProviderFactory(config, session_manager=session_manager) app = FastAPI() console = Console() -config = LLMConfig.load_from_env() + + +def fetch_context_data(context: List[InvestigateContext]) -> dict: + for context_item in context: + if context_item.type == "robusta_issue_id": + # Note we only accept a single robusta_issue_id. I don't think it + # makes sense to have several of them in the context structure. + return dal.get_issue_data(context_item.value) @app.post("/api/investigate") @@ -92,9 +81,9 @@ def investigate_issue(request: InvestigateRequest): source_instance_id=request.source_instance_id, raw=raw_data, ) - investigation = ai.investigate( + investigator = provider_factory.create_issue_investigator(console, allowed_toolsets=ALLOWED_TOOLSETS) + investigation = investigator.investigate( issue, - # TODO prompt should probably be configurable? prompt=load_prompt(request.prompt), console=console, ) @@ -114,13 +103,5 @@ def investigate_issue(request: InvestigateRequest): return ret -def fetch_context_data(context: List[InvestigateContext]) -> dict: - for context_item in context: - if context_item.type == "robusta_issue_id": - # Note we only accept a single robusta_issue_id. I don't think it - # makes sense to have several of them in the context structure. - return dal.get_issue_data(context_item.value) - - if __name__ == "__main__": uvicorn.run(app, host=HOLMES_HOST, port=HOLMES_PORT) From f62ef46b9e50943d2bab1f4f75e56e282f6d0fd4 Mon Sep 17 00:00:00 2001 From: Robert Szefler Date: Wed, 12 Jun 2024 11:56:08 +0200 Subject: [PATCH 6/9] Robusta AI calling fixes --- holmes/core/provider.py | 37 +++++++++--------- holmes/core/robusta_ai.py | 44 +++++++++------------- holmes/core/supabase_dal.py | 2 + holmes/core/tool_calling_llm.py | 67 ++++++++++++++------------------- server.py | 15 ++++++-- 5 files changed, 80 insertions(+), 85 deletions(-) diff --git a/holmes/core/provider.py b/holmes/core/provider.py index 0d281ece..c55a860c 100644 --- a/holmes/core/provider.py +++ b/holmes/core/provider.py @@ -5,11 +5,12 @@ from rich.console import Console from holmes.config import BaseLLMConfig, LLMProviderType -from holmes.core.robusta_ai import RobustaAIToolCallingLLM +from holmes.core.robusta_ai import RobustaAIToolCallingLLM, RobustaIssueInvestigator from holmes.core.runbooks import RunbookManager from holmes.core.tool_calling_llm import ( + BaseIssueInvestigator, BaseToolCallingLLM, - IssueInvestigator, + OpenAIIssueInvestigator, OpenAIToolCallingLLM, YAMLToolExecutor, ) @@ -31,11 +32,11 @@ def __init__(self, config: BaseLLMConfig, session_manager: SessionManager = None def create_llm(self) -> OpenAI: if self.config.llm_provider == LLMProviderType.OPENAI: return OpenAI( - api_key=self.config.api_key.get_secret_value() if self.config.api_key else None, + api_key=(self.config.api_key.get_secret_value() if self.config.api_key else None), ) elif self.config.llm_provider == LLMProviderType.AZURE: return AzureOpenAI( - api_key=self.config.api_key.get_secret_value() if self.config.api_key else None, + api_key=(self.config.api_key.get_secret_value() if self.config.api_key else None), azure_endpoint=self.config.azure_endpoint, api_version=self.config.azure_api_version, ) @@ -52,24 +53,26 @@ def create_toolcalling_llm(self, console: Console, allowed_toolsets: ToolsetPatt self.config.max_steps, ) else: - return RobustaAIToolCallingLLM( - self.config, self.session_manager - ) + # TODO in the future + return RobustaAIToolCallingLLM() - def create_issue_investigator(self, console: Console, allowed_toolsets: ToolsetPattern) -> IssueInvestigator: + def create_issue_investigator(self, console: Console, allowed_toolsets: ToolsetPattern) -> BaseIssueInvestigator: all_runbooks = load_builtin_runbooks() for runbook_path in self.config.custom_runbooks: all_runbooks.extend(load_runbooks_from_file(runbook_path)) - runbook_manager = RunbookManager(all_runbooks) - tool_executor = self._create_tool_executor(console, allowed_toolsets) - return IssueInvestigator( - self.create_llm(), - self.config.model, - tool_executor, - runbook_manager, - self.config.max_steps, - ) + + if self.config.llm_provider == LLMProviderType.ROBUSTA: + return RobustaIssueInvestigator(self.config.url, self.session_manager, runbook_manager) + else: + tool_executor = self._create_tool_executor(console, allowed_toolsets) + return OpenAIIssueInvestigator( + self.create_llm(), + self.config.model, + tool_executor, + runbook_manager, + self.config.max_steps, + ) def create_jira_source(self) -> JiraSource: if self.config.jira_url is None: diff --git a/holmes/core/robusta_ai.py b/holmes/core/robusta_ai.py index 1d4c06b4..a0cc86df 100644 --- a/holmes/core/robusta_ai.py +++ b/holmes/core/robusta_ai.py @@ -1,42 +1,33 @@ # TODO finish refactor import logging -import jinja2 import requests from fastapi import HTTPException -from holmes.core.tool_calling_llm import BaseToolCallingLLM, LLMResult -from holmes.plugins.prompts import load_prompt +from holmes.core.runbooks import RunbookManager +from holmes.core.tool_calling_llm import BaseIssueInvestigator, BaseToolCallingLLM, LLMResult from holmes.utils.auth import SessionManager -from holmes.core.server_models import InvestigateRequest class RobustaAIToolCallingLLM(BaseToolCallingLLM): - def __init__(self, base_url: str, session_manager: SessionManager): + def __init__(self): + raise NotImplementedError("Robusta AI tool calling LLM is not supported yet") + + +class RobustaIssueInvestigator(BaseIssueInvestigator): + def __init__(self, base_url: str, session_manager: SessionManager, runbook_manager: RunbookManager): self.base_url = base_url self.session_manager = session_manager + self.runbook_manager = runbook_manager def call(self, system_prompt: str, user_prompt: str) -> LLMResult: - pass - - def run_analysis(self, request: InvestigateRequest, issue): - # TODO refactor - """Delegate the AI analysis to Robusta AI running as a - separate service.""" - environment = jinja2.Environment() - sys_prompt_template = environment.from_string(load_prompt(request.system_prompt)) - # TODO what about runbooks? - sys_prompt = sys_prompt_template.render(issue=issue, runbooks=[]) - # TODO do we want a new token each time? auth_token = self.session_manager.get_current_token() payload = { - "auth": { - "account_id": auth_token.account_id, - "token": auth_token.token - }, - "system_message": sys_prompt, - "user_message": str(request.model_dump()), - "model": request.model, + "auth": {"account_id": auth_token.account_id, "token": auth_token.token}, + "system_message": system_prompt, + "user_message": user_prompt, +# TODO +# "model": request.model, } resp = requests.post(self.base_url + "/api/ai", json=payload) if resp.status_code == 401: @@ -47,8 +38,9 @@ def run_analysis(self, request: InvestigateRequest, issue): payload["auth"]["token"] = auth_token.token resp = requests.post(self.base_url + "/api/ai", json=payload) if resp.status_code != 200: - logging.error(f"Failed to reauth with Robusta AI. Response status {resp.status_code}, content: {resp.text}") + logging.error( + f"Failed to reauth with Robusta AI. Response status {resp.status_code}, content: {resp.text}" + ) raise HTTPException(status_code=400, detail="Unable to auth with Robusta AI") - # TODO reformat Robusta AI response to conform to the expected Holmes response - # format. + # TODO LLMResult return resp.json() diff --git a/holmes/core/supabase_dal.py b/holmes/core/supabase_dal.py index e2b7d06f..60e77fa9 100644 --- a/holmes/core/supabase_dal.py +++ b/holmes/core/supabase_dal.py @@ -58,11 +58,13 @@ def __init__(self): self.enabled = self.__init_config() if not self.enabled: logging.info("Robusta store initialization parameters not provided. skipping") + self.initialized = False return logging.info(f"Initializing robusta store for account {self.account_id}, user {self.user_id}") options = ClientOptions(postgrest_client_timeout=SUPABASE_TIMEOUT_SECONDS) self.client = create_client(self.url, self.api_key, options) self.sign_in() + self.initialized = True @staticmethod def __load_robusta_config() -> Optional[RobustaToken]: diff --git a/holmes/core/tool_calling_llm.py b/holmes/core/tool_calling_llm.py index eba7cc5f..cb64c06e 100644 --- a/holmes/core/tool_calling_llm.py +++ b/holmes/core/tool_calling_llm.py @@ -72,11 +72,7 @@ def call(self, system_prompt: str, user_prompt: str) -> LLMResult: tool_choice = NOT_GIVEN if tools == NOT_GIVEN else "auto" try: full_response = self.client.chat.completions.create( - model=self.model, - messages=messages, - tools=tools, - tool_choice=tool_choice, - temperature=0.00000001 + model=self.model, messages=messages, tools=tools, tool_choice=tool_choice, temperature=0.00000001 ) logging.debug(f"got response {full_response}") # catch a known error that occurs with Azure and replace the error message with something more obvious to the user @@ -89,11 +85,7 @@ def call(self, system_prompt: str, user_prompt: str) -> LLMResult: raise response = full_response.choices[0] response_message = response.message - messages.append( - response_message.model_dump( - exclude_defaults=True, exclude_unset=True, exclude_none=True - ) - ) + messages.append(response_message.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)) tools_to_call = response_message.tool_calls if not tools_to_call: @@ -112,9 +104,11 @@ def call(self, system_prompt: str, user_prompt: str) -> LLMResult: tool_params = json.loads(t.function.arguments) tool = self.tool_executor.get_tool_by_name(tool_name) tool_response = tool.invoke(tool_params) - MAX_CHARS = 100_000 # an arbitrary limit - we will do something smarter in the future + MAX_CHARS = 100_000 # an arbitrary limit - we will do something smarter in the future if len(tool_response) > MAX_CHARS: - logging.warning(f"tool {tool_name} returned a very long response ({len(tool_response)} chars) - truncating to last 10000 chars") + logging.warning( + f"tool {tool_name} returned a very long response ({len(tool_response)} chars) - truncating to last 10000 chars" + ) tool_response = tool_response[-MAX_CHARS:] messages.append( { @@ -133,8 +127,29 @@ def call(self, system_prompt: str, user_prompt: str) -> LLMResult: ) -# TODO: consider getting rid of this entirely and moving templating into the cmds in holmes.py -class IssueInvestigator(OpenAIToolCallingLLM): +# TODO: consider getting rid of this entirely and moving templating into the cmds in holmes.py +class BaseIssueInvestigator: + def call(self, system_prompt: str, user_prompt: str) -> LLMResult: + raise NotImplementedError() + + def investigate(self, issue: Issue, prompt: str, console: Console) -> LLMResult: + environment = jinja2.Environment() + system_prompt_template = environment.from_string(prompt) + runbooks = self.runbook_manager.get_instructions_for_issue(issue) + if runbooks: + console.print(f"[bold]Analyzing with {len(runbooks)} runbooks: {runbooks}[/bold]") + else: + console.print( + f"[bold]No runbooks found for this issue. Using default behaviour. (Add runbooks to guide the investigation.)[/bold]" + ) + system_prompt = system_prompt_template.render(issue=issue, runbooks=runbooks) + user_prompt = f"{issue.raw}" + logging.debug("Rendered system prompt:\n%s", textwrap.indent(system_prompt, " ")) + logging.debug("Rendered user prompt:\n%s", textwrap.indent(user_prompt, " ")) + return self.call(system_prompt, user_prompt) + + +class OpenAIIssueInvestigator(BaseIssueInvestigator, OpenAIToolCallingLLM): """ Thin wrapper around ToolCallingLLM which: 1) Provides a default prompt for RCA @@ -152,27 +167,3 @@ def __init__( ): super().__init__(client, model, tool_executor, max_steps) self.runbook_manager = runbook_manager - - def investigate( - self, issue: Issue, prompt: str, console: Console - ) -> LLMResult: - environment = jinja2.Environment() - system_prompt_template = environment.from_string(prompt) - runbooks = self.runbook_manager.get_instructions_for_issue(issue) - if runbooks: - console.print( - f"[bold]Analyzing with {len(runbooks)} runbooks: {runbooks}[/bold]" - ) - else: - console.print( - f"[bold]No runbooks found for this issue. Using default behaviour. (Add runbooks to guide the investigation.)[/bold]" - ) - system_prompt = system_prompt_template.render(issue=issue, runbooks=runbooks) - user_prompt = f"{issue.raw}" - logging.debug( - "Rendered system prompt:\n%s", textwrap.indent(system_prompt, " ") - ) - logging.debug( - "Rendered user prompt:\n%s", textwrap.indent(user_prompt, " ") - ) - return self.call(system_prompt, user_prompt) diff --git a/server.py b/server.py index 1dba2e7d..457a497d 100644 --- a/server.py +++ b/server.py @@ -1,4 +1,5 @@ import os +import sys from holmes.utils.cert_utils import add_custom_certificate @@ -22,7 +23,7 @@ HOLMES_HOST, HOLMES_PORT, ) -from holmes.config import BaseLLMConfig +from holmes.config import BaseLLMConfig, LLMProviderType from holmes.core.issue import Issue from holmes.core.provider import LLMProviderFactory from holmes.core.server_models import InvestigateContext, InvestigateRequest @@ -36,7 +37,9 @@ def init_logging(): logging_format = "%(log_color)s%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s" logging_datefmt = "%Y-%m-%d %H:%M:%S" - colorlog.basicConfig(format=logging_format, level=logging_level, datefmt=logging_datefmt) + colorlog.basicConfig( + format=logging_format, level=logging_level, datefmt=logging_datefmt + ) logging.getLogger().setLevel(logging_level) httpx_logger = logging.getLogger("httpx") @@ -47,14 +50,18 @@ def init_logging(): init_logging() +console = Console() config = BaseLLMConfig.load_from_env() logging.info(f"Starting AI server with config: {config}") dal = SupabaseDal() + +if not dal.initialized and config.llm_provider == LLMProviderType.ROBUSTA: + logging.error("Holmes cannot run without store configuration when the LLM provider is Robusta AI") + sys.exit(1) session_manager = SessionManager(dal, "RelayHolmes") -provider_factory = LLMProviderFactory(config, session_manager=session_manager) +provider_factory = LLMProviderFactory(config, session_manager) app = FastAPI() -console = Console() def fetch_context_data(context: List[InvestigateContext]) -> dict: From ae1540a5bc73d7023602665c6d3972b723bff7e0 Mon Sep 17 00:00:00 2001 From: Robert Szefler Date: Wed, 12 Jun 2024 12:30:42 +0200 Subject: [PATCH 7/9] token handling fixes --- holmes/core/robusta_ai.py | 3 +++ holmes/core/supabase_dal.py | 11 +++++++---- server.py | 6 +----- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/holmes/core/robusta_ai.py b/holmes/core/robusta_ai.py index a0cc86df..dd53315f 100644 --- a/holmes/core/robusta_ai.py +++ b/holmes/core/robusta_ai.py @@ -22,6 +22,9 @@ def __init__(self, base_url: str, session_manager: SessionManager, runbook_manag def call(self, system_prompt: str, user_prompt: str) -> LLMResult: auth_token = self.session_manager.get_current_token() + if auth_token is None: + auth_token = self.session_manager.create_token() + payload = { "auth": {"account_id": auth_token.account_id, "token": auth_token.token}, "system_message": system_prompt, diff --git a/holmes/core/supabase_dal.py b/holmes/core/supabase_dal.py index 60e77fa9..95e6bc95 100644 --- a/holmes/core/supabase_dal.py +++ b/holmes/core/supabase_dal.py @@ -172,7 +172,7 @@ def create_auth_token(self, token_type: str) -> AuthToken: { "account_id": self.account_id, "user_id": self.user_id, - "token": uuid4(), + "token": str(uuid4()), "type": token_type, } ) @@ -180,17 +180,20 @@ def create_auth_token(self, token_type: str) -> AuthToken: ) return AuthToken(**result.data[0]) - def get_freshest_auth_token(self, token_type: str) -> AuthToken: + def get_freshest_auth_token(self, token_type: str) -> Optional[AuthToken]: result = ( self.client.table(TOKENS_TABLE) .select("*") - .filter("token_type", "eq", token_type) + .filter("type", "eq", token_type) .filter("deleted", "eq", False) .order("created_at", desc=True) .limit(1) .execute() ) - return AuthToken(**result.data[0]) + if not result.data: + return None + else: + return AuthToken(**result.data[0]) def invalidate_auth_token(self, token: AuthToken) -> None: ( diff --git a/server.py b/server.py index 457a497d..9ce8cf34 100644 --- a/server.py +++ b/server.py @@ -18,11 +18,7 @@ from fastapi import FastAPI from rich.console import Console -from holmes.common.env_vars import ( - ALLOWED_TOOLSETS, - HOLMES_HOST, - HOLMES_PORT, -) +from holmes.common.env_vars import ALLOWED_TOOLSETS, HOLMES_HOST, HOLMES_PORT from holmes.config import BaseLLMConfig, LLMProviderType from holmes.core.issue import Issue from holmes.core.provider import LLMProviderFactory From 1065a32c7475d45b6f8e9818a9732229fa987e22 Mon Sep 17 00:00:00 2001 From: Robert Szefler Date: Wed, 12 Jun 2024 15:50:59 +0200 Subject: [PATCH 8/9] some minor fixes; better error handling --- holmes/core/robusta_ai.py | 32 +++++++++++++++++++++----------- holmes/core/tool_calling_llm.py | 4 ++++ server.py | 19 +++++++++++-------- 3 files changed, 36 insertions(+), 19 deletions(-) diff --git a/holmes/core/robusta_ai.py b/holmes/core/robusta_ai.py index dd53315f..eb4883e2 100644 --- a/holmes/core/robusta_ai.py +++ b/holmes/core/robusta_ai.py @@ -1,14 +1,16 @@ -# TODO finish refactor import logging import requests -from fastapi import HTTPException from holmes.core.runbooks import RunbookManager -from holmes.core.tool_calling_llm import BaseIssueInvestigator, BaseToolCallingLLM, LLMResult +from holmes.core.tool_calling_llm import BaseIssueInvestigator, BaseToolCallingLLM, LLMError, LLMResult from holmes.utils.auth import SessionManager +class RobustaAICallError(LLMError): + pass + + class RobustaAIToolCallingLLM(BaseToolCallingLLM): def __init__(self): raise NotImplementedError("Robusta AI tool calling LLM is not supported yet") @@ -27,12 +29,18 @@ def call(self, system_prompt: str, user_prompt: str) -> LLMResult: payload = { "auth": {"account_id": auth_token.account_id, "token": auth_token.token}, - "system_message": system_prompt, - "user_message": user_prompt, -# TODO -# "model": request.model, + "body": { + "system_message": system_prompt, + "user_message": user_prompt, +# TODO? +# "model": request.model, + }, } - resp = requests.post(self.base_url + "/api/ai", json=payload) + try: + resp = requests.post(f"{self.base_url}/api/ai", json=payload) + except: + logging.exception("Robusta AI API call failed") + raise RobustaAICallError("Robusta AI API call failed") if resp.status_code == 401: self.session_manager.invalidate_token(auth_token) # Attempt auth again using a fresh token @@ -44,6 +52,8 @@ def call(self, system_prompt: str, user_prompt: str) -> LLMResult: logging.error( f"Failed to reauth with Robusta AI. Response status {resp.status_code}, content: {resp.text}" ) - raise HTTPException(status_code=400, detail="Unable to auth with Robusta AI") - # TODO LLMResult - return resp.json() + raise RobustaAICallError("Unable to auth with Robusta AI") + resp_data = resp.json() + if not resp_data["success"]: + raise RobustaAICallError("Robusta AI API call failed") + return LLMResult(result=resp_data["msg"], prompt=user_prompt) diff --git a/holmes/core/tool_calling_llm.py b/holmes/core/tool_calling_llm.py index cb64c06e..09c2f8db 100644 --- a/holmes/core/tool_calling_llm.py +++ b/holmes/core/tool_calling_llm.py @@ -15,6 +15,10 @@ from holmes.core.tools import YAMLToolExecutor +class LLMError(Exception): + pass + + class ToolCallResult(BaseModel): tool_name: str description: str diff --git a/server.py b/server.py index 9ce8cf34..4af77081 100644 --- a/server.py +++ b/server.py @@ -15,7 +15,7 @@ import colorlog import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from rich.console import Console from holmes.common.env_vars import ALLOWED_TOOLSETS, HOLMES_HOST, HOLMES_PORT @@ -24,6 +24,7 @@ from holmes.core.provider import LLMProviderFactory from holmes.core.server_models import InvestigateContext, InvestigateRequest from holmes.core.supabase_dal import SupabaseDal +from holmes.core.tool_calling_llm import LLMError from holmes.plugins.prompts import load_prompt from holmes.utils.auth import SessionManager @@ -54,12 +55,11 @@ def init_logging(): if not dal.initialized and config.llm_provider == LLMProviderType.ROBUSTA: logging.error("Holmes cannot run without store configuration when the LLM provider is Robusta AI") sys.exit(1) -session_manager = SessionManager(dal, "RelayHolmes") +session_manager = SessionManager(dal, "AIRelay") provider_factory = LLMProviderFactory(config, session_manager) app = FastAPI() - def fetch_context_data(context: List[InvestigateContext]) -> dict: for context_item in context: if context_item.type == "robusta_issue_id": @@ -85,11 +85,14 @@ def investigate_issue(request: InvestigateRequest): raw=raw_data, ) investigator = provider_factory.create_issue_investigator(console, allowed_toolsets=ALLOWED_TOOLSETS) - investigation = investigator.investigate( - issue, - prompt=load_prompt(request.prompt), - console=console, - ) + try: + investigation = investigator.investigate( + issue, + prompt=load_prompt(request.system_prompt), + console=console, + ) + except LLMError as exc: + raise HTTPException(status_code=500, detail=f"Error calling the LLM provider: {str(exc)}") ret = { "analysis": investigation.result } From 9585fc9f9d1d42dc2b4721e9b5252be9473dfc6c Mon Sep 17 00:00:00 2001 From: Robert Szefler Date: Fri, 14 Jun 2024 10:21:10 +0200 Subject: [PATCH 9/9] post rebase fixes --- holmes.py | 2 +- holmes/core/provider.py | 44 ++++++++++++++++++++--------------------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/holmes.py b/holmes.py index 20d551f9..53d9f71c 100644 --- a/holmes.py +++ b/holmes.py @@ -34,7 +34,7 @@ llm_provider_names = ", ".join(str(tp) for tp in LLMProviderType) opt_llm: Optional[LLMProviderType] = typer.Option( LLMProviderType.OPENAI, - help="LLM provider (supported values: {llm_provider_names})" + help=f"LLM provider (supported values: {llm_provider_names})" ) opt_api_key: Optional[str] = typer.Option( None, diff --git a/holmes/core/provider.py b/holmes/core/provider.py index c55a860c..9d11161d 100644 --- a/holmes/core/provider.py +++ b/holmes/core/provider.py @@ -19,6 +19,8 @@ from holmes.plugins.runbooks import load_builtin_runbooks, load_runbooks_from_file from holmes.plugins.sources.jira import JiraSource from holmes.plugins.sources.github import GitHubSource +from holmes.plugins.sources.opsgenie import OpsGenieSource +from holmes.plugins.sources.pagerduty import PagerDutySource from holmes.plugins.sources.prometheus.plugin import AlertManagerSource from holmes.plugins.toolsets import load_builtin_toolsets, load_toolsets_from_file from holmes.utils.auth import SessionManager @@ -110,23 +112,27 @@ def create_github_source(self) -> GitHubSource: ) def create_pagerduty_source(self) -> PagerDutySource: - if self.pagerduty_api_key is None: + if self.config.pagerduty_api_key is None: raise ValueError("--pagerduty-api-key must be specified") return PagerDutySource( - api_key=self.pagerduty_api_key.get_secret_value(), - user_email=self.pagerduty_user_email, - incident_key=self.pagerduty_incident_key, + api_key=self.config.pagerduty_api_key.get_secret_value(), + user_email=self.config.pagerduty_user_email, + incident_key=self.config.pagerduty_incident_key, ) def create_opsgenie_source(self) -> OpsGenieSource: - if self.opsgenie_api_key is None: + if self.config.opsgenie_api_key is None: raise ValueError("--opsgenie-api-key must be specified") return OpsGenieSource( - api_key=self.opsgenie_api_key.get_secret_value(), - query=self.opsgenie_query, - team_integration_key=self.opsgenie_team_integration_key.get_secret_value() if self.opsgenie_team_integration_key else None, + api_key=self.config.opsgenie_api_key.get_secret_value(), + query=self.config.opsgenie_query, + team_integration_key=( + self.config.opsgenie_team_integration_key.get_secret_value() + if self.config.opsgenie_team_integration_key + else None + ), ) def create_alertmanager_source(self) -> AlertManagerSource: @@ -141,7 +147,7 @@ def create_alertmanager_source(self) -> AlertManagerSource: url=self.config.alertmanager_url, username=self.config.alertmanager_username, password=self.config.alertmanager_password, - alertname=self.alertmanager_alertname, + alertname=self.config.alertmanager_alertname, ) def create_slack_destination(self): @@ -153,33 +159,25 @@ def create_slack_destination(self): def _create_tool_executor(self, console: Console, allowed_toolsets: ToolsetPattern) -> YAMLToolExecutor: all_toolsets = load_builtin_toolsets() - for ts_path in self.custom_toolsets: + for ts_path in self.config.custom_toolsets: all_toolsets.extend(load_toolsets_from_file(ts_path)) if allowed_toolsets == "*": matching_toolsets = all_toolsets else: - matching_toolsets = get_matching_toolsets( - all_toolsets, allowed_toolsets.split(",") - ) + matching_toolsets = get_matching_toolsets(all_toolsets, allowed_toolsets.split(",")) enabled_toolsets = [ts for ts in matching_toolsets if ts.is_enabled()] for ts in all_toolsets: if ts not in matching_toolsets: - console.print( - f"[yellow]Disabling toolset {ts.name} [/yellow] from {ts.get_path()}" - ) + console.print(f"[yellow]Disabling toolset {ts.name} [/yellow] from {ts.get_path()}") elif ts not in enabled_toolsets: - console.print( - f"[yellow]Not loading toolset {ts.name}[/yellow] ({ts.get_disabled_reason()})" - ) - #console.print(f"[red]The following tools will be disabled: {[t.name for t in ts.tools]}[/red])") + console.print(f"[yellow]Not loading toolset {ts.name}[/yellow] ({ts.get_disabled_reason()})") + # console.print(f"[red]The following tools will be disabled: {[t.name for t in ts.tools]}[/red])") else: logging.debug(f"Loaded toolset {ts.name} from {ts.get_path()}") # console.print(f"[green]Loaded to olset {ts.name}[/green] from {ts.get_path()}") enabled_tools = concat(*[ts.tools for ts in enabled_toolsets]) - logging.debug( - f"Starting AI session with tools: {[t.name for t in enabled_tools]}" - ) + logging.debug(f"Starting AI session with tools: {[t.name for t in enabled_tools]}") return YAMLToolExecutor(enabled_toolsets)