Skip to content

Commit 92c444f

Browse files
Fixup: Address additional PR comments
Added aditional doc strings, redefined flows modified how failures are handled added output_mapping
1 parent 3412335 commit 92c444f

File tree

5 files changed

+92
-30
lines changed

5 files changed

+92
-30
lines changed

examples/configs/trend_micro_v2/rails.co

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@ import guardrails
22
import nemoguardrails.library.trend_micro
33

44
flow input rails $input_text
5-
$result = await TrendAiGuardAction(text=$input_text)
6-
7-
if $result.action == "Block"
8-
send AiGuardException(message="AI Guard detection: " + $result.reason)
9-
abort
5+
trend ai guard input $input_text
106

117
flow output rails $output_text
12-
trend ai guard $output_text
8+
trend ai guard output $output_text

nemoguardrails/library/trend_micro/actions.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
# limitations under the License.
1515

1616
import logging
17-
from typing import Optional
17+
from typing import Literal, Optional
1818

1919
import httpx
20-
from pydantic import BaseModel
20+
from pydantic import BaseModel, Field
21+
from pydantic import field_validator as validator
22+
from pydantic import model_validator
2123
from pydantic_core import to_json
2224
from typing_extensions import cast
2325

@@ -28,15 +30,60 @@
2830

2931

3032
class Guard(BaseModel):
33+
"""
34+
Represents a guard entity with a single string attribute.
35+
36+
Attributes:
37+
guard (str): The input text for guard analysis.
38+
"""
39+
3140
guard: str
3241

3342

3443
class GuardResult(BaseModel):
35-
action: str
36-
reason: str
44+
"""
45+
Represents the result of a guard analysis, specifying the action to take and the reason.
46+
47+
Attributes:
48+
action (Literal["Block", "Allow"]): The action to take based on guard analysis.
49+
Must be either "Block" or "Allow".
50+
reason (str): Explanation for the chosen action. Must be a non-empty string.
51+
"""
52+
53+
action: Literal["Block", "Allow"] = Field(
54+
..., description="Action to take based on " "guard analysis"
55+
)
56+
reason: str = Field(..., min_length=1, description="Explanation for the action")
57+
blocked: bool = Field(
58+
default=False, description="True if action is 'Block', else False"
59+
)
60+
61+
@validator("action")
62+
def validate_action(cls, v):
63+
log.error(f"Validating action: {v}")
64+
if v not in ["Block", "Allow"]:
65+
return "Allow"
66+
return v
67+
68+
@model_validator(mode="before")
69+
def set_blocked(cls, values):
70+
a = values.get("action")
71+
values["blocked"] = a.lower() == "block"
72+
return values
3773

3874

3975
def get_config(config: RailsConfig) -> TrendMicroRailConfig:
76+
"""
77+
Retrieves the TrendMicroRailConfig from the provided RailsConfig object.
78+
79+
Args:
80+
config (RailsConfig): The Rails configuration object containing possible
81+
Trend Micro settings.
82+
83+
Returns:
84+
TrendMicroRailConfig: The Trend Micro configuration, either from the provided
85+
config or a default instance.
86+
"""
4087
if (
4188
not hasattr(config.rails.config, "trend_micro")
4289
or config.rails.config.trend_micro is None
@@ -46,7 +93,12 @@ def get_config(config: RailsConfig) -> TrendMicroRailConfig:
4693
return cast(TrendMicroRailConfig, config.rails.config.trend_micro)
4794

4895

49-
@action(is_system_action=True)
96+
def trend_ai_guard_mapping(result: GuardResult) -> bool:
97+
"""Convert Trend Micro result to boolean for flow logic."""
98+
return result.action.lower() == "block"
99+
100+
101+
@action(is_system_action=True, output_mapping=trend_ai_guard_mapping)
50102
async def trend_ai_guard(config: RailsConfig, text: Optional[str] = None):
51103
"""
52104
Custom action to invoke the Trend Ai Guard
@@ -59,10 +111,11 @@ async def trend_ai_guard(config: RailsConfig, text: Optional[str] = None):
59111

60112
v1_api_key = trend_config.get_api_key()
61113
if not v1_api_key:
62-
raise ValueError("Trend Micro Vision One API Key not found")
63-
64-
if text is None:
65-
raise ValueError("No prompt/response found in the last event.")
114+
log.error("Trend Micro Vision One API Key not found")
115+
return GuardResult(
116+
action="Block",
117+
reason="Trend Micro Vision One API Key not found",
118+
)
66119

67120
async with httpx.AsyncClient() as client:
68121
data = Guard(guard=text).model_dump()
@@ -80,10 +133,10 @@ async def trend_ai_guard(config: RailsConfig, text: Optional[str] = None):
80133
response.raise_for_status()
81134
guard_result = GuardResult(**response.json())
82135
log.debug("Trend Micro AI Guard Result: %s", guard_result)
83-
except Exception as e:
136+
except httpx.HTTPStatusError as e:
84137
log.error("Error calling Trend Micro AI Guard API: %s", e)
85138
return GuardResult(
86-
action="allow",
139+
action="Allow",
87140
reason="An error occurred while calling the Trend Micro AI Guard API.",
88141
)
89142
return guard_result
Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
# INPUT AND/OR OUTPUT RAIL
2-
flow trend ai guard $text
2+
flow trend ai guard input $text
33
$result = await TrendAiGuardAction(text=$text)
44

5-
if $result.action == "Block" # Fails open if AI Guard service has an error
5+
if $result.blocked # Fails open if AI Guard service has an error
66
if $system.config.enable_rails_exceptions
7-
send TrendAiGuardException(message="Blocked by the 'trend ai guard' flow: " + $result.reason)
7+
send TrendAiGuardRailException(message="Blocked by the 'trend ai guard input' flow: " + $result.reason)
88
else
99
bot refuse to respond
10-
abort
10+
abort
11+
12+
13+
# OUTPUT RAIL
14+
flow trend ai guard output $text
15+
$result = await TrendAiGuardAction(text=$text)
16+
17+
if $result.blocked # Fails open if AI Guard service has an error
18+
if $system.config.enable_rails_exceptions
19+
send TrendAiGuardRailException(message="Blocked by the 'trend ai guard output' flow: " + $result.reason)
20+
else
21+
bot refuse to respond
22+
abort

nemoguardrails/library/trend_micro/flows.v1.co

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,22 @@
22
define subflow trend ai guard input
33
$result = execute trend_ai_guard(text=$user_message)
44

5-
if $result.action == "Block" # Fails open if AI Guard service has an error
5+
if $result.blocked # Fails open if AI Guard service has an error
66
if $config.enable_rails_exceptions
77
$msg = "Blocked by the 'trend ai guard input' flow: " + $result.reason
8-
create event TrendAiGuardException(message=$msg)
8+
create event TrendAiGuardRailException(message=$msg)
99
else
1010
bot refuse to respond
11-
stop
11+
stop
1212

1313
# OUTPUT RAIL
1414
define subflow trend ai guard output
1515
$result = execute trend_ai_guard(text=$bot_message)
1616

17-
if $result.action == "Block" # Fails open if AI Guard service has an error
17+
if $result.blocked # Fails open if AI Guard service has an error
1818
if $config.enable_rails_exceptions
1919
$msg = "Blocked by the 'trend ai guard output' flow: " + $result.reason
20-
create event TrendAiGuardException(message=$msg)
20+
create event TrendAiGuardRailException(message=$msg)
2121
else
2222
bot refuse to respond
23-
stop
23+
stop

tests/test_trend_ai_guard.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_trend_ai_guard_blocked(httpx_mock: HTTPXMock, monkeypatch: pytest.Monke
5252
monkeypatch.setenv("V1_API_KEY", "test-token")
5353
httpx_mock.add_response(
5454
is_reusable=True,
55-
json={"action": "Block", "reason": "Prompt Attack Detected"},
55+
json={"action": "Block", "reason": "Prompt Attack Detected", "blocked": True},
5656
)
5757

5858
chat = TestChat(
@@ -86,8 +86,9 @@ def test_trend_ai_guard_error(
8686
@pytest.mark.unit
8787
def test_trend_ai_guard_missing_env_var():
8888
chat = TestChat(input_rail_config, llm_completions=[])
89+
8990
chat >> "Hi!"
90-
chat << "I'm sorry, an internal error has occurred."
91+
chat << "I'm sorry, I can't respond to that."
9192

9293

9394
@pytest.mark.unit
@@ -104,4 +105,4 @@ def test_trend_ai_guard_malformed_response(
104105

105106
# Should fail open
106107
chat >> "What is the air-speed velocity of an unladen swallow?"
107-
chat << "What do you mean? An African or a European swallow?"
108+
chat << "I'm sorry, an internal error has occurred."

0 commit comments

Comments
 (0)