diff --git a/elasticsearch/esql/esql.py b/elasticsearch/esql/esql.py index 05f4e3e3e..3fa176e2e 100644 --- a/elasticsearch/esql/esql.py +++ b/elasticsearch/esql/esql.py @@ -143,7 +143,7 @@ def _is_forked(self) -> bool: return False def change_point(self, value: FieldType) -> "ChangePoint": - """`CHANGE_POINT` detects spikes, dips, and change points in a metric. + """``CHANGE_POINT`` detects spikes, dips, and change points in a metric. :param value: The column with the metric in which you want to detect a change point. @@ -163,17 +163,18 @@ def change_point(self, value: FieldType) -> "ChangePoint": def completion( self, *prompt: ExpressionType, **named_prompt: ExpressionType ) -> "Completion": - """The `COMPLETION` command allows you to send prompts and context to a Large + """The ``COMPLETION`` command allows you to send prompts and context to a Large Language Model (LLM) directly within your ES|QL queries, to perform text generation tasks. :param prompt: The input text or expression used to prompt the LLM. This can be a string literal or a reference to a column containing text. :param named_prompt: The input text or expresion, given as a keyword argument. - The argument name is used for the column name. If not - specified, the results will be stored in a column named - `completion`. If the specified column already exists, it - will be overwritten with the new results. + The argument name is used for the column name. If the + prompt is given as a positional argument, the results will + be stored in a column named ``completion``. If the + specified column already exists, it will be overwritten + with the new results. Examples:: @@ -439,6 +440,54 @@ def rename(self, **columns: FieldType) -> "Rename": """ return Rename(self, **columns) + def rerank(self, *query: ExpressionType, **named_query: ExpressionType) -> "Rerank": + """The ``RERANK`` command uses an inference model to compute a new relevance score + for an initial set of documents, directly within your ES|QL queries. + + :param query: The query text used to rerank the documents. This is typically the + same query used in the initial search. + :param named_query: The query text used to rerank the documents, given as a + keyword argument. The argument name is used for the column + name. If the query is given as a positional argument, the + results will be stored in a column named `_score`. If the + specified column already exists, it will be overwritten with + the new results. + + Examples:: + + query1 = ( + ESQL.from_("books").metadata("_score") + .where('MATCH(description, "hobbit")') + .sort("_score DESC") + .limit(100) + .rerank("hobbit").on("description").with_(inference_id="test_reranker") + .limit(3) + .keep("title", "_score") + ) + query2 = ( + ESQL.from_("books").metadata("_score") + .where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")') + .sort("_score DESC") + .limit(100) + .rerank(rerank_score="hobbit").on("description", "author").with_(inference_id="test_reranker") + .sort("rerank_score") + .limit(3) + .keep("title", "_score", "rerank_score") + ) + query3 = ( + ESQL.from_("books").metadata("_score") + .where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")') + .sort("_score DESC") + .limit(100) + .rerank(rerank_score="hobbit").on("description", "author").with_(inference_id="test_reranker") + .eval(original_score="_score", _score="rerank_score + original_score") + .sort("_score") + .limit(3) + .keep("title", "original_score", "rerank_score", "_score") + ) + """ + return Rerank(self, *query, **named_query) + def sample(self, probability: float) -> "Sample": """The ``SAMPLE`` command samples a fraction of the table rows. @@ -1046,6 +1095,65 @@ def _render_internal(self) -> str: return f'RENAME {", ".join([f"{self._format_id(old_name)} AS {self._format_id(new_name)}" for old_name, new_name in self._columns.items()])}' +class Rerank(ESQLBase): + """Implementation of the ``RERANK`` processing command. + + This class inherits from :class:`ESQLBase `, + to make it possible to chain all the commands that belong to an ES|QL query + in a single expression. + """ + + def __init__( + self, parent: ESQLBase, *query: ExpressionType, **named_query: ExpressionType + ): + if len(query) + len(named_query) > 1: + raise ValueError( + "this method requires either one positional or one keyword argument only" + ) + super().__init__(parent) + self._query = query + self._named_query = named_query + self._fields: Optional[Tuple[str, ...]] = None + self._inference_id: Optional[str] = None + + def on(self, *fields: str) -> "Rerank": + self._fields = fields + return self + + def with_(self, inference_id: str) -> "Rerank": + """Continuation of the `COMPLETION` command. + + :param inference_id: The ID of the inference endpoint to use for the task. The + inference endpoint must be configured with the completion + task type. + """ + self._inference_id = inference_id + return self + + def _render_internal(self) -> str: + if self._fields is None: + raise ValueError( + "The rerank command requires one or more fields to rerank on" + ) + if self._inference_id is None: + raise ValueError("The completion command requires an inference ID") + with_ = {"inference_id": self._inference_id} + if self._named_query: + column = list(self._named_query.keys())[0] + query = list(self._named_query.values())[0] + return ( + f"RERANK {self._format_id(column)} = {json.dumps(query)} " + f"ON {', '.join([self._format_id(field) for field in self._fields])} " + f"WITH {json.dumps(with_)}" + ) + else: + return ( + f"RERANK {json.dumps(self._query[0])} " + f"ON {', '.join([self._format_id(field) for field in self._fields])} " + f"WITH {json.dumps(with_)}" + ) + + class Sample(ESQLBase): """Implementation of the ``SAMPLE`` processing command. diff --git a/test_elasticsearch/test_esql.py b/test_elasticsearch/test_esql.py index 35b026fb5..e33f288da 100644 --- a/test_elasticsearch/test_esql.py +++ b/test_elasticsearch/test_esql.py @@ -422,6 +422,83 @@ def test_rename(): ) +def test_rerank(): + query = ( + ESQL.from_("books") + .metadata("_score") + .where('MATCH(description, "hobbit")') + .sort("_score DESC") + .limit(100) + .rerank("hobbit") + .on("description") + .with_(inference_id="test_reranker") + .limit(3) + .keep("title", "_score") + ) + assert ( + query.render() + == """FROM books METADATA _score +| WHERE MATCH(description, "hobbit") +| SORT _score DESC +| LIMIT 100 +| RERANK "hobbit" ON description WITH {"inference_id": "test_reranker"} +| LIMIT 3 +| KEEP title, _score""" + ) + + query = ( + ESQL.from_("books") + .metadata("_score") + .where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")') + .sort("_score DESC") + .limit(100) + .rerank(rerank_score="hobbit") + .on("description", "author") + .with_(inference_id="test_reranker") + .sort("rerank_score") + .limit(3) + .keep("title", "_score", "rerank_score") + ) + assert ( + query.render() + == """FROM books METADATA _score +| WHERE MATCH(description, "hobbit") OR MATCH(author, "Tolkien") +| SORT _score DESC +| LIMIT 100 +| RERANK rerank_score = "hobbit" ON description, author WITH {"inference_id": "test_reranker"} +| SORT rerank_score +| LIMIT 3 +| KEEP title, _score, rerank_score""" + ) + + query = ( + ESQL.from_("books") + .metadata("_score") + .where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")') + .sort("_score DESC") + .limit(100) + .rerank(rerank_score="hobbit") + .on("description", "author") + .with_(inference_id="test_reranker") + .eval(original_score="_score", _score="rerank_score + original_score") + .sort("_score") + .limit(3) + .keep("title", "original_score", "rerank_score", "_score") + ) + assert ( + query.render() + == """FROM books METADATA _score +| WHERE MATCH(description, "hobbit") OR MATCH(author, "Tolkien") +| SORT _score DESC +| LIMIT 100 +| RERANK rerank_score = "hobbit" ON description, author WITH {"inference_id": "test_reranker"} +| EVAL original_score = _score, _score = rerank_score + original_score +| SORT _score +| LIMIT 3 +| KEEP title, original_score, rerank_score, _score""" + ) + + def test_sample(): query = ESQL.from_("employees").keep("emp_no").sample(0.05) assert (