Skip to content

feat: use Pydantic V2 models when possible instead of defaulting to Pydantic V1 models #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 40 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e5e3f25
feat: added test for pydantic compatibility
adubovik Feb 25, 2025
1cf9903
feat: tentative migration to Pydantic V2
adubovik Feb 25, 2025
bce447f
fix: pytest: treat warnings as errors
adubovik Feb 25, 2025
cf15078
feat: support pydantic v2
adubovik Feb 26, 2025
e4ab8f7
fix: report deprecations are errors
adubovik Feb 26, 2025
ff2e8a3
fix: remove dead code
adubovik Feb 26, 2025
d72d902
fix: moving code around
adubovik Feb 26, 2025
a29ccbd
fix: fixed old Python versions
adubovik Feb 26, 2025
e1174ac
fix: moving code
adubovik Feb 26, 2025
40d4adb
fix: add version agnostic model_validator helper
adubovik Feb 26, 2025
e26e0e1
fix: fix validation tests
adubovik Feb 26, 2025
2865029
Merge branch 'development' into feat/compatibility-with-pydantic-v2
adubovik Feb 27, 2025
5dc575f
fix: minor fixes
adubovik Feb 27, 2025
911b9c8
fix: migrate tic_tac_toe example to Pydantic v2
adubovik Feb 27, 2025
5601798
feat: fix buttons for Pydantic V2
adubovik Mar 4, 2025
f789dd1
fix: fix Pydantic V1 tests
adubovik Mar 4, 2025
c008370
fix: fix Pydantic V2 tests
adubovik Mar 4, 2025
10575a9
fix: minor fix
adubovik Mar 4, 2025
f0891a2
fix: fix linter; first green version
adubovik Mar 4, 2025
33a5640
fix: eliminate "Using extra keyword arguments on `Field` is deprecate…
adubovik Mar 4, 2025
d3eda8c
fix: remove Config class deprecation in Pydantic V2
adubovik Mar 4, 2025
26c3861
chore: move ConfigWrapper to a helper module
adubovik Mar 4, 2025
f8f836c
chore: simplify code
adubovik Mar 4, 2025
2823569
chore: refactor ModelConfigWrapper
adubovik Mar 4, 2025
2729d97
chore: simplify form decorator
adubovik Mar 4, 2025
d655213
fix: simplified FormMetaclass
adubovik Mar 5, 2025
f0aa0ef
fix: minor renamings
adubovik Mar 5, 2025
1e85d3b
feat: simplify chat completion validation tests
adubovik Mar 5, 2025
73d77ac
feat: fix inheritance test
adubovik Mar 5, 2025
f785cf3
chore: moving files around
adubovik Mar 5, 2025
69db930
chore: moving code around
adubovik Mar 5, 2025
0334168
fix: minor fix
adubovik Mar 5, 2025
8184812
feat: add PYDANTIC_V2 env var for backward compatibility
adubovik Mar 5, 2025
03d6106
fix: introduce _pydantic.py for pydantic basic reexports
adubovik Mar 5, 2025
1f200f0
chore: add module comments
adubovik Mar 5, 2025
7749704
chore: moving files around
adubovik Mar 6, 2025
9246f3d
chore: add test for PYDANTIC_V2=True mode
adubovik Mar 10, 2025
7b66e3b
Merge branch 'development' into feat/compatibility-with-pydantic-v2
adubovik Mar 24, 2025
42ff0ee
fix: first cut extra field validation for Pydantic V2
adubovik Mar 24, 2025
69637d2
fix: fix validation tests
adubovik Mar 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ exclude =
.nox,
.pytest_cache
__pycache__,
_pydantic.py,
__init__.py
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.8
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ format: install
test: install
$(POETRY) run -- nox -s test $(if $(PYTHON),--python=$(PYTHON),) -- $(ARGS)

test_fast: install
$(POETRY) run -- nox -s test $(if $(PYTHON),--python=$(PYTHON),) -- -m 'not slow' $(ARGS)

benchmark: install
python -m benchmark.benchmark_merge_chunks

Expand All @@ -47,5 +50,6 @@ help:
@echo 'lint - run linters'
@echo '-- TESTS --'
@echo 'test - run unit tests'
@echo 'test_fast - run unit tests without slow tests'
@echo 'test PYTHON=<python_version> - run unit tests with the specific python version'
@echo 'benchmark - run benchmarks'
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ install PyCharm>=2023.2 with [built-in Black support](https://blog.jetbrains.com
|Variable|Default|Description|
|---|---|---|
|DIAL_SDK_LOG|WARNING|DIAL SDK log level|
|PYDANTIC_V2|False|When `True` and Pydantic V2 is installed, DIAL SDK classes for requests/responses will be based on Pydantic V2 `BaseModel`. Otherwise, they will be based on Pydantic V1 `BaseModel`.|

## Lint

Expand Down
2 changes: 1 addition & 1 deletion aidial_sdk/_errors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse

from aidial_sdk._pydantic import ValidationError
from aidial_sdk.exceptions import HTTPException as DIALException
from aidial_sdk.exceptions import InvalidRequestError
from aidial_sdk.pydantic_v1 import ValidationError


def pydantic_validation_exception_handler(
Expand Down
50 changes: 50 additions & 0 deletions aidial_sdk/_pydantic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
The entry point for all Pydantic definition that unifies v1 and v2 APIs for SDK internals.

It's private, since it expected that the SDK client
will either import `aidial_sdk.pydantic_v1` or `pydantic`.

This is the only place where `pydantic` imports
are allowed in the DIAL SDK package.
"""

from typing import TYPE_CHECKING

import pydantic

from aidial_sdk.utils.env import env_bool

INSTALLED_PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
USE_PYDANTIC_V2 = env_bool("PYDANTIC_V2", False)
PYDANTIC_V2 = INSTALLED_PYDANTIC_V2 and USE_PYDANTIC_V2

if TYPE_CHECKING:
from pydantic import BaseModel
from pydantic import ConfigDict as ConfigDict
from pydantic import Field, ValidationError
from pydantic import field_validator as validator
from pydantic import model_validator
from pydantic._internal._model_construction import ModelMetaclass
from pydantic.fields import FieldInfo
else:

if PYDANTIC_V2:
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from pydantic import field_validator as validator
from pydantic import model_validator
from pydantic._internal._model_construction import ModelMetaclass
from pydantic.fields import FieldInfo
else:
from pydantic.v1 import BaseModel, Field, validator

try:
from pydantic.v1.main import ModelMetaclass
except ImportError:
from pydantic.main import ModelMetaclass
from pydantic.v1 import ValidationError, root_validator
from pydantic.v1.fields import FieldInfo

def _fail(*args, **kwargs):
raise ImportError("ConfigDict is only supported in Pydantic v2")

ConfigDict = _fail
97 changes: 97 additions & 0 deletions aidial_sdk/_pydantic/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
The module provide the basic Pydantic BaseModel extended
with `model_dump` method mimicking the one from Pydantic V2.

All SDK models inherit from this class.

It proves to be useful since
1. `model_dump` method is used extensively in the SDK,
2. the SDK client may call this method on SDK models even if the client uses Pydantic V1.
"""

from datetime import date, datetime
from typing import Any, Dict, Iterable, Mapping, Optional, Set, Union, cast

from typing_extensions import Literal

import aidial_sdk._pydantic as pydantic
from aidial_sdk._pydantic import PYDANTIC_V2

_IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any], None]


class BaseModel(pydantic.BaseModel):
if not PYDANTIC_V2:
# we define aliases for some of the new pydantic v2 methods so
# that we can just document these methods without having to specify
# a specific pydantic version as some users may not know which
# pydantic version they are currently using

def model_dump(
self,
*,
mode: Union[Literal["json", "python"], str] = "python",
include: _IncEx = None,
exclude: _IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: Union[bool, Literal["none", "warn", "error"]] = True,
context: Optional[Dict[str, Any]] = None,
serialize_as_any: bool = False,
) -> Dict[str, Any]:
if mode not in {"json", "python"}:
raise ValueError("mode must be either 'json' or 'python'")
if round_trip is not False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings is not True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any is not False:
raise ValueError(
"serialize_as_any is only supported in Pydantic v2"
)
dumped = super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

return (
cast(Dict[str, Any], _json_safe(dumped))
if mode == "json"
else dumped
)


def _json_safe(data: object) -> object:
"""Translates a mapping / sequence recursively in the same fashion
as `pydantic` v2's `model_dump(mode="json")`.
"""
if isinstance(data, Mapping):
return {
_json_safe(key): _json_safe(value) for key, value in data.items()
}

if isinstance(data, Iterable) and not isinstance(
data, (str, bytes, bytearray)
):
return [_json_safe(item) for item in data]

if isinstance(data, (datetime, date)):
return data.isoformat()

return data


def model_validator(*, mode: Literal["before", "after"]) -> Any:
if PYDANTIC_V2:
return pydantic.model_validator(mode=mode)
else:
return pydantic.root_validator(pre=(mode == "before")) # type: ignore
144 changes: 144 additions & 0 deletions aidial_sdk/_pydantic/_model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Helper classes that unify model configuration between Pydantic v1 and v2.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Type, TypeVar

from pydantic import BaseModel

from aidial_sdk._pydantic import PYDANTIC_V2

_Model = TypeVar("_Model", bound=BaseModel)


class ModelConfigWrapper:
_model_config: ModelConfigBase

def __init__(self, model_config: ModelConfigBase):
self._model_config = model_config

def __getitem__(self, field: str) -> Any:
return self._model_config.get_field(field, None)

def __setitem__(self, field: str, value: Any) -> None:
self._model_config.set_field(field, value)

def post_process_schema(
self, on_schema: Callable[[Dict[str, Any]], None]
) -> None:
attr_name = self._model_config.schema_extra_field
old_schema_extra = self[attr_name]

def _schema_extra(
schema: Dict[str, Any], model: Type[BaseModel]
) -> None:
if old_schema_extra:
old_schema_extra(schema, model)
on_schema(schema)

self[attr_name] = _schema_extra

@classmethod
def create(
cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any]
) -> ModelConfigWrapper:
if PYDANTIC_V2:
return cls(_ConfigV2.create(base_cls, namespace))
else:
return cls(_ConfigV1.create(base_cls, namespace))


class ModelConfigBase(ABC):
@abstractmethod
def set_field(self, field: str, value: Any) -> None:
pass

@abstractmethod
def get_field(self, field: str, default: Any) -> Any:
pass

@property
@abstractmethod
def schema_extra_field(self) -> str:
pass

@classmethod
@abstractmethod
def create(
cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any]
) -> ModelConfigBase:
pass


class _ConfigV1(ModelConfigBase):
config_cls: type

def __init__(self, config_cls: type):
self.config_cls = config_cls

def set_field(self, field: str, value: Any) -> None:
setattr(self.config_cls, field, value)

def get_field(self, field: str, default: Any) -> Any:
return getattr(self.config_cls, field, default)

@property
def schema_extra_field(self) -> str:
return "schema_extra"

@classmethod
def create(
cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any]
) -> ModelConfigBase:
if (config_cls := namespace.get("Config")) is None:
conf_base_cls = (
None if base_cls is None else getattr(base_cls, "Config", None)
)

config_cls = type("Config", (conf_base_cls or object,), {})

if module := namespace.get("__module__"):
config_cls.__module__ = module
if qualname := namespace.get("__qualname__"):
config_cls.__qualname__ = f"{qualname}.{config_cls.__name__}"

namespace["Config"] = config_cls

return cls(config_cls)


class _ConfigV2(ModelConfigBase):
model_config: dict

def __init__(self, model_config: dict):
self.model_config = model_config

def set_field(self, field: str, value: Any) -> None:
self.model_config[field] = value

def get_field(self, field: str, default: Any) -> Any:
return self.model_config.get(field, default)

@property
def schema_extra_field(self) -> str:
return "json_schema_extra"

@classmethod
def create(
cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any]
) -> ModelConfigBase:
base_model_config = (
{} if base_cls is None else getattr(base_cls, "model_config", {})
)

curr_model_config = namespace.get("model_config") or {}

model_config = namespace["model_config"] = {
**base_model_config,
**curr_model_config,
}

return cls(model_config)
9 changes: 5 additions & 4 deletions aidial_sdk/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
fastapi_exception_handler,
pydantic_validation_exception_handler,
)
from aidial_sdk._pydantic import ValidationError
from aidial_sdk._pydantic._compat import BaseModel
from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import Request as ChatCompletionRequest
from aidial_sdk.chat_completion.response import (
Expand All @@ -26,7 +28,6 @@
from aidial_sdk.embeddings.request import Request as EmbeddingsRequest
from aidial_sdk.exceptions import HTTPException as DIALException
from aidial_sdk.header_propagator import HeaderPropagator
from aidial_sdk.pydantic_v1 import BaseModel, ValidationError
from aidial_sdk.telemetry.types import TelemetryConfig
from aidial_sdk.utils._reflection import get_method_implementation
from aidial_sdk.utils.log_config import LogConfig
Expand All @@ -38,7 +39,7 @@
to_streaming_response,
)

logging.config.dictConfig(LogConfig().dict())
logging.config.dictConfig(LogConfig().model_dump())

RequestType = TypeVar("RequestType", bound=FromRequestMixin)

Expand Down Expand Up @@ -207,7 +208,7 @@ async def _handler(original_request: Request) -> Response:
if isinstance(response, dict):
response_json = response
elif isinstance(response, BaseModel):
response_json = response.dict()
response_json = response.model_dump()
else:
raise ValueError(
f"Unexpected response type from {endpoint}: {type(response)}"
Expand Down Expand Up @@ -291,7 +292,7 @@ async def _handler(original_request: Request):
EmbeddingsRequest, original_request, deployment_id
)
response = await impl.embeddings(request)
response_json = response.dict()
response_json = response.model_dump()
return JSONResponse(content=response_json)

return _handler
Expand Down
Loading