Skip to content

Commit

Permalink
Apply linter
Browse files Browse the repository at this point in the history
Signed-off-by: Martín Santillán Cooper <[email protected]>
  • Loading branch information
martinscooper committed Feb 10, 2025
1 parent dc93f54 commit d570a32
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
1 change: 0 additions & 1 deletion examples/evaluate_granite_guardian_rag_risks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import List

from unitxt import evaluate
from unitxt.api import create_dataset
Expand Down
14 changes: 6 additions & 8 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import ast
from enum import Enum
import json
import math
import os
Expand All @@ -10,6 +9,7 @@
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from dataclasses import field
from enum import Enum
from functools import lru_cache
from typing import (
Any,
Expand Down Expand Up @@ -46,9 +46,9 @@
from .inference import (
HFPipelineBasedInferenceEngine,
InferenceEngine,
LogProbInferenceEngine,
TorchDeviceMixin,
WMLInferenceEngineGeneration,
LogProbInferenceEngine,
)
from .logging_utils import get_logger
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
Expand Down Expand Up @@ -5852,7 +5852,7 @@ def compute(


class RiskType(str, Enum):
"""Risk type for the Granite Guardian models"""
"""Risk type for the Granite Guardian models."""

RAG = "rag_risk"
USER_MESSAGE = "user_risk"
Expand Down Expand Up @@ -5946,7 +5946,7 @@ def verify_guardian_config(self, task_data):
)
elif self.risk_name == RiskType.USER_MESSAGE or (
self.risk_name in self.available_risks[RiskType.USER_MESSAGE]
and not self.assistant_message_field in task_data
and self.assistant_message_field not in task_data
):
# User message risks only require the user message field and are the same as the assistant message risks, except for jailbreak
self.risk_type = RiskType.USER_MESSAGE
Expand Down Expand Up @@ -6004,15 +6004,13 @@ def get_prompt(self, messages):
if self.risk_type == RiskType.CUSTOM_RISK:
guardian_config["risk_definition"] = self.risk_definition

processed_input = self._tokenizer.apply_chat_template(
return self._tokenizer.apply_chat_template(
messages,
guardian_config=guardian_config,
tokenize=False,
add_generation_prompt=True,
)

return processed_input

def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
from transformers import AutoTokenizer

Expand Down Expand Up @@ -6052,7 +6050,7 @@ def create_message(self, role: str, content: str) -> List[Dict[str, str]]:

def process_input_fields(self, task_data):
messages = []
logger.debug(f"Preparing messages for Granite Guardian.")
logger.debug("Preparing messages for Granite Guardian.")
if self.risk_type == RiskType.RAG:
if self.risk_name == "context_relevance":
messages += self.create_message(
Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
"filename": "src/unitxt/metrics.py",
"hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889",
"is_verified": false,
"line_number": 68,
"line_number": 70,
"is_secret": false
}
],
Expand All @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-02-09T13:52:43Z"
"generated_at": "2025-02-10T13:12:47Z"
}

0 comments on commit d570a32

Please sign in to comment.