From ad062ba78e5a294830b180ae8261861e8ca70f01 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:07:49 -0400 Subject: [PATCH 1/6] use user language for search --- machine-learning/app/models/clip/textual.py | 18 ++++-- machine-learning/app/models/constants.py | 59 +++++++++++++++++++ server/src/dtos/search.dto.ts | 5 ++ .../interfaces/machine-learning.interface.ts | 7 ++- .../machine-learning.repository.ts | 5 +- server/src/services/search.service.ts | 10 ++-- .../[[assetId=id]]/+page.svelte | 3 +- 7 files changed, 91 insertions(+), 16 deletions(-) diff --git a/machine-learning/app/models/clip/textual.py b/machine-learning/app/models/clip/textual.py index 32c28ea2bb145..b164dcc17cca5 100644 --- a/machine-learning/app/models/clip/textual.py +++ b/machine-learning/app/models/clip/textual.py @@ -10,6 +10,7 @@ from app.config import log from app.models.base import InferenceModel +from app.models.constants import WEBLATE_TO_FLORES200 from app.models.transforms import clean_text from app.schemas import ModelSession, ModelTask, ModelType @@ -18,8 +19,9 @@ class BaseCLIPTextualEncoder(InferenceModel): depends = [] identity = (ModelType.TEXTUAL, ModelTask.SEARCH) - def _predict(self, inputs: str, **kwargs: Any) -> NDArray[np.float32]: - res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0] + def _predict(self, inputs: str, language: str | None = None, **kwargs: Any) -> NDArray[np.float32]: + tokens = self.tokenize(inputs, language=language) + res: NDArray[np.float32] = self.session.run(None, tokens)[0][0] return res def _load(self) -> ModelSession: @@ -28,6 +30,7 @@ def _load(self) -> ModelSession: self.tokenizer = self._load_tokenizer() tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs") self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize" + self.is_nllb = self.model_name.startswith("nllb") log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'") return session @@ -37,7 +40,7 @@ def _load_tokenizer(self) -> Tokenizer: pass @abstractmethod - def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: + def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]: pass @property @@ -92,14 +95,19 @@ def _load_tokenizer(self) -> Tokenizer: return tokenizer - def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: + def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]: text = clean_text(text, canonicalize=self.canonicalize) + if self.is_nllb: + flores_code = code if language and (code := WEBLATE_TO_FLORES200.get(language)) else "eng_Latn" + print(f"{language=}") + print(f"{flores_code=}") + text = f"{flores_code}{text}" tokens: Encoding = self.tokenizer.encode(text) return {"text": np.array([tokens.ids], dtype=np.int32)} class MClipTextualEncoder(OpenClipTextualEncoder): - def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: + def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]: text = clean_text(text, canonicalize=self.canonicalize) tokens: Encoding = self.tokenizer.encode(text) return { diff --git a/machine-learning/app/models/constants.py b/machine-learning/app/models/constants.py index 338a481594f4d..a84c3802ec97c 100644 --- a/machine-learning/app/models/constants.py +++ b/machine-learning/app/models/constants.py @@ -66,6 +66,65 @@ SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"] +WEBLATE_TO_FLORES200 = { + "af": "afr_Latn", + "ar": "arb_Arab", + "az": "azj_Latn", + "be": "bel_Cyrl", + "bg": "bul_Cyrl", + "ca": "cat_Latn", + "cs": "ces_Latn", + "da": "dan_Latn", + "de": "deu_Latn", + "el": "ell_Grek", + "en": "eng_Latn", + "es": "spa_Latn", + "et": "est_Latn", + "fa": "pes_Arab", + "fi": "fin_Latn", + "fr": "fra_Latn", + "he": "heb_Hebr", + "hi": "hin_Deva", + "hr": "hrv_Latn", + "hu": "hun_Latn", + "hy": "hye_Armn", + "id": "ind_Latn", + "it": "ita_Latn", + "ja": "jpn_Hira", + "kmr": "kmr_Latn", + "ko": "kor_Hang", + "lb": "ltz_Latn", + "lt": "lit_Latn", + "lv": "lav_Latn", + "mfa": "zsm_Latn", + "mk": "mkd_Cyrl", + "mn": "khk_Cyrl", + "mr": "mar_Deva", + "ms": "zsm_Latn", + "nb_NO": "nob_Latn", + "nl": "nld_Latn", + "pl": "pol_Latn", + "pt_BR": "por_Latn", + "pt": "por_Latn", + "ro": "ron_Latn", + "ru": "rus_Cyrl", + "sk": "slk_Latn", + "sl": "slv_Latn", + "sr_Cyrl": "srp_Cyrl", + "sv": "swe_Latn", + "ta": "tam_Taml", + "te": "tel_Telu", + "th": "tha_Thai", + "tr": "tur_Latn", + "uk": "ukr_Cyrl", + "vi": "vie_Latn", + "zh-CN": "zho_Hans", + "zh-TW": "zho_Hant", + "zh_Hant": "zho_Hant", + "zh_SIMPLIFIED": "zho_Hans", +} + + def get_model_source(model_name: str) -> ModelSource | None: cleaned_name = clean_name(model_name) diff --git a/server/src/dtos/search.dto.ts b/server/src/dtos/search.dto.ts index 5c5dce1a1190a..434bb3562fa4a 100644 --- a/server/src/dtos/search.dto.ts +++ b/server/src/dtos/search.dto.ts @@ -177,6 +177,11 @@ export class SmartSearchDto extends BaseSearchDto { @IsNotEmpty() query!: string; + @IsString() + @IsNotEmpty() + @Optional() + language?: string; + @IsInt() @Min(1) @Type(() => Number) diff --git a/server/src/interfaces/machine-learning.interface.ts b/server/src/interfaces/machine-learning.interface.ts index 5342030c8fde7..b140755a27c2f 100644 --- a/server/src/interfaces/machine-learning.interface.ts +++ b/server/src/interfaces/machine-learning.interface.ts @@ -30,7 +30,9 @@ type VisualResponse = { imageHeight: number; imageWidth: number }; export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } }; export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse; -export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } }; +export type ClipTextualRequest = { + [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions & { options: { language?: string } } }; +}; export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] }; export type FacialRecognitionRequest = { @@ -49,9 +51,10 @@ export interface Face { export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse; export type DetectedFaces = { faces: Face[] } & VisualResponse; export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest; +export type TextEncodingOptions = ModelOptions & { language?: string }; export interface IMachineLearningRepository { encodeImage(url: string, imagePath: string, config: ModelOptions): Promise; - encodeText(url: string, text: string, config: ModelOptions): Promise; + encodeText(url: string, text: string, config: TextEncodingOptions): Promise; detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise; } diff --git a/server/src/repositories/machine-learning.repository.ts b/server/src/repositories/machine-learning.repository.ts index b9404022efffa..f15364c24fa43 100644 --- a/server/src/repositories/machine-learning.repository.ts +++ b/server/src/repositories/machine-learning.repository.ts @@ -11,6 +11,7 @@ import { ModelPayload, ModelTask, ModelType, + TextEncodingOptions, } from 'src/interfaces/machine-learning.interface'; import { Instrumentation } from 'src/utils/instrumentation'; @@ -55,8 +56,8 @@ export class MachineLearningRepository implements IMachineLearningRepository { return response[ModelTask.SEARCH]; } - async encodeText(url: string, text: string, { modelName }: CLIPConfig) { - const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } }; + async encodeText(url: string, text: string, { language, modelName }: TextEncodingOptions) { + const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, options: { language } } } }; const response = await this.predict(url, { text }, request); return response[ModelTask.SEARCH]; } diff --git a/server/src/services/search.service.ts b/server/src/services/search.service.ts index 03ffbe97db14e..db6c70b143921 100644 --- a/server/src/services/search.service.ts +++ b/server/src/services/search.service.ts @@ -86,12 +86,10 @@ export class SearchService extends BaseService { } const userIds = await this.getUserIdsToSearch(auth); - - const embedding = await this.machineLearningRepository.encodeText( - machineLearning.url, - dto.query, - machineLearning.clip, - ); + const embedding = await this.machineLearningRepository.encodeText(machineLearning.url, dto.query, { + modelName: machineLearning.clip.modelName, + language: dto.language, + }); const page = dto.page ?? 1; const size = dto.size || 100; const { hasNextPage, items } = await this.searchRepository.searchSmart( diff --git a/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte b/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte index eb0c493204f1b..0efee47472bb3 100644 --- a/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte +++ b/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte @@ -31,7 +31,7 @@ } from '@immich/sdk'; import { mdiArrowLeft, mdiDotsVertical, mdiImageOffOutline, mdiPlus, mdiSelectAll } from '@mdi/js'; import type { Viewport } from '$lib/stores/assets.store'; - import { locale } from '$lib/stores/preferences.store'; + import { lang, locale } from '$lib/stores/preferences.store'; import LoadingSpinner from '$lib/components/shared-components/loading-spinner.svelte'; import { handlePromiseError } from '$lib/utils'; import { parseUtcDate } from '$lib/utils/date-time'; @@ -144,6 +144,7 @@ page: nextPage, withExif: true, isVisible: true, + language: $lang, ...terms, }; From 5652bb3bfae6c5aa489cc4c9c36ebdd295cbc8a2 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:27:21 -0400 Subject: [PATCH 2/6] update api --- .../openapi/lib/model/smart_search_dto.dart | 19 ++++++++++++++++++- open-api/immich-openapi-specs.json | 3 +++ open-api/typescript-sdk/src/fetch-client.ts | 1 + 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/mobile/openapi/lib/model/smart_search_dto.dart b/mobile/openapi/lib/model/smart_search_dto.dart index 4e1408cafa737..4f438411a9fe3 100644 --- a/mobile/openapi/lib/model/smart_search_dto.dart +++ b/mobile/openapi/lib/model/smart_search_dto.dart @@ -25,6 +25,7 @@ class SmartSearchDto { this.isNotInAlbum, this.isOffline, this.isVisible, + this.language, this.lensModel, this.libraryId, this.make, @@ -130,6 +131,14 @@ class SmartSearchDto { /// bool? isVisible; + /// + /// Please note: This property should have been non-nullable! Since the specification file + /// does not include a default value (using the "default:" property), however, the generated + /// source code must fall back to having a nullable type. + /// Consider adding a "default:" property in the specification file to hide this note. + /// + String? language; + String? lensModel; String? libraryId; @@ -257,6 +266,7 @@ class SmartSearchDto { other.isNotInAlbum == isNotInAlbum && other.isOffline == isOffline && other.isVisible == isVisible && + other.language == language && other.lensModel == lensModel && other.libraryId == libraryId && other.make == make && @@ -292,6 +302,7 @@ class SmartSearchDto { (isNotInAlbum == null ? 0 : isNotInAlbum!.hashCode) + (isOffline == null ? 0 : isOffline!.hashCode) + (isVisible == null ? 0 : isVisible!.hashCode) + + (language == null ? 0 : language!.hashCode) + (lensModel == null ? 0 : lensModel!.hashCode) + (libraryId == null ? 0 : libraryId!.hashCode) + (make == null ? 0 : make!.hashCode) + @@ -313,7 +324,7 @@ class SmartSearchDto { (withExif == null ? 0 : withExif!.hashCode); @override - String toString() => 'SmartSearchDto[city=$city, country=$country, createdAfter=$createdAfter, createdBefore=$createdBefore, deviceId=$deviceId, isArchived=$isArchived, isEncoded=$isEncoded, isFavorite=$isFavorite, isMotion=$isMotion, isNotInAlbum=$isNotInAlbum, isOffline=$isOffline, isVisible=$isVisible, lensModel=$lensModel, libraryId=$libraryId, make=$make, model=$model, page=$page, personIds=$personIds, query=$query, size=$size, state=$state, takenAfter=$takenAfter, takenBefore=$takenBefore, trashedAfter=$trashedAfter, trashedBefore=$trashedBefore, type=$type, updatedAfter=$updatedAfter, updatedBefore=$updatedBefore, withArchived=$withArchived, withDeleted=$withDeleted, withExif=$withExif]'; + String toString() => 'SmartSearchDto[city=$city, country=$country, createdAfter=$createdAfter, createdBefore=$createdBefore, deviceId=$deviceId, isArchived=$isArchived, isEncoded=$isEncoded, isFavorite=$isFavorite, isMotion=$isMotion, isNotInAlbum=$isNotInAlbum, isOffline=$isOffline, isVisible=$isVisible, language=$language, lensModel=$lensModel, libraryId=$libraryId, make=$make, model=$model, page=$page, personIds=$personIds, query=$query, size=$size, state=$state, takenAfter=$takenAfter, takenBefore=$takenBefore, trashedAfter=$trashedAfter, trashedBefore=$trashedBefore, type=$type, updatedAfter=$updatedAfter, updatedBefore=$updatedBefore, withArchived=$withArchived, withDeleted=$withDeleted, withExif=$withExif]'; Map toJson() { final json = {}; @@ -377,6 +388,11 @@ class SmartSearchDto { } else { // json[r'isVisible'] = null; } + if (this.language != null) { + json[r'language'] = this.language; + } else { + // json[r'language'] = null; + } if (this.lensModel != null) { json[r'lensModel'] = this.lensModel; } else { @@ -484,6 +500,7 @@ class SmartSearchDto { isNotInAlbum: mapValueOfType(json, r'isNotInAlbum'), isOffline: mapValueOfType(json, r'isOffline'), isVisible: mapValueOfType(json, r'isVisible'), + language: mapValueOfType(json, r'language'), lensModel: mapValueOfType(json, r'lensModel'), libraryId: mapValueOfType(json, r'libraryId'), make: mapValueOfType(json, r'make'), diff --git a/open-api/immich-openapi-specs.json b/open-api/immich-openapi-specs.json index 415cc663f4adf..97730a53008dd 100644 --- a/open-api/immich-openapi-specs.json +++ b/open-api/immich-openapi-specs.json @@ -11319,6 +11319,9 @@ "isVisible": { "type": "boolean" }, + "language": { + "type": "string" + }, "lensModel": { "nullable": true, "type": "string" diff --git a/open-api/typescript-sdk/src/fetch-client.ts b/open-api/typescript-sdk/src/fetch-client.ts index 2077943bf8218..baf70617a267f 100644 --- a/open-api/typescript-sdk/src/fetch-client.ts +++ b/open-api/typescript-sdk/src/fetch-client.ts @@ -881,6 +881,7 @@ export type SmartSearchDto = { isNotInAlbum?: boolean; isOffline?: boolean; isVisible?: boolean; + language?: string; lensModel?: string | null; libraryId?: string | null; make?: string; From 0bcfbc9ca7675c85e94c3cd43df12ebea300303f Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:52:06 -0400 Subject: [PATCH 3/6] add server unit tests --- server/src/services/search.service.spec.ts | 106 ++++++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/server/src/services/search.service.spec.ts b/server/src/services/search.service.spec.ts index e0b03f31aee3b..548a058c79f7a 100644 --- a/server/src/services/search.service.spec.ts +++ b/server/src/services/search.service.spec.ts @@ -1,11 +1,16 @@ +import { BadRequestException } from '@nestjs/common'; import { mapAsset } from 'src/dtos/asset-response.dto'; import { SearchSuggestionType } from 'src/dtos/search.dto'; import { IAssetRepository } from 'src/interfaces/asset.interface'; +import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface'; +import { IPartnerRepository } from 'src/interfaces/partner.interface'; import { IPersonRepository } from 'src/interfaces/person.interface'; import { ISearchRepository } from 'src/interfaces/search.interface'; +import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface'; import { SearchService } from 'src/services/search.service'; import { assetStub } from 'test/fixtures/asset.stub'; import { authStub } from 'test/fixtures/auth.stub'; +import { partnerStub } from 'test/fixtures/partner.stub'; import { personStub } from 'test/fixtures/person.stub'; import { newTestService } from 'test/utils'; import { Mocked, beforeEach, vitest } from 'vitest'; @@ -16,11 +21,15 @@ describe(SearchService.name, () => { let sut: SearchService; let assetMock: Mocked; + let machineLearningMock: Mocked; + let partnerMock: Mocked; let personMock: Mocked; let searchMock: Mocked; + let systemMock: Mocked; beforeEach(() => { - ({ sut, assetMock, personMock, searchMock } = newTestService(SearchService)); + ({ sut, assetMock, machineLearningMock, partnerMock, personMock, searchMock, systemMock } = + newTestService(SearchService)); }); it('should work', () => { @@ -80,4 +89,99 @@ describe(SearchService.name, () => { expect(searchMock.getCountries).toHaveBeenCalledWith([authStub.user1.user.id]); }); }); + + describe('searchSmart', () => { + beforeEach(() => { + searchMock.searchSmart.mockResolvedValue({ hasNextPage: false, items: [] }); + machineLearningMock.encodeText.mockResolvedValue([1, 2, 3]); + }); + + it('should raise a BadRequestException if machine learning is disabled', async () => { + systemMock.get.mockResolvedValue({ + machineLearning: { enabled: false }, + }); + + await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError( + new BadRequestException('Smart search is not enabled'), + ); + }); + + it('should raise a BadRequestException if smart search is disabled', async () => { + systemMock.get.mockResolvedValue({ + machineLearning: { clip: { enabled: false } }, + }); + + await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError( + new BadRequestException('Smart search is not enabled'), + ); + }); + + it('should work', async () => { + await sut.searchSmart(authStub.user1, { query: 'test' }); + + expect(machineLearningMock.encodeText).toHaveBeenCalledWith( + expect.any(String), + 'test', + expect.objectContaining({ modelName: expect.any(String) }), + ); + expect(searchMock.searchSmart).toHaveBeenCalledWith( + { page: 1, size: 100 }, + { query: 'test', embedding: [1, 2, 3], userIds: [authStub.user1.user.id] }, + ); + }); + + it('should include partner shared assets', async () => { + partnerMock.getAll.mockResolvedValue([partnerStub.adminToUser1]); + + await sut.searchSmart(authStub.user1, { query: 'test' }); + + expect(machineLearningMock.encodeText).toHaveBeenCalledWith( + expect.any(String), + 'test', + expect.objectContaining({ modelName: expect.any(String) }), + ); + expect(searchMock.searchSmart).toHaveBeenCalledWith( + { page: 1, size: 100 }, + { query: 'test', embedding: [1, 2, 3], userIds: [authStub.user1.user.id, authStub.admin.user.id] }, + ); + }); + + it('should consider page and size parameters', async () => { + await sut.searchSmart(authStub.user1, { query: 'test', page: 2, size: 50 }); + + expect(machineLearningMock.encodeText).toHaveBeenCalledWith( + expect.any(String), + 'test', + expect.objectContaining({ modelName: expect.any(String) }), + ); + expect(searchMock.searchSmart).toHaveBeenCalledWith( + { page: 2, size: 50 }, + expect.objectContaining({ query: 'test', embedding: [1, 2, 3], userIds: [authStub.user1.user.id] }), + ); + }); + + it('should use clip model specified in config', async () => { + systemMock.get.mockResolvedValue({ + machineLearning: { clip: { modelName: 'ViT-B-16-SigLIP__webli' } }, + }); + + await sut.searchSmart(authStub.user1, { query: 'test' }); + + expect(machineLearningMock.encodeText).toHaveBeenCalledWith( + expect.any(String), + 'test', + expect.objectContaining({ modelName: 'ViT-B-16-SigLIP__webli' }), + ); + }); + + it('should use language specified in request', async () => { + await sut.searchSmart(authStub.user1, { query: 'test', language: 'de' }); + + expect(machineLearningMock.encodeText).toHaveBeenCalledWith( + expect.any(String), + 'test', + expect.objectContaining({ language: 'de' }), + ); + }); + }); }); From c2ecf825501f002f7b34f457d9eaf3a4fcd28871 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:58:49 -0400 Subject: [PATCH 4/6] add ml unit tests --- machine-learning/app/models/clip/textual.py | 2 -- machine-learning/app/test_main.py | 40 +++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/machine-learning/app/models/clip/textual.py b/machine-learning/app/models/clip/textual.py index b164dcc17cca5..28e5c8102c93c 100644 --- a/machine-learning/app/models/clip/textual.py +++ b/machine-learning/app/models/clip/textual.py @@ -99,8 +99,6 @@ def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[ text = clean_text(text, canonicalize=self.canonicalize) if self.is_nllb: flores_code = code if language and (code := WEBLATE_TO_FLORES200.get(language)) else "eng_Latn" - print(f"{language=}") - print(f"{flores_code=}") text = f"{flores_code}{text}" tokens: Encoding = self.tokenizer.encode(text) return {"text": np.array([tokens.ids], dtype=np.int32)} diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index 50ec188aa4ed6..362c76a50daf0 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -426,6 +426,46 @@ def test_openclip_tokenizer_canonicalizes_text( assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0) mock_tokenizer.encode.assert_called_once_with("test search query") + def test_openclip_tokenizer_adds_flores_token_for_nllb( + self, + mocker: MockerFixture, + clip_model_cfg: dict[str, Any], + clip_tokenizer_cfg: Callable[[Path], dict[str, Any]], + ) -> None: + mocker.patch.object(OpenClipTextualEncoder, "download") + mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg) + mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg) + mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value + mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value + mock_ids = [randint(0, 50000) for _ in range(77)] + mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids) + + clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache") + clip_encoder._load() + clip_encoder.tokenize("test search query", language="de") + + mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query") + + def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model( + self, + mocker: MockerFixture, + clip_model_cfg: dict[str, Any], + clip_tokenizer_cfg: Callable[[Path], dict[str, Any]], + ) -> None: + mocker.patch.object(OpenClipTextualEncoder, "download") + mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg) + mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg) + mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value + mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value + mock_ids = [randint(0, 50000) for _ in range(77)] + mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids) + + clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache") + clip_encoder._load() + clip_encoder.tokenize("test search query", language="de") + + mock_tokenizer.encode.assert_called_once_with("test search query") + def test_mclip_tokenizer( self, mocker: MockerFixture, From b7888cb345d61950b59e2be6d9d88b319f7140d6 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Thu, 17 Oct 2024 19:13:26 -0400 Subject: [PATCH 5/6] support mobile --- mobile/lib/models/search/search_filter.model.dart | 8 +++++++- mobile/lib/pages/search/search.page.dart | 1 + mobile/lib/services/search.service.dart | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mobile/lib/models/search/search_filter.model.dart b/mobile/lib/models/search/search_filter.model.dart index 47baf356b7f6a..53dd79fa48538 100644 --- a/mobile/lib/models/search/search_filter.model.dart +++ b/mobile/lib/models/search/search_filter.model.dart @@ -235,6 +235,7 @@ class SearchDisplayFilters { class SearchFilter { String? context; String? filename; + String? language; Set people; SearchLocationFilter location; SearchCameraFilter camera; @@ -247,6 +248,7 @@ class SearchFilter { SearchFilter({ this.context, this.filename, + this.language, required this.people, required this.location, required this.camera, @@ -258,6 +260,7 @@ class SearchFilter { SearchFilter copyWith({ String? context, String? filename, + String? language, Set? people, SearchLocationFilter? location, SearchCameraFilter? camera, @@ -268,6 +271,7 @@ class SearchFilter { return SearchFilter( context: context, filename: filename, + language: language, people: people ?? this.people, location: location ?? this.location, camera: camera ?? this.camera, @@ -279,7 +283,7 @@ class SearchFilter { @override String toString() { - return 'SearchFilter(context: $context, filename: $filename, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)'; + return 'SearchFilter(context: $context, filename: $filename, language: $language, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)'; } @override @@ -288,6 +292,7 @@ class SearchFilter { return other.context == context && other.filename == filename && + other.language == language && other.people == people && other.location == location && other.camera == camera && @@ -300,6 +305,7 @@ class SearchFilter { int get hashCode { return context.hashCode ^ filename.hashCode ^ + language.hashCode ^ people.hashCode ^ location.hashCode ^ camera.hashCode ^ diff --git a/mobile/lib/pages/search/search.page.dart b/mobile/lib/pages/search/search.page.dart index 60e61da4cc5d5..632c5b6e983bf 100644 --- a/mobile/lib/pages/search/search.page.dart +++ b/mobile/lib/pages/search/search.page.dart @@ -46,6 +46,7 @@ class SearchPage extends HookConsumerWidget { isFavorite: false, ), mediaType: prefilter?.mediaType ?? AssetType.other, + language: context.locale.languageCode, ), ); diff --git a/mobile/lib/services/search.service.dart b/mobile/lib/services/search.service.dart index 336fe450108d3..8312ad7c05348 100644 --- a/mobile/lib/services/search.service.dart +++ b/mobile/lib/services/search.service.dart @@ -58,6 +58,7 @@ class SearchService { response = await _apiService.searchApi.searchSmart( SmartSearchDto( query: filter.context!, + language: filter.language, country: filter.location.country, state: filter.location.state, city: filter.location.city, From 8ea15a99c893667b7235b1db7a9febc1cd83ca18 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Thu, 17 Oct 2024 20:31:55 -0400 Subject: [PATCH 6/6] add docs formatting clean up docs wording --- ...t-serach.webp => mobile-smart-search.webp} | Bin docs/docs/features/smart-search.md | 34 ++++++++++++++---- 2 files changed, 28 insertions(+), 6 deletions(-) rename docs/docs/features/img/{moblie-smart-serach.webp => mobile-smart-search.webp} (100%) diff --git a/docs/docs/features/img/moblie-smart-serach.webp b/docs/docs/features/img/mobile-smart-search.webp similarity index 100% rename from docs/docs/features/img/moblie-smart-serach.webp rename to docs/docs/features/img/mobile-smart-search.webp diff --git a/docs/docs/features/smart-search.md b/docs/docs/features/smart-search.md index dcbed4f3b4534..e869a872ec729 100644 --- a/docs/docs/features/smart-search.md +++ b/docs/docs/features/smart-search.md @@ -44,7 +44,7 @@ Some search examples: - + @@ -55,16 +55,36 @@ Navigating to `Administration > Settings > Machine Learning Settings > Smart Sea ### CLIP model -More powerful models can be used for more accurate search results, but are slower and can require more server resources. Check out the models [here][huggingface-clip] for more options! +The default search model is fast, but there are many other options that can provide better search results. The tradeoff of using these models is that they use more memory and are slower (both during Smart Search jobs and when searching). For example, the current best model for English, `ViT-H-14-378-quickgelu__dfn5b`, is roughly 72x slower and uses approximates 4.3GiB of memory compared to 801MiB with the default model `ViT-B-32__openai`. + +The first step of choosing the right model for you is to decide which languages your users will search in. + +If your users will only search in English, then the recommended [CLIP][huggingface-clip] section is the best place to look. This is a curated list of the models that generally perform the best for their size class. The models here are ordered from higher to lower quality. This means that the top models will generally rank the most relevant results higher and have a higher capacity to understand descriptive, detailed, and/or niche queries. They models are also generally ordered from larger to smaller, so consider the impact on memory usage, job processing and search speed when deciding on one. The smaller models in this list are not too different in quality and many times faster. + +[Multilingual models][huggingface-multilingual-clip] are also available so users can search in their native language. Use these models if you expect non-English searches to be common. They can be separated into two search patterns: + +- `nllb` models expect the search query to be in the language specified in the user settings +- `xlm` models understand search text regardless of the current language setting + +`nllb` models perform the best and are recommended when users primarily searches in their native, non-English language. `xlm` models are more flexible and are recommended for mixed language search, where the same user might search in different languages at different times. + +A third option is if your users will search entirely in major Western European languages, such as English, Spanish, French and German. The `ViT-H-14-quickgelu__dfn5b` and `ViT-H-14-378-quickgelu__dfn5b` models perform the best for these languages and are similarly flexible as `xlm` models. However, they understand very few languages compared to the explicitly multilingual `nllb` and `xlm` models, so don't use them for other languages. -[Multilingual models][huggingface-multilingual-clip] are also available so users can search in their native language. These models support over 100 languages; the `nllb` models in particular support 200. :::note Multilingual models are much slower and larger and perform slightly worse for English than English-only models. For this reason, only use them if you actually intend to search in a language besides English. - -As a special case, the `ViT-H-14-quickgelu__dfn5b` and `ViT-H-14-378-quickgelu__dfn5b` models are excellent at many European languages despite not specifically being multilingual. They're very intensive regardless, however - especially the latter. ::: -Once you've chosen a model, change this setting to the name of the model you chose. Be sure to re-run Smart Search on all assets after this change. +Once you've chosen a model, follow these steps: + +1. Copy the name of the model (e.g. `ViT-B-16-SigLIP__webli`) +2. Go to the [Smart Search settings][smart-search-settings] +3. Paste the model name into the Model Name section +4. Save the settings +5. Go to the [Job Status page][job-status-page] +6. Click "All" next to "Smart Search" to begin re-processing your assets with the new model +7. (Optional) Confirm that the logs for the server and machine learning service don't have relevant errors + +In rare instances, changing the model might leave bits of the old model's incompatible data in the database, causing errors when processing Smart Search jobs. If you notice errors like this in the logs, you can change the model back to the previous one and save, then back again to the new model. :::note Feel free to make a feature request if there's a model you want to use that we don't currently support. @@ -72,3 +92,5 @@ Feel free to make a feature request if there's a model you want to use that we d [huggingface-clip]: https://huggingface.co/collections/immich-app/clip-654eaefb077425890874cd07 [huggingface-multilingual-clip]: https://huggingface.co/collections/immich-app/multilingual-clip-654eb08c2382f591eeb8c2a7 +[smart-search-settings]: https://my.immich.app/admin/system-settings?isOpen=machine-learning+smart-search +[job-status-page]: https://my.immich.app/admin/jobs-status