diff --git a/client/public/locales/en/common.json b/client/public/locales/en/common.json index 88ff2ae85..049923006 100644 --- a/client/public/locales/en/common.json +++ b/client/public/locales/en/common.json @@ -326,7 +326,7 @@ }, "extra_fields": { "tab": "Extra Fields", - "description": "

Here you can add extra custom fields to your entities.

Once a field is added, you can not change its key or type, and for choice type fields you can not remove choices or change the multi choice state. If you remove a field, the associated data for all entities will be deleted.

The key is what other programs read/write the data as, so if your custom field is supposed to integrate with a third-party program, make sure to set it correctly. Default value is only applied to new items.

Extra fields can not be sorted or filtered in the table views.

", + "description": "

Here you can add extra custom fields to your entities.

Once a field is added, you can not change its key or type, and for choice type fields you can not remove choices or change the multi choice state. If you remove a field, the associated data for all entities will be deleted.

The key is what other programs read/write the data as, so if your custom field is supposed to integrate with a third-party program, make sure to set it correctly. Default value is only applied to new items.

Extra fields can be shown and sorted in the table views. Choice and boolean fields expose filter options there as well, and all extra field types support filtering for empty values.

", "params": { "key": "Key", "name": "Name", diff --git a/client/src/components/column.tsx b/client/src/components/column.tsx index 059b607f0..d5a5f7658 100644 --- a/client/src/components/column.tsx +++ b/client/src/components/column.tsx @@ -40,7 +40,7 @@ export interface Action { interface BaseColumnProps { id: string | string[]; - dataId?: keyof Obj & string; + dataId?: (keyof Obj & string) | string; // Allow string values for custom fields i18ncat?: string; i18nkey?: string; title?: string; @@ -98,10 +98,8 @@ function Column( // Sorting if (props.sorter) { columnProps.sorter = true; - columnProps.sortOrder = getSortOrderForField( - typeSorters(props.tableState.sorters), - props.dataId ?? (props.id as keyof Obj), - ); + const sortField = props.dataId ?? (Array.isArray(props.id) ? props.id.join(".") : props.id); + columnProps.sortOrder = getSortOrderForField(typeSorters(props.tableState.sorters), sortField); } // Filter @@ -207,11 +205,12 @@ export function FilteredQueryColumn(props: FilteredQueryColu } filters.push({ text: "", - value: "", + value: "", }); const typedFilters = typeFilters(props.tableState.filters); - const filteredValue = getFiltersForField(typedFilters, props.dataId ?? (props.id as keyof Obj)); + const filterField = props.dataId ?? (Array.isArray(props.id) ? props.id.join(".") : props.id); + const filteredValue = getFiltersForField(typedFilters, filterField); const onFilterDropdownOpen = () => { query.refetch(); @@ -325,7 +324,8 @@ export function SpoolIconColumn(props: SpoolIconColumnProps< }); const typedFilters = typeFilters(props.tableState.filters); - const filteredValue = getFiltersForField(typedFilters, props.dataId ?? (props.id as keyof Obj)); + const filterField = props.dataId ?? (Array.isArray(props.id) ? props.id.join(".") : props.id); + const filteredValue = getFiltersForField(typedFilters, filterField); const onFilterDropdownOpen = () => { query.refetch(); @@ -389,13 +389,53 @@ export function NumberRangeColumn(props: NumberColumnProps { + filters.push({ + text: choice, + value: `"${choice}"`, // Exact match + }); + }); + } + + // For boolean fields, add true/false options + if (field.field_type === FieldType.boolean) { + filters.push({ text: "Yes", value: "true" }, { text: "No", value: "false" }); + } + + // Add empty option for all field types + filters.push({ + text: "", + value: "", + }); + + return filters; +} + export function CustomFieldColumn(props: Omit, "id"> & { field: Field }) { const field = props.field; + const fieldId = `extra.${field.key}`; + + // Get filtered values for this field + const typedFilters = typeFilters(props.tableState.filters); + const filteredValue = getFiltersForField(typedFilters, fieldId); + + // Create filters based on field type + const filters = createCustomFieldFilters(field); + const commonProps = { ...props, id: ["extra", field.key], title: field.name, - sorter: false, + sorter: true, // Enable sorting for custom fields + dataId: fieldId, // Set the dataId for sorting + filters: filters, // Add filters + filteredValue: filteredValue, // Set filtered values transform: (value: unknown) => { if (value === null || value === undefined) { return undefined; diff --git a/client/src/components/dataProvider.ts b/client/src/components/dataProvider.ts index 95efd97a9..7bfa36978 100644 --- a/client/src/components/dataProvider.ts +++ b/client/src/components/dataProvider.ts @@ -2,6 +2,8 @@ import { DataProvider } from "@refinedev/core"; import { axiosInstance } from "@refinedev/simple-rest"; import { AxiosInstance } from "axios"; import { stringify } from "query-string"; +import { getCustomFieldFilters } from "../utils/filtering"; +import { isCustomField } from "../utils/queryFields"; type MethodTypes = "get" | "delete" | "head" | "options"; type MethodTypesWithBody = "post" | "put" | "patch"; @@ -25,20 +27,30 @@ const dataProvider = ( } if (sorters && sorters.length > 0) { + // Map all sorters, including custom field sorters queryParams["sort"] = sorters .map((sort) => { const field = sort.field; + // Custom field sorters are already in the correct format (extra.field_key) return `${field}:${sort.order}`; }) .join(","); } if (filters && filters.length > 0) { + // Process regular filters filters.forEach((filter) => { if (!("field" in filter)) { throw Error("Filter must be a LogicalFilter."); } + const field = filter.field; + + // Skip custom fields, they'll be handled separately + if (typeof field === "string" && isCustomField(field)) { + return; + } + if (filter.value.length > 0) { const filterValueArray = Array.isArray(filter.value) ? filter.value : [filter.value]; @@ -54,6 +66,14 @@ const dataProvider = ( queryParams[field] = filterValue; } }); + + // Process custom field filters + const customFieldFilters = getCustomFieldFilters(filters); + Object.entries(customFieldFilters).forEach(([key, values]) => { + if (values.length > 0) { + queryParams[`extra.${key}`] = values.map((value) => (value === "" ? "" : value)).join(","); + } + }); } const { data, headers } = await httpClient[requestMethod](`${url}`, { diff --git a/client/src/utils/filtering.ts b/client/src/utils/filtering.ts index da99b43c4..6acf9a8b3 100644 --- a/client/src/utils/filtering.ts +++ b/client/src/utils/filtering.ts @@ -1,7 +1,8 @@ import { CrudFilter, CrudOperators } from "@refinedev/core"; +import { Field, FieldType, getCustomFieldKey, isCustomField } from "./queryFields"; interface TypedCrudFilter { - field: keyof Obj; + field: keyof Obj | string; operator: Exclude; value: string[]; } @@ -16,10 +17,7 @@ export function typeFilters(filters: CrudFilter[]): TypedCrudFilter[] * @param field The field to get the filter values for. * @returns An array of filter values for the given field. */ -export function getFiltersForField( - filters: TypedCrudFilter[], - field: Field, -): string[] { +export function getFiltersForField(filters: TypedCrudFilter[], field: Field | string): string[] { const filterValues: string[] = []; filters.forEach((filter) => { if (filter.field === field) { @@ -29,6 +27,89 @@ export function getFiltersForField( return filterValues; } +/** + * Creates a filter value for a custom field based on its type + * @param field The custom field definition + * @param value The value to filter by + * @returns The formatted filter value + */ +type CustomFieldFilterValue = + | string + | number + | boolean + | Date + | [number | null | undefined, number | null | undefined] + | null + | undefined; + +export function formatCustomFieldFilterValue(field: Field, value: CustomFieldFilterValue): string { + switch (field.field_type) { + case FieldType.text: + case FieldType.choice: + // For text and choice fields, we can use the value directly + // If it's an exact match, surround with quotes + if (typeof value === "string" && !value.startsWith('"') && !value.endsWith('"')) { + // Check if we need an exact match (no wildcards) + if (!value.includes("*") && !value.includes("?")) { + return `"${value}"`; + } + } + return value == null ? "" : String(value); + + case FieldType.integer: + case FieldType.float: + // For numeric fields, we can use the value directly + return value == null ? "" : value.toString(); + + case FieldType.boolean: + // For boolean fields, convert to "true" or "false" + return value ? "true" : "false"; + + case FieldType.datetime: + // For datetime fields, format as ISO string + if (value instanceof Date) { + return value.toISOString(); + } + return value == null ? "" : String(value); + + case FieldType.integer_range: + case FieldType.float_range: + // For range fields, format as min:max + if (Array.isArray(value) && value.length === 2) { + return `${value[0] ?? ""}:${value[1] ?? ""}`; + } + return value == null ? "" : String(value); + + default: + return value == null ? "" : String(value); + } +} + +/** + * Extracts all custom field filters from a list of filters + * @param filters The list of filters + * @returns An object with custom field keys and their filter values + */ +export function getCustomFieldFilters( + filters: CrudFilter[] | TypedCrudFilter[], +): Record { + const customFieldFilters: Record = {}; + + filters.forEach((filter) => { + if (!("field" in filter)) { + return; // Skip non-field filters + } + + const field = filter.field.toString(); + if (isCustomField(field)) { + const key = getCustomFieldKey(field); + customFieldFilters[key] = filter.value as string[]; + } + }); + + return customFieldFilters; +} + /** * Function that returns an array with all undefined values removed. */ diff --git a/client/src/utils/queryFields.ts b/client/src/utils/queryFields.ts index 7cde38c05..f7a048a74 100644 --- a/client/src/utils/queryFields.ts +++ b/client/src/utils/queryFields.ts @@ -110,6 +110,24 @@ export function useSetField(entity_type: EntityType) { }); } +/** + * Checks if a field is a custom field (starts with "extra.") + * @param field The field to check + * @returns True if the field is a custom field + */ +export function isCustomField(field: string): boolean { + return field.startsWith("extra."); +} + +/** + * Extracts the key from a custom field (removes the "extra." prefix) + * @param field The custom field + * @returns The key of the custom field + */ +export function getCustomFieldKey(field: string): string { + return field.substring(6); // Remove "extra." prefix +} + export function useDeleteField(entity_type: EntityType) { const queryClient = useQueryClient(); diff --git a/client/src/utils/sorting.ts b/client/src/utils/sorting.ts index 543d72546..27ee36b3d 100644 --- a/client/src/utils/sorting.ts +++ b/client/src/utils/sorting.ts @@ -1,8 +1,9 @@ import { CrudSort } from "@refinedev/core"; import { SortOrder } from "antd/es/table/interface"; +import { Field, getCustomFieldKey, isCustomField } from "./queryFields"; interface TypedCrudSort { - field: keyof Obj; + field: keyof Obj | string; order: "asc" | "desc"; } @@ -12,10 +13,7 @@ interface TypedCrudSort { * @param field The field to get the sort order for. * @returns The sort order for the given field, or undefined if the field is not being sorted. */ -export function getSortOrderForField( - sorters: TypedCrudSort[], - field: Field, -): SortOrder | undefined { +export function getSortOrderForField(sorters: TypedCrudSort[], field: Field | string): SortOrder | undefined { const sorter = sorters.find((s) => s.field === field); if (sorter) { return sorter.order === "asc" ? "ascend" : "descend"; @@ -26,3 +24,33 @@ export function getSortOrderForField( export function typeSorters(sorters: CrudSort[]): TypedCrudSort[] { return sorters as TypedCrudSort[]; // <-- Unsafe cast } + +/** + * Checks if a sorter is for a custom field + * @param sorter The sorter to check + * @returns True if the sorter is for a custom field + */ +export function isCustomFieldSorter(sorter: TypedCrudSort | CrudSort): boolean { + return typeof sorter.field === "string" && isCustomField(sorter.field); +} + +/** + * Extracts all custom field sorters from a list of sorters + * @param sorters The list of sorters + * @returns An object with custom field keys and their sort orders + */ +export function getCustomFieldSorters( + sorters: TypedCrudSort[] | CrudSort[], +): Record { + const customFieldSorters: Record = {}; + + sorters.forEach((sorter) => { + if (isCustomFieldSorter(sorter)) { + const field = sorter.field.toString(); + const key = getCustomFieldKey(field); + customFieldSorters[key] = sorter.order; + } + }); + + return customFieldSorters; +} diff --git a/spoolman/api/v1/filament.py b/spoolman/api/v1/filament.py index 3e3f859af..0549f5dbc 100644 --- a/spoolman/api/v1/filament.py +++ b/spoolman/api/v1/filament.py @@ -4,7 +4,7 @@ import logging from typing import Annotated -from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, field_validator, model_validator @@ -201,6 +201,7 @@ def prevent_none(cls: type["FilamentUpdateParameters"], v: float | None) -> floa ) async def find( *, + request: Request, db: Annotated[AsyncSession, Depends(get_db_session)], vendor_name_old: Annotated[ str | None, @@ -342,19 +343,31 @@ async def find( else: filter_by_ids = None - db_items, total_count = await filament.find( - db=db, - ids=filter_by_ids, - vendor_name=vendor_name if vendor_name is not None else vendor_name_old, - vendor_id=vendor_ids, - name=name, - material=material, - article_number=article_number, - external_id=external_id, - sort_by=sort_by, - limit=limit, - offset=offset, - ) + # Extract custom field filters from query parameters + extra_field_filters = {} + query_params = request.query_params + for key, value in query_params.items(): + if key.startswith("extra."): + field_key = key[6:] # Remove "extra." prefix + extra_field_filters[field_key] = value + + try: + db_items, total_count = await filament.find( + db=db, + ids=filter_by_ids, + vendor_name=vendor_name if vendor_name is not None else vendor_name_old, + vendor_id=vendor_ids, + name=name, + material=material, + article_number=article_number, + external_id=external_id, + extra_field_filters=extra_field_filters if extra_field_filters else None, + sort_by=sort_by, + limit=limit, + offset=offset, + ) + except ValueError as e: + return JSONResponse(status_code=400, content=Message(message=str(e)).dict()) # Set x-total-count header for pagination return JSONResponse( diff --git a/spoolman/api/v1/spool.py b/spoolman/api/v1/spool.py index 8f667e3da..fa86b385b 100644 --- a/spoolman/api/v1/spool.py +++ b/spoolman/api/v1/spool.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Annotated -from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, field_validator @@ -127,6 +127,7 @@ class SpoolMeasureParameters(BaseModel): ) async def find( *, + request: Request, db: Annotated[AsyncSession, Depends(get_db_session)], filament_name_old: Annotated[ str | None, @@ -285,20 +286,32 @@ async def find( else: filament_vendor_ids = None - db_items, total_count = await spool.find( - db=db, - filament_name=filament_name if filament_name is not None else filament_name_old, - filament_id=filament_ids, - filament_material=filament_material if filament_material is not None else filament_material_old, - vendor_name=filament_vendor_name if filament_vendor_name is not None else vendor_name_old, - vendor_id=filament_vendor_ids, - location=location, - lot_nr=lot_nr, - allow_archived=allow_archived, - sort_by=sort_by, - limit=limit, - offset=offset, - ) + # Extract custom field filters from query parameters + extra_field_filters = {} + query_params = request.query_params + for key, value in query_params.items(): + if key.startswith("extra."): + field_key = key[6:] # Remove "extra." prefix + extra_field_filters[field_key] = value + + try: + db_items, total_count = await spool.find( + db=db, + filament_name=filament_name if filament_name is not None else filament_name_old, + filament_id=filament_ids, + filament_material=filament_material if filament_material is not None else filament_material_old, + vendor_name=filament_vendor_name if filament_vendor_name is not None else vendor_name_old, + vendor_id=filament_vendor_ids, + location=location, + lot_nr=lot_nr, + allow_archived=allow_archived, + extra_field_filters=extra_field_filters if extra_field_filters else None, + sort_by=sort_by, + limit=limit, + offset=offset, + ) + except ValueError as e: + return JSONResponse(status_code=400, content=Message(message=str(e)).dict()) # Set x-total-count header for pagination return JSONResponse( diff --git a/spoolman/api/v1/vendor.py b/spoolman/api/v1/vendor.py index 9216fba30..f9395a004 100644 --- a/spoolman/api/v1/vendor.py +++ b/spoolman/api/v1/vendor.py @@ -3,7 +3,7 @@ import asyncio from typing import Annotated -from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, field_validator @@ -79,6 +79,7 @@ def prevent_none(cls: type["VendorUpdateParameters"], v: str | None) -> str | No }, ) async def find( + request: Request, db: Annotated[AsyncSession, Depends(get_db_session)], name: Annotated[ str | None, @@ -124,14 +125,27 @@ async def find( field, direction = sort_item.split(":") sort_by[field] = SortOrder[direction.upper()] - db_items, total_count = await vendor.find( - db=db, - name=name, - external_id=external_id, - sort_by=sort_by, - limit=limit, - offset=offset, - ) + # Extract custom field filters from query parameters + extra_field_filters = {} + query_params = request.query_params + for key, value in query_params.items(): + if key.startswith("extra."): + field_key = key[6:] # Remove "extra." prefix + extra_field_filters[field_key] = value + + try: + db_items, total_count = await vendor.find( + db=db, + name=name, + external_id=external_id, + extra_field_filters=extra_field_filters if extra_field_filters else None, + sort_by=sort_by, + limit=limit, + offset=offset, + ) + except ValueError as e: + return JSONResponse(status_code=400, content=Message(message=str(e)).dict()) + # Set x-total-count header for pagination return JSONResponse( content=jsonable_encoder( diff --git a/spoolman/database/extra_field_query.py b/spoolman/database/extra_field_query.py new file mode 100644 index 000000000..43e658cb9 --- /dev/null +++ b/spoolman/database/extra_field_query.py @@ -0,0 +1,246 @@ +"""Helpers for filtering and sorting extra fields.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +import sqlalchemy +from sqlalchemy import Select + +from spoolman.database import models +from spoolman.database.utils import SortOrder +from spoolman.extra_field_registry import EntityType, ExtraField, ExtraFieldType, get_extra_fields + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.orm.attributes import InstrumentedAttribute + + +def _get_field_table_for_entity(entity_type: EntityType) -> type[models.Base]: + """Map an entity type to its extra-field table.""" + if entity_type == EntityType.spool: + return models.SpoolField + if entity_type == EntityType.filament: + return models.FilamentField + if entity_type == EntityType.vendor: + return models.VendorField + raise ValueError(f"Unknown entity type: {entity_type}") + + +def _get_entity_id_column(field_table: type[models.Base]) -> InstrumentedAttribute[int]: + """Map an extra-field table to its owning entity id column.""" + if field_table == models.SpoolField: + return models.SpoolField.spool_id + if field_table == models.FilamentField: + return models.FilamentField.filament_id + if field_table == models.VendorField: + return models.VendorField.vendor_id + raise ValueError(f"Unknown field table: {field_table}") + + +def _parse_boolean_filter(value: str) -> bool: + """Parse a boolean filter using explicit true/false tokens only.""" + normalized = value.strip().lower() + if normalized == "true": + return True + if normalized == "false": + return False + raise ValueError(f"Invalid boolean filter value: {value}") + + +async def apply_extra_field_filters_and_sort( + *, + db: AsyncSession, + stmt: Select, + base_obj: type[models.Base], + entity_type: EntityType, + extra_field_filters: dict[str, str] | None, + sort_by: dict[str, SortOrder] | None, +) -> Select: + """Apply extra-field filtering and sorting to a query.""" + if not extra_field_filters and not (sort_by is not None and any(field.startswith("extra.") for field in sort_by)): + return stmt + + extra_fields = await get_extra_fields(db, entity_type) + extra_fields_dict: dict[str, ExtraField] = {field.key: field for field in extra_fields} + + if extra_field_filters: + for field_key, value in extra_field_filters.items(): + field = extra_fields_dict.get(field_key) + if field is None: + continue + stmt = add_where_clause_extra_field( + stmt=stmt, + base_obj=base_obj, + entity_type=entity_type, + field_key=field_key, + field_type=field.field_type, + value=value, + multi_choice=field.multi_choice if field.field_type == ExtraFieldType.choice else None, + ) + + if sort_by is not None: + for field_name, order in sort_by.items(): + if not field_name.startswith("extra."): + continue + + field_key = field_name[6:] + extra_field = extra_fields_dict.get(field_key) + if extra_field is None: + continue + + stmt = add_order_by_extra_field( + stmt=stmt, + base_obj=base_obj, + entity_type=entity_type, + field_key=field_key, + field_type=extra_field.field_type, + order=order, + ) + + return stmt + + +def add_where_clause_extra_field( # noqa: C901, PLR0912, PLR0915 + stmt: Select, + base_obj: type[models.Base], + entity_type: EntityType, + field_key: str, + field_type: ExtraFieldType, + value: str, + *, + multi_choice: bool | None = None, +) -> Select: + """Add a where clause to a select statement for an extra field.""" + field_table = _get_field_table_for_entity(entity_type) + entity_id_column = _get_entity_id_column(field_table) + base_id_column = base_obj.id + + conditions = [] + for value_part in value.split(","): + # Empty-string filters follow the existing string-query API semantics. + if len(value_part) == 0: + empty_conditions = [ + field_table.value.is_(None), + field_table.value == "null", + ] + if field_type == ExtraFieldType.boolean: + empty_conditions.append(field_table.value == json.dumps(bool(0))) + + field_has_empty_value = sqlalchemy.select(entity_id_column).where( + sqlalchemy.and_(field_table.key == field_key, sqlalchemy.or_(*empty_conditions)) + ) + field_missing_entirely = sqlalchemy.select(base_id_column).where( + base_id_column.not_in(sqlalchemy.select(entity_id_column).where(field_table.key == field_key)) + ) + conditions.append(base_id_column.in_(field_has_empty_value)) + conditions.append(base_id_column.in_(field_missing_entirely)) + continue + + exact_match = value_part.startswith('"') and value_part.endswith('"') + parsed_value = value_part[1:-1] if exact_match else value_part + + if field_type == ExtraFieldType.text: + field_condition = ( + field_table.value == json.dumps(parsed_value) + if exact_match + else field_table.value.ilike(f"%{parsed_value}%") + ) + elif field_type == ExtraFieldType.integer: + try: + field_condition = field_table.value == json.dumps(int(parsed_value)) + except ValueError as exc: + raise ValueError(f"Invalid integer filter value for '{field_key}': {parsed_value}") from exc + elif field_type == ExtraFieldType.float: + try: + field_condition = field_table.value == json.dumps(float(parsed_value)) + except ValueError as exc: + raise ValueError(f"Invalid float filter value for '{field_key}': {parsed_value}") from exc + elif field_type == ExtraFieldType.boolean: + field_condition = field_table.value == json.dumps(_parse_boolean_filter(parsed_value)) + elif field_type == ExtraFieldType.choice: + if multi_choice: + field_condition = field_table.value.like(f'%"{parsed_value}"%') + else: + field_condition = field_table.value == json.dumps(parsed_value) + elif field_type == ExtraFieldType.datetime: + field_condition = field_table.value == json.dumps(parsed_value) + elif field_type in (ExtraFieldType.integer_range, ExtraFieldType.float_range): + if ":" not in parsed_value: + raise ValueError( + f"Invalid range filter value for '{field_key}': {parsed_value}. Expected ':'." + ) + min_val_str, max_val_str = parsed_value.split(":", 1) + converter = int if field_type == ExtraFieldType.integer_range else float + range_conditions = [] + try: + if min_val_str: + range_field = field_table.value[0] + cast_type = sqlalchemy.Integer if field_type == ExtraFieldType.integer_range else sqlalchemy.Float + range_conditions.append(sqlalchemy.cast(range_field, cast_type) >= converter(min_val_str)) + if max_val_str: + range_field = field_table.value[1] + cast_type = sqlalchemy.Integer if field_type == ExtraFieldType.integer_range else sqlalchemy.Float + range_conditions.append(sqlalchemy.cast(range_field, cast_type) <= converter(max_val_str)) + except (ValueError, TypeError) as exc: + range_kind = "integer" if field_type == ExtraFieldType.integer_range else "float" + raise ValueError(f"Invalid {range_kind} range filter value for '{field_key}': {parsed_value}") from exc + if not range_conditions: + raise ValueError( + f"Invalid range filter value for '{field_key}': {parsed_value}. Expected ':'." + ) + field_condition = sqlalchemy.and_(*range_conditions) + else: + raise ValueError(f"Unsupported extra field type for '{field_key}': {field_type}") + + matching_entities = sqlalchemy.select(entity_id_column).where( + sqlalchemy.and_(field_table.key == field_key, field_condition) + ) + conditions.append(base_id_column.in_(matching_entities)) + + if not conditions: + return stmt + + return stmt.where(sqlalchemy.or_(*conditions)) + + +def add_order_by_extra_field( + stmt: Select, + base_obj: type[models.Base], + entity_type: EntityType, + field_key: str, + field_type: ExtraFieldType, + order: SortOrder, +) -> Select: + """Add an order-by clause to a select statement for an extra field.""" + field_table = _get_field_table_for_entity(entity_type) + entity_id_column = _get_entity_id_column(field_table) + + value_subq = ( + sqlalchemy.select(field_table.value) + .where( + sqlalchemy.and_( + field_table.key == field_key, + entity_id_column == base_obj.id, + ) + ) + .scalar_subquery() + .correlate(base_obj) + ) + + if field_type == ExtraFieldType.integer: + sort_expr = sqlalchemy.cast(value_subq, sqlalchemy.Integer) + elif field_type == ExtraFieldType.float: + sort_expr = sqlalchemy.cast(value_subq, sqlalchemy.Float) + elif field_type in (ExtraFieldType.integer_range, ExtraFieldType.float_range): + sort_expr = sqlalchemy.cast( + value_subq[0], + sqlalchemy.Integer if field_type == ExtraFieldType.integer_range else sqlalchemy.Float, + ) + else: + sort_expr = value_subq + + if order == SortOrder.ASC: + return stmt.order_by(sort_expr.asc()) + return stmt.order_by(sort_expr.desc()) diff --git a/spoolman/database/filament.py b/spoolman/database/filament.py index e2d742758..7c8548b80 100644 --- a/spoolman/database/filament.py +++ b/spoolman/database/filament.py @@ -12,6 +12,7 @@ from spoolman.api.v1.models import EventType, Filament, FilamentEvent, MultiColorDirection from spoolman.database import models, vendor +from spoolman.database.extra_field_query import apply_extra_field_filters_and_sort from spoolman.database.utils import ( SortOrder, add_where_clause_int_in, @@ -21,6 +22,7 @@ parse_nested_field, ) from spoolman.exceptions import ItemDeleteError, ItemNotFoundError +from spoolman.extra_field_registry import EntityType from spoolman.math import delta_e, hex_to_rgb, rgb_to_lab from spoolman.ws import websocket_manager @@ -102,6 +104,7 @@ async def find( material: str | None = None, article_number: str | None = None, external_id: str | None = None, + extra_field_filters: dict[str, str] | None = None, sort_by: dict[str, SortOrder] | None = None, limit: int | None = None, offset: int = 0, @@ -129,20 +132,32 @@ async def find( total_count = None - if limit is not None: - total_count_stmt = stmt.with_only_columns(func.count(), maintain_column_froms=True) - total_count = (await db.execute(total_count_stmt)).scalar() - - stmt = stmt.offset(offset).limit(limit) + stmt = await apply_extra_field_filters_and_sort( + db=db, + stmt=stmt, + base_obj=models.Filament, + entity_type=EntityType.filament, + extra_field_filters=extra_field_filters, + sort_by=sort_by, + ) if sort_by is not None: for fieldstr, order in sort_by.items(): + # Check if this is a custom field sort + if fieldstr.startswith("extra."): + continue + field = parse_nested_field(models.Filament, fieldstr) if order == SortOrder.ASC: stmt = stmt.order_by(field.asc()) elif order == SortOrder.DESC: stmt = stmt.order_by(field.desc()) + if limit is not None: + total_count_stmt = stmt.with_only_columns(func.count(), maintain_column_froms=True).order_by(None) + total_count = (await db.execute(total_count_stmt)).scalar() + stmt = stmt.offset(offset).limit(limit) + rows = await db.execute( stmt, execution_options={"populate_existing": True}, diff --git a/spoolman/database/spool.py b/spoolman/database/spool.py index 5c190ce65..f1bb25954 100644 --- a/spoolman/database/spool.py +++ b/spoolman/database/spool.py @@ -13,6 +13,7 @@ from spoolman.api.v1.models import EventType, Spool, SpoolEvent from spoolman.database import filament, models +from spoolman.database.extra_field_query import apply_extra_field_filters_and_sort from spoolman.database.utils import ( SortOrder, add_where_clause_int, @@ -22,6 +23,7 @@ parse_nested_field, ) from spoolman.exceptions import ItemCreateError, ItemNotFoundError, SpoolMeasureError +from spoolman.extra_field_registry import EntityType from spoolman.math import weight_from_length from spoolman.ws import websocket_manager @@ -122,6 +124,7 @@ async def find( # noqa: C901, PLR0912 location: str | None = None, lot_nr: str | None = None, allow_archived: bool = False, + extra_field_filters: dict[str, str] | None = None, sort_by: dict[str, SortOrder] | None = None, limit: int | None = None, offset: int = 0, @@ -159,20 +162,29 @@ async def find( # noqa: C901, PLR0912 total_count = None - if limit is not None: - total_count_stmt = stmt.with_only_columns(func.count(), maintain_column_froms=True) - total_count = (await db.execute(total_count_stmt)).scalar() - - stmt = stmt.offset(offset).limit(limit) + stmt = await apply_extra_field_filters_and_sort( + db=db, + stmt=stmt, + base_obj=models.Spool, + entity_type=EntityType.spool, + extra_field_filters=extra_field_filters, + sort_by=sort_by, + ) if sort_by is not None: for fieldstr, order in sort_by.items(): + # Check if this is a custom field sort + if fieldstr.startswith("extra."): + continue + sorts = [] if fieldstr == "remaining_weight": - sorts.append(coalesce(models.Spool.initial_weight, models.Filament.weight) - models.Spool.used_weight) + sorts.append( + coalesce(models.Spool.initial_weight, models.Filament.weight) - models.Spool.used_weight, + ) elif fieldstr == "remaining_length": - # Simplified weight -> length formula. Absolute value is not correct but the proportionality is still - # kept, which means the sort order is correct. + # Simplified weight -> length formula. Absolute value is not correct but the proportionality + # is still kept, which means the sort order is correct. sorts.append( (coalesce(models.Spool.initial_weight, models.Filament.weight) - models.Spool.used_weight) / models.Filament.density @@ -197,6 +209,11 @@ async def find( # noqa: C901, PLR0912 elif order == SortOrder.DESC: stmt = stmt.order_by(*(f.desc() for f in sorts)) + if limit is not None: + total_count_stmt = stmt.with_only_columns(func.count(), maintain_column_froms=True).order_by(None) + total_count = (await db.execute(total_count_stmt)).scalar() + stmt = stmt.offset(offset).limit(limit) + rows = await db.execute( stmt, execution_options={"populate_existing": True}, diff --git a/spoolman/database/vendor.py b/spoolman/database/vendor.py index f2e83018e..f3b3358b9 100644 --- a/spoolman/database/vendor.py +++ b/spoolman/database/vendor.py @@ -9,8 +9,10 @@ from spoolman.api.v1.models import EventType, Vendor, VendorEvent from spoolman.database import models +from spoolman.database.extra_field_query import apply_extra_field_filters_and_sort from spoolman.database.utils import SortOrder, add_where_clause_str, add_where_clause_str_opt from spoolman.exceptions import ItemNotFoundError +from spoolman.extra_field_registry import EntityType from spoolman.ws import websocket_manager logger = logging.getLogger(__name__) @@ -53,6 +55,7 @@ async def find( db: AsyncSession, name: str | None = None, external_id: str | None = None, + extra_field_filters: dict[str, str] | None = None, sort_by: dict[str, SortOrder] | None = None, limit: int | None = None, offset: int = 0, @@ -68,20 +71,32 @@ async def find( total_count = None - if limit is not None: - total_count_stmt = stmt.with_only_columns(func.count(), maintain_column_froms=True) - total_count = (await db.execute(total_count_stmt)).scalar() - - stmt = stmt.offset(offset).limit(limit) + stmt = await apply_extra_field_filters_and_sort( + db=db, + stmt=stmt, + base_obj=models.Vendor, + entity_type=EntityType.vendor, + extra_field_filters=extra_field_filters, + sort_by=sort_by, + ) if sort_by is not None: for fieldstr, order in sort_by.items(): + # Check if this is a custom field sort + if fieldstr.startswith("extra."): + continue + field = getattr(models.Vendor, fieldstr) if order == SortOrder.ASC: stmt = stmt.order_by(field.asc()) elif order == SortOrder.DESC: stmt = stmt.order_by(field.desc()) + if limit is not None: + total_count_stmt = stmt.with_only_columns(func.count(), maintain_column_froms=True).order_by(None) + total_count = (await db.execute(total_count_stmt)).scalar() + stmt = stmt.offset(offset).limit(limit) + rows = await db.execute( stmt, execution_options={"populate_existing": True}, diff --git a/spoolman/extra_field_registry.py b/spoolman/extra_field_registry.py new file mode 100644 index 000000000..173b2420a --- /dev/null +++ b/spoolman/extra_field_registry.py @@ -0,0 +1,196 @@ +"""Shared extra-field definitions and settings access.""" + +from __future__ import annotations + +import json +import logging +from enum import Enum +from typing import TYPE_CHECKING + +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel, Field + +from spoolman.database import setting as db_setting +from spoolman.exceptions import ItemNotFoundError +from spoolman.settings import parse_setting + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + +logger = logging.getLogger(__name__) + + +class EntityType(Enum): + vendor = "vendor" + filament = "filament" + spool = "spool" + + +class ExtraFieldType(Enum): + text = "text" + integer = "integer" + integer_range = "integer_range" + float = "float" + float_range = "float_range" + datetime = "datetime" + boolean = "boolean" + choice = "choice" + + +class ExtraFieldParameters(BaseModel): + name: str = Field(description="Nice name", min_length=1, max_length=128) + order: int = Field(0, description="Order of the field") + unit: str | None = Field(None, description="Unit of the value", min_length=1, max_length=16) + field_type: ExtraFieldType = Field(description="Type of the field") + default_value: str | None = Field(None, description="Default value of the field") + choices: list[str] | None = Field( + None, + description="Choices for the field, only for field type choice", + min_length=1, + ) + multi_choice: bool | None = Field(None, description="Whether multiple choices can be selected") + + +class ExtraField(ExtraFieldParameters): + key: str = Field(description="Unique key", pattern="^[a-z0-9_]+$", min_length=1, max_length=64) + entity_type: EntityType = Field(description="Entity type this field is for") + + +def validate_extra_field_value(field: ExtraFieldParameters, value: str) -> None: # noqa: C901, PLR0912 + """Validate that the value has the correct type.""" + try: + data = json.loads(value) + except json.JSONDecodeError: + raise ValueError("Value is not valid JSON.") from None + + if field.field_type == ExtraFieldType.text: + if not isinstance(data, str): + raise ValueError("Value is not a string.") + elif field.field_type == ExtraFieldType.integer: + if not isinstance(data, int): + raise ValueError("Value is not an integer.") + elif field.field_type == ExtraFieldType.integer_range: + if not isinstance(data, list): + raise ValueError("Value is not a list.") + if len(data) != 2: # noqa: PLR2004 + raise ValueError("Value list must have exactly two values.") + if not all(isinstance(item, int) or item is None for item in data): + raise ValueError("Value list must contain only integers or null.") + elif field.field_type == ExtraFieldType.float: + if not isinstance(data, (float, int)) or isinstance(data, bool): + raise ValueError("Value is not a float.") + elif field.field_type == ExtraFieldType.float_range: + if not isinstance(data, list): + raise ValueError("Value is not a list.") + if len(data) != 2: # noqa: PLR2004 + raise ValueError("Value list must have exactly two values.") + if not all((isinstance(item, (float, int)) or item is None) and not isinstance(item, bool) for item in data): + raise ValueError("Value list must contain only floats or null.") + elif field.field_type == ExtraFieldType.datetime: + if not isinstance(data, str): + raise ValueError("Value is not a string.") + elif field.field_type == ExtraFieldType.boolean: + if not isinstance(data, bool): + raise ValueError("Value is not a boolean.") + elif field.field_type == ExtraFieldType.choice: + if field.multi_choice: + if not isinstance(data, list): + raise ValueError("Value is not a list.") + if not all(isinstance(item, str) for item in data): + raise ValueError("Value list must contain only strings.") + if field.choices is not None and not all(item in field.choices for item in data): + raise ValueError("Value list contains invalid choices.") + else: + if not isinstance(data, str): + raise ValueError("Value is not a string.") + if field.choices is not None and data not in field.choices: + raise ValueError("Value is not a valid choice.") + else: + raise ValueError(f"Unknown field type {field.field_type}.") + + +def validate_extra_field(field: ExtraFieldParameters) -> None: + """Validate an extra field.""" + if field.field_type == ExtraFieldType.choice: + if field.choices is None: + raise ValueError("Choices must be set for field type choice.") + if field.multi_choice is None: + raise ValueError("Multi choice must be set for field type choice.") + else: + if field.choices is not None: + raise ValueError("Choices must not be set for field type other than choice.") + if field.multi_choice is not None: + raise ValueError("Multi choice must not be set for field type other than choice.") + + if field.default_value is not None: + try: + validate_extra_field_value(field, field.default_value) + except ValueError as e: + raise ValueError(f"Default value is not valid: {e}") from None + + +def validate_extra_field_dict(all_fields: list[ExtraField], fields_input: dict[str, str]) -> None: + """Validate a dict of extra fields.""" + all_field_lookup = {field.key: field for field in all_fields} + for key, value in fields_input.items(): + if key not in all_field_lookup: + raise ValueError(f"Unknown extra field {key}.") + field = all_field_lookup[key] + try: + validate_extra_field_value(field, value) + except ValueError as e: + raise ValueError(f"Invalid extra field for key {key}: {e!s}") from None + + +extra_field_cache: dict[EntityType, list[ExtraField]] = {} + + +async def get_extra_fields(db: AsyncSession, entity_type: EntityType) -> list[ExtraField]: + """Get all extra fields for a specific entity type.""" + if entity_type in extra_field_cache: + return extra_field_cache[entity_type] + + setting_def = parse_setting(f"extra_fields_{entity_type.name}") + try: + setting = await db_setting.get(db, setting_def) + setting_value = setting.value + except ItemNotFoundError: + setting_value = setting_def.default + + setting_array = json.loads(setting_value) + if not isinstance(setting_array, list): + logger.warning("Setting %s is not a list, using default.", setting_def.key) + setting_array = [] + + fields = [ExtraField.parse_obj(obj) for obj in setting_array] + extra_field_cache[entity_type] = fields + return fields + + +async def add_or_update_extra_field(db: AsyncSession, entity_type: EntityType, extra_field: ExtraField) -> None: + """Add or update an extra field for a specific entity type.""" + validate_extra_field(extra_field) + + extra_fields = await get_extra_fields(db, entity_type) + existing_field = next((field for field in extra_fields if field.key == extra_field.key), None) + if existing_field is not None: + if existing_field.field_type != extra_field.field_type: + raise ValueError("Field type cannot be changed.") + if extra_field.field_type == ExtraFieldType.choice: + if existing_field.multi_choice != extra_field.multi_choice: + raise ValueError("Multi choice cannot be changed.") + if ( + existing_field.choices is not None + and extra_field.choices is not None + and not all(choice in extra_field.choices for choice in existing_field.choices) + ): + raise ValueError("Cannot remove existing choices.") + + extra_fields = [field for field in extra_fields if field.key != extra_field.key] + extra_fields.append(extra_field) + + setting_def = parse_setting(f"extra_fields_{entity_type.name}") + await db_setting.update(db=db, definition=setting_def, value=json.dumps(jsonable_encoder(extra_fields))) + + extra_field_cache[entity_type] = extra_fields + logger.info("Added/updated extra field %s for entity type %s.", extra_field.key, entity_type.name) diff --git a/spoolman/extra_fields.py b/spoolman/extra_fields.py index 2be157e3d..958acc900 100644 --- a/spoolman/extra_fields.py +++ b/spoolman/extra_fields.py @@ -2,10 +2,8 @@ import json import logging -from enum import Enum from fastapi.encoders import jsonable_encoder -from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from spoolman.database import filament as db_filament @@ -13,196 +11,36 @@ from spoolman.database import spool as db_spool from spoolman.database import vendor as db_vendor from spoolman.exceptions import ItemNotFoundError +from spoolman.extra_field_registry import ( + EntityType, + ExtraField, + ExtraFieldParameters, + ExtraFieldType, + add_or_update_extra_field, + extra_field_cache, + get_extra_fields, + validate_extra_field, + validate_extra_field_dict, + validate_extra_field_value, +) from spoolman.settings import parse_setting logger = logging.getLogger(__name__) - -class EntityType(Enum): - vendor = "vendor" - filament = "filament" - spool = "spool" - - -class ExtraFieldType(Enum): - text = "text" - integer = "integer" - integer_range = "integer_range" - float = "float" - float_range = "float_range" - datetime = "datetime" - boolean = "boolean" - choice = "choice" - - -class ExtraFieldParameters(BaseModel): - name: str = Field(description="Nice name", min_length=1, max_length=128) - order: int = Field(0, description="Order of the field") - unit: str | None = Field(None, description="Unit of the value", min_length=1, max_length=16) - field_type: ExtraFieldType = Field(description="Type of the field") - default_value: str | None = Field(None, description="Default value of the field") - choices: list[str] | None = Field( - None, - description="Choices for the field, only for field type choice", - min_length=1, - ) - multi_choice: bool | None = Field(None, description="Whether multiple choices can be selected") - - -class ExtraField(ExtraFieldParameters): - key: str = Field(description="Unique key", pattern="^[a-z0-9_]+$", min_length=1, max_length=64) - entity_type: EntityType = Field(description="Entity type this field is for") - - -def validate_extra_field_value(field: ExtraFieldParameters, value: str) -> None: # noqa: C901, PLR0912 - """Validate that the value has the correct type.""" - try: - data = json.loads(value) - except json.JSONDecodeError: - raise ValueError("Value is not valid JSON.") from None - - if field.field_type == ExtraFieldType.text: - if not isinstance(data, str): - raise ValueError("Value is not a string.") - elif field.field_type == ExtraFieldType.integer: - if not isinstance(data, int): - raise ValueError("Value is not an integer.") - elif field.field_type == ExtraFieldType.integer_range: - if not isinstance(data, list): - raise ValueError("Value is not a list.") - if len(data) != 2: # noqa: PLR2004 - raise ValueError("Value list must have exactly two values.") - if not all(isinstance(value, int) or value is None for value in data): - raise ValueError("Value list must contain only integers or null.") - elif field.field_type == ExtraFieldType.float: - if not isinstance(data, (float, int)) or isinstance(data, bool): - raise ValueError("Value is not a float.") - elif field.field_type == ExtraFieldType.float_range: - if not isinstance(data, list): - raise ValueError("Value is not a list.") - if len(data) != 2: # noqa: PLR2004 - raise ValueError("Value list must have exactly two values.") - if not all( - (isinstance(value, (float, int)) or value is None) and not isinstance(value, bool) for value in data - ): - raise ValueError("Value list must contain only floats or null.") - elif field.field_type == ExtraFieldType.datetime: - if not isinstance(data, str): - raise ValueError("Value is not a string.") - elif field.field_type == ExtraFieldType.boolean: - if not isinstance(data, bool): - raise ValueError("Value is not a boolean.") - elif field.field_type == ExtraFieldType.choice: - if field.multi_choice: - if not isinstance(data, list): - raise ValueError("Value is not a list.") - if not all(isinstance(value, str) for value in data): - raise ValueError("Value list must contain only strings.") - if field.choices is not None and not all(value in field.choices for value in data): - raise ValueError("Value list contains invalid choices.") - else: - if not isinstance(data, str): - raise ValueError("Value is not a string.") - if field.choices is not None and data not in field.choices: - raise ValueError("Value is not a valid choice.") - else: - raise ValueError(f"Unknown field type {field.field_type}.") - - -def validate_extra_field(field: ExtraFieldParameters) -> None: - """Validate an extra field.""" - # Validate choices exist if field type is choice - if field.field_type == ExtraFieldType.choice: - if field.choices is None: - raise ValueError("Choices must be set for field type choice.") - if field.multi_choice is None: - raise ValueError("Multi choice must be set for field type choice.") - else: - if field.choices is not None: - raise ValueError("Choices must not be set for field type other than choice.") - if field.multi_choice is not None: - raise ValueError("Multi choice must not be set for field type other than choice.") - - # Validate default value data type - if field.default_value is not None: - try: - validate_extra_field_value(field, field.default_value) - except ValueError as e: - raise ValueError(f"Default value is not valid: {e}") from None - - -def validate_extra_field_dict(all_fields: list[ExtraField], fields_input: dict[str, str]) -> None: - """Validate a dict of extra fields.""" - all_field_lookup = {field.key: field for field in all_fields} - for key, value in fields_input.items(): - if key not in all_field_lookup: - raise ValueError(f"Unknown extra field {key}.") - field = all_field_lookup[key] - try: - validate_extra_field_value(field, value) - except ValueError as e: - raise ValueError(f"Invalid extra field for key {key}: {e!s}") from None - - -extra_field_cache = {} - - -async def get_extra_fields(db: AsyncSession, entity_type: EntityType) -> list[ExtraField]: - """Get all extra fields for a specific entity type.""" - if entity_type in extra_field_cache: - return extra_field_cache[entity_type] - - setting_def = parse_setting(f"extra_fields_{entity_type.name}") - try: - setting = await db_setting.get(db, setting_def) - setting_value = setting.value - except ItemNotFoundError: - setting_value = setting_def.default - - setting_array = json.loads(setting_value) - if not isinstance(setting_array, list): - logger.warning("Setting %s is not a list, using default.", setting_def.key) - setting_array = [] - - fields = [ExtraField.parse_obj(obj) for obj in setting_array] - extra_field_cache[entity_type] = fields - return fields - - -async def add_or_update_extra_field(db: AsyncSession, entity_type: EntityType, extra_field: ExtraField) -> None: - """Add or update an extra field for a specific entity type.""" - validate_extra_field(extra_field) - - extra_fields = await get_extra_fields(db, entity_type) - - # If the field already exists, verify that we don't do anything that would break existing data - existing_field = next((field for field in extra_fields if field.key == extra_field.key), None) - if existing_field is not None: - if existing_field.field_type != extra_field.field_type: - raise ValueError("Field type cannot be changed.") - if extra_field.field_type == ExtraFieldType.choice: - # Can't change multi choice since that would break existing data - if existing_field.multi_choice != extra_field.multi_choice: - raise ValueError("Multi choice cannot be changed.") - - # Verify that we have only added new choices, not removed any - if ( - existing_field.choices is not None - and extra_field.choices is not None - and not all(choice in extra_field.choices for choice in existing_field.choices) - ): - raise ValueError("Cannot remove existing choices.") - - extra_fields = [field for field in extra_fields if field.key != extra_field.key] - extra_fields.append(extra_field) - - setting_def = parse_setting(f"extra_fields_{entity_type.name}") - await db_setting.update(db=db, definition=setting_def, value=json.dumps(jsonable_encoder(extra_fields))) - - # Update cache - extra_field_cache[entity_type] = extra_fields - - logger.info("Added/updated extra field %s for entity type %s.", extra_field.key, entity_type.name) +__all__ = [ + "EntityType", + "ExtraField", + "ExtraFieldParameters", + "ExtraFieldType", + "add_or_update_extra_field", + "delete_extra_field", + "extra_field_cache", + "get_extra_fields", + "populate_with_defaults", + "validate_extra_field", + "validate_extra_field_dict", + "validate_extra_field_value", +] async def delete_extra_field(db: AsyncSession, entity_type: EntityType, key: str) -> None: diff --git a/tests_integration/tests/fields/test_filter_sort.py b/tests_integration/tests/fields/test_filter_sort.py new file mode 100644 index 000000000..f2b47e5f9 --- /dev/null +++ b/tests_integration/tests/fields/test_filter_sort.py @@ -0,0 +1,576 @@ +"""Tests for filtering and sorting by custom fields.""" + +import json +from typing import Any + +import httpx +import pytest + +from ..conftest import URL, assert_httpx_success + + +@pytest.mark.asyncio +async def test_filter_by_custom_field(random_filament: dict[str, Any]): + """Add a custom text field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/test_field", + json={ + "name": "Test field", + "field_type": "text", + "default_value": json.dumps("Hello World"), + }, + ) + assert_httpx_success(result) + + """Test filtering by custom field.""" + # Create a spool with a custom field + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"test_field": json.dumps("test_value")}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + # Create another spool with a different custom field value + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"test_field": json.dumps("other_value")}}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Filter by custom field + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.test_field": "test_value"}) + assert_httpx_success(result) + data = result.json() + assert len(data) == 1 + assert data[0]["id"] == spool_id1 + + # Filter by custom field with exact match + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.test_field": '"test_value"'}) + assert_httpx_success(result) + data = result.json() + assert len(data) == 1 + assert data[0]["id"] == spool_id1 + + # Filter by custom field with multiple values + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.test_field": "test_value,other_value"}) + assert_httpx_success(result) + data = result.json() + assert len(data) == 2 + assert {item["id"] for item in data} == {spool_id1, spool_id2} + + # Clean up + result = httpx.delete(f"{URL}/api/v1/field/spool/test_field") + assert_httpx_success(result) + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_sort_by_custom_field(random_filament: dict[str, Any]): + """Add a custom text field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/text_field", + json={ + "name": "Text field", + "field_type": "text", + }, + ) + assert_httpx_success(result) + + """Test sorting by custom field.""" + # Create spools with custom fields of different types + # Text field + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"text_field": json.dumps("B value")}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"text_field": json.dumps("A value")}}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Sort by custom field ascending + result = httpx.get(f"{URL}/api/v1/spool", params={"sort": "extra.text_field:asc"}) + assert_httpx_success(result) + data = result.json() + assert len(data) >= 2 + # Find our test spools in the results + test_spools = [item for item in data if item["id"] in (spool_id1, spool_id2)] + assert len(test_spools) == 2 + assert test_spools[0]["id"] == spool_id2 # A value should come first + assert test_spools[1]["id"] == spool_id1 # B value should come second + + # Sort by custom field descending + result = httpx.get(f"{URL}/api/v1/spool", params={"sort": "extra.text_field:desc"}) + assert_httpx_success(result) + data = result.json() + assert len(data) >= 2 + # Find our test spools in the results + test_spools = [item for item in data if item["id"] in (spool_id1, spool_id2)] + assert len(test_spools) == 2 + assert test_spools[0]["id"] == spool_id1 # B value should come first + assert test_spools[1]["id"] == spool_id2 # A value should come second + + # Clean up + result = httpx.delete(f"{URL}/api/v1/field/spool/text_field") + assert_httpx_success(result) + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_by_numeric_custom_field(random_filament: dict[str, Any]): + """Add a custom numeric field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/numeric_field", + json={ + "name": "Numeric field", + "field_type": "integer", + }, + ) + assert_httpx_success(result) + + """Test filtering by numeric custom field.""" + # Create a spool with a numeric custom field + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"numeric_field": json.dumps(100)}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + # Create another spool with a different numeric value + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"numeric_field": json.dumps(200)}}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Filter by numeric custom field + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.numeric_field": "100"}) + assert_httpx_success(result) + data = result.json() + assert len(data) == 1 + assert data[0]["id"] == spool_id1 + + # Sort by numeric custom field ascending + result = httpx.get(f"{URL}/api/v1/spool", params={"sort": "extra.numeric_field:asc"}) + assert_httpx_success(result) + data = result.json() + # Find our test spools in the results + test_spools = [item for item in data if item["id"] in (spool_id1, spool_id2)] + assert len(test_spools) == 2 + assert test_spools[0]["id"] == spool_id1 # 100 should come first + assert test_spools[1]["id"] == spool_id2 # 200 should come second + + # Clean up + result = httpx.delete(f"{URL}/api/v1/field/spool/numeric_field") + assert_httpx_success(result) + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_by_boolean_custom_field(random_filament: dict[str, Any]): + """Add a custom boolean field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/boolean_field", + json={ + "name": "Boolean field", + "field_type": "boolean", + }, + ) + assert_httpx_success(result) + + """Test filtering by boolean custom field.""" + # Create a spool with a boolean custom field + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"boolean_field": json.dumps(bool(1))}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + # Create another spool with a different boolean value + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"boolean_field": json.dumps(bool(0))}}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Filter by boolean custom field + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.boolean_field": "true"}) + assert_httpx_success(result) + data = result.json() + assert len(data) == 1 + assert data[0]["id"] == spool_id1 + + # Clean up + result = httpx.delete(f"{URL}/api/v1/field/spool/boolean_field") + assert_httpx_success(result) + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_and_sort_float_custom_field(random_filament: dict[str, Any]): + """Test filtering and sorting by a float custom field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/float_field", + json={"name": "Float field", "field_type": "float"}, + ) + assert_httpx_success(result) + + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"float_field": json.dumps(1.5)}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"float_field": json.dumps(2.5)}}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Filter by exact float value + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.float_field": "1.5"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 not in ids + + # Sort ascending: 1.5 before 2.5 + result = httpx.get(f"{URL}/api/v1/spool", params={"sort": "extra.float_field:asc"}) + assert_httpx_success(result) + test_spools = [item for item in result.json() if item["id"] in (spool_id1, spool_id2)] + assert len(test_spools) == 2 + assert test_spools[0]["id"] == spool_id1 + assert test_spools[1]["id"] == spool_id2 + + # Sort descending: 2.5 before 1.5 + result = httpx.get(f"{URL}/api/v1/spool", params={"sort": "extra.float_field:desc"}) + assert_httpx_success(result) + test_spools = [item for item in result.json() if item["id"] in (spool_id1, spool_id2)] + assert len(test_spools) == 2 + assert test_spools[0]["id"] == spool_id2 + assert test_spools[1]["id"] == spool_id1 + + # Clean up + httpx.delete(f"{URL}/api/v1/field/spool/float_field").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_single_choice_custom_field(random_filament: dict[str, Any]): + """Test filtering by a single-choice custom field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/choice_field", + json={ + "name": "Choice field", + "field_type": "choice", + "choices": ["OptionA", "OptionB", "OptionC"], + "multi_choice": False, + }, + ) + assert_httpx_success(result) + + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"choice_field": json.dumps("OptionA")}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"choice_field": json.dumps("OptionB")}}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Filter by a single choice value + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.choice_field": "OptionA"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 not in ids + + # Filter by multiple choices (OR) — both should be returned + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.choice_field": "OptionA,OptionB"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 in ids + + # Clean up + httpx.delete(f"{URL}/api/v1/field/spool/choice_field").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_multi_choice_custom_field(random_filament: dict[str, Any]): + """Test filtering by a multi-choice custom field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/multi_choice_field", + json={ + "name": "Multi-choice field", + "field_type": "choice", + "choices": ["A", "B", "C"], + "multi_choice": True, + }, + ) + assert_httpx_success(result) + + # Spool 1 has choices A and B + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"multi_choice_field": json.dumps(["A", "B"])}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + # Spool 2 has only choice C + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"multi_choice_field": json.dumps(["C"])}}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Filter by A — only spool 1 has A + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.multi_choice_field": "A"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 not in ids + + # Filter by C — only spool 2 has C + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.multi_choice_field": "C"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id2 in ids + assert spool_id1 not in ids + + # Filter by A,C (OR) — both should be returned + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.multi_choice_field": "A,C"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 in ids + + # Clean up + httpx.delete(f"{URL}/api/v1/field/spool/multi_choice_field").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_empty_custom_field(random_filament: dict[str, Any]): + """Test the empty-string filter returns items that have no value set for a custom field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/optional_field", + json={"name": "Optional field", "field_type": "text"}, + ) + assert_httpx_success(result) + + # Spool 1 has the field set + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"optional_field": json.dumps("has_value")}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + # Spool 2 does NOT have the field set + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"]}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Filter by empty string — spool 2 (no field row) should appear, spool 1 should not + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.optional_field": ""}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id2 in ids + assert spool_id1 not in ids + + # Filter by the value — spool 1 should appear, spool 2 should not + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.optional_field": "has_value"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 not in ids + + # Clean up + httpx.delete(f"{URL}/api/v1/field/spool/optional_field").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_invalid_numeric_custom_field_filters_return_400(): + """Invalid numeric custom-field filters should fail explicitly instead of being ignored.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/numeric_field", + json={ + "name": "Numeric field", + "field_type": "integer", + }, + ) + assert_httpx_success(result) + + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.numeric_field": "abc"}) + assert result.status_code == 400 + assert "Invalid integer filter value" in result.json()["message"] + + httpx.delete(f"{URL}/api/v1/field/spool/numeric_field").raise_for_status() + + +@pytest.mark.asyncio +async def test_invalid_boolean_custom_field_filters_return_400(): + """Invalid boolean custom-field filters should fail explicitly instead of being coerced.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/boolean_field", + json={ + "name": "Boolean field", + "field_type": "boolean", + }, + ) + assert_httpx_success(result) + + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.boolean_field": "maybe"}) + assert result.status_code == 400 + assert "Invalid boolean filter value" in result.json()["message"] + + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.boolean_field": "yes"}) + assert result.status_code == 400 + assert "Invalid boolean filter value" in result.json()["message"] + + httpx.delete(f"{URL}/api/v1/field/spool/boolean_field").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_sort_filament_custom_field(random_filament: dict[str, Any]): + """Test filtering and sorting filaments by a custom field.""" + vendor_id = random_filament["vendor"]["id"] + + result = httpx.post( + f"{URL}/api/v1/field/filament/filament_tag", + json={"name": "Filament tag", "field_type": "text"}, + ) + assert_httpx_success(result) + + result = httpx.post( + f"{URL}/api/v1/filament", + json={"vendor_id": vendor_id, "density": 1.24, "diameter": 1.75, "extra": {"filament_tag": json.dumps("beta")}}, + ) + assert_httpx_success(result) + filament_id1 = result.json()["id"] + + result = httpx.post( + f"{URL}/api/v1/filament", + json={ + "vendor_id": vendor_id, + "density": 1.24, + "diameter": 1.75, + "extra": {"filament_tag": json.dumps("alpha")}, + }, + ) + assert_httpx_success(result) + filament_id2 = result.json()["id"] + + # Filter by custom field — only filament with "beta" should appear + result = httpx.get(f"{URL}/api/v1/filament", params={"extra.filament_tag": "beta"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert filament_id1 in ids + assert filament_id2 not in ids + + # Sort ascending: alpha before beta + result = httpx.get(f"{URL}/api/v1/filament", params={"sort": "extra.filament_tag:asc"}) + assert_httpx_success(result) + test_filaments = [item for item in result.json() if item["id"] in (filament_id1, filament_id2)] + assert len(test_filaments) == 2 + assert test_filaments[0]["id"] == filament_id2 # alpha first + assert test_filaments[1]["id"] == filament_id1 # beta second + + # Sort descending: beta before alpha + result = httpx.get(f"{URL}/api/v1/filament", params={"sort": "extra.filament_tag:desc"}) + assert_httpx_success(result) + test_filaments = [item for item in result.json() if item["id"] in (filament_id1, filament_id2)] + assert len(test_filaments) == 2 + assert test_filaments[0]["id"] == filament_id1 # beta first + assert test_filaments[1]["id"] == filament_id2 # alpha second + + # Clean up + httpx.delete(f"{URL}/api/v1/field/filament/filament_tag").raise_for_status() + httpx.delete(f"{URL}/api/v1/filament/{filament_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/filament/{filament_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_sort_vendor_custom_field(): + """Test filtering and sorting vendors by a custom field.""" + result = httpx.post( + f"{URL}/api/v1/field/vendor/vendor_tier", + json={"name": "Vendor tier", "field_type": "text"}, + ) + assert_httpx_success(result) + + result = httpx.post( + f"{URL}/api/v1/vendor", + json={"name": "Vendor Gold", "extra": {"vendor_tier": json.dumps("gold")}}, + ) + assert_httpx_success(result) + vendor_id1 = result.json()["id"] + + result = httpx.post( + f"{URL}/api/v1/vendor", + json={"name": "Vendor Silver", "extra": {"vendor_tier": json.dumps("silver")}}, + ) + assert_httpx_success(result) + vendor_id2 = result.json()["id"] + + # Filter by vendor custom field — only gold vendor should appear + result = httpx.get(f"{URL}/api/v1/vendor", params={"extra.vendor_tier": "gold"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert vendor_id1 in ids + assert vendor_id2 not in ids + + # Sort ascending: gold before silver + result = httpx.get(f"{URL}/api/v1/vendor", params={"sort": "extra.vendor_tier:asc"}) + assert_httpx_success(result) + test_vendors = [item for item in result.json() if item["id"] in (vendor_id1, vendor_id2)] + assert len(test_vendors) == 2 + assert test_vendors[0]["id"] == vendor_id1 # gold first + assert test_vendors[1]["id"] == vendor_id2 # silver second + + # Sort descending: silver before gold + result = httpx.get(f"{URL}/api/v1/vendor", params={"sort": "extra.vendor_tier:desc"}) + assert_httpx_success(result) + test_vendors = [item for item in result.json() if item["id"] in (vendor_id1, vendor_id2)] + assert len(test_vendors) == 2 + assert test_vendors[0]["id"] == vendor_id2 # silver first + assert test_vendors[1]["id"] == vendor_id1 # gold second + + # Clean up + httpx.delete(f"{URL}/api/v1/field/vendor/vendor_tier").raise_for_status() + httpx.delete(f"{URL}/api/v1/vendor/{vendor_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/vendor/{vendor_id2}").raise_for_status()