Skip to content

Commit ed41865

Browse files
authored
chore(dev): add inequality support to sqlstore where clause (#3272)
# What does this PR do? add the ability to use inequalities in the where clause of the sqlstore. this is infrastructure for files expiration. ## Test Plan unit tests
1 parent 30117de commit ed41865

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
2525
from sqlalchemy.ext.asyncio.engine import AsyncEngine
26+
from sqlalchemy.sql.elements import ColumnElement
2627

2728
from llama_stack.apis.common.responses import PaginatedResponse
2829
from llama_stack.log import get_logger
@@ -43,6 +44,30 @@
4344
}
4445

4546

47+
def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
48+
"""Return a SQLAlchemy expression for a where condition.
49+
50+
`value` may be a simple scalar (equality) or a mapping like {">": 123}.
51+
The returned expression is a SQLAlchemy ColumnElement usable in query.where(...).
52+
"""
53+
if isinstance(value, Mapping):
54+
if len(value) != 1:
55+
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
56+
op, operand = next(iter(value.items()))
57+
if op == "==" or op == "=":
58+
return column == operand
59+
if op == ">":
60+
return column > operand
61+
if op == "<":
62+
return column < operand
63+
if op == ">=":
64+
return column >= operand
65+
if op == "<=":
66+
return column <= operand
67+
raise ValueError(f"Unsupported operator '{op}' in where mapping")
68+
return column == value
69+
70+
4671
class SqlAlchemySqlStoreImpl(SqlStore):
4772
def __init__(self, config: SqlAlchemySqlStoreConfig):
4873
self.config = config
@@ -111,7 +136,7 @@ async def fetch_all(
111136

112137
if where:
113138
for key, value in where.items():
114-
query = query.where(table_obj.c[key] == value)
139+
query = query.where(_build_where_expr(table_obj.c[key], value))
115140

116141
if where_sql:
117142
query = query.where(text(where_sql))
@@ -222,7 +247,7 @@ async def update(
222247
async with self.async_session() as session:
223248
stmt = self.metadata.tables[table].update()
224249
for key, value in where.items():
225-
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
250+
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
226251
await session.execute(stmt, data)
227252
await session.commit()
228253

@@ -233,7 +258,7 @@ async def delete(self, table: str, where: Mapping[str, Any]) -> None:
233258
async with self.async_session() as session:
234259
stmt = self.metadata.tables[table].delete()
235260
for key, value in where.items():
236-
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
261+
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
237262
await session.execute(stmt)
238263
await session.commit()
239264

tests/unit/utils/sqlstore/test_sqlstore.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,63 @@ async def test_sqlstore_pagination_error_handling():
332332
)
333333

334334

335+
async def test_where_operator_gt_and_update_delete():
336+
with TemporaryDirectory() as tmp_dir:
337+
db_path = tmp_dir + "/test.db"
338+
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
339+
340+
await store.create_table(
341+
"items",
342+
{
343+
"id": ColumnType.INTEGER,
344+
"value": ColumnType.INTEGER,
345+
"name": ColumnType.STRING,
346+
},
347+
)
348+
349+
await store.insert("items", {"id": 1, "value": 10, "name": "one"})
350+
await store.insert("items", {"id": 2, "value": 20, "name": "two"})
351+
await store.insert("items", {"id": 3, "value": 30, "name": "three"})
352+
353+
result = await store.fetch_all("items", where={"value": {">": 15}})
354+
assert {r["id"] for r in result.data} == {2, 3}
355+
356+
row = await store.fetch_one("items", where={"value": {">=": 30}})
357+
assert row["id"] == 3
358+
359+
await store.update("items", {"name": "small"}, {"value": {"<": 25}})
360+
rows = (await store.fetch_all("items")).data
361+
names = {r["id"]: r["name"] for r in rows}
362+
assert names[1] == "small"
363+
assert names[2] == "small"
364+
assert names[3] == "three"
365+
366+
await store.delete("items", {"id": {"==": 2}})
367+
rows_after = (await store.fetch_all("items")).data
368+
assert {r["id"] for r in rows_after} == {1, 3}
369+
370+
371+
async def test_where_operator_edge_cases():
372+
with TemporaryDirectory() as tmp_dir:
373+
db_path = tmp_dir + "/test.db"
374+
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
375+
376+
await store.create_table(
377+
"events",
378+
{"id": ColumnType.STRING, "ts": ColumnType.INTEGER},
379+
)
380+
381+
base = 1024
382+
await store.insert("events", {"id": "a", "ts": base - 10})
383+
await store.insert("events", {"id": "b", "ts": base + 10})
384+
385+
row = await store.fetch_one("events", where={"id": "a"})
386+
assert row["id"] == "a"
387+
388+
with pytest.raises(ValueError, match="Unsupported operator"):
389+
await store.fetch_all("events", where={"ts": {"!=": base}})
390+
391+
335392
async def test_sqlstore_pagination_custom_key_column():
336393
"""Test pagination with custom primary key column (not 'id')."""
337394
with TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)