diff --git a/backend/consensus/decisions.py b/backend/consensus/decisions.py index 1aef1228a..0c4386ef5 100644 --- a/backend/consensus/decisions.py +++ b/backend/consensus/decisions.py @@ -9,6 +9,7 @@ from typing import Any +from backend.node.types import NodeConfig from backend.consensus.effects import ( AddTimestampEffect, StatusUpdateEffect, @@ -320,7 +321,7 @@ def decide_accepted( contract_address: str | None, code_slot_b64: str | None, to_address: str, - leader_node_config: dict, + leader_node_config: NodeConfig, ) -> tuple[list[Effect], list[Effect], ConsensusRound, ConsensusRound | None]: """Decide effects for AcceptedState. @@ -478,7 +479,7 @@ def decide_finalizing( tx_hash: str, tx_status_accepted: bool, execution_result_success: bool, - leader_node_config: dict, + leader_node_config: NodeConfig, ) -> tuple[list[Effect], list[Effect], bool]: """Decide effects for FinalizingState. diff --git a/backend/domain/types.py b/backend/domain/types.py index 3fd56c88d..4edb43c81 100644 --- a/backend/domain/types.py +++ b/backend/domain/types.py @@ -6,19 +6,40 @@ import datetime from enum import Enum, IntEnum import os +from typing import NotRequired, TypedDict from backend.database_handler.models import TransactionStatus from backend.database_handler.types import ConsensusData from backend.database_handler.contract_snapshot import ContractSnapshot +class LLMProviderConfig(TypedDict, total=False): + temperature: float + max_tokens: int + use_max_completion_tokens: bool + + +class PluginConfig(TypedDict, total=False): + api_key_env_var: str + api_url: str + mock_response: dict + + +class ProviderParams(TypedDict): + provider: str + model: str + config: LLMProviderConfig + plugin: str + plugin_config: PluginConfig + + @dataclass class SimValidatorConfig: stake: int provider: str model: str - config: dict + config: LLMProviderConfig plugin: str - plugin_config: dict + plugin_config: PluginConfig @classmethod def from_dict(cls, d: dict) -> "SimValidatorConfig": @@ -80,9 +101,9 @@ def to_dict(self) -> dict: class LLMProvider: provider: str model: str - config: dict + config: LLMProviderConfig plugin: str - plugin_config: dict + plugin_config: PluginConfig id: int | None = None def __hash__(self): diff --git a/backend/node/base.py b/backend/node/base.py index bed162b13..ce5e770f7 100644 --- a/backend/node/base.py +++ b/backend/node/base.py @@ -18,7 +18,7 @@ import backend.node.genvm.base as genvmbase import backend.node.genvm.origin.calldata as calldata from backend.database_handler.contract_snapshot import ContractSnapshot -from backend.node.types import Receipt, ExecutionMode, Vote, ExecutionResultStatus +from backend.node.types import Receipt, ExecutionMode, Vote, ExecutionResultStatus, NodeConfig from backend.protocol_rpc.message_handler.base import IMessageHandler from .genvm.origin import logger as genvm_logger from .genvm.origin import public_abi @@ -659,7 +659,7 @@ async def exec_transaction(self, transaction: Transaction) -> Receipt: return receipt - def _create_enhanced_node_config(self, host_data: dict | None) -> dict: + def _create_enhanced_node_config(self, host_data: dict | None) -> NodeConfig: """ Create enhanced node_config that includes both primary and fallback provider info. diff --git a/backend/node/types.py b/backend/node/types.py index d49c83ca1..187fd61d9 100644 --- a/backend/node/types.py +++ b/backend/node/types.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Iterable, Optional, Literal +from typing import Iterable, Optional, Literal, NotRequired, TypedDict import base64 import hashlib import json @@ -226,6 +226,21 @@ def from_dict(cls, input: dict) -> "PendingTransaction": ) +class NodeConfig(TypedDict, total=False): + address: str + stake: int + provider: str + model: str + config: dict + plugin: str + plugin_config: dict + private_key: str | None + fallback_validator: str | None + id: int + primary_model: dict | None + secondary_model: dict | None + + @dataclass class Receipt: result: bytes @@ -233,7 +248,7 @@ class Receipt: gas_used: int mode: ExecutionMode contract_state: dict[str, str] - node_config: dict + node_config: NodeConfig execution_result: ExecutionResultStatus eq_outputs: dict[int, str] | None = None vote: Optional[Vote] = None diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index 795de37d5..b23e3877d 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -15,7 +15,7 @@ from backend.database_handler.llm_providers import LLMProviderRegistry from backend.rollup.consensus_service import ConsensusService from backend.database_handler.models import Base, TransactionStatus -from backend.domain.types import LLMProvider, Validator, TransactionType, SimConfig +from backend.domain.types import LLMProvider, LLMProviderConfig, PluginConfig, ProviderParams, Validator, TransactionType, SimConfig from backend.node.create_nodes.providers import ( get_default_provider_for, validate_provider, @@ -270,7 +270,7 @@ async def check_with_semaphore(genvm_manager, provider): @check_forbidden_method_in_hosted_studio -def add_provider(session: Session, params: dict) -> int: +def add_provider(session: Session, params: ProviderParams) -> int: """Add a provider using the request-scoped session.""" llm_provider_registry = LLMProviderRegistry(session) @@ -288,7 +288,7 @@ def add_provider(session: Session, params: dict) -> int: @check_forbidden_method_in_hosted_studio -def update_provider(session: Session, id: int, params: dict) -> None: +def update_provider(session: Session, id: int, params: ProviderParams) -> None: """Update a provider using the request-scoped session.""" llm_provider_registry = LLMProviderRegistry(session) @@ -317,9 +317,9 @@ async def create_validator( stake: int, provider: str, model: str, - config: dict | None = None, + config: LLMProviderConfig | None = None, plugin: str | None = None, - plugin_config: dict | None = None, + plugin_config: PluginConfig | None = None, ) -> dict: # fallback for default provider llm_provider = None @@ -420,9 +420,9 @@ async def update_validator( stake: int, provider: str, model: str, - config: dict | None = None, + config: LLMProviderConfig | None = None, plugin: str | None = None, - plugin_config: dict | None = None, + plugin_config: PluginConfig | None = None, ) -> dict: # Remove validation while adding migration to update the db address # if not accounts_manager.is_valid_address(validator_address): diff --git a/backend/protocol_rpc/rpc_methods.py b/backend/protocol_rpc/rpc_methods.py index 1c8785850..5deb63ebc 100644 --- a/backend/protocol_rpc/rpc_methods.py +++ b/backend/protocol_rpc/rpc_methods.py @@ -26,6 +26,7 @@ ) from backend.protocol_rpc.rpc_decorators import rpc from backend.protocol_rpc.rpc_endpoint_manager import LogPolicy +from backend.domain.types import LLMProviderConfig, PluginConfig, ProviderParams # --------------------------------------------------------------------------- @@ -79,7 +80,7 @@ def reset_defaults_llm_providers( @rpc.method("sim_addProvider") def add_provider( - params: dict, + params: ProviderParams, session: Session = Depends(get_db_session), ) -> int: return impl.add_provider(session=session, params=params) @@ -88,7 +89,7 @@ def add_provider( @rpc.method("sim_updateProvider") def update_provider( id: int, - params: dict, + params: ProviderParams, session: Session = Depends(get_db_session), ) -> None: return impl.update_provider(session=session, id=id, params=params) @@ -107,9 +108,9 @@ async def create_validator( stake: int, provider: str, model: str, - config: dict | None = None, + config: LLMProviderConfig | None = None, plugin: str | None = None, - plugin_config: dict | None = None, + plugin_config: PluginConfig | None = None, session: Session = Depends(get_db_session), validators_manager=Depends(get_validators_manager), ) -> dict: @@ -170,7 +171,7 @@ async def update_validator( provider: str | None = None, model: str | None = None, plugin: str | None = None, - plugin_config: dict | None = None, + plugin_config: PluginConfig | None = None, session: Session = Depends(get_db_session), validators_manager=Depends(get_validators_manager), ) -> dict: