Skip to content

Commit 4aeeb4c

Browse files
committed
make memory optional in conversational agent
Signed-off-by: Jing Zhang <[email protected]>
1 parent e005cab commit 4aeeb4c

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

+27-17
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import java.util.concurrent.atomic.AtomicInteger;
3737
import java.util.concurrent.atomic.AtomicReference;
3838

39+
import org.apache.commons.lang3.StringUtils;
3940
import org.apache.commons.text.StringSubstitutor;
4041
import org.opensearch.action.ActionRequest;
4142
import org.opensearch.action.StepListener;
@@ -127,6 +128,11 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
127128
String title = params.get(MLAgentExecutor.QUESTION);
128129
int messageHistoryLimit = getMessageHistoryLimit(params);
129130

131+
if (StringUtils.isEmpty(memoryType)) {
132+
runAgent(mlAgent, params, listener, null, null);
133+
return;
134+
}
135+
130136
ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
131137
conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.<ConversationIndexMemory>wrap(memory -> {
132138
// TODO: call runAgent directly if messageHistoryLimit == 0
@@ -151,8 +157,8 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
151157
);
152158
}
153159

154-
StringBuilder chatHistoryBuilder = new StringBuilder();
155160
if (!messageList.isEmpty()) {
161+
StringBuilder chatHistoryBuilder = new StringBuilder();
156162
String chatHistoryPrefix = params.getOrDefault(PROMPT_CHAT_HISTORY_PREFIX, CHAT_HISTORY_PREFIX);
157163
chatHistoryBuilder.append(chatHistoryPrefix);
158164
for (Message message : messageList) {
@@ -220,7 +226,9 @@ private void runReAct(
220226
AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt));
221227
tmpParameters.put(PROMPT, newPrompt.get());
222228

223-
List<ModelTensors> traceTensors = createModelTensors(sessionId, parentInteractionId);
229+
List<ModelTensors> traceTensors = (conversationIndexMemory == null)
230+
? new ArrayList<>()
231+
: createModelTensors(sessionId, parentInteractionId);
224232
int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, "3")) * 2;
225233
for (int i = 0; i < maxIterations; i++) {
226234
int finalI = i;
@@ -401,8 +409,8 @@ private void runReAct(
401409
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
402410
}
403411

404-
private static List<ModelTensors> createFinalAnswerTensors(List<ModelTensors> sessionId, List<ModelTensor> lastThought) {
405-
List<ModelTensors> finalModelTensors = sessionId;
412+
private static List<ModelTensors> createFinalAnswerTensors(List<ModelTensors> modelTensorsList, List<ModelTensor> lastThought) {
413+
List<ModelTensors> finalModelTensors = modelTensorsList;
406414
finalModelTensors.add(ModelTensors.builder().mlModelTensors(lastThought).build());
407415
return finalModelTensors;
408416
}
@@ -572,19 +580,21 @@ private void sendFinalAnswer(
572580
private static List<ModelTensors> createModelTensors(String sessionId, String parentInteractionId) {
573581
List<ModelTensors> cotModelTensors = new ArrayList<>();
574582

575-
cotModelTensors
576-
.add(
577-
ModelTensors
578-
.builder()
579-
.mlModelTensors(
580-
List
581-
.of(
582-
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
583-
ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build()
584-
)
585-
)
586-
.build()
587-
);
583+
if (!StringUtils.isEmpty(sessionId)) {
584+
cotModelTensors
585+
.add(
586+
ModelTensors
587+
.builder()
588+
.mlModelTensors(
589+
List
590+
.of(
591+
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
592+
ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build()
593+
)
594+
)
595+
.build()
596+
);
597+
}
588598
return cotModelTensors;
589599
}
590600

0 commit comments

Comments
 (0)