diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 85bb077ee9..872f939814 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -31,7 +31,7 @@ from crewai.flow.persistence.base import FlowPersistence from crewai.flow.types import FlowExecutionData from crewai.flow.utils import get_possible_return_constants -from crewai.utilities.printer import Printer +from crewai.utilities.printer import Printer, PrinterColor logger = logging.getLogger(__name__) @@ -465,7 +465,7 @@ def __init__( self._is_execution_resuming: bool = False # Initialize state with initial values - self._state = self._create_initial_state() + self._state = self._create_initial_state(kwargs) self.tracing = tracing if ( is_tracing_enabled() @@ -474,9 +474,6 @@ def __init__( ): trace_listener = TraceCollectionListener() trace_listener.setup_listeners(crewai_event_bus) - # Apply any additional kwargs - if kwargs: - self._initialize_state(kwargs) crewai_event_bus.emit( self, @@ -502,9 +499,12 @@ def __init__( method = method.__get__(self, self.__class__) self._methods[method_name] = method - def _create_initial_state(self) -> T: + def _create_initial_state(self, kwargs: dict[str, Any] | None = None) -> T: """Create and initialize flow state with UUID and default values. + Args: + kwargs: Optional initial values for state fields + Returns: New state instance with UUID and default values initialized @@ -518,7 +518,8 @@ def _create_initial_state(self) -> T: if isinstance(state_type, type): if issubclass(state_type, FlowState): # Create instance without id, then set it - instance = state_type() + init_kwargs = kwargs or {} + instance = state_type(**init_kwargs) if not hasattr(instance, "id"): instance.id = str(uuid4()) return cast(T, instance) @@ -527,7 +528,8 @@ def _create_initial_state(self) -> T: class StateWithId(state_type, FlowState): # type: ignore pass - instance = StateWithId() + init_kwargs = kwargs or {} + instance = StateWithId(**init_kwargs) if not hasattr(instance, "id"): instance.id = str(uuid4()) return cast(T, instance) @@ -541,13 +543,13 @@ class StateWithId(state_type, FlowState): # type: ignore # Handle case where initial_state is a type (class) if isinstance(self.initial_state, type): if issubclass(self.initial_state, FlowState): - return cast(T, self.initial_state()) # Uses model defaults + return cast(T, self.initial_state(**(kwargs or {}))) if issubclass(self.initial_state, BaseModel): # Validate that the model has an id field model_fields = getattr(self.initial_state, "model_fields", None) if not model_fields or "id" not in model_fields: raise ValueError("Flow state model must have an 'id' field") - return cast(T, self.initial_state()) # Uses model defaults + return cast(T, self.initial_state(**(kwargs or {}))) if self.initial_state is dict: return cast(T, {"id": str(uuid4())}) @@ -1086,7 +1088,7 @@ async def _execute_listeners(self, trigger_method: str, result: Any) -> None: for method_name in self._start_methods: # Check if this start method is triggered by the current trigger if method_name in self._listeners: - condition_type, trigger_methods = self._listeners[ + _, trigger_methods = self._listeners[ method_name ] if current_trigger in trigger_methods: @@ -1218,7 +1220,7 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non raise def _log_flow_event( - self, message: str, color: str = "yellow", level: str = "info" + self, message: str, color: PrinterColor = "yellow", level: str = "info" ) -> None: """Centralized logging method for flow events. diff --git a/tests/test_flow.py b/tests/test_flow.py index 504cf8e6e9..1ba229239b 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -4,17 +4,17 @@ from datetime import datetime import pytest -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError -from crewai.flow.flow import Flow, and_, listen, or_, router, start from crewai.events.event_bus import crewai_event_bus from crewai.events.types.flow_events import ( FlowFinishedEvent, - FlowStartedEvent, FlowPlotEvent, + FlowStartedEvent, MethodExecutionFinishedEvent, MethodExecutionStartedEvent, ) +from crewai.flow.flow import Flow, and_, listen, or_, router, start def test_simple_sequential_flow(): @@ -679,11 +679,11 @@ def handle_flow_end(source, event): assert isinstance(received_events[3], MethodExecutionStartedEvent) assert received_events[3].method_name == "send_welcome_message" assert received_events[3].params == {} - assert getattr(received_events[3].state, "sent") is False + assert received_events[3].state.sent is False assert isinstance(received_events[4], MethodExecutionFinishedEvent) assert received_events[4].method_name == "send_welcome_message" - assert getattr(received_events[4].state, "sent") is True + assert received_events[4].state.sent is True assert received_events[4].result == "Welcome, Anakin!" assert isinstance(received_events[5], FlowFinishedEvent) @@ -894,3 +894,111 @@ def start(self): flow = MyFlow() assert flow.name == "MyFlow" + + +def test_flow_init_with_required_fields(): + """Test Flow initialization with Pydantic models having required fields.""" + + class RequiredFieldsState(BaseModel): + name: str + age: int + + class RequiredFieldsFlow(Flow[RequiredFieldsState]): + @start() + def step_1(self): + assert self.state.name == "Alice" + assert self.state.age == 30 + + flow = RequiredFieldsFlow(name="Alice", age=30) + flow.kickoff() + + assert flow.state.name == "Alice" + assert flow.state.age == 30 + assert hasattr(flow.state, "id") + assert len(flow.state.id) == 36 + + +def test_flow_init_with_required_fields_missing_values(): + """Test that Flow initialization fails when required fields are missing.""" + + class RequiredFieldsState(BaseModel): + name: str + age: int + + class RequiredFieldsFlow(Flow[RequiredFieldsState]): + @start() + def step_1(self): + pass + + with pytest.raises(ValidationError): + RequiredFieldsFlow() + + +def test_flow_init_with_mixed_required_optional_fields(): + """Test Flow with both required and optional fields.""" + + class MixedFieldsState(BaseModel): + name: str + age: int = 25 + city: str | None = None + + class MixedFieldsFlow(Flow[MixedFieldsState]): + @start() + def step_1(self): + assert self.state.name == "Bob" + assert self.state.age == 25 + assert self.state.city is None + + flow = MixedFieldsFlow(name="Bob") + flow.kickoff() + + assert flow.state.name == "Bob" + assert flow.state.age == 25 + assert flow.state.city is None + + +def test_flow_init_with_required_fields_and_overrides(): + """Test that kwargs override default values.""" + + class DefaultFieldsState(BaseModel): + name: str + age: int = 18 + active: bool = True + + class DefaultFieldsFlow(Flow[DefaultFieldsState]): + @start() + def step_1(self): + assert self.state.name == "Charlie" + assert self.state.age == 35 + assert self.state.active is False + + flow = DefaultFieldsFlow(name="Charlie", age=35, active=False) + flow.kickoff() + + assert flow.state.name == "Charlie" + assert flow.state.age == 35 + assert flow.state.active is False + + +def test_flow_init_backward_compatibility_with_flowstate(): + """Test that existing FlowState subclasses still work.""" + from crewai.flow.flow import FlowState + + class MyFlowState(FlowState): + counter: int = 0 + message: str = "default" + + class BackwardCompatFlow(Flow[MyFlowState]): + @start() + def step_1(self): + self.state.counter += 1 + + flow1 = BackwardCompatFlow() + flow1.kickoff() + assert flow1.state.counter == 1 + assert flow1.state.message == "default" + + flow2 = BackwardCompatFlow(counter=10, message="custom") + flow2.kickoff() + assert flow2.state.counter == 11 + assert flow2.state.message == "custom"