Skip to content

Commit

Permalink
Add Multi Turn Metrics Support
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Feb 5, 2025
1 parent 2ef9091 commit 14401eb
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 3 deletions.
76 changes: 76 additions & 0 deletions prepare/cards/coqa_multi_turn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Any, Dict

from unitxt.blocks import LoadHF, TaskCard
from unitxt.collections_operators import DuplicateBySubLists, Pop, Wrap
from unitxt.operator import InstanceOperator
from unitxt.operators import AddID, Copy, FieldOperator, ZipFieldValues
from unitxt.test_utils.card import test_card


class Pass(InstanceOperator):
def process(
self, instance: Dict[str, Any], stream_name: str | None = None
) -> Dict[str, Any]:
return instance


class ToDialog(FieldOperator):
def process_value(self, value: Any) -> Any:
dialog = []
for question, answer in value:
dialog.append({"role": "user", "content": question})
dialog.append({"role": "agent", "content": answer})
return dialog


card = TaskCard(
loader=LoadHF(path="stanfordnlp/coqa"),
preprocess_steps=[
"splitters.small_no_test",
AddID(),
Copy(field="id", to_field="conversation/id"),
ZipFieldValues(
fields=["questions", "answers/input_text"],
to_field="dialog",
),
DuplicateBySubLists(field="dialog"),
ToDialog(field="dialog"),
Pop(field="dialog", item=-1, to_field="last_turn"),
Copy(
field_to_field={"last_turn/content": "answer", "story": "context"},
),
Wrap(
field="answer",
inside="list",
to_field="answers",
),
Copy(field="dialog", to_field="conversation/dialog"),
],
task="tasks.qa.extractive.multi_turn",
templates=["templates.qa.multi_turn.with_context.simple"],
__tags__={
"annotations_creators": "crowdsourced",
"arxiv": ["1808.07042", "1704.04683", "1506.03340"],
"flags": ["conversational-qa"],
"language": "en",
"language_creators": "found",
"license": "other",
"multilinguality": "monolingual",
"region": "us",
"size_categories": "1K<n<10K",
"source_datasets": [
"extended|race",
"extended|cnn_dailymail",
"extended|wikipedia",
"extended|other",
],
"task_categories": "question-answering",
"task_ids": "extractive-qa",
},
__description__=(
"CoQA is a large-scale dataset for building Conversational Question Answering systems. \n"
"Our dataset contains 127k questions with answers, obtained from 8k conversations about text passages from seven diverse domains. The questions are conversational, and the answers are free-form text with their corresponding evidence highlighted in the passage. Supported Tasks and Leaderboards More Information Needed… See the full description on the dataset page: https://huggingface.co/datasets/stanfordnlp/coqa."
),
)

test_card(card)
30 changes: 28 additions & 2 deletions prepare/tasks/qa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@

from unitxt.blocks import Task
from unitxt.catalog import add_link_to_catalog, add_to_catalog
from unitxt.types import Audio, Dialog, Document, Image, MultiDocument, Table, Text
from unitxt.types import (
Audio,
Conversation,
Dialog,
Document,
Image,
MultiDocument,
Table,
Text,
)

add_link_to_catalog(
artifact_linked_to="tasks.qa.extractive",
Expand All @@ -19,7 +28,7 @@
input_fields={
"context": Union[Text, Table, Dialog],
"context_type": str,
"question": str,
"question": Union[Text, Dialog],
},
reference_fields={"answers": List[str]},
prediction_type=str,
Expand All @@ -31,6 +40,23 @@
overwrite=True,
)

add_to_catalog(
Task(
__description__="""""",
input_fields={
"context": Union[Text, Table],
"conversation": Conversation,
},
reference_fields={"answers": List[str]},
prediction_type=str,
metrics=["metrics.squad"],
default_template="templates.qa.extractive",
augmentable_inputs=["context"],
),
"tasks.qa.extractive.multi_turn",
overwrite=True,
)

add_to_catalog(
Task(
__description__="""This is the Question Answering Task with provided context (which is a either text, image, audio, table , or dialog).
Expand Down
10 changes: 10 additions & 0 deletions prepare/templates/qa/with_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,13 @@
"templates.qa.with_context.all",
overwrite=True,
)


add_to_catalog(
MultiReferenceTemplate(
input_format="Context: {context}\n{conversation}",
references_field="answers",
),
"templates.qa.multi_turn.with_context.simple",
overwrite=True,
)
2 changes: 1 addition & 1 deletion src/unitxt/catalog/tasks/qa/extractive.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"input_fields": {
"context": "Union[Text, Table, Dialog]",
"context_type": "str",
"question": "str"
"question": "Union[Text, Dialog]"
},
"reference_fields": {
"answers": "List[str]"
Expand Down
19 changes: 19 additions & 0 deletions src/unitxt/catalog/tasks/qa/extractive/multi_turn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"__type__": "task",
"__description__": "",
"input_fields": {
"context": "Union[Text, Table]",
"conversation": "Conversation"
},
"reference_fields": {
"answers": "List[str]"
},
"prediction_type": "str",
"metrics": [
"metrics.squad"
],
"default_template": "templates.qa.extractive",
"augmentable_inputs": [
"context"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"__type__": "multi_reference_template",
"input_format": "Context: {context}\n{conversation}",
"references_field": "answers"
}
7 changes: 7 additions & 0 deletions src/unitxt/collections_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def process_value(self, collection: Any) -> Any:
return collection[self.item]


class Pop(FieldOperator):
item: Any = None

def process_value(self, collection: Any) -> Any:
return collection.pop(self.item)


class DuplicateByList(StreamOperator):
field: str
to_field: Optional[str] = None
Expand Down
31 changes: 31 additions & 0 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,37 @@ def reduce(self, intermidates: List[IntermediateType]) -> Dict[str, Any]:
pass


class MultiTurnMetric(
MapReduceMetric[PredictionType, IntermediateType],
Generic[PredictionType, IntermediateType],
):
metric: MapReduceMetric[PredictionType, IntermediateType]

def map(
self,
prediction: PredictionType,
references: List[PredictionType],
task_data: Dict[str, Any],
) -> IntermediateType:
intermidate = self.metric.map_stream([(prediction, references, task_data)])[0]

dialog_id = task_data["conversation"]["id"]
turn_id = len(task_data["conversation"]["dialog"])

return (intermidate, dialog_id, turn_id)

def reduce(self, intermediates: List[IntermediateType]) -> Dict[str, Any]:
data = {}
for intermidate, dialog_id, turn_id in intermediates:
if dialog_id not in data:
data[dialog_id] = {}
if turn_id not in data[dialog_id]:
data[dialog_id][turn_id] = intermidate

for dialog_id, dialog_data in data.items():
pass


class DictReduction(AggregationReduction[Dict[str, float]]):
def reduce_list(self, lst: List[float]):
pass
Expand Down
6 changes: 6 additions & 0 deletions src/unitxt/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class RagResponse(TypedDict):
Dialog = NewType("Dialog", List[Turn])


class Conversation(TypedDict):
id: str
dialog: Dialog


class Image(TypedDict):
image: Any
format: str
Expand Down Expand Up @@ -60,6 +65,7 @@ class SQLDatabase(TypedDict):
register_type(Audio)
register_type(Image)
register_type(Video)
register_type(Conversation)
register_type(Document)
register_type(MultiDocument)
register_type(RagResponse)
Expand Down

0 comments on commit 14401eb

Please sign in to comment.