Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ml): better multilingual search with nllb models #13567

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions docs/docs/features/smart-search.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Some search examples:
</TabItem>
<TabItem value="Mobile" label="Mobile">

<img src={require('./img/moblie-smart-serach.webp').default} width="30%" title='Smart search on mobile' />
<img src={require('./img/mobile-smart-search.webp').default} width="30%" title='Smart search on mobile' />

</TabItem>
</Tabs>
Expand All @@ -55,20 +55,42 @@ Navigating to `Administration > Settings > Machine Learning Settings > Smart Sea

### CLIP model

More powerful models can be used for more accurate search results, but are slower and can require more server resources. Check out the models [here][huggingface-clip] for more options!
The default search model is fast, but there are many other options that can provide better search results. The tradeoff of using these models is that they use more memory and are slower (both during Smart Search jobs and when searching). For example, the current best model for English, `ViT-H-14-378-quickgelu__dfn5b`, is roughly 72x slower and uses approximates 4.3GiB of memory compared to 801MiB with the default model `ViT-B-32__openai`.

The first step of choosing the right model for you is to decide which languages your users will search in.

If your users will only search in English, then the recommended [CLIP][huggingface-clip] section is the best place to look. This is a curated list of the models that generally perform the best for their size class. The models here are ordered from higher to lower quality. This means that the top models will generally rank the most relevant results higher and have a higher capacity to understand descriptive, detailed, and/or niche queries. They models are also generally ordered from larger to smaller, so consider the impact on memory usage, job processing and search speed when deciding on one. The smaller models in this list are not too different in quality and many times faster.

[Multilingual models][huggingface-multilingual-clip] are also available so users can search in their native language. Use these models if you expect non-English searches to be common. They can be separated into two search patterns:

- `nllb` models expect the search query to be in the language specified in the user settings
- `xlm` models understand search text regardless of the current language setting

`nllb` models perform the best and are recommended when users primarily searches in their native, non-English language. `xlm` models are more flexible and are recommended for mixed language search, where the same user might search in different languages at different times.

A third option is if your users will search entirely in major Western European languages, such as English, Spanish, French and German. The `ViT-H-14-quickgelu__dfn5b` and `ViT-H-14-378-quickgelu__dfn5b` models perform the best for these languages and are similarly flexible as `xlm` models. However, they understand very few languages compared to the explicitly multilingual `nllb` and `xlm` models, so don't use them for other languages.

[Multilingual models][huggingface-multilingual-clip] are also available so users can search in their native language. These models support over 100 languages; the `nllb` models in particular support 200.
:::note
Multilingual models are much slower and larger and perform slightly worse for English than English-only models. For this reason, only use them if you actually intend to search in a language besides English.

As a special case, the `ViT-H-14-quickgelu__dfn5b` and `ViT-H-14-378-quickgelu__dfn5b` models are excellent at many European languages despite not specifically being multilingual. They're very intensive regardless, however - especially the latter.
:::

Once you've chosen a model, change this setting to the name of the model you chose. Be sure to re-run Smart Search on all assets after this change.
Once you've chosen a model, follow these steps:

1. Copy the name of the model (e.g. `ViT-B-16-SigLIP__webli`)
2. Go to the [Smart Search settings][smart-search-settings]
3. Paste the model name into the Model Name section
4. Save the settings
5. Go to the [Job Status page][job-status-page]
6. Click "All" next to "Smart Search" to begin re-processing your assets with the new model
7. (Optional) Confirm that the logs for the server and machine learning service don't have relevant errors

In rare instances, changing the model might leave bits of the old model's incompatible data in the database, causing errors when processing Smart Search jobs. If you notice errors like this in the logs, you can change the model back to the previous one and save, then back again to the new model.

:::note
Feel free to make a feature request if there's a model you want to use that we don't currently support.
:::

[huggingface-clip]: https://huggingface.co/collections/immich-app/clip-654eaefb077425890874cd07
[huggingface-multilingual-clip]: https://huggingface.co/collections/immich-app/multilingual-clip-654eb08c2382f591eeb8c2a7
[smart-search-settings]: https://my.immich.app/admin/system-settings?isOpen=machine-learning+smart-search
[job-status-page]: https://my.immich.app/admin/jobs-status
16 changes: 11 additions & 5 deletions machine-learning/app/models/clip/textual.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from app.config import log
from app.models.base import InferenceModel
from app.models.constants import WEBLATE_TO_FLORES200
from app.models.transforms import clean_text
from app.schemas import ModelSession, ModelTask, ModelType

Expand All @@ -18,8 +19,9 @@ class BaseCLIPTextualEncoder(InferenceModel):
depends = []
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)

def _predict(self, inputs: str, **kwargs: Any) -> NDArray[np.float32]:
res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0]
def _predict(self, inputs: str, language: str | None = None, **kwargs: Any) -> NDArray[np.float32]:
tokens = self.tokenize(inputs, language=language)
res: NDArray[np.float32] = self.session.run(None, tokens)[0][0]
return res

def _load(self) -> ModelSession:
Expand All @@ -28,6 +30,7 @@ def _load(self) -> ModelSession:
self.tokenizer = self._load_tokenizer()
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
self.is_nllb = self.model_name.startswith("nllb")
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")

return session
Expand All @@ -37,7 +40,7 @@ def _load_tokenizer(self) -> Tokenizer:
pass

@abstractmethod
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
pass

@property
Expand Down Expand Up @@ -92,14 +95,17 @@ def _load_tokenizer(self) -> Tokenizer:

return tokenizer

def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
text = clean_text(text, canonicalize=self.canonicalize)
if self.is_nllb:
flores_code = code if language and (code := WEBLATE_TO_FLORES200.get(language)) else "eng_Latn"
text = f"{flores_code}{text}"
tokens: Encoding = self.tokenizer.encode(text)
return {"text": np.array([tokens.ids], dtype=np.int32)}


class MClipTextualEncoder(OpenClipTextualEncoder):
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
text = clean_text(text, canonicalize=self.canonicalize)
tokens: Encoding = self.tokenizer.encode(text)
return {
Expand Down
59 changes: 59 additions & 0 deletions machine-learning/app/models/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,65 @@
SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]


WEBLATE_TO_FLORES200 = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder where this mapping should live. To some extent the server knows which model it is using when it sends the request so you could argue this belongs there. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say it's better for it to live here - it makes things easier if you query ML directly or use the models in another context like in notebooks.

"af": "afr_Latn",
"ar": "arb_Arab",
"az": "azj_Latn",
"be": "bel_Cyrl",
"bg": "bul_Cyrl",
"ca": "cat_Latn",
"cs": "ces_Latn",
"da": "dan_Latn",
"de": "deu_Latn",
"el": "ell_Grek",
"en": "eng_Latn",
"es": "spa_Latn",
"et": "est_Latn",
"fa": "pes_Arab",
"fi": "fin_Latn",
"fr": "fra_Latn",
"he": "heb_Hebr",
"hi": "hin_Deva",
"hr": "hrv_Latn",
"hu": "hun_Latn",
"hy": "hye_Armn",
"id": "ind_Latn",
"it": "ita_Latn",
"ja": "jpn_Hira",
"kmr": "kmr_Latn",
"ko": "kor_Hang",
"lb": "ltz_Latn",
"lt": "lit_Latn",
"lv": "lav_Latn",
"mfa": "zsm_Latn",
"mk": "mkd_Cyrl",
"mn": "khk_Cyrl",
"mr": "mar_Deva",
"ms": "zsm_Latn",
"nb_NO": "nob_Latn",
"nl": "nld_Latn",
"pl": "pol_Latn",
"pt_BR": "por_Latn",
"pt": "por_Latn",
"ro": "ron_Latn",
"ru": "rus_Cyrl",
"sk": "slk_Latn",
"sl": "slv_Latn",
"sr_Cyrl": "srp_Cyrl",
"sv": "swe_Latn",
"ta": "tam_Taml",
"te": "tel_Telu",
"th": "tha_Thai",
"tr": "tur_Latn",
"uk": "ukr_Cyrl",
"vi": "vie_Latn",
"zh-CN": "zho_Hans",
"zh-TW": "zho_Hant",
"zh_Hant": "zho_Hant",
"zh_SIMPLIFIED": "zho_Hans",
}


def get_model_source(model_name: str) -> ModelSource | None:
cleaned_name = clean_name(model_name)

Expand Down
40 changes: 40 additions & 0 deletions machine-learning/app/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,46 @@ def test_openclip_tokenizer_canonicalizes_text(
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
mock_tokenizer.encode.assert_called_once_with("test search query")

def test_openclip_tokenizer_adds_flores_token_for_nllb(
self,
mocker: MockerFixture,
clip_model_cfg: dict[str, Any],
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
) -> None:
mocker.patch.object(OpenClipTextualEncoder, "download")
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
mock_ids = [randint(0, 50000) for _ in range(77)]
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)

clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
clip_encoder._load()
clip_encoder.tokenize("test search query", language="de")

mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")

def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model(
self,
mocker: MockerFixture,
clip_model_cfg: dict[str, Any],
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
) -> None:
mocker.patch.object(OpenClipTextualEncoder, "download")
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
mock_ids = [randint(0, 50000) for _ in range(77)]
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)

clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
clip_encoder._load()
clip_encoder.tokenize("test search query", language="de")

mock_tokenizer.encode.assert_called_once_with("test search query")

def test_mclip_tokenizer(
self,
mocker: MockerFixture,
Expand Down
8 changes: 7 additions & 1 deletion mobile/lib/models/search/search_filter.model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class SearchDisplayFilters {
class SearchFilter {
String? context;
String? filename;
String? language;
Set<Person> people;
SearchLocationFilter location;
SearchCameraFilter camera;
Expand All @@ -247,6 +248,7 @@ class SearchFilter {
SearchFilter({
this.context,
this.filename,
this.language,
required this.people,
required this.location,
required this.camera,
Expand All @@ -258,6 +260,7 @@ class SearchFilter {
SearchFilter copyWith({
String? context,
String? filename,
String? language,
Set<Person>? people,
SearchLocationFilter? location,
SearchCameraFilter? camera,
Expand All @@ -268,6 +271,7 @@ class SearchFilter {
return SearchFilter(
context: context,
filename: filename,
language: language,
people: people ?? this.people,
location: location ?? this.location,
camera: camera ?? this.camera,
Expand All @@ -279,7 +283,7 @@ class SearchFilter {

@override
String toString() {
return 'SearchFilter(context: $context, filename: $filename, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)';
return 'SearchFilter(context: $context, filename: $filename, language: $language, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)';
}

@override
Expand All @@ -288,6 +292,7 @@ class SearchFilter {

return other.context == context &&
other.filename == filename &&
other.language == language &&
other.people == people &&
other.location == location &&
other.camera == camera &&
Expand All @@ -300,6 +305,7 @@ class SearchFilter {
int get hashCode {
return context.hashCode ^
filename.hashCode ^
language.hashCode ^
people.hashCode ^
location.hashCode ^
camera.hashCode ^
Expand Down
1 change: 1 addition & 0 deletions mobile/lib/pages/search/search.page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class SearchPage extends HookConsumerWidget {
isFavorite: false,
),
mediaType: prefilter?.mediaType ?? AssetType.other,
language: context.locale.languageCode,
),
);

Expand Down
1 change: 1 addition & 0 deletions mobile/lib/services/search.service.dart
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class SearchService {
response = await _apiService.searchApi.searchSmart(
SmartSearchDto(
query: filter.context!,
language: filter.language,
country: filter.location.country,
state: filter.location.state,
city: filter.location.city,
Expand Down
19 changes: 18 additions & 1 deletion mobile/openapi/lib/model/smart_search_dto.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions open-api/immich-openapi-specs.json
Original file line number Diff line number Diff line change
Expand Up @@ -11319,6 +11319,9 @@
"isVisible": {
"type": "boolean"
},
"language": {
"type": "string"
},
"lensModel": {
"nullable": true,
"type": "string"
Expand Down
1 change: 1 addition & 0 deletions open-api/typescript-sdk/src/fetch-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ export type SmartSearchDto = {
isNotInAlbum?: boolean;
isOffline?: boolean;
isVisible?: boolean;
language?: string;
lensModel?: string | null;
libraryId?: string | null;
make?: string;
Expand Down
5 changes: 5 additions & 0 deletions server/src/dtos/search.dto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ export class SmartSearchDto extends BaseSearchDto {
@IsNotEmpty()
query!: string;

@IsString()
@IsNotEmpty()
@Optional()
language?: string;

@IsInt()
@Min(1)
@Type(() => Number)
Expand Down
Loading
Loading