Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Multi Turn Metrics Support #1579

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/evaluate_multi_turn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from unitxt import settings
from unitxt.api import evaluate, load_dataset
from unitxt.inference import (
CrossProviderInferenceEngine,
)

with settings.context(
disable_hf_datasets_cache=False,
):
model = CrossProviderInferenceEngine(
model="llama-3-2-1b-instruct", provider="watsonx"
)
dataset = load_dataset(
card="cards.coqa.multi_turn",
format="formats.chat_api",
split="test",
max_test_instances=100,
)

predictions = model.infer(dataset)
results = evaluate(predictions=predictions, data=dataset)

print("Global Results:")
print(results.global_scores.summary)

print("Instance Results:")
print(results.instance_scores.summary)
70 changes: 70 additions & 0 deletions prepare/cards/coqa_multi_turn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Any, Dict

from unitxt.blocks import LoadHF, TaskCard
from unitxt.catalog import add_to_catalog
from unitxt.collections_operators import DuplicateBySubLists, Pop, Wrap
from unitxt.dialog_operators import ToDialog
from unitxt.operator import InstanceOperator
from unitxt.operators import AddID, Copy, ZipFieldValues
from unitxt.test_utils.card import test_card


class Pass(InstanceOperator):
def process(
self, instance: Dict[str, Any], stream_name: str | None = None
) -> Dict[str, Any]:
return instance


card = TaskCard(
loader=LoadHF(path="stanfordnlp/coqa"),
preprocess_steps=[
"splitters.small_no_test",
AddID(),
Copy(field="id", to_field="conversation/id"),
ZipFieldValues(
fields=["questions", "answers/input_text"],
to_field="dialog",
),
DuplicateBySubLists(field="dialog"),
ToDialog(field="dialog"),
Pop(field="dialog", item=-1, to_field="last_turn"),
Copy(
field_to_field={"last_turn/content": "answer", "story": "context"},
),
Wrap(
field="answer",
inside="list",
to_field="answers",
),
Copy(field="dialog", to_field="conversation/dialog"),
],
task="tasks.qa.extractive.multi_turn",
templates=["templates.qa.multi_turn.with_context.simple"],
__tags__={
"annotations_creators": "crowdsourced",
"arxiv": ["1808.07042", "1704.04683", "1506.03340"],
"flags": ["conversational-qa"],
"language": "en",
"language_creators": "found",
"license": "other",
"multilinguality": "monolingual",
"region": "us",
"size_categories": "1K<n<10K",
"source_datasets": [
"extended|race",
"extended|cnn_dailymail",
"extended|wikipedia",
"extended|other",
],
"task_categories": "question-answering",
"task_ids": "extractive-qa",
},
__description__=(
"CoQA is a large-scale dataset for building Conversational Question Answering systems. \n"
"Our dataset contains 127k questions with answers, obtained from 8k conversations about text passages from seven diverse domains. The questions are conversational, and the answers are free-form text with their corresponding evidence highlighted in the passage. Supported Tasks and Leaderboards More Information Needed… See the full description on the dataset page: https://huggingface.co/datasets/stanfordnlp/coqa."
),
)

test_card(card)
add_to_catalog(card, "cards.coqa.multi_turn", overwrite=True)
60 changes: 60 additions & 0 deletions prepare/metrics/multi_turn.py/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from unitxt import add_to_catalog
from unitxt.metrics import AccuracyFast, MultiTurnMetric
from unitxt.test_utils.metrics import test_metric

metric = MultiTurnMetric(metric=AccuracyFast())

predictions = ["A", "B", "C"]
references = [["B"], ["A"], ["C"]]
task_data = [
{
"conversation": {
"id": "aa",
"dialog": [{"role": "user", "content": "what is it?"}],
}
},
{
"conversation": {
"id": "aa",
"dialog": [
{"role": "user", "content": "what is it?"},
{"role": "agent", "content": "A"},
{"role": "user", "content": "what is it again?"},
],
}
},
{
"conversation": {
"id": "bb",
"dialog": [{"role": "user", "content": "what is it?"}],
}
},
]

instance_targets = [
{"accuracy": 0.0, "score": 0.0, "score_name": "accuracy"},
{"accuracy": 0.0, "score": 0.0, "score_name": "accuracy"},
{"accuracy": 1.0, "score": 1.0, "score_name": "accuracy"},
]

global_target = {
"accuracy": 0.5,
"accuracy_ci_high": 1.0,
"accuracy_ci_low": 0.0,
"num_of_instances": 3,
"score": 0.5,
"score_ci_high": 1.0,
"score_ci_low": 0.0,
"score_name": "accuracy",
}

outputs = test_metric(
metric=metric,
predictions=predictions,
references=references,
task_data=task_data,
instance_targets=instance_targets,
global_target=global_target,
)

add_to_catalog(metric, "metrics.multi_turn.accuracy", overwrite=True)
30 changes: 28 additions & 2 deletions prepare/tasks/qa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@

from unitxt.blocks import Task
from unitxt.catalog import add_link_to_catalog, add_to_catalog
from unitxt.types import Audio, Dialog, Document, Image, MultiDocument, Table, Text
from unitxt.types import (
Audio,
Conversation,
Dialog,
Document,
Image,
MultiDocument,
Table,
Text,
)

add_link_to_catalog(
artifact_linked_to="tasks.qa.extractive",
Expand All @@ -19,7 +28,7 @@
input_fields={
"context": Union[Text, Table, Dialog],
"context_type": str,
"question": str,
"question": Union[Text, Dialog],
},
reference_fields={"answers": List[str]},
prediction_type=str,
Expand All @@ -31,6 +40,23 @@
overwrite=True,
)

add_to_catalog(
Task(
__description__="""""",
input_fields={
"context": Union[Text, Table],
"conversation": Conversation,
},
reference_fields={"answers": List[str]},
prediction_type=str,
metrics=["metrics.multi_turn.accuracy"],
default_template="templates.qa.multi_turn.with_context.simple",
augmentable_inputs=["context"],
),
"tasks.qa.extractive.multi_turn",
overwrite=True,
)

add_to_catalog(
Task(
__description__="""This is the Question Answering Task with provided context (which is a either text, image, audio, table , or dialog).
Expand Down
32 changes: 32 additions & 0 deletions prepare/templates/qa/with_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from unitxt.catalog import add_to_catalog
from unitxt.serializers import (
ConversationSerializer,
DialogSerializer,
ImageSerializer,
ListSerializer,
MultiTypeSerializer,
SQLDatabaseAsSchemaSerializer,
TableSerializer,
VideoSerializer,
)
from unitxt.templates import MultiReferenceTemplate, TemplatesList

add_to_catalog(
Expand Down Expand Up @@ -131,3 +141,25 @@
"templates.qa.with_context.all",
overwrite=True,
)


add_to_catalog(
MultiReferenceTemplate(
instruction="Read the context and answer the last question in the conversation. Answer with the minimal span from the context answering the question.",
input_format="Context: {context}\n\nConversation:\n{conversation}",
references_field="answers",
serializer=MultiTypeSerializer(
serializers=[
ImageSerializer(),
VideoSerializer(),
TableSerializer(),
DialogSerializer(),
ConversationSerializer(),
ListSerializer(),
SQLDatabaseAsSchemaSerializer(),
]
),
),
"templates.qa.multi_turn.with_context.simple",
overwrite=True,
)
88 changes: 88 additions & 0 deletions src/unitxt/catalog/cards/coqa/multi_turn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
{
"__type__": "task_card",
"loader": {
"__type__": "load_hf",
"path": "stanfordnlp/coqa"
},
"preprocess_steps": [
"splitters.small_no_test",
{
"__type__": "add_id"
},
{
"__type__": "copy",
"field": "id",
"to_field": "conversation/id"
},
{
"__type__": "zip_field_values",
"fields": [
"questions",
"answers/input_text"
],
"to_field": "dialog"
},
{
"__type__": "duplicate_by_sub_lists",
"field": "dialog"
},
{
"__type__": "to_dialog",
"field": "dialog"
},
{
"__type__": "pop",
"field": "dialog",
"item": -1,
"to_field": "last_turn"
},
{
"__type__": "copy",
"field_to_field": {
"last_turn/content": "answer",
"story": "context"
}
},
{
"__type__": "wrap",
"field": "answer",
"inside": "list",
"to_field": "answers"
},
{
"__type__": "copy",
"field": "dialog",
"to_field": "conversation/dialog"
}
],
"task": "tasks.qa.extractive.multi_turn",
"templates": [
"templates.qa.multi_turn.with_context.simple"
],
"__tags__": {
"annotations_creators": "crowdsourced",
"arxiv": [
"1808.07042",
"1704.04683",
"1506.03340"
],
"flags": [
"conversational-qa"
],
"language": "en",
"language_creators": "found",
"license": "other",
"multilinguality": "monolingual",
"region": "us",
"size_categories": "1K<n<10K",
"source_datasets": [
"extended|race",
"extended|cnn_dailymail",
"extended|wikipedia",
"extended|other"
],
"task_categories": "question-answering",
"task_ids": "extractive-qa"
},
"__description__": "CoQA is a large-scale dataset for building Conversational Question Answering systems. \nOur dataset contains 127k questions with answers, obtained from 8k conversations about text passages from seven diverse domains. The questions are conversational, and the answers are free-form text with their corresponding evidence highlighted in the passage. Supported Tasks and Leaderboards More Information Needed… See the full description on the dataset page: https://huggingface.co/datasets/stanfordnlp/coqa."
}
6 changes: 6 additions & 0 deletions src/unitxt/catalog/metrics/multi_turn/accuracy.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"__type__": "multi_turn_metric",
"metric": {
"__type__": "accuracy_fast"
}
}
2 changes: 1 addition & 1 deletion src/unitxt/catalog/tasks/qa/extractive.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"input_fields": {
"context": "Union[Text, Table, Dialog]",
"context_type": "str",
"question": "str"
"question": "Union[Text, Dialog]"
},
"reference_fields": {
"answers": "List[str]"
Expand Down
19 changes: 19 additions & 0 deletions src/unitxt/catalog/tasks/qa/extractive/multi_turn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"__type__": "task",
"__description__": "",
"input_fields": {
"context": "Union[Text, Table]",
"conversation": "Conversation"
},
"reference_fields": {
"answers": "List[str]"
},
"prediction_type": "str",
"metrics": [
"metrics.multi_turn.accuracy"
],
"default_template": "templates.qa.multi_turn.with_context.simple",
"augmentable_inputs": [
"context"
]
}
Loading
Loading