-
Notifications
You must be signed in to change notification settings - Fork 144
Add EmbeddingValue union type and Base64 support for embeddings #519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This implementation adds support for Base64-encoded embeddings as the default response format, while maintaining complete backward compatibility with existing `List<Float>` usage. ### 1. Default Behavior Change - **New Default**: Embedding requests now default to Base64 encoding format - **Backward Compatibility**: Existing code using `embedding()` method continues to work unchanged - **Performance**: Base64 encoding reduces network payload size significantly Introduces the EmbeddingValue class to support both float list and base64-encoded embedding data, enabling efficient handling and backward compatibility. Embedding, EmbeddingCreateParams, and related classes are updated to use EmbeddingValue, with automatic decoding and encoding between formats. Adds EmbeddingDefaults for global default encoding configuration, and comprehensive tests for new behaviors and compatibility.
When verifying the operation in the Azure OpenAI environment, I was able to confirm the operation with the following code: package com.openai.example;
import com.openai.azure.AzureOpenAIServiceVersion;
import com.openai.azure.credential.AzureApiKeyCredential;
import com.openai.client.OpenAIClient;
import com.openai.client.OpenAIClientAsync;
import com.openai.client.okhttp.OpenAIOkHttpClient;
import com.openai.client.okhttp.OpenAIOkHttpClientAsync;
import com.openai.models.embeddings.CreateEmbeddingResponse;
import com.openai.models.embeddings.EmbeddingCreateParams;
import com.openai.models.embeddings.EmbeddingModel;
import com.openai.models.embeddings.EmbeddingValue;
import com.openai.services.blocking.EmbeddingService;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
/**
* Sample code using Azure OpenAI Embedding API
* Demonstrates how to retrieve and process embedding data in Base64 and Float formats
*/
public final class EmbeddingsExampleAzure {
private EmbeddingsExampleAzure() {}
// Azure OpenAI endpoint and key configuration
// Please replace with actual values
private static final String AZURE_OPENAI_ENDPOINT = "https://***********.openai.azure.com";
private static final String AZURE_OPENAI_KEY = "***********";
private static OpenAIClient client;
public static void main(String[] args) {
// Initialize Azure OpenAI client
client = OpenAIOkHttpClient.builder()
.baseUrl(AZURE_OPENAI_ENDPOINT)
.credential(AzureApiKeyCredential.create(AZURE_OPENAI_KEY))
.azureServiceVersion(AzureOpenAIServiceVersion.getV2024_02_15_PREVIEW())
.build();
EmbeddingsExampleAzure example = new EmbeddingsExampleAzure();
example.basicSample();
example.multipleDataSample();
example.asyncSample();
}
/**
* Basic embedding retrieval sample
* Demonstrates usage of default format, Float format, and Base64 format
*/
public void basicSample() {
EmbeddingService embeddings = client.embeddings();
String singlePoem = "In the quiet night, stars whisper secrets, dreams take flight.";
System.out.println("=== Basic Embedding Sample ===");
// 1. Default format (Base64 is default)
System.out.println("\n1. Getting embeddings in default format:");
EmbeddingCreateParams embeddingCreateParams = EmbeddingCreateParams.builder()
.input(singlePoem)
.model(EmbeddingModel.TEXT_EMBEDDING_3_SMALL.asString())
.build();
embeddings.create(embeddingCreateParams).data().forEach(embedding -> {
System.out.println("Embedding (default format): " + embedding.toString());
// Use EmbeddingValue to check the original format
EmbeddingValue embeddingValue = embedding.embeddingValue();
if (embeddingValue.isBase64String()) {
System.out.println(
"Received in Base64 format: " + embeddingValue.base64String().substring(0, 50) + "...");
} else if (embeddingValue.isFloatList()) {
System.out.println("Received in Float format: " + embeddingValue.floatList().size() + " elements");
}
// embedding() method always returns List<Float> (Base64 is automatically decoded)
List<Float> floats = embedding.embedding();
System.out.println(
"Retrieved as Float array: " + floats.size() + " elements, first 5: " + floats.subList(0, Math.min(5, floats.size())));
});
System.out.println("\n------------------------------------------------");
// 2. Explicitly specify Float format
System.out.println("\n2. Explicitly specifying Float format:");
EmbeddingCreateParams embeddingCreateParams2 = EmbeddingCreateParams.builder()
.input(singlePoem)
.model(EmbeddingModel.TEXT_EMBEDDING_3_SMALL.asString())
.encodingFormat(EmbeddingCreateParams.EncodingFormat.FLOAT)
.build();
embeddings.create(embeddingCreateParams2).data().forEach(embedding -> {
EmbeddingValue embeddingValue = embedding.embeddingValue();
if (embeddingValue.isFloatList()) {
System.out.println("Received in Float format: " + embeddingValue.floatList().size() + " elements");
System.out.println("First 5 values: "
+ embeddingValue
.floatList()
.subList(
0,
Math.min(5, embeddingValue.floatList().size())));
}
// Can also convert to Base64 format
String base64 = embeddingValue.asBase64String();
System.out.println("Converted to Base64 format: " + base64.substring(0, 50) + "...");
});
System.out.println("\n------------------------------------------------");
// 3. Explicitly specify Base64 format
System.out.println("\n3. Explicitly specifying Base64 format:");
EmbeddingCreateParams embeddingCreateParams3 = EmbeddingCreateParams.builder()
.input(singlePoem)
.model(EmbeddingModel.TEXT_EMBEDDING_3_SMALL.asString())
.encodingFormat(EmbeddingCreateParams.EncodingFormat.BASE64)
.build();
embeddings.create(embeddingCreateParams3).data().forEach(embedding -> {
EmbeddingValue embeddingValue = embedding.embeddingValue();
if (embeddingValue.isBase64String()) {
System.out.println(
"Received in Base64 format: " + embeddingValue.base64String().substring(0, 50) + "...");
}
// Automatically convert to Float array
List<Float> floats = embeddingValue.asFloatList();
System.out.println(
"Auto-converted to Float array: " + floats.size() + " elements, first 5: " + floats.subList(0, Math.min(5, floats.size())));
});
System.out.println("\n================================================");
}
/**
* Multiple data embedding retrieval sample
*/
public void multipleDataSample() {
EmbeddingService embeddings = client.embeddings();
System.out.println("\n=== Multiple Data Embedding Sample ===");
getPoems().forEach(poem -> {
System.out.println("\nPoem (start): " + poem);
EmbeddingCreateParams embeddingCreateParams = EmbeddingCreateParams.builder()
.input(poem)
.model(EmbeddingModel.TEXT_EMBEDDING_3_SMALL.asString())
.build(); // Use default format (Base64)
embeddings.create(embeddingCreateParams).data().forEach(embedding -> {
List<Float> floats = embedding.embedding();
System.out.println("Embedding (default): " + floats.size() + " dimensions");
// Check original format
EmbeddingValue embeddingValue = embedding.embeddingValue();
if (embeddingValue.isBase64String()) {
System.out.println("Original format: Base64");
} else {
System.out.println("Original format: Float array");
}
});
System.out.println("Poem (end)");
});
System.out.println("\n================================================");
}
/**
* Asynchronous embedding retrieval sample
*/
public void asyncSample() {
System.out.println("\n=== Asynchronous Embedding Sample ===");
CountDownLatch latch = new CountDownLatch(1);
try {
OpenAIClientAsync asyncClient = OpenAIOkHttpClientAsync.builder()
.baseUrl(AZURE_OPENAI_ENDPOINT)
.credential(AzureApiKeyCredential.create(AZURE_OPENAI_KEY))
.azureServiceVersion(AzureOpenAIServiceVersion.getV2024_02_15_PREVIEW())
.build();
CompletableFuture<CreateEmbeddingResponse> completableFuture = asyncClient
.embeddings()
.create(EmbeddingCreateParams.builder()
.input("The quick brown fox jumped over the lazy dog")
.model(EmbeddingModel.TEXT_EMBEDDING_3_SMALL)
.encodingFormat(EmbeddingCreateParams.EncodingFormat.FLOAT)
.user("user-1234")
.build());
completableFuture
.thenAccept(response -> {
response.validate();
response.data().forEach(embedding -> {
System.out.println("Asynchronous embedding retrieval completed:");
System.out.println("Embedding info: " + embedding.toString());
EmbeddingValue embeddingValue = embedding.embeddingValue();
if (embeddingValue.isFloatList()) {
System.out.println(
"Float format: " + embeddingValue.floatList().size() + " dimensions");
}
// Visitor pattern usage example
String result = embeddingValue.accept(new EmbeddingValue.Visitor<String>() {
@Override
public String visitFloatList(List<Float> floatList) {
return "Processing Float array: " + floatList.size() + " elements";
}
@Override
public String visitBase64String(String base64String) {
return "Processing Base64 string: " + base64String.length() + " characters";
}
});
System.out.println("Visitor pattern result: " + result);
latch.countDown();
});
})
.exceptionally(ex -> {
System.err.println("Error: " + ex.getMessage());
latch.countDown();
return null;
});
latch.await();
System.out.println("Asynchronous processing completed");
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
System.err.println("Processing was interrupted");
}
System.out.println("\n================================================");
System.out.println("All samples execution completed");
}
/**
* Get sample poems list
*/
private List<String> getPoems() {
List<String> poems = new ArrayList<>();
poems.add("In the quiet night, stars whisper secrets, dreams take flight.");
poems.add("Beneath the moon's glow, shadows dance, hearts begin to know.");
poems.add("Waves crash on the shore, time stands still, love forevermore.");
poems.add("Autumn leaves fall, painting the ground, nature's final call.");
poems.add("Morning dew glistens, a new day dawns, hope always listens.");
poems.add("Mountains stand tall, silent guardians, witnessing it all.");
poems.add("In a field of green, flowers bloom bright, a serene scene.");
poems.add("Winter's chill bites, fireside warmth, cozy, long nights.");
poems.add("Spring's gentle breeze, life awakens, hearts find ease.");
poems.add("Sunset hues blend, day meets night, a perfect end.");
return poems;
}
} |
Deleted invalid EmbeddingDefaults.kt
@@ -22,6 +22,7 @@ import kotlin.jvm.optionals.getOrNull | |||
class Embedding | |||
private constructor( | |||
private val embedding: JsonField<List<Float>>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the new EmbeddingValue
class supports List<Float>
, having both private val embedding: JsonField<List<Float>>
and private val embeddingValue: JsonField<EmbeddingValue>?
is unnecessary, no?
We can change the underlying data model while keeping backwards compat by implementing the existing methods in terms of the new data. For example, fun embeding()
can be implemented as embeddingValue.getRequired("embedding").asFloatList()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To solve your messages, I'm thinking that there is three way.
Which solution do you prefer?
Option 1: Single Field Architecture
Overview: Use only the embeddingValue
field and remove the legacy embedding
field.
Changes:
- Remove
private val embedding: JsonField<List<Float>>
- Change
private val embeddingValue: JsonField<EmbeddingValue>
to a required field (remove nullable) - Implement all existing methods based on the
embeddingValue
field - Unify the Builder pattern to use
EmbeddingValue
Advantages:
- Simplifies the code and removes redundancy
- Unifies the data structure
- Improves type safety (removes nullable)
- Enhances performance (reduces unnecessary conversions)
Disadvantages:
- Requires significant changes to existing JSON deserialization
- Major changes to internal implementation
Option 2: Legacy Field Deprecation
Overview: Maintain the current two-field structure while gradually deprecating the legacy field.
Changes:
- Mark the
embedding
field as@Deprecated
- Use only
embeddingValue
in new constructors - Implement existing methods based on
embeddingValue
, keeping theembedding
field for backward compatibility - Recommend using
embeddingValue
in the Builder pattern
Advantages:
- Ensures complete backward compatibility with existing APIs
- Allows for gradual migration
- Minimizes impact on existing code
Disadvantages:
- Redundancy remains in the code
- May leave technical debt
- Maintenance becomes more complex
Option 3: Hybrid Approach with Smart Migration
Overview: Automatically select the optimal format at runtime and perform internal migration.
Changes:
- Check both fields during initialization and unify to
embeddingValue
- Automatically convert the
embedding
field toEmbeddingValue.ofFloatList()
if present - Keep the public API unchanged, optimizing only the internal implementation
- Use lazy initialization pattern
Advantages:
- Ensures complete compatibility with existing APIs
- Optimized internally
- Transparent migration
Disadvantages:
- Overhead during initialization
- Complex implementation
- Debugging may become difficult
try { | ||
this.embedding = JsonField.of(embedding.embedding().toMutableList()) | ||
} catch (e: Exception) { | ||
// Fallback to field-level copying if embedding() method fails | ||
this.embedding = embedding.embedding.map { it.toMutableList() } | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what this is for. Is this because we have the nullable and non-nullable fields? If so, then I guess this will go away once we implement my other suggestion?
// Apply default encoding format if not explicitly set | ||
if (body._encodingFormat().isMissing()) { | ||
body.encodingFormat(EmbeddingDefaults.defaultEncodingFormat) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't apply the default like this. We can just update the builder field to start out as:
private var encodingFormat: JsonField<EncodingFormat> = JsonField.of(EmbeddingDefaults.defaultEncodingFormat)
Then we also don't need that new internal method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for your messages, I will modify it.
@JsonSerialize(using = EmbeddingValue.Serializer::class) | ||
class EmbeddingValue | ||
private constructor( | ||
private val floatList: List<Float>? = null, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: here and throughout, we can call these floats
instead of floatList
private val floatList: List<Float>? = null, | |
private val floats: List<Float>? = null, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much, I will modify the name of the fields.
openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingValue.kt
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems pretty overkill to me. I think we can just swap the default to base64 and people can set the encoding explicitly on the params object if they want floats over the wire?
In that case we can just delete this class and inline EncodingFormat.BASE64
in the params builder default
object : Visitor<Unit> { | ||
override fun visitFloatList(floatList: List<Float>) { | ||
// Validate that float list is not empty and contains valid values | ||
if (floatList.isEmpty()) { | ||
throw OpenAIInvalidDataException("Float list cannot be empty") | ||
} | ||
floatList.forEach { value -> | ||
if (!value.isFinite()) { | ||
throw OpenAIInvalidDataException("Float values must be finite") | ||
} | ||
} | ||
} | ||
|
||
override fun visitBase64String(base64String: String) { | ||
// Validate base64 format | ||
try { | ||
Base64.getDecoder().decode(base64String) | ||
} catch (e: IllegalArgumentException) { | ||
throw OpenAIInvalidDataException("Invalid base64 string", e) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general the validate()
methods in the SDK don't validate anything other than the "shape" of the data being correct (e.g. required fields are set)
In this case, I think we just want to check that the union is not a _json
. So like:
accept(
object : Visitor<Unit> {
override fun visitFloatList(floatList: List<Float>) {}
override fun visitBase64String(base64String: String) {}
}
)
Is sufficient. This will throw if it's a _json
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I will modify it.
} | ||
} | ||
) | ||
return this // Return this instance if validation succeeds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just use apply
like we do in other classes
This implementation adds support for Base64-encoded embeddings as the default response format, while maintaining complete backward compatibility with existing
List<Float>
usage.1. Default Behavior Change
embedding()
method continues to work unchangedIntroduces the EmbeddingValue class to support both float list and base64-encoded embedding data, enabling efficient handling and backward compatibility. Embedding, EmbeddingCreateParams, and related classes are updated to use EmbeddingValue, with automatic decoding and encoding between formats. Adds EmbeddingDefaults for global default encoding configuration, and comprehensive tests for new behaviors and compatibility.
Improvement from previous PR
This PR is the fix of the issue #211
This issue is related to the PR of #303