Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/en/concepts/tasks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ crew = Crew(
| **Guardrails** _(optional)_ | `guardrails` | `Optional[List[Callable] | List[str]]` | List of guardrails to validate task output before proceeding to next task. |
| **Guardrail Max Retries** _(optional)_ | `guardrail_max_retries` | `Optional[int]` | Maximum number of retries when guardrail validation fails. Defaults to 3. |

<Note type="warning" title="Deprecated: max_retries">
The task attribute `max_retries` is deprecated and will be removed in v1.0.0.
Use `guardrail_max_retries` instead to control retry attempts when a guardrail fails.
</Note>

## Creating Tasks

Expand Down
84 changes: 34 additions & 50 deletions lib/crewai/src/crewai/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Callable
from concurrent.futures import Future
from copy import copy as shallow_copy
import datetime
Expand All @@ -10,18 +11,19 @@
from pathlib import Path
import threading
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
cast,
get_args,
get_origin,
)
import uuid
import warnings

from pydantic import (
UUID4,
BaseModel,
ConfigDict,
Field,
PrivateAttr,
field_validator,
Expand All @@ -37,6 +39,7 @@
TaskFailedEvent,
TaskStartedEvent,
)
from crewai.llms.base_llm import BaseLLM
from crewai.security import Fingerprint, SecurityConfig
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
Expand All @@ -57,6 +60,9 @@
from crewai.utilities.string_utils import interpolate_only


if TYPE_CHECKING:
from crewai.agent.core import Agent

_printer = Printer()


Expand Down Expand Up @@ -101,17 +107,17 @@ class Task(BaseModel):
description="Configuration for the agent",
default=None,
)
callback: Any | None = Field(
callback: Callable[[TaskOutput], None] | None = Field(
description="Callback to be executed after the task is completed.", default=None
)
agent: BaseAgent | None = Field(
agent: Agent | None = Field(
description="Agent responsible for execution the task.", default=None
)
context: list[Task] | None | _NotSpecified = Field(
description="Other tasks that will have their output used as context for this task.",
default=NOT_SPECIFIED,
)
async_execution: bool | None = Field(
async_execution: bool = Field(
description="Whether the task should be executed asynchronously or not.",
default=False,
)
Expand Down Expand Up @@ -151,11 +157,11 @@ class Task(BaseModel):
frozen=True,
description="Unique identifier for the object, not set by user.",
)
human_input: bool | None = Field(
human_input: bool = Field(
description="Whether the task should have a human review the final answer of the agent",
default=False,
)
markdown: bool | None = Field(
markdown: bool = Field(
description="Whether the task should instruct the agent to return the final answer formatted in Markdown",
default=False,
)
Expand All @@ -172,11 +178,6 @@ class Task(BaseModel):
default=None,
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
)

max_retries: int | None = Field(
default=None,
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
)
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
Expand All @@ -187,8 +188,8 @@ class Task(BaseModel):
end_time: datetime.datetime | None = Field(
default=None, description="End time of the task execution"
)
allow_crewai_trigger_context: bool | None = Field(
default=None,
allow_crewai_trigger_context: bool = Field(
default=False,
description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.",
)
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
Expand All @@ -202,7 +203,9 @@ class Task(BaseModel):
_original_expected_output: str | None = PrivateAttr(default=None)
_original_output_file: str | None = PrivateAttr(default=None)
_thread: threading.Thread | None = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

@field_validator("guardrail")
@classmethod
Expand Down Expand Up @@ -288,15 +291,16 @@ def ensure_guardrail_is_callable(self) -> Task:
if self.agent is None:
raise ValueError("Agent is required to use LLMGuardrail")

self._guardrail = cast(
GuardrailCallable,
LLMGuardrail(description=self.guardrail, llm=self.agent.llm),
self._guardrail = LLMGuardrail(
description=self.guardrail, llm=cast(BaseLLM, self.agent.llm)
)

return self

@model_validator(mode="after")
def ensure_guardrails_is_list_of_callables(self) -> Task:
from crewai.tasks.llm_guardrail import LLMGuardrail

guardrails = []
if self.guardrails is not None:
if isinstance(self.guardrails, (list, tuple)):
Expand All @@ -309,14 +313,11 @@ def ensure_guardrails_is_list_of_callables(self) -> Task:
raise ValueError(
"Agent is required to use non-programmatic guardrails"
)
from crewai.tasks.llm_guardrail import LLMGuardrail

guardrails.append(
cast(
GuardrailCallable,
LLMGuardrail(
description=guardrail, llm=self.agent.llm
),
LLMGuardrail(
description=guardrail,
llm=cast(BaseLLM, self.agent.llm),
)
)
else:
Expand All @@ -329,14 +330,11 @@ def ensure_guardrails_is_list_of_callables(self) -> Task:
raise ValueError(
"Agent is required to use non-programmatic guardrails"
)
from crewai.tasks.llm_guardrail import LLMGuardrail

guardrails.append(
cast(
GuardrailCallable,
LLMGuardrail(
description=self.guardrails, llm=self.agent.llm
),
LLMGuardrail(
description=self.guardrails,
llm=cast(BaseLLM, self.agent.llm),
)
)
else:
Expand Down Expand Up @@ -436,21 +434,9 @@ def check_output(self) -> Self:
)
return self

@model_validator(mode="after")
def handle_max_retries_deprecation(self) -> Self:
if self.max_retries is not None:
warnings.warn(
"The 'max_retries' parameter is deprecated and will be removed in CrewAI v1.0.0. "
"Please use 'guardrail_max_retries' instead.",
DeprecationWarning,
stacklevel=2,
)
self.guardrail_max_retries = self.max_retries
return self

def execute_sync(
self,
agent: BaseAgent | None = None,
agent: Agent | None = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> TaskOutput:
Expand Down Expand Up @@ -488,9 +474,9 @@ def execute_async(

def _execute_task_async(
self,
agent: BaseAgent | None,
agent: Agent | None,
context: str | None,
tools: list[Any] | None,
tools: list[BaseTool] | None,
future: Future[TaskOutput],
) -> None:
"""Execute the task asynchronously with context handling."""
Expand All @@ -499,9 +485,9 @@ def _execute_task_async(

def _execute_core(
self,
agent: BaseAgent | None,
agent: Agent | None,
context: str | None,
tools: list[Any] | None,
tools: list[BaseTool] | None,
) -> TaskOutput:
"""Run the core execution logic of the task."""
try:
Expand Down Expand Up @@ -611,8 +597,6 @@ def prompt(self) -> str:
if trigger_payload is not None:
description += f"\n\nTrigger Payload: {trigger_payload}"

tasks_slices = [description]

output = self.i18n.slice("expected_output").format(
expected_output=self.expected_output
)
Expand Down Expand Up @@ -715,7 +699,7 @@ def increment_delegations(self, agent_name: str | None) -> None:
self.processed_by_agents.add(agent_name)
self.delegations += 1

def copy( # type: ignore
def copy( # type: ignore[override]
self, agents: list[BaseAgent], task_mapping: dict[str, Task]
) -> Task:
"""Creates a deep copy of the Task while preserving its original class type.
Expand Down Expand Up @@ -859,7 +843,7 @@ def fingerprint(self) -> Fingerprint:
def _invoke_guardrail_function(
self,
task_output: TaskOutput,
agent: BaseAgent,
agent: Agent,
tools: list[BaseTool],
guardrail: GuardrailCallable | None,
guardrail_index: int | None = None,
Expand Down
13 changes: 7 additions & 6 deletions lib/crewai/src/crewai/tasks/llm_guardrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseModel, Field

from crewai.agent import Agent
from crewai.agent.core import Agent
from crewai.lite_agent_output import LiteAgentOutput
from crewai.llms.base_llm import BaseLLM
from crewai.tasks.task_output import TaskOutput
Expand Down Expand Up @@ -38,7 +38,9 @@ def __init__(

self.llm: BaseLLM = llm

def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput:
def _validate_output(
self, task_output: TaskOutput | LiteAgentOutput
) -> LiteAgentOutput:
agent = Agent(
role="Guardrail Agent",
goal="Validate the output of the task",
Expand All @@ -64,18 +66,17 @@ def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput:

return agent.kickoff(query, response_format=LLMGuardrailResult)

def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]:
def __call__(self, task_output: TaskOutput | LiteAgentOutput) -> tuple[bool, Any]:
"""Validates the output of a task based on specified criteria.

Args:
task_output (TaskOutput): The output to be validated.
task_output: The output to be validated.

Returns:
Tuple[bool, Any]: A tuple containing:
A tuple containing:
- bool: True if validation passed, False otherwise
- Any: The validation result or error message
"""

try:
result = self._validate_output(task_output)
if not isinstance(result.pydantic, LLMGuardrailResult):
Expand Down
Loading