Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData
from crewai.flow.utils import get_possible_return_constants
from crewai.utilities.printer import Printer
from crewai.utilities.printer import Printer, PrinterColor

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -465,7 +465,7 @@ def __init__(
self._is_execution_resuming: bool = False

# Initialize state with initial values
self._state = self._create_initial_state()
self._state = self._create_initial_state(kwargs)
self.tracing = tracing
if (
is_tracing_enabled()
Expand All @@ -474,9 +474,6 @@ def __init__(
):
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)

crewai_event_bus.emit(
self,
Expand All @@ -502,9 +499,12 @@ def __init__(
method = method.__get__(self, self.__class__)
self._methods[method_name] = method

def _create_initial_state(self) -> T:
def _create_initial_state(self, kwargs: dict[str, Any] | None = None) -> T:
"""Create and initialize flow state with UUID and default values.

Args:
kwargs: Optional initial values for state fields

Returns:
New state instance with UUID and default values initialized

Expand All @@ -518,7 +518,8 @@ def _create_initial_state(self) -> T:
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
# Create instance without id, then set it
instance = state_type()
init_kwargs = kwargs or {}
instance = state_type(**init_kwargs)
if not hasattr(instance, "id"):
instance.id = str(uuid4())
return cast(T, instance)
Expand All @@ -527,7 +528,8 @@ def _create_initial_state(self) -> T:
class StateWithId(state_type, FlowState): # type: ignore
pass

instance = StateWithId()
init_kwargs = kwargs or {}
instance = StateWithId(**init_kwargs)
if not hasattr(instance, "id"):
instance.id = str(uuid4())
return cast(T, instance)
Expand All @@ -541,13 +543,13 @@ class StateWithId(state_type, FlowState): # type: ignore
# Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
return cast(T, self.initial_state()) # Uses model defaults
return cast(T, self.initial_state(**(kwargs or {})))
if issubclass(self.initial_state, BaseModel):
# Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state()) # Uses model defaults
return cast(T, self.initial_state(**(kwargs or {})))
if self.initial_state is dict:
return cast(T, {"id": str(uuid4())})

Expand Down Expand Up @@ -1086,7 +1088,7 @@ async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
for method_name in self._start_methods:
# Check if this start method is triggered by the current trigger
if method_name in self._listeners:
condition_type, trigger_methods = self._listeners[
_, trigger_methods = self._listeners[
method_name
]
if current_trigger in trigger_methods:
Expand Down Expand Up @@ -1218,7 +1220,7 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non
raise

def _log_flow_event(
self, message: str, color: str = "yellow", level: str = "info"
self, message: str, color: PrinterColor = "yellow", level: str = "info"
) -> None:
"""Centralized logging method for flow events.

Expand Down
118 changes: 113 additions & 5 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from datetime import datetime

import pytest
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.flow_events import (
FlowFinishedEvent,
FlowStartedEvent,
FlowPlotEvent,
FlowStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.flow.flow import Flow, and_, listen, or_, router, start


def test_simple_sequential_flow():
Expand Down Expand Up @@ -679,11 +679,11 @@ def handle_flow_end(source, event):
assert isinstance(received_events[3], MethodExecutionStartedEvent)
assert received_events[3].method_name == "send_welcome_message"
assert received_events[3].params == {}
assert getattr(received_events[3].state, "sent") is False
assert received_events[3].state.sent is False

assert isinstance(received_events[4], MethodExecutionFinishedEvent)
assert received_events[4].method_name == "send_welcome_message"
assert getattr(received_events[4].state, "sent") is True
assert received_events[4].state.sent is True
assert received_events[4].result == "Welcome, Anakin!"

assert isinstance(received_events[5], FlowFinishedEvent)
Expand Down Expand Up @@ -894,3 +894,111 @@ def start(self):

flow = MyFlow()
assert flow.name == "MyFlow"


def test_flow_init_with_required_fields():
"""Test Flow initialization with Pydantic models having required fields."""

class RequiredFieldsState(BaseModel):
name: str
age: int

class RequiredFieldsFlow(Flow[RequiredFieldsState]):
@start()
def step_1(self):
assert self.state.name == "Alice"
assert self.state.age == 30

flow = RequiredFieldsFlow(name="Alice", age=30)
flow.kickoff()

assert flow.state.name == "Alice"
assert flow.state.age == 30
assert hasattr(flow.state, "id")
assert len(flow.state.id) == 36


def test_flow_init_with_required_fields_missing_values():
"""Test that Flow initialization fails when required fields are missing."""

class RequiredFieldsState(BaseModel):
name: str
age: int

class RequiredFieldsFlow(Flow[RequiredFieldsState]):
@start()
def step_1(self):
pass

with pytest.raises(ValidationError):
RequiredFieldsFlow()


def test_flow_init_with_mixed_required_optional_fields():
"""Test Flow with both required and optional fields."""

class MixedFieldsState(BaseModel):
name: str
age: int = 25
city: str | None = None

class MixedFieldsFlow(Flow[MixedFieldsState]):
@start()
def step_1(self):
assert self.state.name == "Bob"
assert self.state.age == 25
assert self.state.city is None

flow = MixedFieldsFlow(name="Bob")
flow.kickoff()

assert flow.state.name == "Bob"
assert flow.state.age == 25
assert flow.state.city is None


def test_flow_init_with_required_fields_and_overrides():
"""Test that kwargs override default values."""

class DefaultFieldsState(BaseModel):
name: str
age: int = 18
active: bool = True

class DefaultFieldsFlow(Flow[DefaultFieldsState]):
@start()
def step_1(self):
assert self.state.name == "Charlie"
assert self.state.age == 35
assert self.state.active is False

flow = DefaultFieldsFlow(name="Charlie", age=35, active=False)
flow.kickoff()

assert flow.state.name == "Charlie"
assert flow.state.age == 35
assert flow.state.active is False


def test_flow_init_backward_compatibility_with_flowstate():
"""Test that existing FlowState subclasses still work."""
from crewai.flow.flow import FlowState

class MyFlowState(FlowState):
counter: int = 0
message: str = "default"

class BackwardCompatFlow(Flow[MyFlowState]):
@start()
def step_1(self):
self.state.counter += 1

flow1 = BackwardCompatFlow()
flow1.kickoff()
assert flow1.state.counter == 1
assert flow1.state.message == "default"

flow2 = BackwardCompatFlow(counter=10, message="custom")
flow2.kickoff()
assert flow2.state.counter == 11
assert flow2.state.message == "custom"