Skip to content

Commit 18a9aaa

Browse files
committed
[sqlalchemy] allow create/update with object for one/many 2 many
For one-2-many and many-2-many relationships, allow the create and update routes to accept a partial object in the foreign key attribute. For example: client.post("/heros", json={ "name": Bob, "team": {"name": "Avengers"} } Assuming there is already a team called Avengers, Bob will be created, the Team with name "Avengers" will be looked up and used to populate Bob's team_id foreign key attribute. The only setup required is for the input model for the foreign object to specify the Table class that can be used to lookup the object. For example: class Team(Base): """Team DTO.""" __tablename__ = "teams" id = Column(Integer, primary_key=True, index=True) name = Column(String, unique=True) class TeamUpdate(Model): name: str class Meta: orm_model = Team
1 parent 66fd32f commit 18a9aaa

File tree

2 files changed

+362
-6
lines changed

2 files changed

+362
-6
lines changed

fastapi_crudrouter/core/sqlalchemy.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Callable, List, Type, Generator, Optional, Union
2+
from collections.abc import Sequence
23

34
from fastapi import Depends, HTTPException
45

@@ -9,10 +10,12 @@
910
from sqlalchemy.orm import Session
1011
from sqlalchemy.ext.declarative import DeclarativeMeta as Model
1112
from sqlalchemy.exc import IntegrityError
13+
from sqlalchemy import column
1214
except ImportError:
1315
Model = None
1416
Session = None
1517
IntegrityError = None
18+
column = None
1619
sqlalchemy_installed = False
1720
else:
1821
sqlalchemy_installed = True
@@ -39,7 +42,7 @@ def __init__(
3942
update_route: Union[bool, DEPENDENCIES] = True,
4043
delete_one_route: Union[bool, DEPENDENCIES] = True,
4144
delete_all_route: Union[bool, DEPENDENCIES] = True,
42-
**kwargs: Any
45+
**kwargs: Any,
4346
) -> None:
4447
assert (
4548
sqlalchemy_installed
@@ -63,7 +66,7 @@ def __init__(
6366
update_route=update_route,
6467
delete_one_route=delete_one_route,
6568
delete_all_route=delete_all_route,
66-
**kwargs
69+
**kwargs,
6770
)
6871

6972
def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
@@ -97,13 +100,49 @@ def route(
97100

98101
return route
99102

103+
def _get_orm_object(self, db: Session, orm_model: Model, model: Model) -> Any:
104+
query = db.query(orm_model)
105+
filter_items = 0
106+
for key, val in model.dict().items():
107+
if val:
108+
filter_items += 1
109+
query = query.filter(column(key) == val)
110+
if filter_items == 0:
111+
raise Exception("No attributes for filter found")
112+
return query.one()
113+
114+
def _get_orm_object_or_value(self, db: Session, val: Any) -> Any:
115+
"""Return an inflated database object or a plain value.
116+
117+
If a `val` is a SqlModel type and has defined a Meta.orm model
118+
attribute, lookup the object from the `db` and return it.
119+
Otherwise, just return the `val`. If `val` is a sequence of
120+
objects, return the sequence of objects from the db.
121+
"""
122+
# we want to iterate through sequences but not strings
123+
if not val or isinstance(val, str):
124+
return val
125+
126+
if isinstance(val, Sequence):
127+
return [self._get_orm_object_or_value(db, v) for v in val]
128+
else:
129+
if meta_class := getattr(val, "Meta", None):
130+
if orm_model := getattr(meta_class, "orm_model", None):
131+
return self._get_orm_object(db, orm_model, val)
132+
return val
133+
100134
def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
101135
def route(
102136
model: self.create_schema, # type: ignore
103137
db: Session = Depends(self.db_func),
104138
) -> Model:
105139
try:
106-
db_model: Model = self.db_model(**model.dict())
140+
db_model: Model = self.db_model()
141+
142+
for key, val in model:
143+
if val:
144+
setattr(db_model, key, self._get_orm_object_or_value(db, val))
145+
107146
db.add(db_model)
108147
db.commit()
109148
db.refresh(db_model)
@@ -123,9 +162,12 @@ def route(
123162
try:
124163
db_model: Model = self._get_one()(item_id, db)
125164

126-
for key, value in model.dict(exclude={self._pk}).items():
127-
if hasattr(db_model, key):
128-
setattr(db_model, key, value)
165+
for key, val in model:
166+
if key != self._pk:
167+
if hasattr(db_model, key):
168+
setattr(
169+
db_model, key, self._get_orm_object_or_value(db, val)
170+
)
129171

130172
db.commit()
131173
db.refresh(db_model)

0 commit comments

Comments
 (0)