@@ -88,41 +88,107 @@ async def _async_insert_and_get_row(
88
88
conn : AsyncConnection ,
89
89
table : sa .Table ,
90
90
values : dict [str , Any ],
91
- pk_col : sa .Column ,
91
+ pk_col : sa .Column | None = None ,
92
92
pk_value : Any | None = None ,
93
+ pk_cols : list [sa .Column ] | None = None ,
94
+ pk_values : list [Any ] | None = None ,
93
95
) -> sa .engine .Row :
94
- result = await conn .execute (table .insert ().values (** values ).returning (pk_col ))
96
+ # Validate parameters
97
+ single_pk_provided = pk_col is not None
98
+ composite_pk_provided = pk_cols is not None
99
+
100
+ if single_pk_provided == composite_pk_provided :
101
+ msg = "Must provide either pk_col or pk_cols, but not both"
102
+ raise ValueError (msg )
103
+
104
+ if composite_pk_provided :
105
+ if pk_values is not None and len (pk_cols ) != len (pk_values ):
106
+ msg = "pk_cols and pk_values must have the same length"
107
+ raise ValueError (msg )
108
+ returning_cols = pk_cols
109
+ else :
110
+ returning_cols = [pk_col ]
111
+
112
+ result = await conn .execute (
113
+ table .insert ().values (** values ).returning (* returning_cols )
114
+ )
95
115
row = result .one ()
96
116
97
- # Get the pk_value from the row if not provided
98
- if pk_value is None :
99
- pk_value = getattr (row , pk_col .name )
117
+ if composite_pk_provided :
118
+ # Handle composite primary keys
119
+ if pk_values is None :
120
+ pk_values = [getattr (row , col .name ) for col in pk_cols ]
121
+ else :
122
+ for col , expected_value in zip (pk_cols , pk_values , strict = True ):
123
+ assert getattr (row , col .name ) == expected_value
124
+
125
+ # Build WHERE clause for composite key
126
+ where_clause = sa .and_ (
127
+ * [col == val for col , val in zip (pk_cols , pk_values , strict = True )]
128
+ )
100
129
else :
101
- # NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
102
- assert getattr (row , pk_col .name ) == pk_value
130
+ # Handle single primary key (existing logic)
131
+ if pk_value is None :
132
+ pk_value = getattr (row , pk_col .name )
133
+ else :
134
+ assert getattr (row , pk_col .name ) == pk_value
135
+
136
+ where_clause = pk_col == pk_value
103
137
104
- result = await conn .execute (sa .select (table ).where (pk_col == pk_value ))
138
+ result = await conn .execute (sa .select (table ).where (where_clause ))
105
139
return result .one ()
106
140
107
141
108
142
def _sync_insert_and_get_row (
109
143
conn : sa .engine .Connection ,
110
144
table : sa .Table ,
111
145
values : dict [str , Any ],
112
- pk_col : sa .Column ,
146
+ pk_col : sa .Column | None = None ,
113
147
pk_value : Any | None = None ,
148
+ pk_cols : list [sa .Column ] | None = None ,
149
+ pk_values : list [Any ] | None = None ,
114
150
) -> sa .engine .Row :
115
- result = conn .execute (table .insert ().values (** values ).returning (pk_col ))
151
+ # Validate parameters
152
+ single_pk_provided = pk_col is not None
153
+ composite_pk_provided = pk_cols is not None
154
+
155
+ if single_pk_provided == composite_pk_provided :
156
+ msg = "Must provide either pk_col or pk_cols, but not both"
157
+ raise ValueError (msg )
158
+
159
+ if composite_pk_provided :
160
+ if pk_values is not None and len (pk_cols ) != len (pk_values ):
161
+ msg = "pk_cols and pk_values must have the same length"
162
+ raise ValueError (msg )
163
+ returning_cols = pk_cols
164
+ else :
165
+ returning_cols = [pk_col ]
166
+
167
+ result = conn .execute (table .insert ().values (** values ).returning (* returning_cols ))
116
168
row = result .one ()
117
169
118
- # Get the pk_value from the row if not provided
119
- if pk_value is None :
120
- pk_value = getattr (row , pk_col .name )
170
+ if composite_pk_provided :
171
+ # Handle composite primary keys
172
+ if pk_values is None :
173
+ pk_values = [getattr (row , col .name ) for col in pk_cols ]
174
+ else :
175
+ for col , expected_value in zip (pk_cols , pk_values , strict = True ):
176
+ assert getattr (row , col .name ) == expected_value
177
+
178
+ # Build WHERE clause for composite key
179
+ where_clause = sa .and_ (
180
+ * [col == val for col , val in zip (pk_cols , pk_values , strict = True )]
181
+ )
121
182
else :
122
- # NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
123
- assert getattr (row , pk_col .name ) == pk_value
183
+ # Handle single primary key (existing logic)
184
+ if pk_value is None :
185
+ pk_value = getattr (row , pk_col .name )
186
+ else :
187
+ assert getattr (row , pk_col .name ) == pk_value
188
+
189
+ where_clause = pk_col == pk_value
124
190
125
- result = conn .execute (sa .select (table ).where (pk_col == pk_value ))
191
+ result = conn .execute (sa .select (table ).where (where_clause ))
126
192
return result .one ()
127
193
128
194
@@ -132,27 +198,135 @@ async def insert_and_get_row_lifespan(
132
198
* ,
133
199
table : sa .Table ,
134
200
values : dict [str , Any ],
135
- pk_col : sa .Column ,
201
+ pk_col : sa .Column | None = None ,
136
202
pk_value : Any | None = None ,
203
+ pk_cols : list [sa .Column ] | None = None ,
204
+ pk_values : list [Any ] | None = None ,
137
205
) -> AsyncIterator [dict [str , Any ]]:
206
+ """
207
+ Context manager that inserts a row into a table and automatically deletes it on exit.
208
+
209
+ Args:
210
+ sqlalchemy_async_engine: Async SQLAlchemy engine
211
+ table: The table to insert into
212
+ values: Dictionary of column values to insert
213
+ pk_col: Primary key column for deletion (for single-column primary keys)
214
+ pk_value: Optional primary key value (if None, will be taken from inserted row)
215
+ pk_cols: List of primary key columns (for composite primary keys)
216
+ pk_values: Optional list of primary key values (if None, will be taken from inserted row)
217
+
218
+ Yields:
219
+ dict: The inserted row as a dictionary
220
+
221
+ Examples:
222
+ ## Single primary key usage:
223
+
224
+ @pytest.fixture
225
+ async def user_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[dict]:
226
+ user_data = random_user(name="test_user", email="[email protected] ")
227
+ async with insert_and_get_row_lifespan(
228
+ asyncpg_engine,
229
+ table=users,
230
+ values=user_data,
231
+ pk_col=users.c.id,
232
+ ) as row:
233
+ yield row
234
+
235
+ ##Composite primary key usage:
236
+
237
+ @pytest.fixture
238
+ async def service_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[dict]:
239
+ service_data = {"key": "simcore/services/comp/test", "version": "1.0.0", "name": "Test Service"}
240
+ async with insert_and_get_row_lifespan(
241
+ asyncpg_engine,
242
+ table=services,
243
+ values=service_data,
244
+ pk_cols=[services.c.key, services.c.version],
245
+ ) as row:
246
+ yield row
247
+
248
+ ##Multiple rows with single primary keys using AsyncExitStack:
249
+
250
+ @pytest.fixture
251
+ async def users_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[list[dict]]:
252
+ users_data = [
253
+ random_user(name="user1", email="[email protected] "),
254
+ random_user(name="user2", email="[email protected] "),
255
+ ]
256
+
257
+ async with AsyncExitStack() as stack:
258
+ created_users = []
259
+ for user_data in users_data:
260
+ row = await stack.enter_async_context(
261
+ insert_and_get_row_lifespan(
262
+ asyncpg_engine,
263
+ table=users,
264
+ values=user_data,
265
+ pk_col=users.c.id,
266
+ )
267
+ )
268
+ created_users.append(row)
269
+
270
+ yield created_users
271
+
272
+ ## Multiple rows with composite primary keys using AsyncExitStack:
273
+
274
+ @pytest.fixture
275
+ async def services_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[list[dict]]:
276
+ services_data = [
277
+ {"key": "simcore/services/comp/service1", "version": "1.0.0", "name": "Service 1"},
278
+ {"key": "simcore/services/comp/service2", "version": "2.0.0", "name": "Service 2"},
279
+ {"key": "simcore/services/comp/service1", "version": "2.0.0", "name": "Service 1 v2"},
280
+ ]
281
+
282
+ async with AsyncExitStack() as stack:
283
+ created_services = []
284
+ for service_data in services_data:
285
+ row = await stack.enter_async_context(
286
+ insert_and_get_row_lifespan(
287
+ asyncpg_engine,
288
+ table=services,
289
+ values=service_data,
290
+ pk_cols=[services.c.key, services.c.version],
291
+ )
292
+ )
293
+ created_services.append(row)
294
+
295
+ yield created_services
296
+ """
138
297
# SETUP: insert & get
139
298
async with sqlalchemy_async_engine .begin () as conn :
140
299
row = await _async_insert_and_get_row (
141
- conn , table = table , values = values , pk_col = pk_col , pk_value = pk_value
300
+ conn ,
301
+ table = table ,
302
+ values = values ,
303
+ pk_col = pk_col ,
304
+ pk_value = pk_value ,
305
+ pk_cols = pk_cols ,
306
+ pk_values = pk_values ,
142
307
)
143
- # If pk_value was None, get it from the row for deletion later
144
- if pk_value is None :
145
- pk_value = getattr (row , pk_col .name )
308
+
309
+ # Get pk values for deletion
310
+ if pk_cols is not None :
311
+ if pk_values is None :
312
+ pk_values = [getattr (row , col .name ) for col in pk_cols ]
313
+ where_clause = sa .and_ (
314
+ * [col == val for col , val in zip (pk_cols , pk_values , strict = True )]
315
+ )
316
+ else :
317
+ if pk_value is None :
318
+ pk_value = getattr (row , pk_col .name )
319
+ where_clause = pk_col == pk_value
146
320
147
321
assert row
148
322
149
323
# NOTE: DO NO USE dict(row) since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
150
324
# pylint: disable=protected-access
151
325
yield row ._asdict ()
152
326
153
- # TEAD-DOWN : delete row
327
+ # TEARDOWN : delete row
154
328
async with sqlalchemy_async_engine .begin () as conn :
155
- await conn .execute (table .delete ().where (pk_col == pk_value ))
329
+ await conn .execute (table .delete ().where (where_clause ))
156
330
157
331
158
332
@contextmanager
@@ -161,23 +335,43 @@ def sync_insert_and_get_row_lifespan(
161
335
* ,
162
336
table : sa .Table ,
163
337
values : dict [str , Any ],
164
- pk_col : sa .Column ,
338
+ pk_col : sa .Column | None = None ,
165
339
pk_value : Any | None = None ,
340
+ pk_cols : list [sa .Column ] | None = None ,
341
+ pk_values : list [Any ] | None = None ,
166
342
) -> Iterator [dict [str , Any ]]:
167
343
"""sync version of insert_and_get_row_lifespan.
168
344
169
345
TIP: more convenient for **module-scope fixtures** that setup the
170
346
database tables before the app starts since it does not require an `event_loop`
171
- fixture (which is funcition-scoped )
347
+ fixture (which is function-scoped)
348
+
349
+ Supports both single and composite primary keys using the same parameter patterns
350
+ as the async version.
172
351
"""
173
352
# SETUP: insert & get
174
353
with sqlalchemy_sync_engine .begin () as conn :
175
354
row = _sync_insert_and_get_row (
176
- conn , table = table , values = values , pk_col = pk_col , pk_value = pk_value
355
+ conn ,
356
+ table = table ,
357
+ values = values ,
358
+ pk_col = pk_col ,
359
+ pk_value = pk_value ,
360
+ pk_cols = pk_cols ,
361
+ pk_values = pk_values ,
177
362
)
178
- # If pk_value was None, get it from the row for deletion later
179
- if pk_value is None :
180
- pk_value = getattr (row , pk_col .name )
363
+
364
+ # Get pk values for deletion
365
+ if pk_cols is not None :
366
+ if pk_values is None :
367
+ pk_values = [getattr (row , col .name ) for col in pk_cols ]
368
+ where_clause = sa .and_ (
369
+ * [col == val for col , val in zip (pk_cols , pk_values , strict = True )]
370
+ )
371
+ else :
372
+ if pk_value is None :
373
+ pk_value = getattr (row , pk_col .name )
374
+ where_clause = pk_col == pk_value
181
375
182
376
assert row
183
377
@@ -187,4 +381,4 @@ def sync_insert_and_get_row_lifespan(
187
381
188
382
# TEARDOWN: delete row
189
383
with sqlalchemy_sync_engine .begin () as conn :
190
- conn .execute (table .delete ().where (pk_col == pk_value ))
384
+ conn .execute (table .delete ().where (where_clause ))
0 commit comments