Skip to content

Commit

Permalink
Apply linter
Browse files Browse the repository at this point in the history
Signed-off-by: Martín Santillán Cooper <[email protected]>
  • Loading branch information
martinscooper committed Feb 10, 2025
1 parent abf1548 commit 0cf1858
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from unitxt import evaluate
from unitxt.api import create_dataset
from unitxt.blocks import Task
from unitxt.metrics import GraniteGuardianAssistantRisk, RiskType
from unitxt.metrics import GraniteGuardianAssistantRisk
from unitxt.templates import NullTemplate

print("Assistant response risks")
Expand Down
2 changes: 1 addition & 1 deletion examples/evaluate_granite_guardian_custom_risks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unitxt import evaluate
from unitxt.api import create_dataset
from unitxt.blocks import Task
from unitxt.metrics import GraniteGuardianCustomRisk, RiskType
from unitxt.metrics import GraniteGuardianCustomRisk
from unitxt.templates import NullTemplate

print("Bring your own risk")
Expand Down
2 changes: 1 addition & 1 deletion examples/evaluate_granite_guardian_user_message_risks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from unitxt import evaluate
from unitxt.api import create_dataset
from unitxt.blocks import Task
from unitxt.metrics import GraniteGuardianUserRisk, RiskType
from unitxt.metrics import GraniteGuardianUserRisk
from unitxt.templates import NullTemplate

print("User prompt risks")
Expand Down
2 changes: 1 addition & 1 deletion prepare/metrics/granite_guardian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unitxt import add_to_catalog
from unitxt.metrics import GraniteGuardianBase, RISK_TYPE_TO_CLASS
from unitxt.metrics import RISK_TYPE_TO_CLASS, GraniteGuardianBase

for risk_type, risk_names in GraniteGuardianBase.available_risks.items():
for risk_name in risk_names:
Expand Down
12 changes: 4 additions & 8 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5868,6 +5868,7 @@ class GraniteGuardianBase(InstanceMetric):
main_score = None
reduction_map = {}
wml_model_name: str = "ibm/granite-guardian-3-8b"
hf_model_name: str = "ibm-granite/granite-guardian-3.1-8b"

wml_params = {
"decoding_method": "greedy",
Expand All @@ -5880,8 +5881,6 @@ class GraniteGuardianBase(InstanceMetric):
},
}

hf_model_name: str = "ibm-granite/granite-guardian-3.1-8b"

safe_token = "No"
unsafe_token = "Yes"

Expand Down Expand Up @@ -5936,8 +5935,6 @@ def process_input_fields(self, task_data):

@classmethod
def get_available_risk_names(cls):
print(cls.risk_type)
print(cls.available_risks)
return cls.available_risks[cls.risk_type]

def set_main_score(self):
Expand Down Expand Up @@ -5974,7 +5971,6 @@ def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> di
messages = self.process_input_fields(task_data)
prompt = self.get_prompt(messages)
result = self.inference_engine.infer_log_probs([{"source": prompt}])
print(' '.join([r['text'] for r in result[0]]))
generated_tokens_list = result[0]
label, prob_of_risk = self.parse_output(generated_tokens_list)
confidence_score = (
Expand Down Expand Up @@ -6138,8 +6134,8 @@ class GraniteGuardianCustomRisk(GraniteGuardianBase):

def verify(self):
super().verify()
assert self.risk_type != None, UnitxtError("In a custom risk, risk_type must be defined")
assert self.risk_type is not None, UnitxtError("In a custom risk, risk_type must be defined")

def verify_granite_guardian_config(self, task_data):
# even though this is a custom risks, we will limit the
# message roles to be a subset of the roles Granite Guardian
Expand Down Expand Up @@ -6176,7 +6172,7 @@ def process_input_fields(self, task_data):

RISK_TYPE_TO_CLASS: Dict[RiskType, GraniteGuardianBase] = {
RiskType.USER_MESSAGE: GraniteGuardianUserRisk,
RiskType.ASSISTANT_MESSAGE: GraniteGuardianAssistantRisk,
RiskType.ASSISTANT_MESSAGE: GraniteGuardianAssistantRisk,
RiskType.RAG: GraniteGuardianRagRisk,
RiskType.AGENTIC: GraniteGuardianAgenticRisk,
}
Expand Down

0 comments on commit 0cf1858

Please sign in to comment.