Skip to content

add rerank() command to the ES|QL query builder #3043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 114 additions & 6 deletions elasticsearch/esql/esql.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _is_forked(self) -> bool:
return False

def change_point(self, value: FieldType) -> "ChangePoint":
"""`CHANGE_POINT` detects spikes, dips, and change points in a metric.
"""``CHANGE_POINT`` detects spikes, dips, and change points in a metric.

:param value: The column with the metric in which you want to detect a change point.

Expand All @@ -163,17 +163,18 @@ def change_point(self, value: FieldType) -> "ChangePoint":
def completion(
self, *prompt: ExpressionType, **named_prompt: ExpressionType
) -> "Completion":
"""The `COMPLETION` command allows you to send prompts and context to a Large
"""The ``COMPLETION`` command allows you to send prompts and context to a Large
Language Model (LLM) directly within your ES|QL queries, to perform text
generation tasks.

:param prompt: The input text or expression used to prompt the LLM. This can
be a string literal or a reference to a column containing text.
:param named_prompt: The input text or expresion, given as a keyword argument.
The argument name is used for the column name. If not
specified, the results will be stored in a column named
`completion`. If the specified column already exists, it
will be overwritten with the new results.
The argument name is used for the column name. If the
prompt is given as a positional argument, the results will
be stored in a column named ``completion``. If the
specified column already exists, it will be overwritten
with the new results.

Examples::

Expand Down Expand Up @@ -439,6 +440,54 @@ def rename(self, **columns: FieldType) -> "Rename":
"""
return Rename(self, **columns)

def rerank(self, *query: ExpressionType, **named_query: ExpressionType) -> "Rerank":
"""The ``RERANK`` command uses an inference model to compute a new relevance score
for an initial set of documents, directly within your ES|QL queries.

:param query: The query text used to rerank the documents. This is typically the
same query used in the initial search.
:param named_query: The query text used to rerank the documents, given as a
keyword argument. The argument name is used for the column
name. If the query is given as a positional argument, the
results will be stored in a column named `_score`. If the
specified column already exists, it will be overwritten with
the new results.

Examples::

query1 = (
ESQL.from_("books").metadata("_score")
.where('MATCH(description, "hobbit")')
.sort("_score DESC")
.limit(100)
.rerank("hobbit").on("description").with_(inference_id="test_reranker")
.limit(3)
.keep("title", "_score")
)
query2 = (
ESQL.from_("books").metadata("_score")
.where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")')
.sort("_score DESC")
.limit(100)
.rerank(rerank_score="hobbit").on("description", "author").with_(inference_id="test_reranker")
.sort("rerank_score")
.limit(3)
.keep("title", "_score", "rerank_score")
)
query3 = (
ESQL.from_("books").metadata("_score")
.where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")')
.sort("_score DESC")
.limit(100)
.rerank(rerank_score="hobbit").on("description", "author").with_(inference_id="test_reranker")
.eval(original_score="_score", _score="rerank_score + original_score")
.sort("_score")
.limit(3)
.keep("title", "original_score", "rerank_score", "_score")
)
"""
return Rerank(self, *query, **named_query)

def sample(self, probability: float) -> "Sample":
"""The ``SAMPLE`` command samples a fraction of the table rows.

Expand Down Expand Up @@ -1046,6 +1095,65 @@ def _render_internal(self) -> str:
return f'RENAME {", ".join([f"{self._format_id(old_name)} AS {self._format_id(new_name)}" for old_name, new_name in self._columns.items()])}'


class Rerank(ESQLBase):
"""Implementation of the ``RERANK`` processing command.

This class inherits from :class:`ESQLBase <elasticsearch.esql.esql.ESQLBase>`,
to make it possible to chain all the commands that belong to an ES|QL query
in a single expression.
"""

def __init__(
self, parent: ESQLBase, *query: ExpressionType, **named_query: ExpressionType
):
if len(query) + len(named_query) > 1:
raise ValueError(
"this method requires either one positional or one keyword argument only"
)
super().__init__(parent)
self._query = query
self._named_query = named_query
self._fields: Optional[Tuple[str, ...]] = None
self._inference_id: Optional[str] = None

def on(self, *fields: str) -> "Rerank":
self._fields = fields
return self

def with_(self, inference_id: str) -> "Rerank":
"""Continuation of the `COMPLETION` command.

:param inference_id: The ID of the inference endpoint to use for the task. The
inference endpoint must be configured with the completion
task type.
"""
self._inference_id = inference_id
return self

def _render_internal(self) -> str:
if self._fields is None:
raise ValueError(
"The rerank command requires one or more fields to rerank on"
)
if self._inference_id is None:
raise ValueError("The completion command requires an inference ID")
with_ = {"inference_id": self._inference_id}
if self._named_query:
column = list(self._named_query.keys())[0]
query = list(self._named_query.values())[0]
return (
f"RERANK {self._format_id(column)} = {json.dumps(query)} "
f"ON {', '.join([self._format_id(field) for field in self._fields])} "
f"WITH {json.dumps(with_)}"
)
else:
return (
f"RERANK {json.dumps(self._query[0])} "
f"ON {', '.join([self._format_id(field) for field in self._fields])} "
f"WITH {json.dumps(with_)}"
)


class Sample(ESQLBase):
"""Implementation of the ``SAMPLE`` processing command.

Expand Down
77 changes: 77 additions & 0 deletions test_elasticsearch/test_esql.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,83 @@ def test_rename():
)


def test_rerank():
query = (
ESQL.from_("books")
.metadata("_score")
.where('MATCH(description, "hobbit")')
.sort("_score DESC")
.limit(100)
.rerank("hobbit")
.on("description")
.with_(inference_id="test_reranker")
.limit(3)
.keep("title", "_score")
)
assert (
query.render()
== """FROM books METADATA _score
| WHERE MATCH(description, "hobbit")
| SORT _score DESC
| LIMIT 100
| RERANK "hobbit" ON description WITH {"inference_id": "test_reranker"}
| LIMIT 3
| KEEP title, _score"""
)

query = (
ESQL.from_("books")
.metadata("_score")
.where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")')
.sort("_score DESC")
.limit(100)
.rerank(rerank_score="hobbit")
.on("description", "author")
.with_(inference_id="test_reranker")
.sort("rerank_score")
.limit(3)
.keep("title", "_score", "rerank_score")
)
assert (
query.render()
== """FROM books METADATA _score
| WHERE MATCH(description, "hobbit") OR MATCH(author, "Tolkien")
| SORT _score DESC
| LIMIT 100
| RERANK rerank_score = "hobbit" ON description, author WITH {"inference_id": "test_reranker"}
| SORT rerank_score
| LIMIT 3
| KEEP title, _score, rerank_score"""
)

query = (
ESQL.from_("books")
.metadata("_score")
.where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")')
.sort("_score DESC")
.limit(100)
.rerank(rerank_score="hobbit")
.on("description", "author")
.with_(inference_id="test_reranker")
.eval(original_score="_score", _score="rerank_score + original_score")
.sort("_score")
.limit(3)
.keep("title", "original_score", "rerank_score", "_score")
)
assert (
query.render()
== """FROM books METADATA _score
| WHERE MATCH(description, "hobbit") OR MATCH(author, "Tolkien")
| SORT _score DESC
| LIMIT 100
| RERANK rerank_score = "hobbit" ON description, author WITH {"inference_id": "test_reranker"}
| EVAL original_score = _score, _score = rerank_score + original_score
| SORT _score
| LIMIT 3
| KEEP title, original_score, rerank_score, _score"""
)


def test_sample():
query = ESQL.from_("employees").keep("emp_no").sample(0.05)
assert (
Expand Down