Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.GenerateContentResponseUsageMetadata;
import com.google.genai.types.Part;
import com.google.genai.types.Schema;
import com.google.genai.types.ToolConfig;
Expand All @@ -51,6 +52,7 @@
import dev.langchain4j.data.pdf.PdfFile;
import dev.langchain4j.data.video.Video;
import dev.langchain4j.exception.UnsupportedFeatureException;
import dev.langchain4j.model.TokenCountEstimator;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.request.ChatRequest;
Expand All @@ -64,6 +66,7 @@
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.TokenUsage;
import io.reactivex.rxjava3.core.BackpressureStrategy;
import io.reactivex.rxjava3.core.Flowable;
import java.util.ArrayList;
Expand All @@ -83,24 +86,109 @@ public class LangChain4j extends BaseLlm {
private final ChatModel chatModel;
private final StreamingChatModel streamingChatModel;
private final ObjectMapper objectMapper;
private final TokenCountEstimator tokenCountEstimator;

public static Builder builder() {
return new Builder();
}

public static class Builder {
private ChatModel chatModel;
private StreamingChatModel streamingChatModel;
private String modelName;
private TokenCountEstimator tokenCountEstimator;

private Builder() {}

public Builder chatModel(ChatModel chatModel) {
this.chatModel = chatModel;
return this;
}

public Builder streamingChatModel(StreamingChatModel streamingChatModel) {
this.streamingChatModel = streamingChatModel;
return this;
}

public Builder modelName(String modelName) {
this.modelName = modelName;
return this;
}

public Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator) {
this.tokenCountEstimator = tokenCountEstimator;
return this;
}

public LangChain4j build() {
if (chatModel == null && streamingChatModel == null) {
throw new IllegalStateException(
"At least one of chatModel or streamingChatModel must be provided");
}

String effectiveModelName = modelName;
if (effectiveModelName == null) {
if (chatModel != null) {
effectiveModelName = chatModel.defaultRequestParameters().modelName();
} else {
effectiveModelName = streamingChatModel.defaultRequestParameters().modelName();
}
}

if (effectiveModelName == null) {
throw new IllegalStateException("Model name cannot be null");
}

return new LangChain4j(
chatModel, streamingChatModel, effectiveModelName, tokenCountEstimator);
}
}

private LangChain4j(
ChatModel chatModel,
StreamingChatModel streamingChatModel,
String modelName,
TokenCountEstimator tokenCountEstimator) {
super(Objects.requireNonNull(modelName, "model name cannot be null"));
this.chatModel = chatModel;
this.streamingChatModel = streamingChatModel;
this.objectMapper = new ObjectMapper();
this.tokenCountEstimator = tokenCountEstimator;
}

public LangChain4j(ChatModel chatModel) {
this(chatModel, (TokenCountEstimator) null);
}

public LangChain4j(ChatModel chatModel, TokenCountEstimator tokenCountEstimator) {
super(
Objects.requireNonNull(
chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null"));
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
this.streamingChatModel = null;
this.objectMapper = new ObjectMapper();
this.tokenCountEstimator = tokenCountEstimator;
}

public LangChain4j(ChatModel chatModel, String modelName) {
this(chatModel, modelName, (TokenCountEstimator) null);
}

public LangChain4j(
ChatModel chatModel, String modelName, TokenCountEstimator tokenCountEstimator) {
super(Objects.requireNonNull(modelName, "chat model name cannot be null"));
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
this.streamingChatModel = null;
this.objectMapper = new ObjectMapper();
this.tokenCountEstimator = tokenCountEstimator;
}

public LangChain4j(StreamingChatModel streamingChatModel) {
this(streamingChatModel, (TokenCountEstimator) null);
}

public LangChain4j(
StreamingChatModel streamingChatModel, TokenCountEstimator tokenCountEstimator) {
super(
Objects.requireNonNull(
streamingChatModel.defaultRequestParameters().modelName(),
Expand All @@ -109,22 +197,23 @@ public LangChain4j(StreamingChatModel streamingChatModel) {
this.streamingChatModel =
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
this.objectMapper = new ObjectMapper();
this.tokenCountEstimator = tokenCountEstimator;
}

public LangChain4j(StreamingChatModel streamingChatModel, String modelName) {
super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null"));
this.chatModel = null;
this.streamingChatModel =
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
this.objectMapper = new ObjectMapper();
this(streamingChatModel, modelName, (TokenCountEstimator) null);
}

public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) {
super(Objects.requireNonNull(modelName, "model name cannot be null"));
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
public LangChain4j(
StreamingChatModel streamingChatModel,
String modelName,
TokenCountEstimator tokenCountEstimator) {
super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null"));
this.chatModel = null;
this.streamingChatModel =
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
this.objectMapper = new ObjectMapper();
this.tokenCountEstimator = tokenCountEstimator;
}

@Override
Expand Down Expand Up @@ -185,7 +274,7 @@ public void onError(Throwable throwable) {

ChatRequest chatRequest = toChatRequest(llmRequest);
ChatResponse chatResponse = chatModel.chat(chatRequest);
LlmResponse llmResponse = toLlmResponse(chatResponse);
LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest);

return Flowable.just(llmResponse);
}
Expand Down Expand Up @@ -496,11 +585,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) {
}
}

private LlmResponse toLlmResponse(ChatResponse chatResponse) {
private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) {
Content content =
Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build();

return LlmResponse.builder().content(content).build();
LlmResponse.Builder builder = LlmResponse.builder().content(content);
TokenUsage tokenUsage = chatResponse.tokenUsage();
if (tokenCountEstimator != null) {
try {
int estimatedInput =
tokenCountEstimator.estimateTokenCountInMessages(chatRequest.messages());
int estimatedOutput =
tokenCountEstimator.estimateTokenCountInText(chatResponse.aiMessage().text());
int estimatedTotal = estimatedInput + estimatedOutput;
builder.usageMetadata(
GenerateContentResponseUsageMetadata.builder()
.promptTokenCount(estimatedInput)
.candidatesTokenCount(estimatedOutput)
.totalTokenCount(estimatedTotal)
.build());
} catch (Exception e) {
e.printStackTrace();
}
} else if (tokenUsage != null) {
builder.usageMetadata(
GenerateContentResponseUsageMetadata.builder()
.promptTokenCount(tokenUsage.inputTokenCount())
.candidatesTokenCount(tokenUsage.outputTokenCount())
.totalTokenCount(tokenUsage.totalTokenCount())
.build());
}

return builder.build();
}

private List<Part> toParts(AiMessage aiMessage) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.TokenCountEstimator;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.TokenUsage;
import io.reactivex.rxjava3.core.Flowable;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -688,4 +690,140 @@ void testGenerateContentWithStructuredResponseJsonSchema() {
final UserMessage userMessage = (UserMessage) capturedRequest.messages().get(0);
assertThat(userMessage.singleText()).isEqualTo("Give me information about John Doe");
}

@Test
@DisplayName(
"Should use TokenCountEstimator to estimate token usage when TokenUsage is not available")
void testTokenCountEstimatorFallback() {
// Given
// Create a mock TokenCountEstimator
final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class);
when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(50); // Input tokens
when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(20); // Output tokens

// Create LangChain4j with the TokenCountEstimator using Builder
final LangChain4j langChain4jWithEstimator =
LangChain4j.builder()
.chatModel(chatModel)
.modelName(MODEL_NAME)
.tokenCountEstimator(tokenCountEstimator)
.build();

// Create a LlmRequest
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("What is the weather today?"))))
.build();

// Mock ChatResponse WITHOUT TokenUsage (simulating when LLM doesn't provide token counts)
final ChatResponse chatResponse = mock(ChatResponse.class);
final AiMessage aiMessage = AiMessage.from("The weather is sunny today.");
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response =
langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst();

// Then
// Verify the response has usage metadata estimated by TokenCountEstimator
assertThat(response).isNotNull();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("The weather is sunny today.");

// IMPORTANT: Verify that token usage was estimated via the TokenCountEstimator
assertThat(response.usageMetadata()).isPresent();
final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get();
assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(50)); // From estimator
assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(20)); // From estimator
assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(70)); // 50 + 20

// Verify the estimator was actually called
verify(tokenCountEstimator).estimateTokenCountInMessages(any());
verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today.");
}

@Test
@DisplayName("Should prioritize TokenCountEstimator over TokenUsage when estimator is provided")
void testTokenCountEstimatorPriority() {
// Given
// Create a mock TokenCountEstimator
final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class);
when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(100); // From estimator
when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(50); // From estimator

// Create LangChain4j with the TokenCountEstimator using Builder
final LangChain4j langChain4jWithEstimator =
LangChain4j.builder()
.chatModel(chatModel)
.modelName(MODEL_NAME)
.tokenCountEstimator(tokenCountEstimator)
.build();

// Create a LlmRequest
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("What is the weather today?"))))
.build();

// Mock ChatResponse WITH actual TokenUsage from the LLM
final ChatResponse chatResponse = mock(ChatResponse.class);
final AiMessage aiMessage = AiMessage.from("The weather is sunny today.");
final TokenUsage actualTokenUsage = new TokenUsage(30, 15, 45); // Actual token counts from LLM
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatResponse.tokenUsage()).thenReturn(actualTokenUsage); // LLM provides token usage
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response =
langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst();

// Then
// IMPORTANT: When TokenCountEstimator is present, it takes priority over TokenUsage
assertThat(response).isNotNull();
assertThat(response.usageMetadata()).isPresent();
final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get();
assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(100)); // From estimator
assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(50)); // From estimator
assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(150)); // 100 + 50

// Verify the estimator was called (it takes priority)
verify(tokenCountEstimator).estimateTokenCountInMessages(any());
verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today.");
}

@Test
@DisplayName("Should not include usageMetadata when TokenUsage is null and no estimator provided")
void testNoUsageMetadataWithoutEstimator() {
// Given
// Create LangChain4j WITHOUT TokenCountEstimator (default behavior)
final LangChain4j langChain4jNoEstimator = new LangChain4j(chatModel, MODEL_NAME);

// Create a LlmRequest
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("Hello, world!"))))
.build();

// Mock ChatResponse WITHOUT TokenUsage
final ChatResponse chatResponse = mock(ChatResponse.class);
final AiMessage aiMessage = AiMessage.from("Hello! How can I help you?");
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response =
langChain4jNoEstimator.generateContent(llmRequest, false).blockingFirst();

// Then
// Verify the response does NOT have usage metadata
assertThat(response).isNotNull();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("Hello! How can I help you?");

// IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator
assertThat(response.usageMetadata()).isEmpty();
}
}
23 changes: 23 additions & 0 deletions contrib/samples/a2a_basic/bin/.project
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<?xml version="1.0" encoding="UTF-8"?>
<projectDescription>
<name>google-adk-sample-a2a-basic</name>
<comment></comment>
<projects>
</projects>
<buildSpec>
<buildCommand>
<name>org.eclipse.jdt.core.javabuilder</name>
<arguments>
</arguments>
</buildCommand>
<buildCommand>
<name>org.eclipse.m2e.core.maven2Builder</name>
<arguments>
</arguments>
</buildCommand>
</buildSpec>
<natures>
<nature>org.eclipse.jdt.core.javanature</nature>
<nature>org.eclipse.m2e.core.maven2Nature</nature>
</natures>
</projectDescription>
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
eclipse.preferences.version=1
encoding/<project>=UTF-8
Loading