diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 7d24b7766baa3..c28bd485f0e48 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -74,6 +74,8 @@ protected void doInference( InferenceService service, ActionListener listener ) { + // var rerankChunker = new RerankRequestChunker(request.getInput()); + service.infer( model, request.getQuery(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java new file mode 100644 index 0000000000000..322294034b8b9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class RerankRequestChunker { + + private final ChunkingSettings chunkingSettings; + private final List inputs; + private final Map rerankChunks; + + public RerankRequestChunker(List inputs) { + // TODO: Make chunking settings dependent on the model being used. + // There may be a way to do this dynamically knowing the max token size for the model/service and query size + // instead of hardcoding it ona model/service basis. + this.chunkingSettings = new WordBoundaryChunkingSettings(100, 10); + this.inputs = inputs; + this.rerankChunks = chunk(inputs, chunkingSettings); + } + + private Map chunk(List inputs, ChunkingSettings chunkingSettings) { + var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); + var chunks = new HashMap(); + var chunkIndex = 0; + for (int i = 0; i < inputs.size(); i++) { + var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings); + for (var chunk : chunksForInput) { + chunks.put(chunkIndex, new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end()))); + chunkIndex++; + } + } + return chunks; + } + + public List getChunkedInputs() { + List chunkedInputs = new ArrayList<>(); + for (RerankChunks chunk : rerankChunks.values()) { + chunkedInputs.add(chunk.chunkString()); + } + // TODO: Score the inputs here and only return the top N chunks for each document + return chunkedInputs; + } + + public ActionListener parseChunkedRerankResultsListener(ActionListener listener) { + return ActionListener.wrap(results -> { + if (results.getResults() instanceof RankedDocsResults rankedDocsResults) { + listener.onResponse(new InferenceAction.Response(parseRankedDocResultsForChunks(rankedDocsResults))); + // TODO: Figure out if the above correctly creates the response or if it loses any info + + } else { + listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass())); + } + + }, listener::onFailure); + } + + private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) { + Map bestRankedDocResultPerDoc = new HashMap<>(); + for (var rankedDoc : rankedDocsResults.getRankedDocs()) { + int chunkIndex = rankedDoc.index(); + int docIndex = rerankChunks.get(chunkIndex).docIndex(); + if (bestRankedDocResultPerDoc.containsKey(docIndex)) { + RankedDocsResults.RankedDoc existingDoc = bestRankedDocResultPerDoc.get(docIndex); + if (rankedDoc.relevanceScore() > existingDoc.relevanceScore()) { + bestRankedDocResultPerDoc.put( + docIndex, + new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex)) + ); + } + } else { + bestRankedDocResultPerDoc.put( + docIndex, + new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex)) + ); + } + } + var bestRankedDocResultPerDocList = new ArrayList<>(bestRankedDocResultPerDoc.values()); + bestRankedDocResultPerDocList.sort( + (RankedDocsResults.RankedDoc d1, RankedDocsResults.RankedDoc d2) -> Float.compare(d2.relevanceScore(), d1.relevanceScore()) + ); + return new RankedDocsResults(bestRankedDocResultPerDocList); + } + + public record RerankChunks(int docIndex, String chunkString) {}; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 27221dc1f5caf..c231f48ceadac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.chunking.RerankRequestChunker; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; @@ -119,9 +120,16 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList(); - InferenceAction.Request inferenceRequest = generateRequest(featureData); + RerankRequestChunker chunker = new RerankRequestChunker(featureData); + InferenceAction.Request inferenceRequest = generateRequest(chunker.getChunkedInputs()); try { - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceRequest, inferenceListener); + executeAsyncWithOrigin( + client, + INFERENCE_ORIGIN, + InferenceAction.INSTANCE, + inferenceRequest, + chunker.parseChunkedRerankResultsListener(inferenceListener) + ); } finally { inferenceRequest.decRef(); } @@ -156,6 +164,7 @@ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs, boolean rer } protected InferenceAction.Request generateRequest(List docFeatures) { + // TODO: Try running the RerankRequestChunker here. return new InferenceAction.Request( TaskType.RERANK, inferenceId,