Skip to content

Rule engine integration and UI improvements #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ source .venv/bin/activate # On Windows: .venv\Scripts\activate
uv pip install -e .
```

4. Create DB named `canvas` in postgres

## 🏃 Running the Application

Start the development server:
Expand Down
1 change: 0 additions & 1 deletion core/configs/dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ server:
debug: True

database:
type: "postgresql"
url: "postgresql+asyncpg://postgres:postgres@localhost:5432/canvas"
schema_package: "core.repositories.schemas"
1 change: 0 additions & 1 deletion core/configs/prod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ server:
debug: False

database:
type: "postgresql"
url: "postgresql+asyncpg://postgres:postgres@localhost:5432/canvas"
schema_package: "core.repositories.schemas"
1 change: 0 additions & 1 deletion core/configs/stage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ server:
debug: False

database:
type: "postgresql"
url: "postgresql+asyncpg://postgres:postgres@localhost:5432/canvas"
schema_package: "core.repositories.schemas"
2 changes: 1 addition & 1 deletion core/controllers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .agent_controller import AgentController
from .base_controller import BaseController
from .crew_controller import CrewController
from .model_controller import ModelController
from .prompt_controller import PromptController
from .relation_controller import RelationController
from .task_controller import TaskController
from .flow_engine_controller import FlowEngineController
37 changes: 5 additions & 32 deletions core/controllers/base_controller.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,14 @@
from core.services.agent.agent_service import AgentService
from core.services.crew.crew_service import CrewService
from core.services.model.model_service import ModelService
from core.services.prompt.prompt_service import PromptService
from core.services.relation.relation_service import RelationService
from core.services.task.task_service import TaskService
from shared.services.context_manager.context_manager_service import ContextManager


class BaseController:
class BaseController(ContextManager):
def __init__(self):
self.prompt_swagger_tags = ["Prompt"]
self.prompt_service: PromptService = PromptService()
super().__init__()

self.prompt_swagger_tags = ["Prompt"]
self.agent_swagger_tags = ["Agent"]
self.agent_service: AgentService = AgentService(
RelationService(), ModelService(), PromptService()
)

self.model_swagger_tags = ["Model"]
self.model_service: ModelService = ModelService()

self.relation_swagger_tags = ["Relation"]
self.relation_service: RelationService = RelationService()

self.task_swagger_tags = ["Task"]
self.task_service: TaskService = TaskService(
AgentService(RelationService(), ModelService(), PromptService()),
RelationService(),
PromptService(),
)

self.crew_swagger_tags = ["Crew"]
self.crew_service: CrewService = CrewService(
AgentService(RelationService(), ModelService(), PromptService()),
TaskService(
AgentService(RelationService(), ModelService(), PromptService()),
RelationService(),
PromptService(),
),
RelationService(),
)
self.flow_engine_swagger_tags = ["Flow Engine"]
170 changes: 170 additions & 0 deletions core/controllers/flow_engine_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import importlib
import json
import os
import uuid
from typing import Dict, Any, List

from fastapi import APIRouter, HTTPException
from starlette.responses import StreamingResponse

from core.controllers.base_controller import BaseController
from flow_engine.flow_chain.dtos import (
NodeUiConfig,
FlowChain,
NodeTypes,
)
from flow_engine.flow_chain.services import FlowNodeRegistry
from shared.utils.funcs import get_root_path


class FlowEngineController(BaseController):
def __init__(self):
super().__init__()
self.router = APIRouter(tags=self.flow_engine_swagger_tags)

self.router.add_api_route(
"/node-types",
self.get_node_types,
methods=["GET"],
response_model=List[NodeUiConfig],
)
self.router.add_api_route(
"/crewai/flow-chains",
self.create_crewai_flow_chain,
methods=["POST"],
response_model=Dict[str, Any],
)
self.router.add_api_route(
"/crewai/flow-chains/{chain_id}/execute",
self.execute_crewai_flow_chain,
methods=["POST"],
response_model=Dict[str, Any],
)

@staticmethod
async def get_node_types() -> List[NodeUiConfig]:
node_types = []
root_path = get_root_path()
flow_nodes_path = os.path.join(root_path, "flow_engine", "flow_node")
node_ui_configs: List[NodeUiConfig] = []
for root, dirs, files in os.walk(flow_nodes_path):
if "ui_config.py" in files:
ui_config_path = os.path.join(root, "ui_config.py")
try:
spec = importlib.util.spec_from_file_location(
"ui_config", ui_config_path
)
ui_config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(ui_config_module)
if hasattr(ui_config_module, "ui_config"):
node_ui_configs.append(ui_config_module.ui_config)
except Exception as e:
print(f"Error loading {ui_config_path}: {e}")
return node_ui_configs

async def create_crewai_flow_chain(self, request: Dict[str, Any]) -> Dict[str, Any]:
try:
flow_chain = FlowChain(
id=str(uuid.uuid4()),
name=request.get("name", "Unnamed CrewAI Flow"),
description=request.get("description"),
nodes=request.get("nodes", []),
connections=request.get("connections", []),
debug_mode=request.get("debug_mode", False),
first_node_id=request.get("first_node_id"),
)
self.flow_chains[flow_chain.id] = flow_chain
nodes: List[Any] = []
for node_request in flow_chain.nodes:
node_class = FlowNodeRegistry.get_plugin(
request.get("node_template_id")
)
if not isinstance(node_request.configuration, dict):
node_request.configuration = node_request.configuration.model_dump()
node = node_class(
flow_chain_id=flow_chain.id,
node_id=node_request.id,
name=node_request.name,
node_type=node_request.node_type,
configuration=node_request.configuration,
connections=flow_chain.connections,
)
nodes.append(node)
self.flow_nodes[flow_chain.id] = nodes
return {
"flow_chain_id": flow_chain.id,
"message": "CrewAI flow chain created successfully",
"crew_created": True,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

async def execute_crewai_flow_chain(self, chain_id: str, request: Dict[str, Any]):
try:
flow_chain = self.flow_chains[chain_id]
if not flow_chain:
raise HTTPException(
status_code=404, detail="CrewAI flow chain not found"
)
first_node_id = flow_chain.first_node_id
self.flow_chain_events[chain_id] = {}
self.flow_chain_events[chain_id]["traversed_node"] = []
self.flow_chain_events[chain_id]["traversed_node"].append(first_node_id)
flow_nodes = self.flow_nodes.get(flow_chain.id)
total_tool_nodes = len(
[
node
for node in flow_nodes
if node.type == NodeTypes.TOOL or node.type == NodeTypes.CREW
]
)

async def stream_data():
last_index = 0
user_msg = request
has_init = False
while True:
if not has_init:
self.event.set()
has_init = True
await self.event.wait()
self.event.clear()

while last_index <= total_tool_nodes - 1:
traversed_node = self.flow_chain_events.get(flow_chain.id).get(
"traversed_node"
)
if len(traversed_node) == last_index:
yield json.dumps(
{
"flow_chain_id": chain_id,
"message": "Flow chain completed",
}
).encode("utf-8")
return
flow_node_id = traversed_node[last_index]
all_msg = self.flow_chain_events.get(flow_chain.id).get(
"message"
)
node_msg = None
if all_msg:
node_msg = all_msg.get(flow_node_id)
message = node_msg if node_msg else user_msg

flow_node = next(
(node for node in flow_nodes if node.id == flow_node_id),
None,
)
flow_node.process(message)
last_index += 1
yield json.dumps(
self.flow_chain_events.get(flow_chain.id)
).encode("utf-8")
if last_index == len(flow_chain.nodes) - 1:
yield json.dumps(
{"flow_chain_id": chain_id, "result": message}
).encode("utf-8")

return StreamingResponse(stream_data(), media_type="text/event-stream")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
2 changes: 1 addition & 1 deletion core/controllers/task_controller.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import APIRouter
from pydantic import UUID4

from core.controllers import BaseController
from core.controllers.base_controller import BaseController
from core.dtos.task import TaskDto


Expand Down
3 changes: 0 additions & 3 deletions core/dtos/settings/database_settings_dto.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from pydantic import BaseModel

from core.dtos.settings.database_types import DatabaseTypes


class DatabaseSettingsDto(BaseModel):
type: DatabaseTypes = DatabaseTypes.postgresql
url: str
schema_package: str
5 changes: 0 additions & 5 deletions core/dtos/settings/database_types.py

This file was deleted.

2 changes: 1 addition & 1 deletion core/repositories/agent_repository.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from core.repositories.base_repository import BaseRepository
from core.repositories.schemas.agent_schema import AgentSchema
from core.services.database import database_service
from shared.services.database import database_service


class AgentRepository(BaseRepository[AgentSchema]):
Expand Down
2 changes: 1 addition & 1 deletion core/repositories/crew_repository.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from core.repositories.base_repository import BaseRepository
from core.repositories.schemas.crew_schema import CrewSchema
from core.services.database import database_service
from shared.services.database import database_service


class CrewRepository(BaseRepository[CrewSchema]):
Expand Down
2 changes: 1 addition & 1 deletion core/repositories/model_repository.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from core.repositories.base_repository import BaseRepository
from core.repositories.schemas.model_schema import ModelSchema
from core.services.database import database_service
from shared.services.database import database_service


class ModelRepository(BaseRepository[ModelSchema]):
Expand Down
2 changes: 1 addition & 1 deletion core/repositories/prompt_repository.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from core.repositories.base_repository import BaseRepository
from core.repositories.schemas.prompt_schema import PromptSchema
from core.services.database import database_service
from shared.services.database import database_service


class PromptRepository(BaseRepository[PromptSchema]):
Expand Down
3 changes: 2 additions & 1 deletion core/repositories/relation_repository.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import List

from sqlmodel import select

from core.dtos.entity import EntityType
from core.dtos.entity.entity import Entity
from core.dtos.relation.relation_direction import RelationDirection
from core.repositories.base_repository import BaseRepository
from core.repositories.schemas.relation_schema import RelationSchema
from core.services.database import database_service
from shared.services.database import database_service


class RelationRepository(BaseRepository[RelationSchema]):
Expand Down
1 change: 1 addition & 0 deletions core/repositories/schemas/model_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ModelSchema(BaseSchema, table=True):
model_type: str = Field(nullable=False)
description: Optional[str] = Field(nullable=True)
api_key: Optional[str] = Field(nullable=True)
# @todo add is_default prop

created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), server_default=func.now())
Expand Down
2 changes: 1 addition & 1 deletion core/repositories/task_repository.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from core.repositories.base_repository import BaseRepository
from core.repositories.schemas.task_schema import TaskSchema
from core.services.database import database_service
from shared.services.database import database_service


class TaskRepository(BaseRepository[TaskSchema]):
Expand Down
4 changes: 3 additions & 1 deletion core/routers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi.routing import APIRouter
from fastapi import APIRouter

from core.controllers import (
AgentController,
Expand All @@ -7,6 +7,7 @@
ModelController,
RelationController,
CrewController,
FlowEngineController,
)

routes: list[APIRouter] = [
Expand All @@ -16,4 +17,5 @@
ModelController().router,
RelationController().router,
CrewController().router,
FlowEngineController().router,
]
12 changes: 7 additions & 5 deletions core/services/agent/agent_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Optional, List, Dict

from crewai import Agent
from pydantic import validate_call
Expand All @@ -18,6 +18,8 @@


class AgentService(BaseService[AgentDto, Agent]):
agent_flows: Dict[str, List["AgentFlow"]] = {}

def __init__(
self,
relation_service: RelationService,
Expand All @@ -40,10 +42,10 @@ async def build(self, entity: AgentEntity):
llm = await self.model_service.build(
ModelEntity(agent_entity_details.llm_id)
)
prompt_entities: List[
PromptEntity
] = await self.relation_service.get_related_entities(
entity, RelationDirection.TO, PromptEntity
prompt_entities: List[PromptEntity] = (
await self.relation_service.get_related_entities(
entity, RelationDirection.TO, PromptEntity
)
)
prompt_details: list[PromptDto] = await self.prompt_service.read_by_ids(
prompt_entities
Expand Down
Loading