36
36
import java .util .concurrent .atomic .AtomicInteger ;
37
37
import java .util .concurrent .atomic .AtomicReference ;
38
38
39
+ import org .apache .commons .lang3 .StringUtils ;
39
40
import org .apache .commons .text .StringSubstitutor ;
40
41
import org .opensearch .action .ActionRequest ;
41
42
import org .opensearch .action .StepListener ;
@@ -127,6 +128,11 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
127
128
String title = params .get (MLAgentExecutor .QUESTION );
128
129
int messageHistoryLimit = getMessageHistoryLimit (params );
129
130
131
+ if (StringUtils .isEmpty (memoryType )) {
132
+ runAgent (mlAgent , params , listener , null , null );
133
+ return ;
134
+ }
135
+
130
136
ConversationIndexMemory .Factory conversationIndexMemoryFactory = (ConversationIndexMemory .Factory ) memoryFactoryMap .get (memoryType );
131
137
conversationIndexMemoryFactory .create (title , memoryId , appType , ActionListener .<ConversationIndexMemory >wrap (memory -> {
132
138
// TODO: call runAgent directly if messageHistoryLimit == 0
@@ -151,8 +157,8 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
151
157
);
152
158
}
153
159
154
- StringBuilder chatHistoryBuilder = new StringBuilder ();
155
160
if (!messageList .isEmpty ()) {
161
+ StringBuilder chatHistoryBuilder = new StringBuilder ();
156
162
String chatHistoryPrefix = params .getOrDefault (PROMPT_CHAT_HISTORY_PREFIX , CHAT_HISTORY_PREFIX );
157
163
chatHistoryBuilder .append (chatHistoryPrefix );
158
164
for (Message message : messageList ) {
@@ -220,7 +226,9 @@ private void runReAct(
220
226
AtomicReference <String > newPrompt = new AtomicReference <>(tmpSubstitutor .replace (prompt ));
221
227
tmpParameters .put (PROMPT , newPrompt .get ());
222
228
223
- List <ModelTensors > traceTensors = createModelTensors (sessionId , parentInteractionId );
229
+ List <ModelTensors > traceTensors = (conversationIndexMemory == null )
230
+ ? new ArrayList <>()
231
+ : createModelTensors (sessionId , parentInteractionId );
224
232
int maxIterations = Integer .parseInt (tmpParameters .getOrDefault (MAX_ITERATION , "3" )) * 2 ;
225
233
for (int i = 0 ; i < maxIterations ; i ++) {
226
234
int finalI = i ;
@@ -401,8 +409,8 @@ private void runReAct(
401
409
client .execute (MLPredictionTaskAction .INSTANCE , request , firstListener );
402
410
}
403
411
404
- private static List <ModelTensors > createFinalAnswerTensors (List <ModelTensors > sessionId , List <ModelTensor > lastThought ) {
405
- List <ModelTensors > finalModelTensors = sessionId ;
412
+ private static List <ModelTensors > createFinalAnswerTensors (List <ModelTensors > modelTensorsList , List <ModelTensor > lastThought ) {
413
+ List <ModelTensors > finalModelTensors = modelTensorsList ;
406
414
finalModelTensors .add (ModelTensors .builder ().mlModelTensors (lastThought ).build ());
407
415
return finalModelTensors ;
408
416
}
@@ -572,19 +580,21 @@ private void sendFinalAnswer(
572
580
private static List <ModelTensors > createModelTensors (String sessionId , String parentInteractionId ) {
573
581
List <ModelTensors > cotModelTensors = new ArrayList <>();
574
582
575
- cotModelTensors
576
- .add (
577
- ModelTensors
578
- .builder ()
579
- .mlModelTensors (
580
- List
581
- .of (
582
- ModelTensor .builder ().name (MLAgentExecutor .MEMORY_ID ).result (sessionId ).build (),
583
- ModelTensor .builder ().name (MLAgentExecutor .PARENT_INTERACTION_ID ).result (parentInteractionId ).build ()
584
- )
585
- )
586
- .build ()
587
- );
583
+ if (!StringUtils .isEmpty (sessionId )) {
584
+ cotModelTensors
585
+ .add (
586
+ ModelTensors
587
+ .builder ()
588
+ .mlModelTensors (
589
+ List
590
+ .of (
591
+ ModelTensor .builder ().name (MLAgentExecutor .MEMORY_ID ).result (sessionId ).build (),
592
+ ModelTensor .builder ().name (MLAgentExecutor .PARENT_INTERACTION_ID ).result (parentInteractionId ).build ()
593
+ )
594
+ )
595
+ .build ()
596
+ );
597
+ }
588
598
return cotModelTensors ;
589
599
}
590
600
0 commit comments