diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7c916f79af..8e1007f2a6 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -476,6 +476,17 @@ def Relationship( return relationship_info +# Helper function to support Pydantic 2.12+ compatibility +def _find_field_info(cls: type, field_name: str) -> Optional[FieldInfo]: + for c in cls.__mro__: + annotated = get_annotations(c.__dict__).get(field_name) # type: ignore[arg-type] + if annotated: + for meta in getattr(annotated, "__metadata__", ()): + if isinstance(meta, FieldInfo): + return meta + return None + + @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] @@ -562,7 +573,17 @@ def get_config(name: str) -> Any: # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): - col = get_column_from_field(v) + if PYDANTIC_MINOR_VERSION >= (2, 12): + original_field = getattr(v, "_original_assignment", Undefined) + # Get the original sqlmodel FieldInfo, pydantic >=v2.12 changes the model + if isinstance(original_field, FieldInfo): + field = original_field + else: + field = _find_field_info(new_cls, field_name=k) or v # type: ignore[assignment] + field.annotation = v.annotation + col = get_column_from_field(field) + else: + col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. diff --git a/tests/test_declaration_syntax.py b/tests/test_declaration_syntax.py new file mode 100644 index 0000000000..803cb9f0fd --- /dev/null +++ b/tests/test_declaration_syntax.py @@ -0,0 +1,26 @@ +from sqlmodel import Field, SQLModel +from typing_extensions import Annotated + + +def test_declaration_syntax_1(): + class Person1(SQLModel): + name: str = Field(primary_key=True) + + class Person1Final(Person1, table=True): + pass + + +def test_declaration_syntax_2(): + class Person2(SQLModel): + name: Annotated[str, Field(primary_key=True)] + + class Person2Final(Person2, table=True): + pass + + +def test_declaration_syntax_3(): + class Person3(SQLModel): + name: Annotated[str, ...] = Field(primary_key=True) + + class Person3Final(Person3, table=True): + pass diff --git a/tests/test_main.py b/tests/test_main.py index 60d5c40ebb..c0e936ee72 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,9 +1,14 @@ from typing import List, Optional import pytest +from sqlalchemy import inspect +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import RelationshipProperty from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from typing_extensions import Annotated + +from .conftest import needs_pydanticv2 def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): @@ -125,3 +130,73 @@ class Hero(SQLModel, table=True): # The next statement should not raise an AttributeError assert hero_rusty_man.team assert hero_rusty_man.team.name == "Preventers" + + +def test_composite_primary_key(clear_sqlmodel): + class UserPermission(SQLModel, table=True): + user_id: int = Field(primary_key=True) + resource_id: int = Field(primary_key=True) + permission: str + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + + insp: Inspector = inspect(engine) + pk_constraint = insp.get_pk_constraint(str(UserPermission.__tablename__)) + + assert len(pk_constraint["constrained_columns"]) == 2 + assert "user_id" in pk_constraint["constrained_columns"] + assert "resource_id" in pk_constraint["constrained_columns"] + + with Session(engine) as session: + perm1 = UserPermission(user_id=1, resource_id=1, permission="read") + perm2 = UserPermission(user_id=1, resource_id=2, permission="write") + session.add(perm1) + session.add(perm2) + session.commit() + + with pytest.raises(IntegrityError): + with Session(engine) as session: + perm3 = UserPermission(user_id=1, resource_id=1, permission="admin") + session.add(perm3) + session.commit() + + +@needs_pydanticv2 +def test_composite_primary_key_and_validator(clear_sqlmodel): + from pydantic import AfterValidator + + def validate_resource_id(value: int) -> int: + if value < 1: + raise ValueError("Resource ID must be positive") + return value + + class UserPermission(SQLModel, table=True): + user_id: int = Field(primary_key=True) + resource_id: Annotated[int, AfterValidator(validate_resource_id)] = Field( + primary_key=True + ) + permission: str + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + + insp: Inspector = inspect(engine) + pk_constraint = insp.get_pk_constraint(str(UserPermission.__tablename__)) + + assert len(pk_constraint["constrained_columns"]) == 2 + assert "user_id" in pk_constraint["constrained_columns"] + assert "resource_id" in pk_constraint["constrained_columns"] + + with Session(engine) as session: + perm1 = UserPermission(user_id=1, resource_id=1, permission="read") + perm2 = UserPermission(user_id=1, resource_id=2, permission="write") + session.add(perm1) + session.add(perm2) + session.commit() + + with pytest.raises(IntegrityError): + with Session(engine) as session: + perm3 = UserPermission(user_id=1, resource_id=1, permission="admin") + session.add(perm3) + session.commit()