36
36
import java .util .concurrent .atomic .AtomicInteger ;
37
37
import java .util .concurrent .atomic .AtomicReference ;
38
38
39
+ import org .apache .commons .lang3 .StringUtils ;
39
40
import org .apache .commons .text .StringSubstitutor ;
40
41
import org .opensearch .action .ActionRequest ;
41
42
import org .opensearch .action .StepListener ;
@@ -121,12 +122,17 @@ public MLChatAgentRunner(
121
122
122
123
@ Override
123
124
public void run (MLAgent mlAgent , Map <String , String > params , ActionListener <Object > listener ) {
124
- String memoryType = mlAgent .getMemory ().getType ();
125
+ String memoryType = mlAgent .getMemory () == null ? null : mlAgent . getMemory () .getType ();
125
126
String memoryId = params .get (MLAgentExecutor .MEMORY_ID );
126
127
String appType = mlAgent .getAppType ();
127
128
String title = params .get (MLAgentExecutor .QUESTION );
128
129
int messageHistoryLimit = getMessageHistoryLimit (params );
129
130
131
+ if (StringUtils .isEmpty (memoryType )) {
132
+ runAgent (mlAgent , params , listener , null , null );
133
+ return ;
134
+ }
135
+
130
136
ConversationIndexMemory .Factory conversationIndexMemoryFactory = (ConversationIndexMemory .Factory ) memoryFactoryMap .get (memoryType );
131
137
conversationIndexMemoryFactory .create (title , memoryId , appType , ActionListener .<ConversationIndexMemory >wrap (memory -> {
132
138
// TODO: call runAgent directly if messageHistoryLimit == 0
@@ -151,8 +157,8 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
151
157
);
152
158
}
153
159
154
- StringBuilder chatHistoryBuilder = new StringBuilder ();
155
- if ( messageList . size () > 0 ) {
160
+ if (! messageList . isEmpty ()) {
161
+ StringBuilder chatHistoryBuilder = new StringBuilder ();
156
162
String chatHistoryPrefix = params .getOrDefault (PROMPT_CHAT_HISTORY_PREFIX , CHAT_HISTORY_PREFIX );
157
163
chatHistoryBuilder .append (chatHistoryPrefix );
158
164
for (Message message : messageList ) {
@@ -219,7 +225,9 @@ private void runReAct(
219
225
AtomicReference <String > newPrompt = new AtomicReference <>(tmpSubstitutor .replace (prompt ));
220
226
tmpParameters .put (PROMPT , newPrompt .get ());
221
227
222
- List <ModelTensors > traceTensors = createModelTensors (sessionId , parentInteractionId );
228
+ List <ModelTensors > traceTensors = (conversationIndexMemory == null )
229
+ ? new ArrayList <>()
230
+ : createModelTensors (sessionId , parentInteractionId );
223
231
int maxIterations = Integer .parseInt (tmpParameters .getOrDefault (MAX_ITERATION , "3" )) * 2 ;
224
232
for (int i = 0 ; i < maxIterations ; i ++) {
225
233
int finalI = i ;
@@ -396,8 +404,8 @@ private void runReAct(
396
404
client .execute (MLPredictionTaskAction .INSTANCE , request , firstListener );
397
405
}
398
406
399
- private static List <ModelTensors > createFinalAnswerTensors (List <ModelTensors > sessionId , List <ModelTensor > lastThought ) {
400
- List <ModelTensors > finalModelTensors = sessionId ;
407
+ private static List <ModelTensors > createFinalAnswerTensors (List <ModelTensors > modelTensorsList , List <ModelTensor > lastThought ) {
408
+ List <ModelTensors > finalModelTensors = modelTensorsList ;
401
409
finalModelTensors .add (ModelTensors .builder ().mlModelTensors (lastThought ).build ());
402
410
return finalModelTensors ;
403
411
}
@@ -567,19 +575,21 @@ private void sendFinalAnswer(
567
575
private static List <ModelTensors > createModelTensors (String sessionId , String parentInteractionId ) {
568
576
List <ModelTensors > cotModelTensors = new ArrayList <>();
569
577
570
- cotModelTensors
571
- .add (
572
- ModelTensors
573
- .builder ()
574
- .mlModelTensors (
575
- List
576
- .of (
577
- ModelTensor .builder ().name (MLAgentExecutor .MEMORY_ID ).result (sessionId ).build (),
578
- ModelTensor .builder ().name (MLAgentExecutor .PARENT_INTERACTION_ID ).result (parentInteractionId ).build ()
579
- )
580
- )
581
- .build ()
582
- );
578
+ if (!StringUtils .isEmpty (sessionId )) {
579
+ cotModelTensors
580
+ .add (
581
+ ModelTensors
582
+ .builder ()
583
+ .mlModelTensors (
584
+ List
585
+ .of (
586
+ ModelTensor .builder ().name (MLAgentExecutor .MEMORY_ID ).result (sessionId ).build (),
587
+ ModelTensor .builder ().name (MLAgentExecutor .PARENT_INTERACTION_ID ).result (parentInteractionId ).build ()
588
+ )
589
+ )
590
+ .build ()
591
+ );
592
+ }
583
593
return cotModelTensors ;
584
594
}
585
595
0 commit comments