Skip to content

Commit 0c868c6

Browse files
committed
review: type hint fixes
fix
1 parent 09f4638 commit 0c868c6

File tree

2 files changed

+74
-69
lines changed

2 files changed

+74
-69
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 72 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,19 @@ async def generate_async(
873873
The completion (when a prompt is provided) or the next message.
874874
875875
System messages are not yet supported."""
876+
# convert options to gen_options of type GenerationOptions
877+
gen_options: Optional[GenerationOptions] = None
878+
879+
if prompt is None and messages is None:
880+
raise ValueError("Either prompt or messages must be provided.")
881+
882+
if prompt is not None and messages is not None:
883+
raise ValueError("Only one of prompt or messages can be provided.")
884+
885+
if prompt is not None:
886+
# Currently, we transform the prompt request into a single turn conversation
887+
messages = [{"role": "user", "content": prompt}]
888+
876889
# If a state object is specified, then we switch to "generation options" mode.
877890
# This is because we want the output to be a GenerationResponse which will contain
878891
# the output state.
@@ -882,15 +895,25 @@ async def generate_async(
882895
state = json_to_state(state["state"])
883896

884897
if options is None:
885-
options = GenerationOptions()
886-
887-
# We allow options to be specified both as a dict and as an object.
888-
if options and isinstance(options, dict):
889-
options = GenerationOptions(**options)
898+
gen_options = GenerationOptions()
899+
elif isinstance(options, dict):
900+
gen_options = GenerationOptions(**options)
901+
else:
902+
gen_options = options
903+
else:
904+
# We allow options to be specified both as a dict and as an object.
905+
if options and isinstance(options, dict):
906+
gen_options = GenerationOptions(**options)
907+
elif isinstance(options, GenerationOptions):
908+
gen_options = options
909+
elif options is None:
910+
gen_options = None
911+
else:
912+
raise TypeError("options must be a dict or GenerationOptions")
890913

891914
# Save the generation options in the current async context.
892-
# At this point, options is either None or GenerationOptions
893-
generation_options_var.set(options if not isinstance(options, dict) else None)
915+
# At this point, gen_options is either None or GenerationOptions
916+
generation_options_var.set(gen_options)
894917

895918
if streaming_handler:
896919
streaming_handler_var.set(streaming_handler)
@@ -900,23 +923,14 @@ async def generate_async(
900923
# requests are made.
901924
self.explain_info = self._ensure_explain_info()
902925

903-
if prompt is not None:
904-
# Currently, we transform the prompt request into a single turn conversation
905-
messages = [{"role": "user", "content": prompt}]
906-
raw_llm_request.set(prompt)
907-
else:
908-
raw_llm_request.set(messages)
926+
raw_llm_request.set(messages)
909927

910928
# If we have generation options, we also add them to the context
911-
if options:
929+
if gen_options:
912930
messages = [
913931
{
914932
"role": "context",
915-
"content": {
916-
"generation_options": getattr(
917-
options, "dict", lambda: options
918-
)()
919-
},
933+
"content": {"generation_options": gen_options.model_dump()},
920934
}
921935
] + (messages or [])
922936

@@ -926,9 +940,8 @@ async def generate_async(
926940
if (
927941
messages
928942
and messages[-1]["role"] == "assistant"
929-
and options
930-
and hasattr(options, "rails")
931-
and getattr(getattr(options, "rails", None), "dialog", None) is False
943+
and gen_options
944+
and gen_options.rails.dialog is False
932945
):
933946
# We already have the first message with a context update, so we use that
934947
messages[0]["content"]["bot_message"] = messages[-1]["content"]
@@ -945,7 +958,7 @@ async def generate_async(
945958
processing_log = []
946959

947960
# The array of events corresponding to the provided sequence of messages.
948-
events = self._get_events_for_messages(messages or [], state)
961+
events = self._get_events_for_messages(messages, state) # type: ignore
949962

950963
if self.config.colang_version == "1.0":
951964
# If we had a state object, we also need to prepend the events from the state.
@@ -1064,7 +1077,7 @@ async def generate_async(
10641077
# If a state object is not used, then we use the implicit caching
10651078
if state is None:
10661079
# Save the new events in the history and update the cache
1067-
cache_key = get_history_cache_key((messages or []) + [new_message])
1080+
cache_key = get_history_cache_key((messages) + [new_message]) # type: ignore
10681081
self.events_history_cache[cache_key] = events
10691082
else:
10701083
output_state = {"events": events}
@@ -1092,30 +1105,26 @@ async def generate_async(
10921105
# IF tracing is enabled we need to set GenerationLog attrs
10931106
original_log_options = None
10941107
if self.config.tracing.enabled:
1095-
if options is None:
1096-
options = GenerationOptions()
1108+
if gen_options is None:
1109+
gen_options = GenerationOptions()
10971110
else:
1098-
# create a copy of the options to avoid modifying the original
1099-
if isinstance(options, GenerationOptions):
1100-
options = options.model_copy(deep=True)
1101-
else:
1102-
# If options is a dict, convert it to GenerationOptions
1103-
options = GenerationOptions(**options)
1104-
original_log_options = options.log.model_copy(deep=True)
1111+
# create a copy of the gen_options to avoid modifying the original
1112+
gen_options = gen_options.model_copy(deep=True)
1113+
original_log_options = gen_options.log.model_copy(deep=True)
11051114

11061115
# enable log options
11071116
# it is aggressive, but these are required for tracing
11081117
if (
1109-
not options.log.activated_rails
1110-
or not options.log.llm_calls
1111-
or not options.log.internal_events
1118+
not gen_options.log.activated_rails
1119+
or not gen_options.log.llm_calls
1120+
or not gen_options.log.internal_events
11121121
):
1113-
options.log.activated_rails = True
1114-
options.log.llm_calls = True
1115-
options.log.internal_events = True
1122+
gen_options.log.activated_rails = True
1123+
gen_options.log.llm_calls = True
1124+
gen_options.log.internal_events = True
11161125

11171126
# If we have generation options, we prepare a GenerationResponse instance.
1118-
if options:
1127+
if gen_options:
11191128
# If a prompt was used, we only need to return the content of the message.
11201129
if prompt:
11211130
res = GenerationResponse(response=new_message["content"])
@@ -1136,9 +1145,9 @@ async def generate_async(
11361145

11371146
if self.config.colang_version == "1.0":
11381147
# If output variables are specified, we extract their values
1139-
if getattr(options, "output_vars", None):
1148+
if gen_options and gen_options.output_vars:
11401149
context = compute_context(events)
1141-
output_vars = getattr(options, "output_vars", None)
1150+
output_vars = gen_options.output_vars
11421151
if isinstance(output_vars, list):
11431152
# If we have only a selection of keys, we filter to only that.
11441153
res.output_data = {k: context.get(k) for k in output_vars}
@@ -1149,65 +1158,64 @@ async def generate_async(
11491158
_log = compute_generation_log(processing_log)
11501159

11511160
# Include information about activated rails and LLM calls if requested
1152-
log_options = getattr(options, "log", None)
1161+
log_options = gen_options.log if gen_options else None
11531162
if log_options and (
1154-
getattr(log_options, "activated_rails", False)
1155-
or getattr(log_options, "llm_calls", False)
1163+
log_options.activated_rails or log_options.llm_calls
11561164
):
11571165
res.log = GenerationLog()
11581166

11591167
# We always include the stats
11601168
res.log.stats = _log.stats
11611169

1162-
if getattr(log_options, "activated_rails", False):
1170+
if log_options.activated_rails:
11631171
res.log.activated_rails = _log.activated_rails
11641172

1165-
if getattr(log_options, "llm_calls", False):
1173+
if log_options.llm_calls:
11661174
res.log.llm_calls = []
11671175
for activated_rail in _log.activated_rails:
11681176
for executed_action in activated_rail.executed_actions:
11691177
res.log.llm_calls.extend(executed_action.llm_calls)
11701178

11711179
# Include internal events if requested
1172-
if getattr(log_options, "internal_events", False):
1180+
if log_options and log_options.internal_events:
11731181
if res.log is None:
11741182
res.log = GenerationLog()
11751183

11761184
res.log.internal_events = new_events
11771185

11781186
# Include the Colang history if requested
1179-
if getattr(log_options, "colang_history", False):
1187+
if log_options and log_options.colang_history:
11801188
if res.log is None:
11811189
res.log = GenerationLog()
11821190

11831191
res.log.colang_history = get_colang_history(events)
11841192

11851193
# Include the raw llm output if requested
1186-
if getattr(options, "llm_output", False):
1194+
if gen_options and gen_options.llm_output:
11871195
# Currently, we include the output from the generation LLM calls.
11881196
for activated_rail in _log.activated_rails:
11891197
if activated_rail.type == "generation":
11901198
for executed_action in activated_rail.executed_actions:
11911199
for llm_call in executed_action.llm_calls:
11921200
res.llm_output = llm_call.raw_response
11931201
else:
1194-
if getattr(options, "output_vars", None):
1202+
if gen_options and gen_options.output_vars:
11951203
raise ValueError(
11961204
"The `output_vars` option is not supported for Colang 2.0 configurations."
11971205
)
11981206

1199-
log_options = getattr(options, "log", None)
1207+
log_options = gen_options.log if gen_options else None
12001208
if log_options and (
1201-
getattr(log_options, "activated_rails", False)
1202-
or getattr(log_options, "llm_calls", False)
1203-
or getattr(log_options, "internal_events", False)
1204-
or getattr(log_options, "colang_history", False)
1209+
log_options.activated_rails
1210+
or log_options.llm_calls
1211+
or log_options.internal_events
1212+
or log_options.colang_history
12051213
):
12061214
raise ValueError(
12071215
"The `log` option is not supported for Colang 2.0 configurations."
12081216
)
12091217

1210-
if getattr(options, "llm_output", False):
1218+
if gen_options and gen_options.llm_output:
12111219
raise ValueError(
12121220
"The `llm_output` option is not supported for Colang 2.0 configurations."
12131221
)
@@ -1241,25 +1249,21 @@ async def generate_async(
12411249
if original_log_options:
12421250
if not any(
12431251
(
1244-
getattr(original_log_options, "internal_events", False),
1245-
getattr(original_log_options, "activated_rails", False),
1246-
getattr(original_log_options, "llm_calls", False),
1247-
getattr(original_log_options, "colang_history", False),
1252+
original_log_options.internal_events,
1253+
original_log_options.activated_rails,
1254+
original_log_options.llm_calls,
1255+
original_log_options.colang_history,
12481256
)
12491257
):
12501258
res.log = None
12511259
else:
12521260
# Ensure res.log exists before setting attributes
12531261
if res.log is not None:
1254-
if not getattr(
1255-
original_log_options, "internal_events", False
1256-
):
1262+
if not original_log_options.internal_events:
12571263
res.log.internal_events = []
1258-
if not getattr(
1259-
original_log_options, "activated_rails", False
1260-
):
1264+
if not original_log_options.activated_rails:
12611265
res.log.activated_rails = []
1262-
if not getattr(original_log_options, "llm_calls", False):
1266+
if not original_log_options.llm_calls:
12631267
res.log.llm_calls = []
12641268

12651269
return res

nemoguardrails/rails/llm/options.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
# {..., log: {"llm_calls": [...]}}
7777
7878
"""
79+
7980
from typing import Any, Dict, List, Optional, Union
8081

8182
from pydantic import BaseModel, Field, root_validator
@@ -146,7 +147,7 @@ class GenerationOptions(BaseModel):
146147
default=None,
147148
description="Additional parameters that should be used for the LLM call",
148149
)
149-
llm_output: Optional[bool] = Field(
150+
llm_output: bool = Field(
150151
default=False,
151152
description="Whether the response should also include any custom LLM output.",
152153
)

0 commit comments

Comments
 (0)