Skip to content

Commit 45843a1

Browse files
add rerank() command to the ES|QL query builder
1 parent 6bfbdaf commit 45843a1

File tree

2 files changed

+191
-6
lines changed

2 files changed

+191
-6
lines changed

elasticsearch/esql/esql.py

Lines changed: 114 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _is_forked(self) -> bool:
143143
return False
144144

145145
def change_point(self, value: FieldType) -> "ChangePoint":
146-
"""`CHANGE_POINT` detects spikes, dips, and change points in a metric.
146+
"""``CHANGE_POINT`` detects spikes, dips, and change points in a metric.
147147
148148
:param value: The column with the metric in which you want to detect a change point.
149149
@@ -163,17 +163,18 @@ def change_point(self, value: FieldType) -> "ChangePoint":
163163
def completion(
164164
self, *prompt: ExpressionType, **named_prompt: ExpressionType
165165
) -> "Completion":
166-
"""The `COMPLETION` command allows you to send prompts and context to a Large
166+
"""The ``COMPLETION`` command allows you to send prompts and context to a Large
167167
Language Model (LLM) directly within your ES|QL queries, to perform text
168168
generation tasks.
169169
170170
:param prompt: The input text or expression used to prompt the LLM. This can
171171
be a string literal or a reference to a column containing text.
172172
:param named_prompt: The input text or expresion, given as a keyword argument.
173-
The argument name is used for the column name. If not
174-
specified, the results will be stored in a column named
175-
`completion`. If the specified column already exists, it
176-
will be overwritten with the new results.
173+
The argument name is used for the column name. If the
174+
prompt is given as a positional argument, the results will
175+
be stored in a column named ``completion``. If the
176+
specified column already exists, it will be overwritten
177+
with the new results.
177178
178179
Examples::
179180
@@ -439,6 +440,54 @@ def rename(self, **columns: FieldType) -> "Rename":
439440
"""
440441
return Rename(self, **columns)
441442

443+
def rerank(self, *query: ExpressionType, **named_query: ExpressionType) -> "Rerank":
444+
"""The ``RERANK`` command uses an inference model to compute a new relevance score
445+
for an initial set of documents, directly within your ES|QL queries.
446+
447+
:param query: The query text used to rerank the documents. This is typically the
448+
same query used in the initial search.
449+
:param named_query: The query text used to rerank the documents, given as a
450+
keyword argument. The argument name is used for the column
451+
name. If the query is given as a positional argument, the
452+
results will be stored in a column named `_score`. If the
453+
specified column already exists, it will be overwritten with
454+
the new results.
455+
456+
Examples::
457+
458+
query1 = (
459+
ESQL.from_("books").metadata("_score")
460+
.where('MATCH(description, "hobbit")')
461+
.sort("_score DESC")
462+
.limit(100)
463+
.rerank("hobbit").on("description").with_(inference_id="test_reranker")
464+
.limit(3)
465+
.keep("title", "_score")
466+
)
467+
query2 = (
468+
ESQL.from_("books").metadata("_score")
469+
.where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")')
470+
.sort("_score DESC")
471+
.limit(100)
472+
.rerank(rerank_score="hobbit").on("description", "author").with_(inference_id="test_reranker")
473+
.sort("rerank_score")
474+
.limit(3)
475+
.keep("title", "_score", "rerank_score")
476+
)
477+
query3 = (
478+
ESQL.from_("books").metadata("_score")
479+
.where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")')
480+
.sort("_score DESC")
481+
.limit(100)
482+
.rerank(rerank_score="hobbit").on("description", "author").with_(inference_id="test_reranker")
483+
.eval(original_score="_score", _score="rerank_score + original_score")
484+
.sort("_score")
485+
.limit(3)
486+
.keep("title", "original_score", "rerank_score", "_score")
487+
)
488+
"""
489+
return Rerank(self, *query, **named_query)
490+
442491
def sample(self, probability: float) -> "Sample":
443492
"""The ``SAMPLE`` command samples a fraction of the table rows.
444493
@@ -1046,6 +1095,65 @@ def _render_internal(self) -> str:
10461095
return f'RENAME {", ".join([f"{self._format_id(old_name)} AS {self._format_id(new_name)}" for old_name, new_name in self._columns.items()])}'
10471096

10481097

1098+
class Rerank(ESQLBase):
1099+
"""Implementation of the ``RERANK`` processing command.
1100+
1101+
This class inherits from :class:`ESQLBase <elasticsearch.esql.esql.ESQLBase>`,
1102+
to make it possible to chain all the commands that belong to an ES|QL query
1103+
in a single expression.
1104+
"""
1105+
1106+
def __init__(
1107+
self, parent: ESQLBase, *query: ExpressionType, **named_query: ExpressionType
1108+
):
1109+
if len(query) + len(named_query) > 1:
1110+
raise ValueError(
1111+
"this method requires either one positional or one keyword argument only"
1112+
)
1113+
super().__init__(parent)
1114+
self._query = query
1115+
self._named_query = named_query
1116+
self._fields: Optional[Tuple[str]] = None
1117+
self._inference_id: Optional[str] = None
1118+
1119+
def on(self, *fields) -> "Rerank":
1120+
self._fields = fields
1121+
return self
1122+
1123+
def with_(self, inference_id: str) -> "Rerank":
1124+
"""Continuation of the `COMPLETION` command.
1125+
1126+
:param inference_id: The ID of the inference endpoint to use for the task. The
1127+
inference endpoint must be configured with the completion
1128+
task type.
1129+
"""
1130+
self._inference_id = inference_id
1131+
return self
1132+
1133+
def _render_internal(self) -> str:
1134+
if self._fields is None:
1135+
raise ValueError(
1136+
"The rerank command requires one or more fields to rerank on"
1137+
)
1138+
if self._inference_id is None:
1139+
raise ValueError("The completion command requires an inference ID")
1140+
with_ = {"inference_id": self._inference_id}
1141+
if self._named_query:
1142+
column = list(self._named_query.keys())[0]
1143+
query = list(self._named_query.values())[0]
1144+
return (
1145+
f"RERANK {self._format_id(column)} = {json.dumps(query)} "
1146+
f"ON {', '.join([self._format_id(field) for field in self._fields])} "
1147+
f"WITH {json.dumps(with_)}"
1148+
)
1149+
else:
1150+
return (
1151+
f"RERANK {json.dumps(self._query[0])} "
1152+
f"ON {', '.join([self._format_id(field) for field in self._fields])} "
1153+
f"WITH {json.dumps(with_)}"
1154+
)
1155+
1156+
10491157
class Sample(ESQLBase):
10501158
"""Implementation of the ``SAMPLE`` processing command.
10511159

test_elasticsearch/test_esql.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,83 @@ def test_rename():
422422
)
423423

424424

425+
def test_rerank():
426+
query = (
427+
ESQL.from_("books")
428+
.metadata("_score")
429+
.where('MATCH(description, "hobbit")')
430+
.sort("_score DESC")
431+
.limit(100)
432+
.rerank("hobbit")
433+
.on("description")
434+
.with_(inference_id="test_reranker")
435+
.limit(3)
436+
.keep("title", "_score")
437+
)
438+
assert (
439+
query.render()
440+
== """FROM books METADATA _score
441+
| WHERE MATCH(description, "hobbit")
442+
| SORT _score DESC
443+
| LIMIT 100
444+
| RERANK "hobbit" ON description WITH {"inference_id": "test_reranker"}
445+
| LIMIT 3
446+
| KEEP title, _score"""
447+
)
448+
449+
query = (
450+
ESQL.from_("books")
451+
.metadata("_score")
452+
.where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")')
453+
.sort("_score DESC")
454+
.limit(100)
455+
.rerank(rerank_score="hobbit")
456+
.on("description", "author")
457+
.with_(inference_id="test_reranker")
458+
.sort("rerank_score")
459+
.limit(3)
460+
.keep("title", "_score", "rerank_score")
461+
)
462+
assert (
463+
query.render()
464+
== """FROM books METADATA _score
465+
| WHERE MATCH(description, "hobbit") OR MATCH(author, "Tolkien")
466+
| SORT _score DESC
467+
| LIMIT 100
468+
| RERANK rerank_score = "hobbit" ON description, author WITH {"inference_id": "test_reranker"}
469+
| SORT rerank_score
470+
| LIMIT 3
471+
| KEEP title, _score, rerank_score"""
472+
)
473+
474+
query = (
475+
ESQL.from_("books")
476+
.metadata("_score")
477+
.where('MATCH(description, "hobbit") OR MATCH(author, "Tolkien")')
478+
.sort("_score DESC")
479+
.limit(100)
480+
.rerank(rerank_score="hobbit")
481+
.on("description", "author")
482+
.with_(inference_id="test_reranker")
483+
.eval(original_score="_score", _score="rerank_score + original_score")
484+
.sort("_score")
485+
.limit(3)
486+
.keep("title", "original_score", "rerank_score", "_score")
487+
)
488+
assert (
489+
query.render()
490+
== """FROM books METADATA _score
491+
| WHERE MATCH(description, "hobbit") OR MATCH(author, "Tolkien")
492+
| SORT _score DESC
493+
| LIMIT 100
494+
| RERANK rerank_score = "hobbit" ON description, author WITH {"inference_id": "test_reranker"}
495+
| EVAL original_score = _score, _score = rerank_score + original_score
496+
| SORT _score
497+
| LIMIT 3
498+
| KEEP title, original_score, rerank_score, _score"""
499+
)
500+
501+
425502
def test_sample():
426503
query = ESQL.from_("employees").keep("emp_no").sample(0.05)
427504
assert (

0 commit comments

Comments
 (0)