Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions sqrl-openai/src/main/java/com/datasqrl/openai/FunctionExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,45 @@
import org.apache.flink.table.functions.FunctionContext;

import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import static com.datasqrl.openai.RetryUtil.executeWithRetry;

public class FunctionExecutor {

private static final String POOL_SIZE = "ASYNC_FUNCTION_THREAD_POOL_SIZE";

private final FunctionMetricTracker metricTracker;
private final ExecutorService executorService;

public FunctionExecutor(FunctionContext context, String functionName) {
this.metricTracker = new FunctionMetricTracker(context, functionName);
this.executorService = Executors.newFixedThreadPool(getPoolSize());
}

public <T> T execute(Callable<T> task) {
metricTracker.increaseCallCount();

final long start = System.nanoTime();

final T ret = executeWithRetry(
task,
metricTracker::increaseErrorCount,
metricTracker::increaseRetryCount
);

final long elapsedTime = System.nanoTime() - start;
metricTracker.recordLatency(TimeUnit.NANOSECONDS.toMillis(elapsedTime));
public <T> CompletableFuture<T> executeAsync(Callable<T> task) {
final CompletableFuture<T> future = new CompletableFuture<>();
executorService.submit(() -> {
try {
metricTracker.increaseCallCount();
final long start = System.nanoTime();

final T result = task.call();

final long elapsedTime = System.nanoTime() - start;
metricTracker.recordLatency(TimeUnit.NANOSECONDS.toMillis(elapsedTime));

future.complete(result);
} catch (Exception e) {
metricTracker.increaseErrorCount();
future.completeExceptionally(e);
}
});
return future;
}

return ret;
private int getPoolSize() {
return Integer.parseInt(System.getenv().getOrDefault(POOL_SIZE, "10"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,23 @@ public OpenAICompletions(HttpClient httpClient) {
}

public String callCompletions(String prompt, String modelName, Boolean requireJsonOutput, Integer maxOutputTokens, Double temperature, Double topP) throws IOException, InterruptedException {
if (prompt == null || modelName == null) {
return null;
}

// Create the request body JSON
ObjectNode requestBody = createRequestBody(prompt, modelName, requireJsonOutput, maxOutputTokens, temperature, topP);
final ObjectNode requestBody = createRequestBody(prompt, modelName, requireJsonOutput, maxOutputTokens, temperature, topP);

// Build the HTTP request
HttpRequest request = HttpRequest.newBuilder()
final HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create(COMPLETIONS_API))
.header("Authorization", "Bearer " + System.getenv(API_KEY))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(requestBody.toString(), StandardCharsets.UTF_8))
.build();

// Send the request and get the response
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
final HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());

// Handle the response
if (response.statusCode() == 200) {
Expand All @@ -56,11 +60,11 @@ public String callCompletions(String prompt, String modelName, Boolean requireJs
}

private ObjectNode createRequestBody(String prompt, String modelName, Boolean requireJsonOutput, Integer maxOutputTokens, Double temperature, Double topP) {
ObjectNode requestBody = objectMapper.createObjectNode();
final ObjectNode requestBody = objectMapper.createObjectNode();
requestBody.put("model", modelName);

// Create the messages array as required by the chat completions endpoint
ArrayNode messagesArray = objectMapper.createArrayNode();
final ArrayNode messagesArray = objectMapper.createArrayNode();

if (requireJsonOutput) {
// when the model supports JSON output, both setting is needed otherwise the API call will fail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ public OpenAIEmbeddings(HttpClient httpClient) {
}

public double[] vectorEmbedd(String text, String modelName) throws IOException, InterruptedException {
if (text == null || modelName == null) {
return null;
}

return vectorEmbedd(text, modelName, TOKEN_LIMIT);
}

Expand Down
41 changes: 0 additions & 41 deletions sqrl-openai/src/main/java/com/datasqrl/openai/RetryUtil.java

This file was deleted.

27 changes: 14 additions & 13 deletions sqrl-openai/src/main/java/com/datasqrl/openai/completions.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import com.datasqrl.openai.util.FunctionMetricTracker;
import com.google.auto.service.AutoService;
import org.apache.flink.table.functions.AsyncScalarFunction;
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.functions.ScalarFunction;

import java.util.concurrent.CompletableFuture;

@AutoService(ScalarFunction.class)
public class completions extends ScalarFunction {
public class completions extends AsyncScalarFunction {

private transient OpenAICompletions openAICompletions;
private transient FunctionExecutor executor;
Expand All @@ -25,23 +28,21 @@ protected FunctionMetricTracker createMetricTracker(FunctionContext context, Str
return new FunctionMetricTracker(context, functionName);
}

public String eval(String prompt, String modelName) {
return eval(prompt, modelName, null, null, null);
public void eval(CompletableFuture<String> result, String prompt, String modelName) {
eval(result, prompt, modelName, null, null, null);
}

public String eval(String prompt, String modelName, Integer maxOutputTokens) {
return eval(prompt, modelName, maxOutputTokens, null, null);
public void eval(CompletableFuture<String> result, String prompt, String modelName, Integer maxOutputTokens) {
eval(result, prompt, modelName, maxOutputTokens, null, null);
}

public String eval(String prompt, String modelName, Integer maxOutputTokens, Double temperature) {
return eval(prompt, modelName, maxOutputTokens, temperature, null);
public void eval(CompletableFuture<String> result, String prompt, String modelName, Integer maxOutputTokens, Double temperature) {
eval(result, prompt, modelName, maxOutputTokens, temperature, null);
}

public String eval(String prompt, String modelName, Integer maxOutputTokens, Double temperature, Double topP) {
if (prompt == null || modelName == null) return null;

return executor.execute(
() -> openAICompletions.callCompletions(prompt, modelName, false, maxOutputTokens, temperature, topP)
);
public void eval(CompletableFuture<String> result, String prompt, String modelName, Integer maxOutputTokens, Double temperature, Double topP) {
executor.executeAsync(() -> openAICompletions.callCompletions(prompt, modelName, false, maxOutputTokens, temperature, topP))
.thenAccept(result::complete)
.exceptionally(ex -> { result.completeExceptionally(ex); return null; });
}
}
23 changes: 12 additions & 11 deletions sqrl-openai/src/main/java/com/datasqrl/openai/extract_json.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package com.datasqrl.openai;

import com.google.auto.service.AutoService;
import org.apache.flink.table.functions.AsyncScalarFunction;
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.functions.ScalarFunction;

import java.util.concurrent.CompletableFuture;

@AutoService(ScalarFunction.class)
public class extract_json extends ScalarFunction {
public class extract_json extends AsyncScalarFunction {

private transient OpenAICompletions openAICompletions;
private transient FunctionExecutor executor;
Expand All @@ -20,19 +23,17 @@ protected OpenAICompletions createOpenAICompletions() {
return new OpenAICompletions();
}

public String eval(String prompt, String modelName) {
return eval(prompt, modelName, null, null);
public void eval(CompletableFuture<String> result, String prompt, String modelName) {
eval(result, prompt, modelName, null, null);
}

public String eval(String prompt, String modelName, Double temperature) {
return eval(prompt, modelName, temperature, null);
public void eval(CompletableFuture<String> result, String prompt, String modelName, Double temperature) {
eval(result, prompt, modelName, temperature, null);
}

public String eval(String prompt, String modelName, Double temperature, Double topP) {
if (prompt == null || modelName == null) return null;

return executor.execute(
() -> openAICompletions.callCompletions(prompt, modelName, true, null, temperature, topP)
);
public void eval(CompletableFuture<String> result, String prompt, String modelName, Double temperature, Double topP) {
executor.executeAsync(() -> openAICompletions.callCompletions(prompt, modelName, true, null, temperature, topP))
.thenAccept(result::complete)
.exceptionally(ex -> { result.completeExceptionally(ex); return null; });
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,19 @@ public class FunctionMetricTracker {
public static final String P99_METRIC = "com.datasqrl.openai.%s.p99";
public static final String CALL_COUNT = "com.datasqrl.openai.%s.callCount";
public static final String ERROR_COUNT = "com.datasqrl.openai.%s.errorCount";
public static final String RETRY_COUNT = "com.datasqrl.openai.%s.retryCount";

private final P99LatencyTracker latencyTracker = new P99LatencyTracker();
private final Counter callCount;
private final Counter errorCount;
private final Counter retryCount;

public FunctionMetricTracker(FunctionContext context, String functionName) {
final String p99MetricName = String.format(P99_METRIC, functionName);
final String callCountName = String.format(CALL_COUNT, functionName);
final String errorCountName = String.format(ERROR_COUNT, functionName);
final String retryCountName = String.format(RETRY_COUNT, functionName);

context.getMetricGroup().gauge(p99MetricName, (Gauge<Long>) latencyTracker::getP99Latency);
callCount = context.getMetricGroup().counter(callCountName);
errorCount = context.getMetricGroup().counter(errorCountName);
retryCount = context.getMetricGroup().counter(retryCountName);
}

public void increaseCallCount() {
Expand All @@ -36,10 +32,6 @@ public void increaseErrorCount() {
errorCount.inc();
}

public void increaseRetryCount() {
retryCount.inc();
}

public void recordLatency(long latencyMs) {
latencyTracker.recordLatency(latencyMs);
}
Expand Down
15 changes: 8 additions & 7 deletions sqrl-openai/src/main/java/com/datasqrl/openai/vector_embedd.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import com.datasqrl.openai.util.FunctionMetricTracker;
import com.google.auto.service.AutoService;
import org.apache.flink.table.functions.AsyncScalarFunction;
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.functions.ScalarFunction;

import java.util.concurrent.CompletableFuture;

@AutoService(ScalarFunction.class)
public class vector_embedd extends ScalarFunction {
public class vector_embedd extends AsyncScalarFunction {

private transient OpenAIEmbeddings openAIEmbeddings;
private transient FunctionExecutor executor;
Expand All @@ -25,11 +28,9 @@ protected FunctionMetricTracker createMetricTracker(FunctionContext context, Str
return new FunctionMetricTracker(context, functionName);
}

public double[] eval(String text, String modelName) {
if (text == null || modelName == null) return null;

return executor.execute(
() -> openAIEmbeddings.vectorEmbedd(text, modelName)
);
public void eval(CompletableFuture<double[]> result, String text, String modelName) {
executor.executeAsync(() -> openAIEmbeddings.vectorEmbedd(text, modelName))
.thenAccept(result::complete)
.exceptionally(ex -> { result.completeExceptionally(ex); return null; });
}
}
Loading
Loading