diff --git a/holmes.py b/holmes.py index eb11bfb7..53d9f71c 100644 --- a/holmes.py +++ b/holmes.py @@ -2,18 +2,19 @@ # add_custom_certificate("cert goes here as a string (not path to the cert rather the cert itself)") import logging -import re import warnings from pathlib import Path -from typing import List, Optional, Pattern -import json +from typing import List, Optional + import typer from rich.console import Console from rich.logging import RichHandler from rich.markdown import Markdown from rich.rule import Rule + from holmes.utils.file_utils import write_json_file -from holmes.config import Config, LLMType +from holmes.config import BaseLLMConfig, LLMProviderType +from holmes.core.provider import LLMProviderFactory from holmes.plugins.destinations import DestinationType from holmes.plugins.prompts import load_prompt from holmes.plugins.sources.opsgenie import OPSGENIE_TEAM_INTEGRATION_KEY_HELP @@ -28,22 +29,12 @@ ) app.add_typer(investigate_app, name="investigate") -def init_logging(verbose = False): - logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO, format="%(message)s", handlers=[RichHandler(show_level=False, show_time=False)]) - # disable INFO logs from OpenAI - logging.getLogger("httpx").setLevel(logging.WARNING) - # when running in --verbose mode we don't want to see DEBUG logs from these libraries - logging.getLogger("openai._base_client").setLevel(logging.INFO) - logging.getLogger("httpcore").setLevel(logging.INFO) - logging.getLogger("markdown_it").setLevel(logging.INFO) - # Suppress UserWarnings from the slack_sdk module - warnings.filterwarnings("ignore", category=UserWarning, module="slack_sdk.*") - return Console() # Common cli options -opt_llm: Optional[LLMType] = typer.Option( - LLMType.OPENAI, - help="Which LLM to use ('openai' or 'azure')", +llm_provider_names = ", ".join(str(tp) for tp in LLMProviderType) +opt_llm: Optional[LLMProviderType] = typer.Option( + LLMProviderType.OPENAI, + help=f"LLM provider (supported values: {llm_provider_names})" ) opt_api_key: Optional[str] = typer.Option( None, @@ -111,6 +102,19 @@ def init_logging(verbose = False): system_prompt_help = "Advanced. System prompt for LLM. Values starting with builtin:// are loaded from holmes/plugins/prompts, values starting with file:// are loaded from the given path, other values are interpreted as a prompt string" +def init_logging(verbose = False): + logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO, format="%(message)s", handlers=[RichHandler(show_level=False, show_time=False)]) + # disable INFO logs from OpenAI + logging.getLogger("httpx").setLevel(logging.WARNING) + # when running in --verbose mode we don't want to see DEBUG logs from these libraries + logging.getLogger("openai._base_client").setLevel(logging.INFO) + logging.getLogger("httpcore").setLevel(logging.INFO) + logging.getLogger("markdown_it").setLevel(logging.INFO) + # Suppress UserWarnings from the slack_sdk module + warnings.filterwarnings("ignore", category=UserWarning, module="slack_sdk.*") + return Console() + + # TODO: add interactive interpreter mode # TODO: add streaming output @app.command() @@ -141,7 +145,7 @@ def ask( Ask any question and answer using available tools """ console = init_logging(verbose) - config = Config.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -150,8 +154,9 @@ def ask( max_steps=max_steps, custom_toolsets=custom_toolsets, ) + provider_factory = LLMProviderFactory(config) system_prompt = load_prompt(system_prompt) - ai = config.create_toolcalling_llm(console, allowed_toolsets) + ai = provider_factory.create_toolcalling_llm(console, allowed_toolsets) console.print("[bold yellow]User:[/bold yellow] " + prompt) response = ai.call(system_prompt, prompt) text_result = Markdown(response.result) @@ -160,7 +165,8 @@ def ask( if show_tool_output and response.tool_calls: for tool_call in response.tool_calls: console.print(f"[bold magenta]Used Tool:[/bold magenta]", end="") - # we need to print this separately with markup=False because it contains arbitrary text and we don't want console.print to interpret it + # we need to print this separately with markup=False because it contains arbitrary text + # and we don't want console.print to interpret it console.print(f"{tool_call.description}. Output=\n{tool_call.result}", markup=False) console.print(f"[bold green]AI:[/bold green]", end=" ") console.print(text_result, soft_wrap=True) @@ -180,7 +186,7 @@ def alertmanager( None, help="Password to use for basic auth" ), # common options - llm: Optional[LLMType] = opt_llm, + llm: Optional[LLMProviderType] = opt_llm, api_key: Optional[str] = opt_api_key, azure_endpoint: Optional[str] = opt_azure_endpoint, model: Optional[str] = opt_model, @@ -203,7 +209,7 @@ def alertmanager( Investigate a Prometheus/Alertmanager alert """ console = init_logging(verbose) - config = Config.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -219,14 +225,15 @@ def alertmanager( custom_toolsets=custom_toolsets, custom_runbooks=custom_runbooks ) + provider_factory = LLMProviderFactory(config) system_prompt = load_prompt(system_prompt) - ai = config.create_issue_investigator(console, allowed_toolsets) + ai = provider_factory.create_issue_investigator(console, allowed_toolsets) - source = config.create_alertmanager_source() + source = provider_factory.create_alertmanager_source() if destination == DestinationType.SLACK: - slack = config.create_slack_destination() + slack = provider_factory.create_slack_destination() try: issues = source.fetch_issues() @@ -282,7 +289,7 @@ def jira( False, help="Update Jira with AI results" ), # common options - llm: Optional[LLMType] = opt_llm, + llm: Optional[LLMProviderType] = opt_llm, api_key: Optional[str] = opt_api_key, azure_endpoint: Optional[str] = opt_azure_endpoint, model: Optional[str] = opt_model, @@ -301,7 +308,7 @@ def jira( Investigate a Jira ticket """ console = init_logging(verbose) - config = Config.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -315,10 +322,11 @@ def jira( custom_toolsets=custom_toolsets, custom_runbooks=custom_runbooks ) + provider_factory = LLMProviderFactory(config) system_prompt = load_prompt(system_prompt) - ai = config.create_issue_investigator(console, allowed_toolsets) - source = config.create_jira_source() + ai = provider_factory.create_issue_investigator(console, allowed_toolsets) + source = provider_factory.create_jira_source() try: # TODO: allow passing issue ID issues = source.fetch_issues() @@ -371,7 +379,7 @@ def github( help="Investigate tickets matching a GitHub query (e.g. 'is:issue is:open')", ), # common options - llm: Optional[LLMType] = opt_llm, + llm: Optional[LLMProviderType] = opt_llm, api_key: Optional[str] = opt_api_key, azure_endpoint: Optional[str] = opt_azure_endpoint, model: Optional[str] = opt_model, @@ -390,7 +398,7 @@ def github( Investigate a GitHub issue """ console = init_logging(verbose) - config = Config.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -405,10 +413,11 @@ def github( custom_toolsets=custom_toolsets, custom_runbooks=custom_runbooks ) + provider_factory = LLMProviderFactory(config) system_prompt = load_prompt(system_prompt) - ai = config.create_issue_investigator(console, allowed_toolsets) - source = config.create_github_source() + ai = provider_factory.create_issue_investigator(console, allowed_toolsets) + source = provider_factory.create_github_source() try: issues = source.fetch_issues() except Exception as e: diff --git a/holmes/common/env_vars.py b/holmes/common/env_vars.py index 66d9d463..1487135b 100644 --- a/holmes/common/env_vars.py +++ b/holmes/common/env_vars.py @@ -6,6 +6,7 @@ ROBUSTA_CONFIG_PATH = os.environ.get('ROBUSTA_CONFIG_PATH', "/etc/robusta/config/active_playbooks.yaml") ROBUSTA_ACCOUNT_ID = os.environ.get("ROBUSTA_ACCOUNT_ID", "") +ROBUSTA_USER_ID = os.environ.get("ROBUSTA_USER_ID", "") STORE_URL = os.environ.get("STORE_URL", "") STORE_API_KEY = os.environ.get("STORE_API_KEY", "") STORE_EMAIL = os.environ.get("STORE_EMAIL", "") diff --git a/holmes/config.py b/holmes/config.py index 7a566d67..73fe5394 100644 --- a/holmes/config.py +++ b/holmes/config.py @@ -2,47 +2,23 @@ import os import os.path from strenum import StrEnum -from typing import List, Optional - -from openai import AzureOpenAI, OpenAI -from pydantic import FilePath, SecretStr -from pydash.arrays import concat -from rich.console import Console - -from holmes.core.runbooks import RunbookManager -from holmes.core.tool_calling_llm import (IssueInvestigator, ToolCallingLLM, - YAMLToolExecutor) -from holmes.core.tools import ToolsetPattern, get_matching_toolsets -from holmes.plugins.destinations.slack import SlackDestination -from holmes.plugins.runbooks import (load_builtin_runbooks, - load_runbooks_from_file) -from holmes.plugins.sources.github import GitHubSource -from holmes.plugins.sources.jira import JiraSource -from holmes.plugins.sources.opsgenie import OpsGenieSource -from holmes.plugins.sources.pagerduty import PagerDutySource -from holmes.plugins.sources.prometheus.plugin import AlertManagerSource -from holmes.plugins.toolsets import (load_builtin_toolsets, - load_toolsets_from_file) -from holmes.utils.pydantic_utils import RobustaBaseConfig, load_model_from_file - - -class LLMType(StrEnum): +from typing import Annotated, Any, Dict, List, Optional, get_args, get_origin, get_type_hints + +from pydantic import SecretStr, FilePath +from holmes.utils.pydantic_utils import BaseConfig, EnvVarName, load_model_from_file + + +class LLMProviderType(StrEnum): OPENAI = "openai" AZURE = "azure" + ROBUSTA = "robusta_ai" -class Config(RobustaBaseConfig): - llm: Optional[LLMType] = LLMType.OPENAI - api_key: Optional[SecretStr] = ( - None # if None, read from OPENAI_API_KEY or AZURE_OPENAI_ENDPOINT env var - ) - azure_endpoint: Optional[str] = ( - None # if None, read from AZURE_OPENAI_ENDPOINT env var - ) - azure_api_version: Optional[str] = "2024-02-01" - model: Optional[str] = "gpt-4o" - max_steps: Optional[int] = 10 +class BaseLLMConfig(BaseConfig): + llm_provider: LLMProviderType = LLMProviderType.OPENAI + # FIXME: the following settings do not belong here. They define the + # configuration of specific actions, and not of the LLM provider. alertmanager_url: Optional[str] = None alertmanager_username: Optional[str] = None alertmanager_password: Optional[str] = None @@ -74,198 +50,87 @@ class Config(RobustaBaseConfig): custom_toolsets: List[FilePath] = [] @classmethod - def load_from_env(cls): - kwargs = {"llm": LLMType(os.getenv("HOLMES_LLM", "OPENAI").lower())} - for field_name in [ - "model", - "api_key", - "azure_endpoint", - "max_steps", - "alertmanager_url", - "alertmanager_username", - "alertmanager_password", - "jira_url", - "jira_username", - "jira_api_key", - "jira_query", - "slack_token", - "slack_channel", - "github_url", - "github_owner", - "github_repository", - "github_pat", - "github_query", - # TODO - # custom_runbooks - # custom_toolsets - ]: - val = os.getenv(field_name.upper(), None) - if val is not None: - kwargs[field_name] = val - return cls(**kwargs) - - def create_llm(self) -> OpenAI: - if self.llm == LLMType.OPENAI: - return OpenAI( - api_key=self.api_key.get_secret_value() if self.api_key else None, - ) - elif self.llm == LLMType.AZURE: - return AzureOpenAI( - api_key=self.api_key.get_secret_value() if self.api_key else None, - azure_endpoint=self.azure_endpoint, - api_version=self.azure_api_version, - ) + def _collect_env_vars(cls) -> Dict[str, Any]: + """Collect the environment variables that this class might require for setup. + + Environment variable names are determined from model fields as follows: + - if the model field is not annotated with an EnvVarName, the env var name is + just the model field name in upper case + - if the model field is annotated with an EnvVarName, the env var name is + taken from the annotation. + """ + vars_dict = {} + hints = get_type_hints(cls, include_extras=True) + for field_name in cls.model_fields: + if field_name == "llm_provider": + # Handled in load_from_env + continue + tp_obj = hints[field_name] + if get_origin(tp_obj) is Annotated: + tp_args = get_args(tp_obj) + base_type = tp_args[0] + for arg in tp_args[1:]: + if isinstance(arg, EnvVarName): + env_var_name = arg + break + else: # no EnvVarName(...) in annotations + env_var_name = field_name.upper() + else: + base_type = tp_obj + env_var_name = field_name.upper() + if env_var_name in os.environ: + env_value = os.environ[env_var_name] + if get_origin(base_type) == list: + value = [value.strip() for value in env_value.split(",")] + else: + value = env_value + vars_dict[field_name] = value + return vars_dict + + @classmethod + def load_from_env(cls) -> "BaseLLMConfig": + llm_name = os.environ.get("LLM_PROVIDER", "OPENAI").lower() + llm_provider = LLMProviderType(llm_name) + if llm_provider == LLMProviderType.AZURE: + final_class = AzureLLMConfig + elif llm_provider == LLMProviderType.OPENAI: + final_class = OpenAILLMConfig + elif llm_provider == LLMProviderType.ROBUSTA: + final_class = RobustaLLMConfig else: - raise ValueError(f"Unknown LLM type: {self.llm}") + raise NotImplementedError(f"Unknown LLM {llm_name}") + kwargs = final_class._collect_env_vars() + ret = final_class(**kwargs) + return ret - def _create_tool_executor( - self, console: Console, allowed_toolsets: ToolsetPattern - ) -> YAMLToolExecutor: - all_toolsets = load_builtin_toolsets() - for ts_path in self.custom_toolsets: - all_toolsets.extend(load_toolsets_from_file(ts_path)) - if allowed_toolsets == "*": - matching_toolsets = all_toolsets - else: - matching_toolsets = get_matching_toolsets( - all_toolsets, allowed_toolsets.split(",") - ) - - enabled_toolsets = [ts for ts in matching_toolsets if ts.is_enabled()] - for ts in all_toolsets: - if ts not in matching_toolsets: - console.print( - f"[yellow]Disabling toolset {ts.name} [/yellow] from {ts.get_path()}" - ) - elif ts not in enabled_toolsets: - console.print( - f"[yellow]Not loading toolset {ts.name}[/yellow] ({ts.get_disabled_reason()})" - ) - #console.print(f"[red]The following tools will be disabled: {[t.name for t in ts.tools]}[/red])") - else: - logging.debug(f"Loaded toolset {ts.name} from {ts.get_path()}") - # console.print(f"[green]Loaded toolset {ts.name}[/green] from {ts.get_path()}") - - enabled_tools = concat(*[ts.tools for ts in enabled_toolsets]) - logging.debug( - f"Starting AI session with tools: {[t.name for t in enabled_tools]}" - ) - return YAMLToolExecutor(enabled_toolsets) - - def create_toolcalling_llm( - self, console: Console, allowed_toolsets: ToolsetPattern - ) -> IssueInvestigator: - tool_executor = self._create_tool_executor(console, allowed_toolsets) - return ToolCallingLLM( - self.create_llm(), - self.model, - tool_executor, - self.max_steps, - ) - - def create_issue_investigator( - self, console: Console, allowed_toolsets: ToolsetPattern - ) -> IssueInvestigator: - all_runbooks = load_builtin_runbooks() - for runbook_path in self.custom_runbooks: - all_runbooks.extend(load_runbooks_from_file(runbook_path)) - - runbook_manager = RunbookManager(all_runbooks) - tool_executor = self._create_tool_executor(console, allowed_toolsets) - return IssueInvestigator( - self.create_llm(), - self.model, - tool_executor, - runbook_manager, - self.max_steps, - ) - - def create_jira_source(self) -> JiraSource: - if self.jira_url is None: - raise ValueError("--jira-url must be specified") - if not ( - self.jira_url.startswith("http://") or self.jira_url.startswith("https://") - ): - raise ValueError("--jira-url must start with http:// or https://") - if self.jira_username is None: - raise ValueError("--jira-username must be specified") - if self.jira_api_key is None: - raise ValueError("--jira-api-key must be specified") - - return JiraSource( - url=self.jira_url, - username=self.jira_username, - api_key=self.jira_api_key.get_secret_value(), - jql_query=self.jira_query, - ) - - def create_github_source(self) -> GitHubSource: - if not ( - self.github_url.startswith( - "http://") or self.github_url.startswith("https://") - ): - raise ValueError("--github-url must start with http:// or https://") - if self.github_owner is None: - raise ValueError("--github-owner must be specified") - if self.github_repository is None: - raise ValueError("--github-repository must be specified") - if self.github_pat is None: - raise ValueError("--github-pat must be specified") - - return GitHubSource( - url=self.github_url, - owner=self.github_owner, - pat=self.github_pat.get_secret_value(), - repository=self.github_repository, - query=self.github_query, - ) - - def create_pagerduty_source(self) -> OpsGenieSource: - if self.pagerduty_api_key is None: - raise ValueError("--pagerduty-api-key must be specified") - - return PagerDutySource( - api_key=self.pagerduty_api_key.get_secret_value(), - user_email=self.pagerduty_user_email, - incident_key=self.pagerduty_incident_key, - ) - - def create_opsgenie_source(self) -> OpsGenieSource: - if self.opsgenie_api_key is None: - raise ValueError("--opsgenie-api-key must be specified") - - return OpsGenieSource( - api_key=self.opsgenie_api_key.get_secret_value(), - query=self.opsgenie_query, - team_integration_key=self.opsgenie_team_integration_key.get_secret_value() if self.opsgenie_team_integration_key else None, - ) - - def create_alertmanager_source(self) -> AlertManagerSource: - if self.alertmanager_url is None: - raise ValueError("--alertmanager-url must be specified") - if not ( - self.alertmanager_url.startswith("http://") - or self.alertmanager_url.startswith("https://") - ): - raise ValueError("--alertmanager-url must start with http:// or https://") - - return AlertManagerSource( - url=self.alertmanager_url, - username=self.alertmanager_username, - password=self.alertmanager_password, - alertname=self.alertmanager_alertname, - ) - - def create_slack_destination(self): - if self.slack_token is None: - raise ValueError("--slack-token must be specified") - if self.slack_channel is None: - raise ValueError("--slack-channel must be specified") - return SlackDestination(self.slack_token.get_secret_value(), self.slack_channel) +class BaseOpenAIConfig(BaseLLMConfig): + model: Optional[str] = "gpt-4o" + max_steps: Optional[int] = 10 + + +class OpenAILLMConfig(BaseOpenAIConfig): + api_key: Optional[SecretStr] + + +class AzureLLMConfig(BaseOpenAIConfig): + api_key: Optional[SecretStr] + endpoint: Optional[str] + azure_api_version: Optional[str] = "2024-02-01" + + +class RobustaLLMConfig(BaseLLMConfig): + url: Annotated[str, EnvVarName("ROBUSTA_AI_URL")] + + +# TODO refactor everything below + + +class LLMConfig(BaseLLMConfig): @classmethod - def load_from_file(cls, config_file: Optional[str], **kwargs) -> "Config": + def load_from_file(cls, config_file: Optional[str], **kwargs) -> "BaseLLMConfig": + # FIXME! if config_file is not None: logging.debug("Loading config from file %s", config_file) config_from_file = load_model_from_file(cls, config_file) @@ -283,7 +148,26 @@ def load_from_file(cls, config_file: Optional[str], **kwargs) -> "Config": merged_config = config_from_file.dict() # remove Nones to avoid overriding config file with empty cli args cli_overrides = { - k: v for k, v in config_from_cli.dict().items() if v is not None and v != [] + k: v for k, v in config_from_cli.model_dump().items() if v is not None and v != [] } merged_config.update(cli_overrides) return cls(**merged_config) + + +class BaseOpenAIConfig(BaseLLMConfig): + model: Annotated[Optional[str], EnvVarName("AI_MODEL")] = "gpt-4o" + max_steps: Optional[int] = 10 + + +class OpenAILLMConfig(BaseOpenAIConfig): + api_key: Annotated[Optional[SecretStr], EnvVarName("OPENAI_API_KEY")] + + +class AzureLLMConfig(BaseOpenAIConfig): + api_key: Annotated[Optional[SecretStr], EnvVarName("AZURE_API_KEY")] + endpoint: Annotated[Optional[str], EnvVarName("AZURE_ENDPOINT")] + azure_api_version: Optional[str] = "2024-02-01" + + +class RobustaLLMConfig(BaseOpenAIConfig): + url: Annotated[str, EnvVarName("ROBUSTA_AI_URL")] diff --git a/holmes/core/provider.py b/holmes/core/provider.py new file mode 100644 index 00000000..9d11161d --- /dev/null +++ b/holmes/core/provider.py @@ -0,0 +1,183 @@ +import logging + +from openai import AzureOpenAI, OpenAI +from pydash.arrays import concat +from rich.console import Console + +from holmes.config import BaseLLMConfig, LLMProviderType +from holmes.core.robusta_ai import RobustaAIToolCallingLLM, RobustaIssueInvestigator +from holmes.core.runbooks import RunbookManager +from holmes.core.tool_calling_llm import ( + BaseIssueInvestigator, + BaseToolCallingLLM, + OpenAIIssueInvestigator, + OpenAIToolCallingLLM, + YAMLToolExecutor, +) +from holmes.core.tools import ToolsetPattern, get_matching_toolsets +from holmes.plugins.destinations.slack import SlackDestination +from holmes.plugins.runbooks import load_builtin_runbooks, load_runbooks_from_file +from holmes.plugins.sources.jira import JiraSource +from holmes.plugins.sources.github import GitHubSource +from holmes.plugins.sources.opsgenie import OpsGenieSource +from holmes.plugins.sources.pagerduty import PagerDutySource +from holmes.plugins.sources.prometheus.plugin import AlertManagerSource +from holmes.plugins.toolsets import load_builtin_toolsets, load_toolsets_from_file +from holmes.utils.auth import SessionManager + + +class LLMProviderFactory: + def __init__(self, config: BaseLLMConfig, session_manager: SessionManager = None): + self.config = config + self.session_manager = session_manager + + def create_llm(self) -> OpenAI: + if self.config.llm_provider == LLMProviderType.OPENAI: + return OpenAI( + api_key=(self.config.api_key.get_secret_value() if self.config.api_key else None), + ) + elif self.config.llm_provider == LLMProviderType.AZURE: + return AzureOpenAI( + api_key=(self.config.api_key.get_secret_value() if self.config.api_key else None), + azure_endpoint=self.config.azure_endpoint, + api_version=self.config.azure_api_version, + ) + else: + raise ValueError(f"Unknown LLM type: {self.config.llm_provider}") + + def create_toolcalling_llm(self, console: Console, allowed_toolsets: ToolsetPattern) -> BaseToolCallingLLM: + if self.config.llm_provider in [LLMProviderType.OPENAI, LLMProviderType.AZURE]: + tool_executor = self._create_tool_executor(console, allowed_toolsets) + return OpenAIToolCallingLLM( + self.create_llm(), + self.config.model, + tool_executor, + self.config.max_steps, + ) + else: + # TODO in the future + return RobustaAIToolCallingLLM() + + def create_issue_investigator(self, console: Console, allowed_toolsets: ToolsetPattern) -> BaseIssueInvestigator: + all_runbooks = load_builtin_runbooks() + for runbook_path in self.config.custom_runbooks: + all_runbooks.extend(load_runbooks_from_file(runbook_path)) + runbook_manager = RunbookManager(all_runbooks) + + if self.config.llm_provider == LLMProviderType.ROBUSTA: + return RobustaIssueInvestigator(self.config.url, self.session_manager, runbook_manager) + else: + tool_executor = self._create_tool_executor(console, allowed_toolsets) + return OpenAIIssueInvestigator( + self.create_llm(), + self.config.model, + tool_executor, + runbook_manager, + self.config.max_steps, + ) + + def create_jira_source(self) -> JiraSource: + if self.config.jira_url is None: + raise ValueError("--jira-url must be specified") + if not (self.config.jira_url.startswith("http://") or self.config.jira_url.startswith("https://")): + raise ValueError("--jira-url must start with http:// or https://") + if self.config.jira_username is None: + raise ValueError("--jira-username must be specified") + if self.config.jira_api_key is None: + raise ValueError("--jira-api-key must be specified") + + return JiraSource( + url=self.config.jira_url, + username=self.config.jira_username, + api_key=self.config.jira_api_key.get_secret_value(), + jql_query=self.config.jira_query, + ) + + def create_github_source(self) -> GitHubSource: + if not (self.config.github_url.startswith("http://") or self.config.github_url.startswith("https://")): + raise ValueError("--github-url must start with http:// or https://") + if self.config.github_owner is None: + raise ValueError("--github-owner must be specified") + if self.config.github_repository is None: + raise ValueError("--github-repository must be specified") + if self.config.github_pat is None: + raise ValueError("--github-pat must be specified") + + return GitHubSource( + url=self.config.github_url, + owner=self.config.github_owner, + pat=self.config.github_pat.get_secret_value(), + repository=self.config.github_repository, + query=self.config.github_query, + ) + + def create_pagerduty_source(self) -> PagerDutySource: + if self.config.pagerduty_api_key is None: + raise ValueError("--pagerduty-api-key must be specified") + + return PagerDutySource( + api_key=self.config.pagerduty_api_key.get_secret_value(), + user_email=self.config.pagerduty_user_email, + incident_key=self.config.pagerduty_incident_key, + ) + + def create_opsgenie_source(self) -> OpsGenieSource: + if self.config.opsgenie_api_key is None: + raise ValueError("--opsgenie-api-key must be specified") + + return OpsGenieSource( + api_key=self.config.opsgenie_api_key.get_secret_value(), + query=self.config.opsgenie_query, + team_integration_key=( + self.config.opsgenie_team_integration_key.get_secret_value() + if self.config.opsgenie_team_integration_key + else None + ), + ) + + def create_alertmanager_source(self) -> AlertManagerSource: + if self.config.alertmanager_url is None: + raise ValueError("--alertmanager-url must be specified") + if not ( + self.config.alertmanager_url.startswith("http://") or self.config.alertmanager_url.startswith("https://") + ): + raise ValueError("--alertmanager-url must start with http:// or https://") + + return AlertManagerSource( + url=self.config.alertmanager_url, + username=self.config.alertmanager_username, + password=self.config.alertmanager_password, + alertname=self.config.alertmanager_alertname, + ) + + def create_slack_destination(self): + if self.config.slack_token is None: + raise ValueError("--slack-token must be specified") + if self.config.slack_channel is None: + raise ValueError("--slack-channel must be specified") + return SlackDestination(self.config.slack_token.get_secret_value(), self.config.slack_channel) + + def _create_tool_executor(self, console: Console, allowed_toolsets: ToolsetPattern) -> YAMLToolExecutor: + all_toolsets = load_builtin_toolsets() + for ts_path in self.config.custom_toolsets: + all_toolsets.extend(load_toolsets_from_file(ts_path)) + + if allowed_toolsets == "*": + matching_toolsets = all_toolsets + else: + matching_toolsets = get_matching_toolsets(all_toolsets, allowed_toolsets.split(",")) + + enabled_toolsets = [ts for ts in matching_toolsets if ts.is_enabled()] + for ts in all_toolsets: + if ts not in matching_toolsets: + console.print(f"[yellow]Disabling toolset {ts.name} [/yellow] from {ts.get_path()}") + elif ts not in enabled_toolsets: + console.print(f"[yellow]Not loading toolset {ts.name}[/yellow] ({ts.get_disabled_reason()})") + # console.print(f"[red]The following tools will be disabled: {[t.name for t in ts.tools]}[/red])") + else: + logging.debug(f"Loaded toolset {ts.name} from {ts.get_path()}") + # console.print(f"[green]Loaded to olset {ts.name}[/green] from {ts.get_path()}") + + enabled_tools = concat(*[ts.tools for ts in enabled_toolsets]) + logging.debug(f"Starting AI session with tools: {[t.name for t in enabled_tools]}") + return YAMLToolExecutor(enabled_toolsets) diff --git a/holmes/core/robusta_ai.py b/holmes/core/robusta_ai.py new file mode 100644 index 00000000..eb4883e2 --- /dev/null +++ b/holmes/core/robusta_ai.py @@ -0,0 +1,59 @@ +import logging + +import requests + +from holmes.core.runbooks import RunbookManager +from holmes.core.tool_calling_llm import BaseIssueInvestigator, BaseToolCallingLLM, LLMError, LLMResult +from holmes.utils.auth import SessionManager + + +class RobustaAICallError(LLMError): + pass + + +class RobustaAIToolCallingLLM(BaseToolCallingLLM): + def __init__(self): + raise NotImplementedError("Robusta AI tool calling LLM is not supported yet") + + +class RobustaIssueInvestigator(BaseIssueInvestigator): + def __init__(self, base_url: str, session_manager: SessionManager, runbook_manager: RunbookManager): + self.base_url = base_url + self.session_manager = session_manager + self.runbook_manager = runbook_manager + + def call(self, system_prompt: str, user_prompt: str) -> LLMResult: + auth_token = self.session_manager.get_current_token() + if auth_token is None: + auth_token = self.session_manager.create_token() + + payload = { + "auth": {"account_id": auth_token.account_id, "token": auth_token.token}, + "body": { + "system_message": system_prompt, + "user_message": user_prompt, +# TODO? +# "model": request.model, + }, + } + try: + resp = requests.post(f"{self.base_url}/api/ai", json=payload) + except: + logging.exception("Robusta AI API call failed") + raise RobustaAICallError("Robusta AI API call failed") + if resp.status_code == 401: + self.session_manager.invalidate_token(auth_token) + # Attempt auth again using a fresh token + auth_token = self.session_manager.create_token() + payload["auth"]["account_id"] = auth_token.account_id + payload["auth"]["token"] = auth_token.token + resp = requests.post(self.base_url + "/api/ai", json=payload) + if resp.status_code != 200: + logging.error( + f"Failed to reauth with Robusta AI. Response status {resp.status_code}, content: {resp.text}" + ) + raise RobustaAICallError("Unable to auth with Robusta AI") + resp_data = resp.json() + if not resp_data["success"]: + raise RobustaAICallError("Robusta AI API call failed") + return LLMResult(result=resp_data["msg"], prompt=user_prompt) diff --git a/holmes/core/server_models.py b/holmes/core/server_models.py new file mode 100644 index 00000000..dc5b332e --- /dev/null +++ b/holmes/core/server_models.py @@ -0,0 +1,21 @@ +from typing import Union, List + +from pydantic import BaseModel + + +class InvestigateContext(BaseModel): + type: str + value: Union[str, dict] + + +class InvestigateRequest(BaseModel): + source: str # "prometheus" etc + title: str + description: str + subject: dict + context: List[InvestigateContext] + source_instance_id: str + model: str = "gpt-4o" + system_prompt: str = "builtin://generic_investigation.jinja2" + # TODO in the future + # response_handler: ... diff --git a/holmes/core/supabase_dal.py b/holmes/core/supabase_dal.py index 0378da10..95e6bc95 100644 --- a/holmes/core/supabase_dal.py +++ b/holmes/core/supabase_dal.py @@ -2,21 +2,32 @@ import json import logging import os +from datetime import datetime from typing import Dict, Optional, List +from uuid import uuid4 import yaml from supabase import create_client from supabase.lib.client_options import ClientOptions from pydantic import BaseModel -from holmes.common.env_vars import (ROBUSTA_CONFIG_PATH, ROBUSTA_ACCOUNT_ID, STORE_URL, STORE_API_KEY, STORE_EMAIL, - STORE_PASSWORD) +from holmes.common.env_vars import ( + ROBUSTA_ACCOUNT_ID, + ROBUSTA_CONFIG_PATH, + ROBUSTA_USER_ID, + STORE_API_KEY, + STORE_EMAIL, + STORE_PASSWORD, + STORE_URL, +) SUPABASE_TIMEOUT_SECONDS = int(os.getenv("SUPABASE_TIMEOUT_SECONDS", 3600)) ISSUES_TABLE = "Issues" EVIDENCE_TABLE = "Evidence" +TOKENS_TABLE = "AuthTokens" +ACCOUNT_USERS_TABLE = "AccountUsers" class RobustaConfig(BaseModel): @@ -27,21 +38,33 @@ class RobustaToken(BaseModel): store_url: str api_key: str account_id: str + user_id: str email: str password: str +class AuthToken(BaseModel): + account_id: str + user_id: str + token: str + type: str + deleted: bool = False + created_at: datetime = None + + class SupabaseDal: def __init__(self): self.enabled = self.__init_config() if not self.enabled: logging.info("Robusta store initialization parameters not provided. skipping") + self.initialized = False return - logging.info(f"Initializing robusta store for account {self.account_id}") + logging.info(f"Initializing robusta store for account {self.account_id}, user {self.user_id}") options = ClientOptions(postgrest_client_timeout=SUPABASE_TIMEOUT_SECONDS) self.client = create_client(self.url, self.api_key, options) self.sign_in() + self.initialized = True @staticmethod def __load_robusta_config() -> Optional[RobustaToken]: @@ -67,24 +90,48 @@ def __init_config(self) -> bool: robusta_token = self.__load_robusta_config() if robusta_token: self.account_id = robusta_token.account_id + self.user_id = robusta_token.user_id self.url = robusta_token.store_url self.api_key = robusta_token.api_key self.email = robusta_token.email self.password = robusta_token.password else: self.account_id = ROBUSTA_ACCOUNT_ID + self.user_id = ROBUSTA_USER_ID self.url = STORE_URL self.api_key = STORE_API_KEY self.email = STORE_EMAIL self.password = STORE_PASSWORD # valid only if all store parameters are provided - return all([self.account_id, self.url, self.api_key, self.email, self.password]) + return self.check_settings() + + def check_settings(self): + unset_attrs = [] + for attr_name in [ + "account_id", + "user_id", + "url", + "api_key", + "email", + "password", + ]: + if not getattr(self, attr_name, None): + unset_attrs.append(attr_name) + if unset_attrs: + logging.warning(f"Unset store config variables: {', '.join(unset_attrs)}") + return False + else: + return True def sign_in(self): logging.info("Supabase DAL login") - res = self.client.auth.sign_in_with_password({"email": self.email, "password": self.password}) - self.client.auth.set_session(res.session.access_token, res.session.refresh_token) + res = self.client.auth.sign_in_with_password( + {"email": self.email, "password": self.password} + ) + self.client.auth.set_session( + res.session.access_token, res.session.refresh_token + ) self.client.postgrest.auth(res.session.access_token) def get_issue_data(self, issue_id: str) -> Optional[Dict]: @@ -96,11 +143,10 @@ def get_issue_data(self, issue_id: str) -> Optional[Dict]: issue_data = None try: issue_response = ( - self.client - .table(ISSUES_TABLE) - .select(f"*") - .filter("id", "eq", issue_id) - .execute() + self.client.table(ISSUES_TABLE) + .select("*") + .filter("id", "eq", issue_id) + .execute() ) if len(issue_response.data): issue_data = issue_response.data[0] @@ -111,11 +157,51 @@ def get_issue_data(self, issue_id: str) -> Optional[Dict]: if not issue_data: return None evidence = ( - self.client - .table(EVIDENCE_TABLE) - .select(f"*") + self.client.table(EVIDENCE_TABLE) + .select("*") .filter("issue_id", "eq", issue_id) .execute() ) issue_data["evidence"] = evidence.data - return issue_data \ No newline at end of file + return issue_data + + def create_auth_token(self, token_type: str) -> AuthToken: + result = ( + self.client.table(TOKENS_TABLE) + .insert( + { + "account_id": self.account_id, + "user_id": self.user_id, + "token": str(uuid4()), + "type": token_type, + } + ) + .execute() + ) + return AuthToken(**result.data[0]) + + def get_freshest_auth_token(self, token_type: str) -> Optional[AuthToken]: + result = ( + self.client.table(TOKENS_TABLE) + .select("*") + .filter("type", "eq", token_type) + .filter("deleted", "eq", False) + .order("created_at", desc=True) + .limit(1) + .execute() + ) + if not result.data: + return None + else: + return AuthToken(**result.data[0]) + + def invalidate_auth_token(self, token: AuthToken) -> None: + ( + self.client.table(TOKENS_TABLE) + .update({"deleted": True}) + .eq("account_id", token.account_id) + .eq("user_id", token.user_id) + .eq("token", token.token) + .eq("type", token.type) + ) + # TODO maybe handle errors such as non-existent tokens? diff --git a/holmes/core/tool_calling_llm.py b/holmes/core/tool_calling_llm.py index 6e89984d..09c2f8db 100644 --- a/holmes/core/tool_calling_llm.py +++ b/holmes/core/tool_calling_llm.py @@ -1,14 +1,12 @@ -import datetime +import abc import json import logging import textwrap -from typing import Dict, Generator, List, Optional +from typing import List, Optional import jinja2 from openai import BadRequestError, OpenAI from openai._types import NOT_GIVEN -from openai.types.chat.chat_completion import Choice -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from pydantic import BaseModel from rich.console import Console @@ -17,11 +15,16 @@ from holmes.core.tools import YAMLToolExecutor +class LLMError(Exception): + pass + + class ToolCallResult(BaseModel): tool_name: str description: str result: str + class LLMResult(BaseModel): tool_calls: Optional[List[ToolCallResult]] = None result: Optional[str] = None @@ -33,8 +36,14 @@ def get_tool_usage_summary(self): [f"`{tool_call.description}`" for tool_call in self.tool_calls] ) -class ToolCallingLLM: +class BaseToolCallingLLM(abc.ABC): + @abc.abstractmethod + def call(self, system_prompt: str, user_prompt: str) -> LLMResult: + raise NotImplementedError + + +class OpenAIToolCallingLLM(BaseToolCallingLLM): def __init__( self, client: OpenAI, @@ -47,7 +56,7 @@ def __init__( self.max_steps = max_steps self.model = model - def call(self, system_prompt, user_prompt) -> LLMResult: + def call(self, system_prompt: str, user_prompt: str) -> LLMResult: messages = [ { "role": "system", @@ -67,11 +76,7 @@ def call(self, system_prompt, user_prompt) -> LLMResult: tool_choice = NOT_GIVEN if tools == NOT_GIVEN else "auto" try: full_response = self.client.chat.completions.create( - model=self.model, - messages=messages, - tools=tools, - tool_choice=tool_choice, - temperature=0.00000001 + model=self.model, messages=messages, tools=tools, tool_choice=tool_choice, temperature=0.00000001 ) logging.debug(f"got response {full_response}") # catch a known error that occurs with Azure and replace the error message with something more obvious to the user @@ -84,11 +89,7 @@ def call(self, system_prompt, user_prompt) -> LLMResult: raise response = full_response.choices[0] response_message = response.message - messages.append( - response_message.model_dump( - exclude_defaults=True, exclude_unset=True, exclude_none=True - ) - ) + messages.append(response_message.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)) tools_to_call = response_message.tool_calls if not tools_to_call: @@ -107,9 +108,11 @@ def call(self, system_prompt, user_prompt) -> LLMResult: tool_params = json.loads(t.function.arguments) tool = self.tool_executor.get_tool_by_name(tool_name) tool_response = tool.invoke(tool_params) - MAX_CHARS = 100_000 # an arbitrary limit - we will do something smarter in the future + MAX_CHARS = 100_000 # an arbitrary limit - we will do something smarter in the future if len(tool_response) > MAX_CHARS: - logging.warning(f"tool {tool_name} returned a very long response ({len(tool_response)} chars) - truncating to last 10000 chars") + logging.warning( + f"tool {tool_name} returned a very long response ({len(tool_response)} chars) - truncating to last 10000 chars" + ) tool_response = tool_response[-MAX_CHARS:] messages.append( { @@ -127,8 +130,30 @@ def call(self, system_prompt, user_prompt) -> LLMResult: ) ) -# TODO: consider getting rid of this entirely and moving templating into the cmds in holmes.py -class IssueInvestigator(ToolCallingLLM): + +# TODO: consider getting rid of this entirely and moving templating into the cmds in holmes.py +class BaseIssueInvestigator: + def call(self, system_prompt: str, user_prompt: str) -> LLMResult: + raise NotImplementedError() + + def investigate(self, issue: Issue, prompt: str, console: Console) -> LLMResult: + environment = jinja2.Environment() + system_prompt_template = environment.from_string(prompt) + runbooks = self.runbook_manager.get_instructions_for_issue(issue) + if runbooks: + console.print(f"[bold]Analyzing with {len(runbooks)} runbooks: {runbooks}[/bold]") + else: + console.print( + f"[bold]No runbooks found for this issue. Using default behaviour. (Add runbooks to guide the investigation.)[/bold]" + ) + system_prompt = system_prompt_template.render(issue=issue, runbooks=runbooks) + user_prompt = f"{issue.raw}" + logging.debug("Rendered system prompt:\n%s", textwrap.indent(system_prompt, " ")) + logging.debug("Rendered user prompt:\n%s", textwrap.indent(user_prompt, " ")) + return self.call(system_prompt, user_prompt) + + +class OpenAIIssueInvestigator(BaseIssueInvestigator, OpenAIToolCallingLLM): """ Thin wrapper around ToolCallingLLM which: 1) Provides a default prompt for RCA @@ -146,27 +171,3 @@ def __init__( ): super().__init__(client, model, tool_executor, max_steps) self.runbook_manager = runbook_manager - - def investigate( - self, issue: Issue, prompt: str, console: Console - ) -> LLMResult: - environment = jinja2.Environment() - system_prompt_template = environment.from_string(prompt) - runbooks = self.runbook_manager.get_instructions_for_issue(issue) - if runbooks: - console.print( - f"[bold]Analyzing with {len(runbooks)} runbooks: {runbooks}[/bold]" - ) - else: - console.print( - f"[bold]No runbooks found for this issue. Using default behaviour. (Add runbooks to guide the investigation.)[/bold]" - ) - system_prompt = system_prompt_template.render(issue=issue, runbooks=runbooks) - user_prompt = f"{issue.raw}" - logging.debug( - "Rendered system prompt:\n%s", textwrap.indent(system_prompt, " ") - ) - logging.debug( - "Rendered user prompt:\n%s", textwrap.indent(user_prompt, " ") - ) - return self.call(system_prompt, user_prompt) diff --git a/holmes/plugins/runbooks/__init__.py b/holmes/plugins/runbooks/__init__.py index 242c34a6..7b5d51e9 100644 --- a/holmes/plugins/runbooks/__init__.py +++ b/holmes/plugins/runbooks/__init__.py @@ -1,20 +1,21 @@ import os import os.path -from typing import List, Literal, Optional, Pattern, Union +from typing import List, Optional, Pattern -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, PrivateAttr -from holmes.utils.pydantic_utils import RobustaBaseConfig, load_model_from_file +from holmes.utils.pydantic_utils import BaseConfig, load_model_from_file THIS_DIR = os.path.abspath(os.path.dirname(__file__)) -class IssueMatcher (RobustaBaseConfig): + +class IssueMatcher(BaseConfig): issue_id: Optional[Pattern] = None # unique id issue_name: Optional[Pattern] = None # not necessary unique source: Optional[Pattern] = None -class Runbook(RobustaBaseConfig): +class Runbook(BaseConfig): match: IssueMatcher instructions: str @@ -26,15 +27,18 @@ def set_path(self, path: str): def get_path(self) -> str: return self._path + class ListOfRunbooks(BaseModel): runbooks: List[Runbook] + def load_runbooks_from_file(path: str) -> List[Runbook]: data: ListOfRunbooks = load_model_from_file(ListOfRunbooks, file_path=path) for runbook in data.runbooks: runbook.set_path(path) return data.runbooks + def load_builtin_runbooks() -> List[Runbook]: all_runbooks = [] for filename in os.listdir(THIS_DIR): diff --git a/holmes/plugins/sources/jira/__init__.py b/holmes/plugins/sources/jira/__init__.py index 764354cd..d25915ee 100644 --- a/holmes/plugins/sources/jira/__init__.py +++ b/holmes/plugins/sources/jira/__init__.py @@ -1,15 +1,12 @@ import logging -from typing import List, Literal, Optional, Pattern +from typing import List, Pattern -import humanize import requests -from pydantic import BaseModel, SecretStr, ValidationError, parse_obj_as, validator from requests.auth import HTTPBasicAuth from holmes.core.issue import Issue -from holmes.core.tool_calling_llm import LLMResult, ToolCallingLLM, ToolCallResult +from holmes.core.tool_calling_llm import LLMResult from holmes.plugins.interfaces import SourcePlugin -from holmes.plugins.utils import dict_to_markdown class JiraSource(SourcePlugin): diff --git a/holmes/utils/auth.py b/holmes/utils/auth.py new file mode 100644 index 00000000..f0699e7a --- /dev/null +++ b/holmes/utils/auth.py @@ -0,0 +1,25 @@ +# from cachetools import cached +from typing import Optional + +from holmes.core.supabase_dal import AuthToken, SupabaseDal + + +class SessionManager: + def __init__(self, dal: SupabaseDal, token_type: str): + self.dal = dal + self.token_type = token_type + self.cached_token: Optional[AuthToken] = None + + def get_current_token(self) -> AuthToken: + if self.cached_token: + return self.cached_token + else: + return self.dal.get_freshest_auth_token(self.token_type) + + def create_token(self) -> AuthToken: + new_token = self.dal.create_auth_token(self.token_type) + self.cached_token = new_token + return new_token + + def invalidate_token(self, token: AuthToken): + self.dal.invalidate_auth_token(token) diff --git a/holmes/utils/pydantic_utils.py b/holmes/utils/pydantic_utils.py index 2ebe007e..cd7c0263 100644 --- a/holmes/utils/pydantic_utils.py +++ b/holmes/utils/pydantic_utils.py @@ -10,9 +10,15 @@ PromptField = Annotated[str, BeforeValidator(lambda v: load_prompt(v))] -class RobustaBaseConfig(BaseModel): + +class BaseConfig(BaseModel): model_config = ConfigDict(extra='forbid', validate_default=True) + +class EnvVarName(str): + pass + + def loc_to_dot_sep(loc: Tuple[Union[str, int], ...]) -> str: path = "" for i, x in enumerate(loc): @@ -43,7 +49,7 @@ def load_model_from_file( contents = contents[yaml_path] return model.model_validate(contents) except ValidationError as e: - print(e) + print(e) # FIXME bad_fields = [e["loc"] for e in convert_errors(e)] typer.secho( f"Invalid config file at {file_path}. Check the fields {bad_fields}.\nSee detailed errors above.", diff --git a/server.py b/server.py index a107bb85..4af77081 100644 --- a/server.py +++ b/server.py @@ -1,4 +1,6 @@ import os +import sys + from holmes.utils.cert_utils import add_custom_certificate ADDITIONAL_CERTIFICATE: str = os.environ.get("CERTIFICATE", "") @@ -9,39 +11,22 @@ # IMPORTING ABOVE MIGHT INITIALIZE AN HTTPS CLIENT THAT DOESN'T TRUST THE CUSTOM CERTIFICATEE import logging -import uvicorn -import colorlog - from typing import List, Union -from fastapi import FastAPI -from pydantic import BaseModel +import colorlog +import uvicorn +from fastapi import FastAPI, HTTPException from rich.console import Console -from holmes.common.env_vars import HOLMES_HOST, HOLMES_PORT, ALLOWED_TOOLSETS -from holmes.core.supabase_dal import SupabaseDal -from holmes.config import Config +from holmes.common.env_vars import ALLOWED_TOOLSETS, HOLMES_HOST, HOLMES_PORT +from holmes.config import BaseLLMConfig, LLMProviderType from holmes.core.issue import Issue +from holmes.core.provider import LLMProviderFactory +from holmes.core.server_models import InvestigateContext, InvestigateRequest +from holmes.core.supabase_dal import SupabaseDal +from holmes.core.tool_calling_llm import LLMError from holmes.plugins.prompts import load_prompt - - -class InvestigateContext(BaseModel): - type: str - value: Union[str, dict] - - -class InvestigateRequest(BaseModel): - source: str # "prometheus" etc - title: str - description: str - subject: dict - context: List[InvestigateContext] - source_instance_id: str - include_tool_calls: bool = False - include_tool_call_results: bool = False - prompt_template: str = "builtin://generic_investigation.jinja2" - # TODO in the future - # response_handler: ... +from holmes.utils.auth import SessionManager def init_logging(): @@ -49,8 +34,9 @@ def init_logging(): logging_format = "%(log_color)s%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s" logging_datefmt = "%Y-%m-%d %H:%M:%S" - print("setting up colored logging") - colorlog.basicConfig(format=logging_format, level=logging_level, datefmt=logging_datefmt) + colorlog.basicConfig( + format=logging_format, level=logging_level, datefmt=logging_datefmt + ) logging.getLogger().setLevel(logging_level) httpx_logger = logging.getLogger("httpx") @@ -61,34 +47,52 @@ def init_logging(): init_logging() +console = Console() +config = BaseLLMConfig.load_from_env() +logging.info(f"Starting AI server with config: {config}") dal = SupabaseDal() + +if not dal.initialized and config.llm_provider == LLMProviderType.ROBUSTA: + logging.error("Holmes cannot run without store configuration when the LLM provider is Robusta AI") + sys.exit(1) +session_manager = SessionManager(dal, "AIRelay") +provider_factory = LLMProviderFactory(config, session_manager) app = FastAPI() -console = Console() -config = Config.load_from_env() + +def fetch_context_data(context: List[InvestigateContext]) -> dict: + for context_item in context: + if context_item.type == "robusta_issue_id": + # Note we only accept a single robusta_issue_id. I don't think it + # makes sense to have several of them in the context structure. + return dal.get_issue_data(context_item.value) @app.post("/api/investigate") -def investigate_issues(request: InvestigateRequest): +def investigate_issue(request: InvestigateRequest): context = fetch_context_data(request.context) raw_data = request.model_dump() + raw_data.pop("model") + raw_data.pop("system_prompt") if context: raw_data["extra_context"] = context - ai = config.create_issue_investigator(console, allowed_toolsets=ALLOWED_TOOLSETS) issue = Issue( - id=context['id'] if context else "", + id=context["id"] if context else "", name=request.title, source_type=request.source, source_instance_id=request.source_instance_id, raw=raw_data, ) - investigation = ai.investigate( - issue, - # TODO prompt should probably be configurable? - prompt=load_prompt(request.prompt), - console=console, - ) + investigator = provider_factory.create_issue_investigator(console, allowed_toolsets=ALLOWED_TOOLSETS) + try: + investigation = investigator.investigate( + issue, + prompt=load_prompt(request.system_prompt), + console=console, + ) + except LLMError as exc: + raise HTTPException(status_code=500, detail=f"Error calling the LLM provider: {str(exc)}") ret = { "analysis": investigation.result } @@ -105,12 +109,5 @@ def investigate_issues(request: InvestigateRequest): return ret -def fetch_context_data(context: List[InvestigateContext]) -> dict: - for context_item in context: - if context_item.type == "robusta_issue_id": - # Note we only accept a single robusta_issue_id. I don't think it - # makes sense to have several of them in the context structure. - return dal.get_issue_data(context_item.value) - if __name__ == "__main__": - uvicorn.run(app, host=HOLMES_HOST, port=HOLMES_PORT) \ No newline at end of file + uvicorn.run(app, host=HOLMES_HOST, port=HOLMES_PORT) diff --git a/test-api.sh b/test-api.sh index 49d4cc58..caf1495d 100755 --- a/test-api.sh +++ b/test-api.sh @@ -1,4 +1,4 @@ -curl -XPOST 127.0.0.1:8000/api/investigate -H "Content-Type: application/json" --data "{ +curl -XPOST 127.0.0.1:5050/api/investigate -H "Content-Type: application/json" --data "{ \"source\": \"prometheus\", \"source_instance_id\": \"some-instance\", \"title\": \"Pod is crash looping.\",