Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 5 additions & 65 deletions ccflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@

import collections.abc
import copy
import inspect
import logging
import pathlib
import platform
import sys
import warnings
from types import GenericAlias, MappingProxyType
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar

from omegaconf import DictConfig
from packaging import version
from pydantic import (
BaseModel as PydanticBaseModel,
ConfigDict,
Expand Down Expand Up @@ -88,66 +85,7 @@ def get_registry_dependencies(self, types: Optional[Tuple["ModelType"]] = None)
return deps


# Pydantic 2 has different handling of serialization.
# This requires some workarounds at the moment until the feature is added to easily get a mode that
# is compatible with Pydantic 1
# This is done by adjusting annotations via a MetaClass for any annotation that includes a BaseModel,
# such that the new annotation contains SerializeAsAny
# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
# https://github.com/pydantic/pydantic/issues/6423
# https://github.com/pydantic/pydantic-core/pull/740
# See https://github.com/pydantic/pydantic/issues/6381 for inspiration on implementation
# NOTE: For this logic to be removed, require https://github.com/pydantic/pydantic-core/pull/1478
from pydantic._internal._model_construction import ModelMetaclass # noqa: E402

_IS_PY39 = version.parse(platform.python_version()) < version.parse("3.10")


def _adjust_annotations(annotation):
origin = get_origin(annotation)
args = get_args(annotation)
if not _IS_PY39:
from types import UnionType

if origin is UnionType:
origin = Union

if isinstance(annotation, GenericAlias) or (inspect.isclass(annotation) and issubclass(annotation, PydanticBaseModel)):
return SerializeAsAny[annotation]
elif origin and args:
# Filter out typing.Type and generic types
if origin is type or (inspect.isclass(origin) and issubclass(origin, Generic)):
return annotation
elif origin is ClassVar: # ClassVar doesn't accept a tuple of length 1 in py39
return ClassVar[_adjust_annotations(args[0])]
else:
try:
return origin[tuple(_adjust_annotations(arg) for arg in args)]
except TypeError:
raise TypeError(f"Could not adjust annotations for {origin}")
else:
return annotation


class _SerializeAsAnyMeta(ModelMetaclass):
def __new__(self, name: str, bases: Tuple[type], namespaces: Dict[str, Any], **kwargs):
annotations: dict = namespaces.get("__annotations__", {})

for base in bases:
for base_ in base.__mro__:
if base_ is PydanticBaseModel:
annotations.update(base_.__annotations__)

for field, annotation in annotations.items():
if not field.startswith("__"):
annotations[field] = _adjust_annotations(annotation)

namespaces["__annotations__"] = annotations

return super().__new__(self, name, bases, namespaces, **kwargs)


class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta):
class BaseModel(PydanticBaseModel, _RegistryMixin):
"""BaseModel is a base class for all pydantic models within the cubist flow framework.

This gives us a way to add functionality to the framework, including
Expand Down Expand Up @@ -179,6 +117,8 @@ def type_(self) -> PyObjectPath:
# where the default behavior is just to drop the mis-named value. This prevents that
extra="forbid",
ser_json_timedelta="float",
# Polymorphic serialization is the behavior of allowing a subclass of a model (or Pydantic dataclass) to override serialization so that the subclass' serialization is used, rather than the original model types's serialization. This will expose all the data defined on the subclass in the serialized payload.
polymorphic_serialization=True,
)

def __str__(self):
Expand Down
11 changes: 10 additions & 1 deletion ccflow/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
from inspect import Signature, isclass, signature
from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin

from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator
from pydantic import (
BaseModel as PydanticBaseModel,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
TypeAdapter,
field_validator,
model_validator,
)
from typing_extensions import override

from .base import (
Expand Down
43 changes: 1 addition & 42 deletions ccflow/tests/test_base_serialize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pickle
import platform
import unittest
from typing import Annotated, ClassVar, Dict, List, Optional, Type, Union
from typing import Annotated, Optional

import numpy as np
from packaging import version
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, ValidationError

from ccflow import BaseModel, NDArray
Expand Down Expand Up @@ -213,45 +211,6 @@ class C(PydanticBaseModel):
# C implements the normal pydantic BaseModel whichhould allow extra fields.
_ = C(extra_field1=1)

def test_serialize_as_any(self):
# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
# https://github.com/pydantic/pydantic/issues/6423
# This test could be removed once there is a different solution to the issue above
from pydantic import SerializeAsAny
from pydantic.types import constr

if version.parse(platform.python_version()) >= version.parse("3.10"):
pipe_union = A | int
else:
pipe_union = Union[A, int]

class MyNestedModel(BaseModel):
a1: A
a2: Optional[Union[A, int]]
a3: Dict[str, Optional[List[A]]]
a4: ClassVar[A]
a5: Type[A]
a6: constr(min_length=1)
a7: pipe_union

target = {
"a1": SerializeAsAny[A],
"a2": Optional[Union[SerializeAsAny[A], int]],
"a4": ClassVar[SerializeAsAny[A]],
"a5": Type[A],
"a6": constr(min_length=1), # Uses Annotation
"a7": Union[SerializeAsAny[A], int],
}
target["a3"] = dict[str, Optional[list[SerializeAsAny[A]]]]
annotations = MyNestedModel.__annotations__
self.assertEqual(str(annotations["a1"]), str(target["a1"]))
self.assertEqual(str(annotations["a2"]), str(target["a2"]))
self.assertEqual(str(annotations["a3"]), str(target["a3"]))
self.assertEqual(str(annotations["a4"]), str(target["a4"]))
self.assertEqual(str(annotations["a5"]), str(target["a5"]))
self.assertEqual(str(annotations["a6"]), str(target["a6"]))
self.assertEqual(str(annotations["a7"]), str(target["a7"]))

def test_pickle_consistency(self):
model = MultiAttributeModel(z=1, y="test", x=3.14, w=True)
serialized = pickle.dumps(model)
Expand Down
79 changes: 79 additions & 0 deletions ccflow/tests/test_evaluation_context_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
from datetime import date

from ccflow import DateContext
from ccflow.callable import ModelEvaluationContext
from ccflow.evaluators import GraphEvaluator, LoggingEvaluator, MultiEvaluator
from ccflow.tests.evaluators.util import NodeModel


def _make_nested_mec(model):
ctx = DateContext(date=date(2022, 1, 1))
mec = model.__call__.get_evaluation_context(model, ctx)
assert isinstance(mec, ModelEvaluationContext)
# ensure nested: outer model is an evaluator, inner is a ModelEvaluationContext
assert isinstance(mec.context, ModelEvaluationContext)
return mec


def test_mec_model_dump_basic():
m = NodeModel()
mec = _make_nested_mec(m)

d = mec.model_dump()
assert isinstance(d, dict)
assert "fn" in d and "model" in d and "context" in d and "options" in d

s = mec.model_dump_json()
parsed = json.loads(s)
assert parsed["fn"] == d["fn"]
# Also verify mode-specific dumps
d_py = mec.model_dump(mode="python")
assert isinstance(d_py, dict)
d_json = mec.model_dump(mode="json")
assert isinstance(d_json, dict)
json.dumps(d_json)


def test_mec_model_dump_diamond_graph():
n0 = NodeModel()
n1 = NodeModel(deps_model=[n0])
n2 = NodeModel(deps_model=[n0])
root = NodeModel(deps_model=[n1, n2])

mec = _make_nested_mec(root)

d = mec.model_dump()
assert isinstance(d, dict)
assert set(["fn", "model", "context", "options"]).issubset(d.keys())

s = mec.model_dump_json()
json.loads(s)
# verify mode dumps
d_py = mec.model_dump(mode="python")
assert isinstance(d_py, dict)
d_json = mec.model_dump(mode="json")
assert isinstance(d_json, dict)
json.dumps(d_json)


def test_mec_model_dump_with_multi_evaluator():
m = NodeModel()
_ = LoggingEvaluator() # ensure import/validation
evaluator = MultiEvaluator(evaluators=[LoggingEvaluator(), GraphEvaluator()])

# Simulate how Flow builds evaluation context with a custom evaluator
ctx = DateContext(date=date(2022, 1, 1))
mec = ModelEvaluationContext(model=evaluator, context=m.__call__.get_evaluation_context(m, ctx))

d = mec.model_dump()
assert isinstance(d, dict)
assert "fn" in d and "model" in d and "context" in d
s = mec.model_dump_json()
json.loads(s)
# verify mode dumps
d_py = mec.model_dump(mode="python")
assert isinstance(d_py, dict)
d_json = mec.model_dump(mode="json")
assert isinstance(d_json, dict)
json.dumps(d_json)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies = [
"orjson",
"pandas",
"pyarrow",
"pydantic>=2.6,<3",
"pydantic>=2.13,<3",
"smart_open",
"tenacity",
]
Expand Down
Loading