Skip to content

Commit c8cb902

Browse files
committed
Add types to PathQuery
1 parent b850363 commit c8cb902

File tree

1 file changed

+11
-19
lines changed

1 file changed

+11
-19
lines changed

beets/dbcore/query.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
from beets import util
3131

3232
if TYPE_CHECKING:
33-
from beets.dbcore import Model
34-
from beets.dbcore.db import AnyModel
33+
from beets.dbcore.db import AnyModel, Model
3534

3635
P = TypeVar("P", default=Any)
3736
else:
@@ -283,13 +282,11 @@ class PathQuery(FieldQuery[bytes]):
283282
and case-sensitive otherwise.
284283
"""
285284

286-
def __init__(self, field, pattern, fast=True):
285+
def __init__(self, field: str, pattern: bytes, fast: bool = True) -> None:
287286
"""Create a path query.
288287
289288
`pattern` must be a path, either to a file or a directory.
290289
"""
291-
super().__init__(field, pattern, fast)
292-
293290
path = util.normpath(pattern)
294291

295292
# Case sensitivity depends on the filesystem that the query path is located on.
@@ -304,13 +301,10 @@ def __init__(self, field, pattern, fast=True):
304301
# from `col_clause()` do the same thing.
305302
path = path.lower()
306303

307-
# Match the path as a single file.
308-
self.file_path = path
309-
# As a directory (prefix).
310-
self.dir_path = os.path.join(path, b"")
304+
super().__init__(field, path, fast)
311305

312-
@classmethod
313-
def is_path_query(cls, query_part):
306+
@staticmethod
307+
def is_path_query(query_part: str) -> bool:
314308
"""Try to guess whether a unicode query part is a path query.
315309
316310
Condition: separator precedes colon and the file exists.
@@ -328,22 +322,20 @@ def is_path_query(cls, query_part):
328322

329323
return os.path.exists(util.syspath(util.normpath(query_part)))
330324

331-
def match(self, item):
332-
path = item.path if self.case_sensitive else item.path.lower()
333-
return (path == self.file_path) or path.startswith(self.dir_path)
334-
335-
def col_clause(self):
336-
file_blob = BLOB_TYPE(self.file_path)
337-
dir_blob = BLOB_TYPE(self.dir_path)
325+
def match(self, obj: Model) -> bool:
326+
path = obj.path if self.case_sensitive else obj.path.lower()
327+
return path.startswith(self.pattern)
338328

329+
def col_clause(self) -> tuple[str, Sequence[SQLiteType]]:
339330
if self.case_sensitive:
340331
query_part = "({0} = ?) || (substr({0}, 1, ?) = ?)"
341332
else:
342333
query_part = "(BYTELOWER({0}) = BYTELOWER(?)) || \
343334
(substr(BYTELOWER({0}), 1, ?) = BYTELOWER(?))"
344335

336+
dir_blob = BLOB_TYPE(os.path.join(self.pattern, b""))
345337
return query_part.format(self.field), (
346-
file_blob,
338+
BLOB_TYPE(self.pattern),
347339
len(dir_blob),
348340
dir_blob,
349341
)

0 commit comments

Comments
 (0)