1414# limitations under the License.
1515
1616import logging
17- from typing import Optional
17+ from typing import Literal , Optional
1818
1919import httpx
20- from pydantic import BaseModel
20+ from pydantic import BaseModel , Field
21+ from pydantic import field_validator as validator
22+ from pydantic import model_validator
2123from pydantic_core import to_json
2224from typing_extensions import cast
2325
2830
2931
3032class Guard (BaseModel ):
33+ """
34+ Represents a guard entity with a single string attribute.
35+
36+ Attributes:
37+ guard (str): The input text for guard analysis.
38+ """
39+
3140 guard : str
3241
3342
3443class GuardResult (BaseModel ):
35- action : str
36- reason : str
44+ """
45+ Represents the result of a guard analysis, specifying the action to take and the reason.
46+
47+ Attributes:
48+ action (Literal["Block", "Allow"]): The action to take based on guard analysis.
49+ Must be either "Block" or "Allow".
50+ reason (str): Explanation for the chosen action. Must be a non-empty string.
51+ """
52+
53+ action : Literal ["Block" , "Allow" ] = Field (
54+ ..., description = "Action to take based on " "guard analysis"
55+ )
56+ reason : str = Field (..., min_length = 1 , description = "Explanation for the action" )
57+ blocked : bool = Field (
58+ default = False , description = "True if action is 'Block', else False"
59+ )
60+
61+ @validator ("action" )
62+ def validate_action (cls , v ):
63+ log .error (f"Validating action: { v } " )
64+ if v not in ["Block" , "Allow" ]:
65+ return "Allow"
66+ return v
67+
68+ @model_validator (mode = "before" )
69+ def set_blocked (cls , values ):
70+ a = values .get ("action" )
71+ values ["blocked" ] = a .lower () == "block"
72+ return values
3773
3874
3975def get_config (config : RailsConfig ) -> TrendMicroRailConfig :
76+ """
77+ Retrieves the TrendMicroRailConfig from the provided RailsConfig object.
78+
79+ Args:
80+ config (RailsConfig): The Rails configuration object containing possible
81+ Trend Micro settings.
82+
83+ Returns:
84+ TrendMicroRailConfig: The Trend Micro configuration, either from the provided
85+ config or a default instance.
86+ """
4087 if (
4188 not hasattr (config .rails .config , "trend_micro" )
4289 or config .rails .config .trend_micro is None
@@ -46,7 +93,12 @@ def get_config(config: RailsConfig) -> TrendMicroRailConfig:
4693 return cast (TrendMicroRailConfig , config .rails .config .trend_micro )
4794
4895
49- @action (is_system_action = True )
96+ def trend_ai_guard_mapping (result : GuardResult ) -> bool :
97+ """Convert Trend Micro result to boolean for flow logic."""
98+ return result .action .lower () == "block"
99+
100+
101+ @action (is_system_action = True , output_mapping = trend_ai_guard_mapping )
50102async def trend_ai_guard (config : RailsConfig , text : Optional [str ] = None ):
51103 """
52104 Custom action to invoke the Trend Ai Guard
@@ -59,10 +111,11 @@ async def trend_ai_guard(config: RailsConfig, text: Optional[str] = None):
59111
60112 v1_api_key = trend_config .get_api_key ()
61113 if not v1_api_key :
62- raise ValueError ("Trend Micro Vision One API Key not found" )
63-
64- if text is None :
65- raise ValueError ("No prompt/response found in the last event." )
114+ log .error ("Trend Micro Vision One API Key not found" )
115+ return GuardResult (
116+ action = "Block" ,
117+ reason = "Trend Micro Vision One API Key not found" ,
118+ )
66119
67120 async with httpx .AsyncClient () as client :
68121 data = Guard (guard = text ).model_dump ()
@@ -80,10 +133,10 @@ async def trend_ai_guard(config: RailsConfig, text: Optional[str] = None):
80133 response .raise_for_status ()
81134 guard_result = GuardResult (** response .json ())
82135 log .debug ("Trend Micro AI Guard Result: %s" , guard_result )
83- except Exception as e :
136+ except httpx . HTTPStatusError as e :
84137 log .error ("Error calling Trend Micro AI Guard API: %s" , e )
85138 return GuardResult (
86- action = "allow " ,
139+ action = "Allow " ,
87140 reason = "An error occurred while calling the Trend Micro AI Guard API." ,
88141 )
89142 return guard_result
0 commit comments