@@ -873,6 +873,19 @@ async def generate_async(
873
873
The completion (when a prompt is provided) or the next message.
874
874
875
875
System messages are not yet supported."""
876
+ # convert options to gen_options of type GenerationOptions
877
+ gen_options : Optional [GenerationOptions ] = None
878
+
879
+ if prompt is None and messages is None :
880
+ raise ValueError ("Either prompt or messages must be provided." )
881
+
882
+ if prompt is not None and messages is not None :
883
+ raise ValueError ("Only one of prompt or messages can be provided." )
884
+
885
+ if prompt is not None :
886
+ # Currently, we transform the prompt request into a single turn conversation
887
+ messages = [{"role" : "user" , "content" : prompt }]
888
+
876
889
# If a state object is specified, then we switch to "generation options" mode.
877
890
# This is because we want the output to be a GenerationResponse which will contain
878
891
# the output state.
@@ -882,15 +895,25 @@ async def generate_async(
882
895
state = json_to_state (state ["state" ])
883
896
884
897
if options is None :
885
- options = GenerationOptions ()
886
-
887
- # We allow options to be specified both as a dict and as an object.
888
- if options and isinstance (options , dict ):
889
- options = GenerationOptions (** options )
898
+ gen_options = GenerationOptions ()
899
+ elif isinstance (options , dict ):
900
+ gen_options = GenerationOptions (** options )
901
+ else :
902
+ gen_options = options
903
+ else :
904
+ # We allow options to be specified both as a dict and as an object.
905
+ if options and isinstance (options , dict ):
906
+ gen_options = GenerationOptions (** options )
907
+ elif isinstance (options , GenerationOptions ):
908
+ gen_options = options
909
+ elif options is None :
910
+ gen_options = None
911
+ else :
912
+ raise TypeError ("options must be a dict or GenerationOptions" )
890
913
891
914
# Save the generation options in the current async context.
892
- # At this point, options is either None or GenerationOptions
893
- generation_options_var .set (options if not isinstance ( options , dict ) else None )
915
+ # At this point, gen_options is either None or GenerationOptions
916
+ generation_options_var .set (gen_options )
894
917
895
918
if streaming_handler :
896
919
streaming_handler_var .set (streaming_handler )
@@ -900,23 +923,14 @@ async def generate_async(
900
923
# requests are made.
901
924
self .explain_info = self ._ensure_explain_info ()
902
925
903
- if prompt is not None :
904
- # Currently, we transform the prompt request into a single turn conversation
905
- messages = [{"role" : "user" , "content" : prompt }]
906
- raw_llm_request .set (prompt )
907
- else :
908
- raw_llm_request .set (messages )
926
+ raw_llm_request .set (messages )
909
927
910
928
# If we have generation options, we also add them to the context
911
- if options :
929
+ if gen_options :
912
930
messages = [
913
931
{
914
932
"role" : "context" ,
915
- "content" : {
916
- "generation_options" : getattr (
917
- options , "dict" , lambda : options
918
- )()
919
- },
933
+ "content" : {"generation_options" : gen_options .model_dump ()},
920
934
}
921
935
] + (messages or [])
922
936
@@ -926,9 +940,8 @@ async def generate_async(
926
940
if (
927
941
messages
928
942
and messages [- 1 ]["role" ] == "assistant"
929
- and options
930
- and hasattr (options , "rails" )
931
- and getattr (getattr (options , "rails" , None ), "dialog" , None ) is False
943
+ and gen_options
944
+ and gen_options .rails .dialog is False
932
945
):
933
946
# We already have the first message with a context update, so we use that
934
947
messages [0 ]["content" ]["bot_message" ] = messages [- 1 ]["content" ]
@@ -945,7 +958,7 @@ async def generate_async(
945
958
processing_log = []
946
959
947
960
# The array of events corresponding to the provided sequence of messages.
948
- events = self ._get_events_for_messages (messages or [] , state )
961
+ events = self ._get_events_for_messages (messages , state ) # type: ignore
949
962
950
963
if self .config .colang_version == "1.0" :
951
964
# If we had a state object, we also need to prepend the events from the state.
@@ -1064,7 +1077,7 @@ async def generate_async(
1064
1077
# If a state object is not used, then we use the implicit caching
1065
1078
if state is None :
1066
1079
# Save the new events in the history and update the cache
1067
- cache_key = get_history_cache_key ((messages or [] ) + [new_message ])
1080
+ cache_key = get_history_cache_key ((messages ) + [new_message ]) # type: ignore
1068
1081
self .events_history_cache [cache_key ] = events
1069
1082
else :
1070
1083
output_state = {"events" : events }
@@ -1092,30 +1105,26 @@ async def generate_async(
1092
1105
# IF tracing is enabled we need to set GenerationLog attrs
1093
1106
original_log_options = None
1094
1107
if self .config .tracing .enabled :
1095
- if options is None :
1096
- options = GenerationOptions ()
1108
+ if gen_options is None :
1109
+ gen_options = GenerationOptions ()
1097
1110
else :
1098
- # create a copy of the options to avoid modifying the original
1099
- if isinstance (options , GenerationOptions ):
1100
- options = options .model_copy (deep = True )
1101
- else :
1102
- # If options is a dict, convert it to GenerationOptions
1103
- options = GenerationOptions (** options )
1104
- original_log_options = options .log .model_copy (deep = True )
1111
+ # create a copy of the gen_options to avoid modifying the original
1112
+ gen_options = gen_options .model_copy (deep = True )
1113
+ original_log_options = gen_options .log .model_copy (deep = True )
1105
1114
1106
1115
# enable log options
1107
1116
# it is aggressive, but these are required for tracing
1108
1117
if (
1109
- not options .log .activated_rails
1110
- or not options .log .llm_calls
1111
- or not options .log .internal_events
1118
+ not gen_options .log .activated_rails
1119
+ or not gen_options .log .llm_calls
1120
+ or not gen_options .log .internal_events
1112
1121
):
1113
- options .log .activated_rails = True
1114
- options .log .llm_calls = True
1115
- options .log .internal_events = True
1122
+ gen_options .log .activated_rails = True
1123
+ gen_options .log .llm_calls = True
1124
+ gen_options .log .internal_events = True
1116
1125
1117
1126
# If we have generation options, we prepare a GenerationResponse instance.
1118
- if options :
1127
+ if gen_options :
1119
1128
# If a prompt was used, we only need to return the content of the message.
1120
1129
if prompt :
1121
1130
res = GenerationResponse (response = new_message ["content" ])
@@ -1136,9 +1145,9 @@ async def generate_async(
1136
1145
1137
1146
if self .config .colang_version == "1.0" :
1138
1147
# If output variables are specified, we extract their values
1139
- if getattr ( options , "output_vars" , None ) :
1148
+ if gen_options and gen_options . output_vars :
1140
1149
context = compute_context (events )
1141
- output_vars = getattr ( options , " output_vars" , None )
1150
+ output_vars = gen_options . output_vars
1142
1151
if isinstance (output_vars , list ):
1143
1152
# If we have only a selection of keys, we filter to only that.
1144
1153
res .output_data = {k : context .get (k ) for k in output_vars }
@@ -1149,65 +1158,64 @@ async def generate_async(
1149
1158
_log = compute_generation_log (processing_log )
1150
1159
1151
1160
# Include information about activated rails and LLM calls if requested
1152
- log_options = getattr ( options , " log" , None )
1161
+ log_options = gen_options . log if gen_options else None
1153
1162
if log_options and (
1154
- getattr (log_options , "activated_rails" , False )
1155
- or getattr (log_options , "llm_calls" , False )
1163
+ log_options .activated_rails or log_options .llm_calls
1156
1164
):
1157
1165
res .log = GenerationLog ()
1158
1166
1159
1167
# We always include the stats
1160
1168
res .log .stats = _log .stats
1161
1169
1162
- if getattr ( log_options , " activated_rails" , False ) :
1170
+ if log_options . activated_rails :
1163
1171
res .log .activated_rails = _log .activated_rails
1164
1172
1165
- if getattr ( log_options , " llm_calls" , False ) :
1173
+ if log_options . llm_calls :
1166
1174
res .log .llm_calls = []
1167
1175
for activated_rail in _log .activated_rails :
1168
1176
for executed_action in activated_rail .executed_actions :
1169
1177
res .log .llm_calls .extend (executed_action .llm_calls )
1170
1178
1171
1179
# Include internal events if requested
1172
- if getattr ( log_options , "internal_events" , False ) :
1180
+ if log_options and log_options . internal_events :
1173
1181
if res .log is None :
1174
1182
res .log = GenerationLog ()
1175
1183
1176
1184
res .log .internal_events = new_events
1177
1185
1178
1186
# Include the Colang history if requested
1179
- if getattr ( log_options , "colang_history" , False ) :
1187
+ if log_options and log_options . colang_history :
1180
1188
if res .log is None :
1181
1189
res .log = GenerationLog ()
1182
1190
1183
1191
res .log .colang_history = get_colang_history (events )
1184
1192
1185
1193
# Include the raw llm output if requested
1186
- if getattr ( options , "llm_output" , False ) :
1194
+ if gen_options and gen_options . llm_output :
1187
1195
# Currently, we include the output from the generation LLM calls.
1188
1196
for activated_rail in _log .activated_rails :
1189
1197
if activated_rail .type == "generation" :
1190
1198
for executed_action in activated_rail .executed_actions :
1191
1199
for llm_call in executed_action .llm_calls :
1192
1200
res .llm_output = llm_call .raw_response
1193
1201
else :
1194
- if getattr ( options , "output_vars" , None ) :
1202
+ if gen_options and gen_options . output_vars :
1195
1203
raise ValueError (
1196
1204
"The `output_vars` option is not supported for Colang 2.0 configurations."
1197
1205
)
1198
1206
1199
- log_options = getattr ( options , " log" , None )
1207
+ log_options = gen_options . log if gen_options else None
1200
1208
if log_options and (
1201
- getattr ( log_options , " activated_rails" , False )
1202
- or getattr ( log_options , " llm_calls" , False )
1203
- or getattr ( log_options , " internal_events" , False )
1204
- or getattr ( log_options , " colang_history" , False )
1209
+ log_options . activated_rails
1210
+ or log_options . llm_calls
1211
+ or log_options . internal_events
1212
+ or log_options . colang_history
1205
1213
):
1206
1214
raise ValueError (
1207
1215
"The `log` option is not supported for Colang 2.0 configurations."
1208
1216
)
1209
1217
1210
- if getattr ( options , "llm_output" , False ) :
1218
+ if gen_options and gen_options . llm_output :
1211
1219
raise ValueError (
1212
1220
"The `llm_output` option is not supported for Colang 2.0 configurations."
1213
1221
)
@@ -1241,25 +1249,21 @@ async def generate_async(
1241
1249
if original_log_options :
1242
1250
if not any (
1243
1251
(
1244
- getattr ( original_log_options , " internal_events" , False ) ,
1245
- getattr ( original_log_options , " activated_rails" , False ) ,
1246
- getattr ( original_log_options , " llm_calls" , False ) ,
1247
- getattr ( original_log_options , " colang_history" , False ) ,
1252
+ original_log_options . internal_events ,
1253
+ original_log_options . activated_rails ,
1254
+ original_log_options . llm_calls ,
1255
+ original_log_options . colang_history ,
1248
1256
)
1249
1257
):
1250
1258
res .log = None
1251
1259
else :
1252
1260
# Ensure res.log exists before setting attributes
1253
1261
if res .log is not None :
1254
- if not getattr (
1255
- original_log_options , "internal_events" , False
1256
- ):
1262
+ if not original_log_options .internal_events :
1257
1263
res .log .internal_events = []
1258
- if not getattr (
1259
- original_log_options , "activated_rails" , False
1260
- ):
1264
+ if not original_log_options .activated_rails :
1261
1265
res .log .activated_rails = []
1262
- if not getattr ( original_log_options , " llm_calls" , False ) :
1266
+ if not original_log_options . llm_calls :
1263
1267
res .log .llm_calls = []
1264
1268
1265
1269
return res
0 commit comments