From 4b22adb74d0251782ee7f2470253af17c4006ff8 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 4 Mar 2025 11:26:01 -0800 Subject: [PATCH] make memory optional in conversational agent Signed-off-by: Jing Zhang --- .../algorithms/agent/MLChatAgentRunner.java | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 3992b9f341..f54a46509d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -36,6 +36,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; import org.opensearch.action.StepListener; @@ -127,6 +128,11 @@ public void run(MLAgent mlAgent, Map params, ActionListenerwrap(memory -> { // TODO: call runAgent directly if messageHistoryLimit == 0 @@ -151,8 +157,8 @@ public void run(MLAgent mlAgent, Map params, ActionListener newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); tmpParameters.put(PROMPT, newPrompt.get()); - List traceTensors = createModelTensors(sessionId, parentInteractionId); + List traceTensors = (conversationIndexMemory == null) + ? new ArrayList<>() + : createModelTensors(sessionId, parentInteractionId); int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, "3")) * 2; for (int i = 0; i < maxIterations; i++) { int finalI = i; @@ -401,8 +409,8 @@ private void runReAct( client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener); } - private static List createFinalAnswerTensors(List sessionId, List lastThought) { - List finalModelTensors = sessionId; + private static List createFinalAnswerTensors(List modelTensorsList, List lastThought) { + List finalModelTensors = modelTensorsList; finalModelTensors.add(ModelTensors.builder().mlModelTensors(lastThought).build()); return finalModelTensors; } @@ -572,19 +580,21 @@ private void sendFinalAnswer( private static List createModelTensors(String sessionId, String parentInteractionId) { List cotModelTensors = new ArrayList<>(); - cotModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - List - .of( - ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), - ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build() - ) - ) - .build() - ); + if (!StringUtils.isEmpty(sessionId)) { + cotModelTensors + .add( + ModelTensors + .builder() + .mlModelTensors( + List + .of( + ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), + ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build() + ) + ) + .build() + ); + } return cotModelTensors; }