diff --git a/labml_db/driver/__init__.py b/labml_db/driver/__init__.py index 20d9500..9bc2626 100644 --- a/labml_db/driver/__init__.py +++ b/labml_db/driver/__init__.py @@ -1,6 +1,6 @@ -from typing import List, Type, TYPE_CHECKING, Optional +from typing import List, Type, TYPE_CHECKING, Optional, Tuple -from ..types import ModelDict +from ..types import ModelDict, QueryDict, SortDict if TYPE_CHECKING: from .. import Serializer, Model @@ -28,3 +28,8 @@ def msave_dict(self, key: List[str], data: List[ModelDict]): def get_all(self) -> List[str]: raise NotImplementedError + + def search(self, text_query: Optional[str], filters: Optional[QueryDict], sort: Optional[SortDict], + randomize: bool = False, limit: Optional[int] = None, sort_by_text_score: bool = False) -> Tuple[ + List[Tuple[str, ModelDict]], int]: + raise NotImplementedError diff --git a/labml_db/driver/mongo.py b/labml_db/driver/mongo.py index 1777764..39b3803 100644 --- a/labml_db/driver/mongo.py +++ b/labml_db/driver/mongo.py @@ -1,12 +1,13 @@ -from typing import List, Type, TYPE_CHECKING, Optional, Dict +from collections import OrderedDict +from typing import List, Type, TYPE_CHECKING, Optional, Dict, Tuple, OrderedDict -from labml_db.serializer.utils import encode_keys, decode_keys +import pymongo +from ..serializer.utils import encode_keys, decode_keys from . import DbDriver -from ..types import ModelDict +from ..types import ModelDict, QueryDict, SortDict if TYPE_CHECKING: - import pymongo from ..model import Model @@ -70,3 +71,54 @@ def get_all(self) -> List[str]: cur = self._collection.find(projection=['_id']) keys = [self._to_key(d['_id']) for d in cur] return keys + + def search(self, text_query: Optional[str], filters: Optional[QueryDict], sort: Optional[SortDict], + randomize: bool = False, limit: Optional[int] = None, sort_by_text_score: bool = False) -> Tuple[ + List[Tuple[str, ModelDict]], int]: + pipeline = [] + + match = dict() + if filters: + for property_name, item in filters.items(): + value, equal = item + if equal: + match[property_name] = value + else: + match[property_name] = {'$ne': value} + if text_query: + match['$text'] = {'$search': text_query} + if len(match) > 0: + pipeline.append({'$match': match}) + + if randomize: + pipeline.append({'$facet': {'data': [{'$sample': {'size': limit}}], 'count': [{'$count': 'count'}]}}) + else: + sort_query = OrderedDict() + if sort_by_text_score: + sort_query['score'] = {'$meta': 'textScore'} + if sort is not None and len(sort) > 0: + for k, v in sort: + sort_query[k] = pymongo.ASCENDING if v else pymongo.DESCENDING + + if len(sort_query) > 0: + pipeline.append({'$sort': sort_query}) + + if limit: + pipeline.append({'$facet': {'data': [{'$limit': limit}], 'count': [{'$count': 'count'}]}}) + + cursor = self._collection.aggregate(pipeline) + res = [] + count = 0 + if limit: + for item in cursor: + for c in item['count']: + count += c['count'] + for d in item['data']: + res.append((d['_id'], self._load_data(d))) + else: + for d in cursor: + res.append((d['_id'], self._load_data(d))) + + count = len(res) + + return res, count diff --git a/labml_db/model.py b/labml_db/model.py index 08a1760..ab08383 100644 --- a/labml_db/model.py +++ b/labml_db/model.py @@ -1,9 +1,9 @@ import copy import warnings -from typing import Generic, Union, Any +from typing import Generic, Union, Any, Tuple from typing import TypeVar, List, Dict, Type, Set, Optional, _GenericAlias, TYPE_CHECKING -from .types import Primitive, ModelDict +from .types import Primitive, ModelDict, QueryDict, SortDict if TYPE_CHECKING: from .driver import DbDriver @@ -192,6 +192,10 @@ def __init__(self, key: Optional[str] = None, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) + for k, v in self._defaults.items(): + if k not in kwargs: + setattr(self, k, v) + def __init_subclass__(cls, **kwargs): if cls.__name__ in Model.__models: warnings.warn(f"{cls.__name__} already used") @@ -308,8 +312,8 @@ def from_dict_transform(cls, data: ModelDict) -> Dict[str, Any]: def to_dict(self) -> ModelDict: values = {} for k, v in self._values.items(): - if k not in self._defaults or self._defaults[k] != v: - values[k] = v + # TODO: exclude defaults from the saved data based on a flag + values[k] = v values = self.to_dict_transform(values) return values @@ -340,3 +344,23 @@ def __repr__(self): kv = [f'{k}={repr(v)}' for k, v in self._values.items()] kv = ', '.join(kv) return f'{self.__class__.__name__}({kv})' + + @classmethod + def search(cls, text_query: Optional[str] = None, filters: Optional[QueryDict] = None, + sort: Optional[SortDict] = None, randomize: bool = False, limit: Optional[int] = None, + sort_by_text_score: bool = False) -> Tuple[List[_KT], int]: + if sort is not None and len(sort) > 0 and randomize: + raise ValueError('Cannot have both randomize and sort criteria') + if limit is not None and limit <= 0: + raise ValueError('Limit should be higher than 0') + if randomize and not limit: + raise ValueError('A limit should be provided when results are randomized') + if sort_by_text_score and not text_query: + raise ValueError("Cannot search by text score when there's no text query") + if randomize and sort_by_text_score: + raise ValueError('Cannot have both randomize and sort by text score') + + db_driver = Model.__db_drivers[cls.__name__] + data, total_count = db_driver.search(text_query=text_query, filters=filters, sort=sort, randomize=randomize, + limit=limit, sort_by_text_score=sort_by_text_score) + return [Model._to_model(k, d) for k, d in data], total_count diff --git a/labml_db/serializer/json.py b/labml_db/serializer/json.py index 13ff059..20d69a8 100644 --- a/labml_db/serializer/json.py +++ b/labml_db/serializer/json.py @@ -10,6 +10,7 @@ class JsonSerializer(Serializer): file_extension = 'json' def to_string(self, data: ModelDict) -> str: + assert data return json.dumps(encode_keys(data)) def from_string(self, data: Optional[str]) -> Optional[ModelDict]: diff --git a/labml_db/serializer/utils.py b/labml_db/serializer/utils.py index d6b230b..8d52809 100644 --- a/labml_db/serializer/utils.py +++ b/labml_db/serializer/utils.py @@ -1,7 +1,7 @@ from typing import Dict -from labml_db import Key -from labml_db.types import Primitive +from .. import Key +from ..types import Primitive def encode_key(key: Key) -> Dict[str, str]: diff --git a/labml_db/serializer/yaml.py b/labml_db/serializer/yaml.py index cb7a610..c58d2bd 100644 --- a/labml_db/serializer/yaml.py +++ b/labml_db/serializer/yaml.py @@ -10,6 +10,7 @@ class YamlSerializer(Serializer): def to_string(self, data: ModelDict) -> str: import yaml + assert data return yaml.dump(encode_keys(data), default_flow_style=False) def from_string(self, data: Optional[str]) -> Optional[ModelDict]: diff --git a/labml_db/types.py b/labml_db/types.py index 0ac16d8..0dc1ffb 100644 --- a/labml_db/types.py +++ b/labml_db/types.py @@ -1,5 +1,7 @@ -from typing import List, Dict, Union +from typing import List, Dict, Union, Tuple Primitive = Union[Dict[str, 'Primitive'], List['Primitive'], int, str, float, bool, None] ModelDict = Dict[str, Primitive] - +# {Property: (value, equal/not_equal)} +QueryDict = Dict[str, Tuple[Union[List['Primitive'], int, str, float, bool], bool]] +SortDict = List[Tuple[str, bool]]