Skip to content

Commit ebdda57

Browse files
committed
make memory optional in conversational agent
Signed-off-by: Jing Zhang <[email protected]>
1 parent 6ea4772 commit ebdda57

File tree

2 files changed

+82
-19
lines changed

2 files changed

+82
-19
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

+29-19
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import java.util.concurrent.atomic.AtomicInteger;
3737
import java.util.concurrent.atomic.AtomicReference;
3838

39+
import org.apache.commons.lang3.StringUtils;
3940
import org.apache.commons.text.StringSubstitutor;
4041
import org.opensearch.action.ActionRequest;
4142
import org.opensearch.action.StepListener;
@@ -121,12 +122,17 @@ public MLChatAgentRunner(
121122

122123
@Override
123124
public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener) {
124-
String memoryType = mlAgent.getMemory().getType();
125+
String memoryType = mlAgent.getMemory() == null ? null : mlAgent.getMemory().getType();
125126
String memoryId = params.get(MLAgentExecutor.MEMORY_ID);
126127
String appType = mlAgent.getAppType();
127128
String title = params.get(MLAgentExecutor.QUESTION);
128129
int messageHistoryLimit = getMessageHistoryLimit(params);
129130

131+
if (StringUtils.isEmpty(memoryType)) {
132+
runAgent(mlAgent, params, listener, null, null);
133+
return;
134+
}
135+
130136
ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
131137
conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.<ConversationIndexMemory>wrap(memory -> {
132138
// TODO: call runAgent directly if messageHistoryLimit == 0
@@ -151,8 +157,8 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
151157
);
152158
}
153159

154-
StringBuilder chatHistoryBuilder = new StringBuilder();
155-
if (messageList.size() > 0) {
160+
if (!messageList.isEmpty()) {
161+
StringBuilder chatHistoryBuilder = new StringBuilder();
156162
String chatHistoryPrefix = params.getOrDefault(PROMPT_CHAT_HISTORY_PREFIX, CHAT_HISTORY_PREFIX);
157163
chatHistoryBuilder.append(chatHistoryPrefix);
158164
for (Message message : messageList) {
@@ -219,7 +225,9 @@ private void runReAct(
219225
AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt));
220226
tmpParameters.put(PROMPT, newPrompt.get());
221227

222-
List<ModelTensors> traceTensors = createModelTensors(sessionId, parentInteractionId);
228+
List<ModelTensors> traceTensors = (conversationIndexMemory == null)
229+
? new ArrayList<>()
230+
: createModelTensors(sessionId, parentInteractionId);
223231
int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, "3")) * 2;
224232
for (int i = 0; i < maxIterations; i++) {
225233
int finalI = i;
@@ -396,8 +404,8 @@ private void runReAct(
396404
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
397405
}
398406

399-
private static List<ModelTensors> createFinalAnswerTensors(List<ModelTensors> sessionId, List<ModelTensor> lastThought) {
400-
List<ModelTensors> finalModelTensors = sessionId;
407+
private static List<ModelTensors> createFinalAnswerTensors(List<ModelTensors> modelTensorsList, List<ModelTensor> lastThought) {
408+
List<ModelTensors> finalModelTensors = modelTensorsList;
401409
finalModelTensors.add(ModelTensors.builder().mlModelTensors(lastThought).build());
402410
return finalModelTensors;
403411
}
@@ -567,19 +575,21 @@ private void sendFinalAnswer(
567575
private static List<ModelTensors> createModelTensors(String sessionId, String parentInteractionId) {
568576
List<ModelTensors> cotModelTensors = new ArrayList<>();
569577

570-
cotModelTensors
571-
.add(
572-
ModelTensors
573-
.builder()
574-
.mlModelTensors(
575-
List
576-
.of(
577-
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
578-
ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build()
579-
)
580-
)
581-
.build()
582-
);
578+
if (!StringUtils.isEmpty(sessionId)) {
579+
cotModelTensors
580+
.add(
581+
ModelTensors
582+
.builder()
583+
.mlModelTensors(
584+
List
585+
.of(
586+
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
587+
ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build()
588+
)
589+
)
590+
.build()
591+
);
592+
}
583593
return cotModelTensors;
584594
}
585595

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

+53
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,42 @@ public void testToolExecutionWithChatHistoryParameter() {
879879
Assert.assertTrue(toolParamsCapture.getValue().containsKey(MLChatAgentRunner.CHAT_HISTORY));
880880
}
881881

882+
@Test
883+
public void testParsingJsonBlockFromResponseNoMemory() {
884+
// Prepare the response with JSON block
885+
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
886+
+ "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}";
887+
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";
888+
889+
// Mock LLM response to not contain "thought" but contain "response" with JSON block
890+
Map<String, String> llmResponse = new HashMap<>();
891+
llmResponse.put("response", responseWithJsonBlock);
892+
doAnswer(getLLMAnswer(llmResponse))
893+
.when(client)
894+
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));
895+
896+
// Create an MLAgent and run the MLChatAgentRunner
897+
MLAgent mlAgent = createMLAgentNoMemory();
898+
Map<String, String> params = new HashMap<>();
899+
params.put("verbose", "true");
900+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
901+
902+
// Capture the response passed to the listener
903+
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
904+
verify(agentActionListener).onResponse(responseCaptor.capture());
905+
906+
// Extract the captured response
907+
Object capturedResponse = responseCaptor.getValue();
908+
assertTrue(capturedResponse instanceof ModelTensorOutput);
909+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
910+
911+
ModelTensor modelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0);
912+
913+
assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
914+
assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
915+
assertEquals("parsed final answer", modelTensor.getResult());
916+
}
917+
882918
// Helper methods to create MLAgent and parameters
883919
private MLAgent createMLAgentWithTools() {
884920
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
@@ -917,6 +953,23 @@ private MLAgent createMLAgentWithToolsConfig(Map<String, String> configMap) {
917953
.build();
918954
}
919955

956+
private MLAgent createMLAgentNoMemory() {
957+
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
958+
MLToolSpec firstToolSpec = MLToolSpec
959+
.builder()
960+
.name(FIRST_TOOL)
961+
.type(FIRST_TOOL)
962+
.parameters(ImmutableMap.of("key1", "value1", "key2", "value2"))
963+
.build();
964+
return MLAgent
965+
.builder()
966+
.name("TestAgent")
967+
.type(MLAgentType.CONVERSATIONAL.name())
968+
.tools(Arrays.asList(firstToolSpec))
969+
.llm(llmSpec)
970+
.build();
971+
}
972+
920973
private Map<String, String> createAgentParamsWithAction(String action, String actionInput) {
921974
Map<String, String> params = new HashMap<>();
922975
params.put("action", action);

0 commit comments

Comments
 (0)