Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions libs/core/langchain_core/utils/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import Field as Field_v1
from pydantic.v1 import create_model as create_model_v1
from typing_extensions import TypedDict, is_typeddict
from typing_extensions import NotRequired, Required, TypedDict, is_typeddict

import langchain_core
from langchain_core._api import beta
Expand Down Expand Up @@ -237,8 +237,15 @@ def _convert_any_typed_dicts_to_pydantic(
)
fields: dict = {}
for arg, arg_type in annotations_.items():
if get_origin(arg_type) is Annotated: # type: ignore[comparison-overlap]
annotated_args = get_args(arg_type)
origin = get_origin(arg_type)
is_not_required = origin is NotRequired # type: ignore[comparison-overlap]
if origin in {NotRequired, Required}: # type: ignore[comparison-overlap]
inner_arg_type = get_args(arg_type)[0]
else:
inner_arg_type = arg_type

if get_origin(inner_arg_type) is Annotated:
annotated_args = get_args(inner_arg_type)
new_arg_type = _convert_any_typed_dicts_to_pydantic(
annotated_args[0], depth=depth + 1, visited=visited
)
Expand All @@ -256,20 +263,27 @@ def _convert_any_typed_dicts_to_pydantic(
raise ValueError(msg)
if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc
if is_not_required and "default" not in field_kwargs:
field_kwargs["default"] = None
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
else:
new_arg_type = _convert_any_typed_dicts_to_pydantic(
arg_type, depth=depth + 1, visited=visited
inner_arg_type, depth=depth + 1, visited=visited
)
field_kwargs = {"default": ...}
# NotRequired fields have None as default, required fields use ...
field_kwargs = {"default": None if is_not_required else ...}
if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
model = create_model_v1(typed_dict.__name__, **fields)
model.__doc__ = description
visited[typed_dict] = model
return model
if (origin := get_origin(type_)) and (type_args := get_args(type_)):

if (origin := get_origin(type_)) and origin in {NotRequired, Required}: # type: ignore[assignment]
return type_

if (origin := get_origin(type_)) and (type_args := get_args(type_)): # type: ignore[assignment]
subscriptable_origin = _py_38_safe_origin(origin)
type_args = tuple(
_convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited)
Expand Down
17 changes: 17 additions & 0 deletions libs/core/tests/unit_tests/utils/test_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore
from pydantic import Field as FieldV2Maybe # pydantic: ignore
from typing_extensions import NotRequired
from typing_extensions import TypedDict as ExtensionsTypedDict

try:
Expand Down Expand Up @@ -1168,3 +1169,19 @@ class MyModel(BaseModel):
func = convert_to_openai_function(MyModel, strict=True)
actual = func["parameters"]["required"]
assert actual == expected


def test_convert_to_openai_function_typed_dict_with_not_required() -> None:
class MyTypedDict(TypingTypedDict):
"""A TypedDict with NotRequired field."""

required_field: str
optional_field: NotRequired[str]

result = convert_to_openai_function(MyTypedDict)

assert result["name"] == "MyTypedDict"
assert "required_field" in result["parameters"]["properties"]
assert "optional_field" in result["parameters"]["properties"]
assert "required_field" in result["parameters"]["required"]
assert "optional_field" not in result["parameters"]["required"]
Loading