Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create new IntersectCorrespondingFields operator #1531

Merged
merged 12 commits into from
Jan 30, 2025
37 changes: 22 additions & 15 deletions src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,33 +1448,40 @@ def prepare(self):
self.min_frequency_percent = 0


class Intersect(FieldOperator):
class Intersect(InstanceOperator):
"""Intersects the value of a field, which must be a list, with a given list.

Args:
allowed_values (list) - list to intersect.
allowed_field_values (list) - list to intersect.
fields_to_intersect (list) - list of fields to be filtered
"""

allowed_values: List[Any]
allowed_field_values: List[str]
fields_to_intersect: List[str]

def verify(self):
super().verify()
if self.process_every_value:
raise ValueError(
"'process_every_value=True' is not supported in Intersect operator"
)

if not isinstance(self.allowed_values, list):
if not isinstance(self.allowed_field_values, list):
raise ValueError(
f"The allowed_values is not a list but '{self.allowed_values}'"
f"The allowed_field_values is not a type list but '{type(self.allowed_field_values)}'"
)

def process_value(self, value: Any) -> Any:
super().process_value(value)
if not isinstance(value, list):
raise ValueError(f"The value in field is not a list but '{value}'")
return [e for e in value if e in self.allowed_values]

def process(self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:

if set(self.allowed_field_values) == set(instance['labels']):
return instance

data_to_keep_indices = [i for i, label in enumerate(instance['labels']) if label in set(self.allowed_field_values)]

return dict(
(
key, value[data_to_keep_indices]
if key in self.fields_to_intersect
else value)
for key,value in instance.items()
)

class RemoveValues(FieldOperator):
"""Removes elements in a field, which must be a list, using a given list of unallowed.
Expand Down
98 changes: 60 additions & 38 deletions tests/library/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,47 +579,68 @@ def test_execute_expression(self):
)

def test_intersect(self):
inputs = [
{"label": ["a", "b"]},
{"label": ["a", "c", "d"]},
{"label": ["a", "b", "f"]},
]

targets = [
{"label": ["b"]},
{"label": []},
{"label": ["b", "f"]},
]

check_operator(
operator=Intersect(field="label", allowed_values=["b", "f"]),
inputs=inputs,
targets=targets,
tester=self,
)
with self.assertRaises(ValueError) as cm:
check_operator(
operator=Intersect(field="label", allowed_values=3),
inputs=inputs,
targets=targets,
tester=self,
)
self.assertEqual(str(cm.exception), "The allowed_values is not a list but '3'")

with self.assertRaises(ValueError) as cm:
check_operator(
operator=Intersect(
field="label", allowed_values=["3"], process_every_value=True
),

def __test_intersect(inputs, targets, fields_to_intersect, allowed_field_values):
return check_operator(
operator=Intersect(fields_to_intersect, allowed_field_values),
inputs=inputs,
targets=targets,
tester=self,
)
self.assertEqual(
str(cm.exception),
"'process_every_value=True' is not supported in Intersect operator",
)

## basic test
__test_intersect(
inputs=[{"label": [1,2]}],
targets=[{"label": [1]}],
fields_to_intersect=["label"],
allowed_field_values=[1]
)

# multiple fields of the same name
__test_intersect(
inputs = [
{"label": ["a", "b"]},
{"label": ["a", "c", "d"]},
{"name": ["a", "b", "f"]},
],
targets = [
{"label": ["b"]},
{"label": []},
{"name": ["b", "f"]},
],
fields_to_intersect=["label",'name'],
allowed_field_values=["b", "f"]
)

__test_intersect(
inputs = [
{"label": ["a", "b"]},
{"label": ["a", "c", "d"]},
{"label": ["a", "b", "f"]},
],
targets = [
{"label": ["b"]},
{"label": []},
{"label": ["b", "f"]},
],
fields_to_intersect=["label"],
allowed_field_values=["b", "f"]
)


with self.assertRaises(ValueError) as cm:
__test_intersect(
inputs = [
{"label": ["a", "b"]},
],
targets = [
{"label": ["b"]},
],
fields_to_intersect=["label"],
allowed_field_values=3
)
self.assertEqual(str(cm.exception), "The allowed_field_values is not a list but '<class 'int'>'")


inputs = [
{"label": "b"},
]
Expand All @@ -629,7 +650,7 @@ def test_intersect(self):
"The value in field is not a list but 'b'",
]
check_operator_exception(
operator=Intersect(field="label", allowed_values=["c"]),
operator=Intersect(field=["label"], allowed_field_values=["c"]),
inputs=inputs,
exception_texts=exception_texts,
tester=self,
Expand Down Expand Up @@ -3125,3 +3146,4 @@ def test_select_fields(self):
}
]
TestOperators().compare_streams(joined_stream, expected_joined_stream)

Loading