Skip to content

[ML] SPLADE embedding support #131679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

daixque
Copy link
Contributor

@daixque daixque commented Jul 22, 2025

Overview

Elasticsearch supports sparse embeddings including non-ELSER models which is introduced by #116935.

A significant example of sparse vector model is the SPLADE model, which is a reference model for ELSER. But the format of the output of SPLADE is not same as ELSER, so we need to implement a post processing step for it.

Interface change

This PR introduces a new expansion_type parameter to trained model resource.

GET /_ml/trained_models/hotchpotch__japanese-splade-v2?pretty
{
  "count" : 1,
  "trained_model_configs" : [
    {
      "model_id" : "hotchpotch__japanese-splade-v2",
      "model_type" : "pytorch",
      "created_by" : "api_user",
      "version" : "12.0.0",
      "create_time" : 1753171035706,
      "model_size_bytes" : 0,
      "estimated_operations" : 0,
      "license_level" : "platinum",
      "description" : "Model hotchpotch/japanese-splade-v2 for task type 'text_expansion'",
      "tags" : [ ],
      "metadata" : {
        "per_allocation_memory_bytes" : 392242176,
        "per_deployment_memory_bytes" : 545637376
      },
      "input" : {
        "field_names" : [
          "text_field"
        ]
      },
      "inference_config" : {
        "text_expansion" : {
          "vocabulary" : {
            "index" : ".ml-inference-native-000002"
          },
          "tokenization" : {
            "bert_ja" : {
              "do_lower_case" : false,
              "with_special_tokens" : true,
              "max_sequence_length" : 512,
              "truncate" : "first",
              "span" : -1
            }
          },
          "expansion_type" : "splade"
        }
      },
      "location" : {
        "index" : {
          "name" : ".ml-inference-native-000002"
        }
      }
    }
  ]
}

expansion_type can be one of elser or splade. If it is not specified, it defaults to elser.

To use the SPLADE model, eland needs to be updated to support the expansion_type parameter.
elastic/eland#802

Logic for SPLADE

SPLADE model outputs a embedding in the shape of [1, input_token_size, vocab_size]. The second dimention is different from ELSER, which is [1, chunk_size, vocab_size].
For SPLADE, we need to apply the saturation function to the output, which is log(1 + relu(x)), and then apply max pooling to the second dimension.

To be considered

  • Currently, we only implement max pooling for SPLADE. Should we also implement sum and/or other pooling?
  • Top-k expansion is not supported. Some models produce a large amount of output from small input. Should we implement it? If yes, we need to have another option, and it can not be determined by eland automatically.
  • I observed OOM when I send infer request with a little long text for japanese-splade-v2 model.

Reference

  • I referred this for the spladeMaxPooling implementation.

@daixque daixque requested a review from davidkyle July 22, 2025 08:44
@elasticsearchmachine elasticsearchmachine added needs:triage Requires assignment of a team area label v9.2.0 external-contributor Pull request authored by a developer outside the Elasticsearch team labels Jul 22, 2025
@daixque daixque added :ml Machine learning Team:ML Meta label for the ML team labels Jul 22, 2025
@daixque daixque changed the title SPLADE embedding support [ML] SPLADE embedding support Jul 22, 2025
@elasticsearchmachine elasticsearchmachine removed the needs:triage Requires assignment of a team area label label Jul 22, 2025
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

@@ -121,11 +135,12 @@ public InferenceConfig apply(InferenceConfigUpdate update) {
return new TextExpansionConfig(
vocabularyConfig,
configUpdate.tokenizationUpdate == null ? tokenization : configUpdate.tokenizationUpdate.apply(tokenization),
Optional.ofNullable(configUpdate.getResultsField()).orElse(resultsField)
Optional.ofNullable(configUpdate.getResultsField()).orElse(resultsField),
Optional.ofNullable(configUpdate.getExpansionType()).orElse(expansionType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the expansionType changes the way the the output is processed and may not be compatible with the different types of model. Switching this value for the ELSER model would break the processing.

Fixing expansionType when the config is created and not allowing it to be overridden at inference is good enough. If the wrong expansionType value is set the user will have to recreate the model

@@ -24,22 +24,24 @@

import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig.EXPANSION_TYPE;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above regarding updating expansion type. I think the changes in this file can be reverted.

}

public TextExpansionConfig(StreamInput in) throws IOException {
vocabularyConfig = new VocabularyConfig(in);
tokenization = in.readNamedWriteable(Tokenization.class);
resultsField = in.readOptionalString();
expansionType = in.readOptionalString();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a mixed version cluster older nodes will not expect this new field to be serialised. We protect against this using TransportVersions. Pls add a new version to TransportVersions.java at

public static final TransportVersion ESQL_TOPN_TIMINGS = def(9_128_0_00);

Suggested change
expansionType = in.readOptionalString();
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_EXPANSION_TYPE)) {
expansionType = in.readOptionalString();
}

}

@Override
public void writeTo(StreamOutput out) throws IOException {
vocabularyConfig.writeTo(out);
out.writeNamedWriteable(tokenization);
out.writeOptionalString(resultsField);
out.writeOptionalString(expansionType);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out.writeOptionalString(expansionType);
if (out.getTransportVersion().onOrAfter(TransportVersions. ML_EXPANSION_TYPE)) {
out.writeOptionalString(expansionType);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
>enhancement external-contributor Pull request authored by a developer outside the Elasticsearch team :ml Machine learning Team:ML Meta label for the ML team v9.2.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants