diff --git a/.github/workflows/build-pull-request.yml b/.github/workflows/build-pull-request.yml index 372f118f3..4ee6963fb 100644 --- a/.github/workflows/build-pull-request.yml +++ b/.github/workflows/build-pull-request.yml @@ -66,7 +66,7 @@ jobs: | jq -R -s -c 'split("\n")[:-1]') # Integration tests (without the in-process embedding models) - # Remove JLama and Llama3 from the list + # Remove JLama, Llama3 and GPU Llama3 from the list cd integration-tests IT_MODULES=$( \ find . -mindepth 2 -maxdepth 2 -type f -name 'pom.xml' -exec dirname {} \; \ @@ -74,6 +74,7 @@ jobs: | sort -u \ | grep -v jlama \ | grep -v llama3-java \ + | grep -v gpu-llama3 \ | grep -v in-process-embedding-models \ | jq -R -s -c 'split("\n")[:-1]') @@ -143,6 +144,13 @@ jobs: run: | ./mvnw -B clean install -DskipTests -Dno-format -ntp -f model-providers/jlama/pom.xml + # Build Jlama if JDK >= 21 + # It's not build by default as it requires Java 21+ + - name: Build GPU Llama3 extension + if: ${{ matrix.java >= 21 }} + run: | + ./mvnw -B clean install -DskipTests -Dno-format -ntp -f model-providers/gpu-llama3/pom.xml + # Build Llama3.java if JDK >= 22. See https://x.com/tjake/status/1849141171475399083?t=EpgVJCPLC17fCXio0FvnhA&s=19 for the reason - name: Build Llama3-java extension if: ${{ matrix.java >= 22 }} diff --git a/docs/modules/ROOT/pages/includes/attributes.adoc b/docs/modules/ROOT/pages/includes/attributes.adoc index caf7b472a..092ee1183 100644 --- a/docs/modules/ROOT/pages/includes/attributes.adoc +++ b/docs/modules/ROOT/pages/includes/attributes.adoc @@ -1,4 +1,4 @@ :project-version: 1.3.1 -:langchain4j-version: 1.6.0 -:langchain4j-embeddings-version: 1.6.0-beta12 +:langchain4j-version: 1.8.0 +:langchain4j-embeddings-version: 1.8.0-beta15 :examples-dir: ./../examples/ diff --git a/integration-tests/gpu-llama3/README.md b/integration-tests/gpu-llama3/README.md new file mode 100644 index 000000000..2140aee9e --- /dev/null +++ b/integration-tests/gpu-llama3/README.md @@ -0,0 +1,51 @@ +### How to run the integrated tests: + +#### 1) Install TornadoVM: + +```bash +cd ~ +git clone git@github.com:beehive-lab/TornadoVM.git +cd ~/TornadoVM +./bin/tornadovm-installer --jdk jdk21 --backend opencl +source setvars.sh +``` + +Note that the above steps: +- Set `TORNADOVM_SDK` environment variable to the path of the TornadoVM SDK. +- Create the `tornado-argfile` under `~/TornadoVM` which contains all the required JVM arguments to enable TornadoVM. +- The argfile is automatically used in Quarkus dev mode; however, in production mode, you need to manually pass the argfile to the JVM (see step 3). + +#### 2) Build Quarkus-langchain4j: + +```bash +cd ~ +git clone git@github.com:mikepapadim/quarkus-langchain4j.git +cd ~/quarkus-langchain4j +git checkout gpu-llama3-integration +mvn clean install -DskipTests +``` + +#### 3) Run the integrated tests: + +##### 3.1 Deploy the Quarkus app: + +```bash +cd ~/quarkus-langchain4j/integration-tests/gpullama3 +``` +- For *dev* mode, run: +``` +mvn quarkus:dev +``` + +- For *production* mode, run: +```bash +java @~/TornadoVM/tornado-argfile -jar target/quarkus-app/quarkus-run.jar +``` +##### 3.2 Send requests to the Quarkus app: + +when quarkus is running, open a new terminal and run: + +```bash +curl http://localhost:8080/chat/blocking +``` + diff --git a/integration-tests/gpu-llama3/pom.xml b/integration-tests/gpu-llama3/pom.xml new file mode 100644 index 000000000..41f9fce8e --- /dev/null +++ b/integration-tests/gpu-llama3/pom.xml @@ -0,0 +1,104 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-integration-tests-parent + 999-SNAPSHOT + + quarkus-langchain4j-integration-test-gpu-llama3 + Quarkus LangChain4j - Integration Tests - GPULlama3 + + true + 21 + 3.18.0 + + ${env.TORNADO_SDK}/../../../tornado-argfile + + + + io.quarkus + quarkus-rest-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-gpu-llama3 + 999-SNAPSHOT + + + io.quarkus + quarkus-junit5 + test + + + io.rest-assured + rest-assured + test + + + org.assertj + assertj-core + test + + + io.quarkus + quarkus-devtools-testing + test + + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-gpu-llama3-deployment + 999-SNAPSHOT + pom + test + + + * + * + + + + + + + + io.quarkus + quarkus-maven-plugin + + + + build + + + + + + @${tornado.argfile} + + + + + maven-failsafe-plugin + + + + integration-test + verify + + + @${tornado.argfile} + + ${project.build.directory}/${project.build.finalName}-runner + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + diff --git a/integration-tests/gpu-llama3/src/main/java/org/acme/example/gpullama3/chat/ChatLanguageModelResource.java b/integration-tests/gpu-llama3/src/main/java/org/acme/example/gpullama3/chat/ChatLanguageModelResource.java new file mode 100644 index 000000000..e53da3dce --- /dev/null +++ b/integration-tests/gpu-llama3/src/main/java/org/acme/example/gpullama3/chat/ChatLanguageModelResource.java @@ -0,0 +1,22 @@ +package org.acme.example.gpullama3.chat; + +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +import dev.langchain4j.model.chat.ChatModel; + +@Path("chat") +public class ChatLanguageModelResource { + + private final ChatModel chatModel; + + public ChatLanguageModelResource(ChatModel chatModel) { + this.chatModel = chatModel; + } + + @GET + @Path("blocking") + public String blocking() { + return chatModel.chat("When was the nobel prize for economics first awarded?"); + } +} diff --git a/integration-tests/gpu-llama3/src/main/resources/application.properties b/integration-tests/gpu-llama3/src/main/resources/application.properties new file mode 100644 index 000000000..88445d9f2 --- /dev/null +++ b/integration-tests/gpu-llama3/src/main/resources/application.properties @@ -0,0 +1,12 @@ +quarkus.langchain4j.gpu-llama3.include-models-in-artifact=false + +# Configure GPULlama3 +quarkus.langchain4j.gpu-llama3.enable-integration=true +quarkus.langchain4j.gpu-llama3.chat-model.model-name=beehive-lab/Llama-3.2-1B-Instruct-GGUF +quarkus.langchain4j.gpu-llama3.chat-model.quantization=FP16 +quarkus.langchain4j.gpu-llama3.chat-model.temperature=0.7 +quarkus.langchain4j.gpu-llama3.chat-model.max-tokens=513 + +# other supported models: +#model-name=ggml-org/Qwen3-0.6B-GGUF +#quantization=f16 \ No newline at end of file diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index f97887299..1e551a94e 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -56,7 +56,18 @@ llama3-java - + + TornadoVM + + + tornado + + + + gpu-llama3 + + + default-project-deps @@ -107,6 +118,11 @@ quarkus-langchain4j-easy-rag 999-SNAPSHOT + + io.quarkiverse.langchain4j + quarkus-langchain4j-gpullama3 + ${quarkus-langchain4j.version} + io.quarkiverse.langchain4j quarkus-langchain4j-hugging-face @@ -122,6 +138,11 @@ quarkus-langchain4j-llama3-java 999-SNAPSHOT + + io.quarkiverse.langchain4j + quarkus-langchain4j-gpu-llama3 + 999-SNAPSHOT + io.quarkiverse.langchain4j quarkus-langchain4j-mcp diff --git a/model-providers/gpu-llama3/deployment/pom.xml b/model-providers/gpu-llama3/deployment/pom.xml new file mode 100644 index 000000000..84d81d6cc --- /dev/null +++ b/model-providers/gpu-llama3/deployment/pom.xml @@ -0,0 +1,62 @@ + + + 4.0.0 + + + io.quarkiverse.langchain4j + quarkus-langchain4j-gpu-llama3-parent + 999-SNAPSHOT + + + quarkus-langchain4j-gpu-llama3-deployment + Quarkus LangChain4j - GPULlama3 - Deployment + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-gpu-llama3 + ${project.version} + + + + + io.quarkus + quarkus-arc-deployment + provided + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core-deployment + ${project.version} + + + + io.quarkus + quarkus-junit5-internal + test + + + org.assertj + assertj-core + test + + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + + + + + + + diff --git a/model-providers/gpu-llama3/deployment/src/main/java/io/quarkiverse/langchain4j/gpullama3/deployment/ChatModelBuildConfig.java b/model-providers/gpu-llama3/deployment/src/main/java/io/quarkiverse/langchain4j/gpullama3/deployment/ChatModelBuildConfig.java new file mode 100644 index 000000000..1abc98143 --- /dev/null +++ b/model-providers/gpu-llama3/deployment/src/main/java/io/quarkiverse/langchain4j/gpullama3/deployment/ChatModelBuildConfig.java @@ -0,0 +1,16 @@ +package io.quarkiverse.langchain4j.gpullama3.deployment; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigGroup; + +@ConfigGroup +public interface ChatModelBuildConfig { + + /** + * Whether the model should be enabled + */ + @ConfigDocDefault("true") + Optional enabled(); +} diff --git a/model-providers/gpu-llama3/deployment/src/main/java/io/quarkiverse/langchain4j/gpullama3/deployment/GPULlama3Processor.java b/model-providers/gpu-llama3/deployment/src/main/java/io/quarkiverse/langchain4j/gpullama3/deployment/GPULlama3Processor.java new file mode 100644 index 000000000..5fbc1924f --- /dev/null +++ b/model-providers/gpu-llama3/deployment/src/main/java/io/quarkiverse/langchain4j/gpullama3/deployment/GPULlama3Processor.java @@ -0,0 +1,214 @@ +package io.quarkiverse.langchain4j.gpullama3.deployment; + +import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import jakarta.enterprise.context.ApplicationScoped; + +import org.jboss.logging.Logger; + +import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; +import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; +import io.quarkiverse.langchain4j.gpullama3.GPULlama3ModelRegistry; +import io.quarkiverse.langchain4j.gpullama3.runtime.GPULlama3Recorder; +import io.quarkiverse.langchain4j.gpullama3.runtime.NameAndQuantization; +import io.quarkiverse.langchain4j.gpullama3.runtime.config.ChatModelFixedRuntimeConfig; +import io.quarkiverse.langchain4j.gpullama3.runtime.config.LangChain4jGPULlama3FixedRuntimeConfig; +import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; +import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.builder.item.MultiBuildItem; +import io.quarkus.deployment.annotations.*; +import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.FeatureBuildItem; +import io.quarkus.deployment.builditem.LaunchModeBuildItem; +import io.quarkus.deployment.builditem.ServiceStartBuildItem; +import io.quarkus.deployment.console.ConsoleInstalledBuildItem; +import io.quarkus.deployment.console.StartupLogCompressor; +import io.quarkus.deployment.logging.LoggingSetupBuildItem; + +public class GPULlama3Processor { + + private final static Logger LOG = Logger.getLogger(GPULlama3Processor.class); + + private static final String PROVIDER = "gpu-llama3"; + private static final String FEATURE = "langchain4j-gpu-llama3"; + + @BuildStep + FeatureBuildItem feature() { + return new FeatureBuildItem(FEATURE); + } + + @BuildStep + public void providerCandidates(BuildProducer chatProducer, + LangChain4jGPULlama3BuildTimeConfig config) { + if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) { + chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER)); + } + } + + @BuildStep + @Record(ExecutionTime.RUNTIME_INIT) + void generateBeans(GPULlama3Recorder recorder, + List selectedChatModels, + BuildProducer beanProducer) { + + for (var selected : selectedChatModels) { + if (PROVIDER.equals(selected.getProvider())) { + String configName = selected.getConfigName(); + + var builder = SyntheticBeanBuildItem + .configure(CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.chatModel(configName)); + + beanProducer.produce(builder.done()); + } + } + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + @Produce(ServiceStartBuildItem.class) + @BuildStep + void downloadModels(List selectedChatModels, + LoggingSetupBuildItem loggingSetupBuildItem, + Optional consoleInstalledBuildItem, + LaunchModeBuildItem launchMode, + LangChain4jGPULlama3BuildTimeConfig buildTimeConfig, + LangChain4jGPULlama3FixedRuntimeConfig fixedRuntimeConfig, + BuildProducer modelDownloadedProducer) { + if (!buildTimeConfig.includeModelsInArtifact()) { + return; + } + GPULlama3ModelRegistry registry = GPULlama3ModelRegistry.getOrCreate(fixedRuntimeConfig.modelsPath()); + + BigDecimal ONE_HUNDRED = new BigDecimal("100"); + + if (buildTimeConfig.chatModel().enabled().orElse(true)) { + List modelsNeeded = new ArrayList<>(); + for (var selected : selectedChatModels) { + if (PROVIDER.equals(selected.getProvider())) { + String configName = selected.getConfigName(); + + ChatModelFixedRuntimeConfig matchingConfig = NamedConfigUtil.isDefault(configName) + ? fixedRuntimeConfig.defaultConfig().chatModel() + : fixedRuntimeConfig.namedConfig().get(configName).chatModel(); + modelsNeeded.add(new NameAndQuantization(matchingConfig.modelName(), matchingConfig.quantization())); + } + } + + if (!modelsNeeded.isEmpty()) { + StartupLogCompressor compressor = new StartupLogCompressor( + (launchMode.isTest() ? "(test) " : "") + "GPULlama3.java model pull:", + consoleInstalledBuildItem, + loggingSetupBuildItem); + + for (var model : modelsNeeded) { + GPULlama3ModelRegistry.ModelInfo modelInfo = GPULlama3ModelRegistry.ModelInfo.from(model.name()); + Path pathOfModelDirOnDisk = registry.constructModelDirectoryPath(modelInfo); + // Check if the model is already downloaded + // this is done automatically by download model, but we want to provide a good progress experience, so we do it again here + if (Files.exists(pathOfModelDirOnDisk.resolve(GPULlama3ModelRegistry.FINISHED_MARKER))) { + LOG.debug("Model " + model.name() + "already exists in " + pathOfModelDirOnDisk); + } else { + // we pull one model at a time and provide progress updates to the user via logging + LOG.info("Pulling model " + model.name()); + + AtomicReference LAST_UPDATE_REF = new AtomicReference<>(); + + try { + registry.downloadModel(model.name(), model.quantization(), Optional.empty(), + Optional.of(new GPULlama3ModelRegistry.ProgressReporter() { + @Override + public void update(String filename, long sizeDownloaded, long totalSize) { + // Jlama downloads a bunch of files for each mode of which only the + // weights file is large + // and makes sense to report progress on + if (totalSize < 100_000) { + return; + } + + if (!logUpdate(LAST_UPDATE_REF.get())) { + return; + } + + LAST_UPDATE_REF.set(System.nanoTime()); + + BigDecimal percentage = new BigDecimal(sizeDownloaded) + .divide(new BigDecimal(totalSize), + 4, + RoundingMode.HALF_DOWN) + .multiply(ONE_HUNDRED); + BigDecimal progress = percentage.setScale(2, RoundingMode.HALF_DOWN); + if (progress.compareTo(ONE_HUNDRED) >= 0) { + // avoid showing 100% for too long + LOG.infof("Verifying and cleaning up\n", progress); + } else { + LOG.infof("%s - Progress: %s%%\n", model.name(), progress); + } + } + + /** + * @param lastUpdate The last update time in nanoseconds + * Determines whether we should log an update. + * This is done in order to not overwhelm the console with updates which might + * make + * canceling the download difficult. See + * this + */ + private boolean logUpdate(Long lastUpdate) { + if (lastUpdate == null) { + return true; + } else { + return TimeUnit.NANOSECONDS.toMillis(System.nanoTime()) + - TimeUnit.NANOSECONDS.toMillis(lastUpdate) > 1_000; + } + } + })); + } catch (IOException e) { + compressor.closeAndDumpCaptured(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + modelDownloadedProducer.produce(new ModelDownloadedBuildItem(model, pathOfModelDirOnDisk)); + } + + compressor.close(); + } + } + + } + + public static final class ModelDownloadedBuildItem extends MultiBuildItem { + + private final NameAndQuantization model; + private final Path directory; + + public ModelDownloadedBuildItem(NameAndQuantization model, Path directory) { + this.model = model; + this.directory = directory; + } + + public NameAndQuantization getModel() { + return model; + } + + public Path getDirectory() { + return directory; + } + } +} diff --git a/model-providers/gpu-llama3/deployment/src/main/java/io/quarkiverse/langchain4j/gpullama3/deployment/LangChain4jGPULlama3BuildTimeConfig.java b/model-providers/gpu-llama3/deployment/src/main/java/io/quarkiverse/langchain4j/gpullama3/deployment/LangChain4jGPULlama3BuildTimeConfig.java new file mode 100644 index 000000000..52ff04989 --- /dev/null +++ b/model-providers/gpu-llama3/deployment/src/main/java/io/quarkiverse/langchain4j/gpullama3/deployment/LangChain4jGPULlama3BuildTimeConfig.java @@ -0,0 +1,24 @@ +package io.quarkiverse.langchain4j.gpullama3.deployment; + +import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME; + +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; + +@ConfigRoot(phase = BUILD_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.gpu-llama3") +public interface LangChain4jGPULlama3BuildTimeConfig { + + /** + * Determines whether the necessary GPULlama3 models are downloaded and included in the jar at build time. + * Currently, this option is only valid for {@code fast-jar} deployments. + */ + @WithDefault("true") + boolean includeModelsInArtifact(); + + /** + * Chat model related settings + */ + ChatModelBuildConfig chatModel(); +} diff --git a/model-providers/gpu-llama3/pom.xml b/model-providers/gpu-llama3/pom.xml new file mode 100644 index 000000000..57b0f51e4 --- /dev/null +++ b/model-providers/gpu-llama3/pom.xml @@ -0,0 +1,26 @@ + + + 4.0.0 + + + io.quarkiverse.langchain4j + quarkus-langchain4j-parent + 999-SNAPSHOT + ../../pom.xml + + + quarkus-langchain4j-gpu-llama3-parent + Quarkus LangChain4j - GPULlama3.java - Parent + pom + + + 21 + + + + runtime + deployment + + diff --git a/model-providers/gpu-llama3/runtime/pom.xml b/model-providers/gpu-llama3/runtime/pom.xml new file mode 100644 index 000000000..249563efc --- /dev/null +++ b/model-providers/gpu-llama3/runtime/pom.xml @@ -0,0 +1,105 @@ + + + 4.0.0 + + + io.quarkiverse.langchain4j + quarkus-langchain4j-gpu-llama3-parent + 999-SNAPSHOT + + + quarkus-langchain4j-gpu-llama3 + Quarkus LangChain4j - GPULlama3 - Runtime + + + + io.quarkus + quarkus-arc + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core + ${project.version} + + + + io.quarkus + quarkus-junit5-internal + test + + + org.mockito + mockito-core + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + + io.github.beehive-lab + gpu-llama3 + ${gpu-llama3.version} + + + + + + + io.quarkus + quarkus-extension-maven-plugin + ${quarkus.version} + + + compile + + extension-descriptor + + + ${project.groupId}:${project.artifactId}-deployment:${project.version} + + + + + + org.graalvm:graal-sdk + io.github.beehive-lab:gpu-llama3 + + + org.graalvm:graal-sdk + io.github.beehive-lab:gpu-llama3 + + + + + maven-compiler-plugin + + + 21 + 21 + + + --add-modules=jdk.incubator.vector + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + + + diff --git a/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/Consts.java b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/Consts.java new file mode 100644 index 000000000..678492408 --- /dev/null +++ b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/Consts.java @@ -0,0 +1,17 @@ +package io.quarkiverse.langchain4j.gpullama3; + +public final class Consts { + + private Consts() { + } + + /** + * working links: + * https://huggingface.co/beehive-lab/Llama-3.2-1B-Instruct-GGUF/blob/main/Llama-3.2-1B-Instruct-FP16.gguf + * https://huggingface.co/ggml-org/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-f16.gguf + */ + + public static final String DEFAULT_CHAT_MODEL_NAME = "beehive-lab/Llama-3.2-1B-Instruct-GGUF"; + public static final String DEFAULT_CHAT_MODEL_QUANTIZATION = "FP16"; + +} \ No newline at end of file diff --git a/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3BaseModel.java b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3BaseModel.java new file mode 100644 index 000000000..e1d2d719a --- /dev/null +++ b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3BaseModel.java @@ -0,0 +1,150 @@ +package io.quarkiverse.langchain4j.gpullama3; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +import org.beehive.gpullama3.auxiliary.LastRunMetrics; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.loader.ModelLoader; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.request.ChatRequest; + +abstract class GPULlama3BaseModel { + State state; + List promptTokens; + ChatFormat chatFormat; + TornadoVMMasterPlan tornadoVMPlan; + private Integer maxTokens; + private Boolean onGPU; + private Model model; + private Sampler sampler; + + // @formatter:off + public void init( + Path modelPath, + Double temperature, + Double topP, + Integer seed, + Integer maxTokens, + Boolean onGPU) { + this.maxTokens = maxTokens; + this.onGPU = onGPU; + + try { + this.model = ModelLoader.loadModel(modelPath, maxTokens, true, onGPU); + this.state = model.createNewState(); + this.sampler = Sampler.selectSampler( + model.configuration().vocabularySize(), temperature.floatValue(), topP.floatValue(), seed); + this.chatFormat = model.chatFormat(); + if (onGPU) { + tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); + // cleanup ? + } else { + tornadoVMPlan = null; + } + } catch (IOException e) { + throw new RuntimeException("Failed to load model from " + modelPath, e); + } + } + + public Model getModel() { + return model; + } + + public Sampler getSampler() { + return sampler; + } + + public String modelResponse(ChatRequest request, IntConsumer tokenConsumer) { + this.promptTokens = new ArrayList<>(); + + if (model.shouldAddBeginOfText()) { + promptTokens.add(chatFormat.getBeginOfText()); + } + + processPromptMessages(request.messages()); + + Set stopTokens = chatFormat.getStopTokens(); + List responseTokens; + + if (onGPU) { + responseTokens = model.generateTokensGPU( + state, + 0, + promptTokens.subList(0, promptTokens.size()), + stopTokens, + maxTokens, + sampler, + false, + tokenConsumer, + tornadoVMPlan); + } else { + responseTokens = model.generateTokens( + state, + 0, + promptTokens.subList(0, promptTokens.size()), + stopTokens, + maxTokens, + sampler, + false, + tokenConsumer); + } + + Integer stopToken = null; + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { + stopToken = responseTokens.getLast(); + responseTokens.removeLast(); + } + + String responseText = model.tokenizer().decode(responseTokens); + + // Add the response content tokens to conversation history + promptTokens.addAll(responseTokens); + + // Add the stop token to complete the message + if (stopToken != null) { + promptTokens.add(stopToken); + } + + if (stopToken == null) { + return "Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX"; + } else { + return responseText; + } + } + // @formatter:on + + public void printLastMetrics() { + LastRunMetrics.printMetrics(); + } + + private void processPromptMessages(List messageList) { + for (ChatMessage msg : messageList) { + if (msg instanceof UserMessage userMessage) { + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.USER, userMessage.singleText()))); + } else if (msg instanceof SystemMessage systemMessage && model.shouldAddSystemPrompt()) { + promptTokens.addAll( + chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage.text()))); + } else if (msg instanceof AiMessage aiMessage) { + promptTokens.addAll( + chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, aiMessage.text()))); + } + } + + // EncodeHeader to prime the model to start generating a new assistant response. + promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + } +} diff --git a/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3ChatModel.java b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3ChatModel.java new file mode 100644 index 000000000..4301a2f58 --- /dev/null +++ b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3ChatModel.java @@ -0,0 +1,129 @@ +package io.quarkiverse.langchain4j.gpullama3; + +import static dev.langchain4j.internal.Utils.getOrDefault; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Path; +import java.util.Optional; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.internal.ChatRequestValidationUtils; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.request.ChatRequestParameters; +import dev.langchain4j.model.chat.response.ChatResponse; + +public class GPULlama3ChatModel extends GPULlama3BaseModel implements ChatModel { + + // @formatter:off + private GPULlama3ChatModel(Builder builder) { + GPULlama3ModelRegistry gpuLlama3ModelRegistry = GPULlama3ModelRegistry.getOrCreate(builder.modelCachePath); + try { + Path modelPath = gpuLlama3ModelRegistry.downloadModel(builder.modelName, builder.quantization, + Optional.empty(), Optional.empty()); + init( + modelPath, + getOrDefault(builder.temperature, 0.1), + getOrDefault(builder.topP, 1.0), + getOrDefault(builder.seed, 12345), + getOrDefault(builder.maxTokens, 512), + getOrDefault(builder.onGPU, Boolean.TRUE)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + @Override + public ChatResponse doChat(ChatRequest chatRequest) { + ChatRequestValidationUtils.validateMessages(chatRequest.messages()); + ChatRequestParameters parameters = chatRequest.parameters(); + ChatRequestValidationUtils.validateParameters(parameters); + ChatRequestValidationUtils.validate(parameters.toolChoice()); + ChatRequestValidationUtils.validate(parameters.responseFormat()); + + try { + // Generate a raw response from the model + String rawResponse = modelResponse(chatRequest, null); + + // Parse thinking and actual response using the GPULlama3ResponseParser + GPULlama3ResponseParser.ParsedResponse parsed = GPULlama3ResponseParser.parseResponse(rawResponse); + + return ChatResponse.builder() + .aiMessage(AiMessage.builder() + .text(parsed.getActualResponse()) + .thinking(parsed.getThinkingContent()) + .build()) + .build(); + } catch (Exception e) { + throw new RuntimeException("Failed to generate response from GPULlama3", e); + } + } + + public static class Builder { + + private Optional modelCachePath; + private String modelName = Consts.DEFAULT_CHAT_MODEL_NAME; + private String quantization = Consts.DEFAULT_CHAT_MODEL_QUANTIZATION; + protected Double temperature; + protected Double topP; + protected Integer seed; + protected Integer maxTokens; + protected Boolean onGPU; + + public Builder() { + // This is public so it can be extended + } + + public Builder modelCachePath(Optional modelCachePath) { + this.modelCachePath = modelCachePath; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder quantization(String quantization) { + this.quantization = quantization; + return this; + } + + public Builder onGPU(Boolean onGPU) { + this.onGPU = onGPU; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder seed(Integer seed) { + this.seed = seed; + return this; + } + + public GPULlama3ChatModel build() { + return new GPULlama3ChatModel(this); + } + } +} diff --git a/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3ModelRegistry.java b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3ModelRegistry.java new file mode 100644 index 000000000..de62ca701 --- /dev/null +++ b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3ModelRegistry.java @@ -0,0 +1,249 @@ +package io.quarkiverse.langchain4j.gpullama3; + +import java.io.*; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import org.jboss.logging.Logger; + +/** + * A registry for managing GPULlama3.java models on local disk. + *

+ * Beehive Lab HuggingFace repository. + *

+ * Reused implementation of {@link io.quarkiverse.langchain4j.llama3.Llama3ModelRegistry} + */ +@SuppressWarnings("OptionalUsedAsFieldOrParameterType") +public class GPULlama3ModelRegistry { + + private static final Logger LOG = Logger.getLogger(GPULlama3ModelRegistry.class); + + private static final String DEFAULT_MODEL_CACHE_PATH = System.getProperty("user.home", "") + File.separator + ".langchain4j" + + File.separator + "models"; + + public static String FINISHED_MARKER = ".finished"; + + private final Path modelCachePath; + + private GPULlama3ModelRegistry(Path modelCachePath) { + this.modelCachePath = modelCachePath; + if (!Files.exists(modelCachePath)) { + try { + Files.createDirectories(modelCachePath); + } catch (IOException e) { + throw new IOError(e); + } + } + } + + public static GPULlama3ModelRegistry getOrCreate(Optional modelCachePath) { + return new GPULlama3ModelRegistry(modelCachePath.orElse(Path.of(DEFAULT_MODEL_CACHE_PATH))); + } + + public Path constructModelDirectoryPath(ModelInfo modelInfo) { + return Paths.get(modelCachePath.toAbsolutePath().toString(), modelInfo.owner() + "_" + modelInfo.name()); + } + + public Path constructGgufModelFilePath(ModelInfo modelInfo, String quantization) { + String effectiveFileName = getEffectiveFileName(modelInfo, quantization); + Path modelDirectory = constructModelDirectoryPath(modelInfo); + return modelDirectory.resolve(effectiveFileName); + } + + public Path downloadModel(String modelName, String quantization, Optional authToken, + Optional maybeProgressReporter) + throws IOException, InterruptedException { + ModelInfo modelInfo = ModelInfo.from(modelName); + + String effectiveFileName = getEffectiveFileName(modelInfo, quantization); + Path modelDirectory = constructModelDirectoryPath(modelInfo); + Path result = modelDirectory.resolve(effectiveFileName); + if (Files.exists(result) && Files.exists(modelDirectory.resolve(FINISHED_MARKER))) { + return result; + } + + HttpClient client = HttpClient.newBuilder().followRedirects(HttpClient.Redirect.ALWAYS).build(); + URI uri = URI.create( + String.format("https://huggingface.co/%s/%s/resolve/main/%s", modelInfo.owner(), modelInfo.name(), + effectiveFileName)); + HttpRequest request = HttpRequest.newBuilder().uri(uri).build(); + HttpResponse httpResponse = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); + if (httpResponse.statusCode() != 200) { + throw new RuntimeException( + "Unable to download model " + modelName + ". Response code from " + uri + " is : " + + httpResponse.statusCode()); + } + Files.createDirectories(result.getParent()); + long totalBytes = httpResponse.headers().firstValueAsLong("content-length").orElse(-1); + ProgressReporter progressReporter = maybeProgressReporter.orElse((filename, sizeDownloaded, totalSize) -> { + }); + + if (maybeProgressReporter.isEmpty()) { + LOG.info("Downloading file " + result.toAbsolutePath()); + } + String resultFileName = result.getFileName().toString(); + progressReporter.update(resultFileName, 0L, totalBytes); + + try (CountingInputStream inStream = new CountingInputStream(httpResponse.body())) { + CompletableFuture cf = CompletableFuture.supplyAsync(() -> { + try { + return Files.copy(inStream, result, StandardCopyOption.REPLACE_EXISTING); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + while (!cf.isDone()) { + progressReporter.update(resultFileName, inStream.count, totalBytes); + } + if (cf.isCompletedExceptionally()) { + progressReporter.update(resultFileName, inStream.count, totalBytes); + } else { + progressReporter.update(resultFileName, totalBytes, totalBytes); + } + + try { + cf.get(); + } catch (Throwable e) { + throw new IOException("Failed to download file: " + resultFileName, e); + } + if (maybeProgressReporter.isEmpty()) { + LOG.info("Downloaded file " + result.toAbsolutePath()); + } + } + + // create a finished marker + Files.createFile(modelDirectory.resolve(FINISHED_MARKER)); + return result; + } + + private String getEffectiveFileName(ModelInfo modelInfo, String quantization) { + String effectiveFileName = modelInfo.name(); + if (effectiveFileName.endsWith("-GGUF")) { + effectiveFileName = effectiveFileName.substring(0, effectiveFileName.length() - 5); + } + effectiveFileName = effectiveFileName + "-" + quantization + ".gguf"; + return effectiveFileName; + } + + /** + * This interface reports the progress of a .gguf file download. + * The implementation of the update method is used to communicate this progress. + */ + public interface ProgressReporter { + + void update(String filename, long sizeDownloaded, long totalSize); + } + + /** + * ModelInfo is a simple data class that represents a model's owner and name. + *

+ * Reused implementation of {@link io.quarkiverse.langchain4j.llama3.Llama3ModelRegistry} + */ + public record ModelInfo(String owner, String name) { + + public static ModelInfo from(String modelName) { + String[] parts = modelName.split("/"); + if (parts.length == 0 || parts.length > 2) { + throw new IllegalArgumentException("Model must be in the form owner/name"); + } + + String owner; + String name; + + if (parts.length == 1) { + owner = null; + name = modelName; + } else { + owner = parts[0]; + name = parts[1]; + } + + return new ModelInfo(owner, name); + } + + public String toFileName() { + return owner + "_" + name; + } + } + + /** + * An {@link InputStream} that counts the number of bytes read. + * + * @author Chris Nokleberg + * + * Copied from Guava + */ + public static final class CountingInputStream extends FilterInputStream { + + private long count; + private long mark = -1; + + /** + * Wraps another input stream, counting the number of bytes read. + * + * @param in the input stream to be wrapped + */ + public CountingInputStream(InputStream in) { + super(Objects.requireNonNull(in)); + } + + /** Returns the number of bytes read. */ + public long getCount() { + return count; + } + + @Override + public int read() throws IOException { + int result = in.read(); + if (result != -1) { + count++; + } + return result; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + int result = in.read(b, off, len); + if (result != -1) { + count += result; + } + return result; + } + + @Override + public long skip(long n) throws IOException { + long result = in.skip(n); + count += result; + return result; + } + + @Override + public synchronized void mark(int readlimit) { + in.mark(readlimit); + mark = count; + // it's okay to mark even if mark isn't supported, as reset won't work + } + + @Override + public synchronized void reset() throws IOException { + if (!in.markSupported()) { + throw new IOException("Mark not supported"); + } + if (mark == -1) { + throw new IOException("Mark not set"); + } + + in.reset(); + count = mark; + } + } +} diff --git a/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3ResponseParser.java b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3ResponseParser.java new file mode 100644 index 000000000..76a50a562 --- /dev/null +++ b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3ResponseParser.java @@ -0,0 +1,201 @@ +package io.quarkiverse.langchain4j.gpullama3; + +import dev.langchain4j.model.chat.response.PartialThinking; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; + +public class GPULlama3ResponseParser { + + private GPULlama3ResponseParser() { + // Utility class - prevent instantiation + } + + public static class ParsedResponse { + private final String thinkingContent; + private final String actualResponse; + + /** + * Creates a new ParsedResponse. + * + * @param thinkingContent the thinking content including tags, or null if none + * @param actualResponse the cleaned response content + */ + public ParsedResponse(String thinkingContent, String actualResponse) { + this.thinkingContent = thinkingContent; + this.actualResponse = actualResponse; + } + + /** + * Returns the thinking content including <think> and </think> tags. + * + * @return the thinking content with tags, or null if no thinking content was found + */ + public String getThinkingContent() { + return thinkingContent; + } + + /** + * Returns the actual response content with thinking tags removed. + * + * @return the cleaned response content + */ + public String getActualResponse() { + return actualResponse; + } + + /** + * Returns true if the response contained thinking content. + * + * @return true if thinking content was found, false otherwise + */ + public boolean hasThinking() { + return thinkingContent != null && !thinkingContent.trim().isEmpty(); + } + } + + public static ParsedResponse parseResponse(String rawResponse) { + if (rawResponse == null) { + throw new IllegalArgumentException("Raw response cannot be null"); + } + + String thinking = null; + String actualResponse = rawResponse; + + // Find and positions + int thinkStart = rawResponse.indexOf(""); + int thinkEnd = rawResponse.indexOf(""); + + if (thinkStart != -1 && thinkEnd != -1 && thinkEnd > thinkStart) { + // Extract thinking content INCLUDING the tags + thinking = rawResponse.substring(thinkStart, thinkEnd + 8).trim(); // Include + + // Remove the entire thinking block from response + String beforeThink = rawResponse.substring(0, thinkStart); + String afterThink = rawResponse.substring(thinkEnd + 8); // Skip + actualResponse = (beforeThink + afterThink).trim(); + + // Clean up any extra whitespace + actualResponse = actualResponse.replaceAll("\\s+", " ").trim(); + } + + return new ParsedResponse(thinking, actualResponse); + } + + public static String extractThinking(String rawResponse) { + return parseResponse(rawResponse).getThinkingContent(); + } + + public static String extractResponse(String rawResponse) { + return parseResponse(rawResponse).getActualResponse(); + } + + public static StreamingParser createStreamingParser( + StreamingChatResponseHandler handler, org.beehive.gpullama3.model.Model model) { + return new StreamingParser(handler, model); + } + + /** + * Parser for handling streaming responses with real-time thinking content separation. + *

+ * This parser detects thinking content as tokens are generated and routes it to + * the appropriate handler methods (onPartialThinking vs onPartialResponse). + * The thinking tags are preserved and streamed as part of the thinking content. + */ + public static class StreamingParser { + private final StreamingChatResponseHandler handler; + private final org.beehive.gpullama3.model.Model model; + private final StringBuilder buffer = new StringBuilder(); + private boolean insideThinking = false; + private int lastProcessedLength = 0; + + /** + * Creates a new streaming parser. + * + * @param handler the streaming response handler + * @param model the GPULlama3 model instance for token decoding + */ + public StreamingParser(StreamingChatResponseHandler handler, org.beehive.gpullama3.model.Model model) { + this.handler = handler; + this.model = model; + } + + /** + * Processes each token as it's generated by the model. + * + * @param tokenId the token ID generated by the model + */ + public void onToken(int tokenId) { + // Check if this is a stop token and skip it + if (model.chatFormat().getStopTokens().contains(tokenId)) { + return; // Don't stream stop tokens like <|im_end|> + } + + // Decode the token and add to buffer + String tokenStr = model.tokenizer().decode(java.util.List.of(tokenId)); + buffer.append(tokenStr); + + String currentText = buffer.toString(); + + // Process any new content since last time + processNewContent(currentText); + } + + /** + * Processes new content in the buffer, detecting thinking state transitions + * and routing content to appropriate handler methods. + */ + private void processNewContent(String currentText) { + if (currentText.length() <= lastProcessedLength) { + return; // No new content + } + + String newContent = currentText.substring(lastProcessedLength); + + // Process each character in the new content + for (int i = 0; i < newContent.length(); i++) { + int currentPosition = lastProcessedLength + i; + + // Check if we're starting thinking + if (!insideThinking && isStartOfThinkTag(currentText, currentPosition)) { + insideThinking = true; + // Stream the opening tag as thinking + handler.onPartialThinking(new PartialThinking("")); + i += 6; // Skip the rest of "" + continue; + } + + // Check if we're ending thinking + if (insideThinking && isStartOfEndThinkTag(currentText, currentPosition)) { + // Stream the closing tag as thinking + handler.onPartialThinking(new PartialThinking("")); + insideThinking = false; + i += 7; // Skip the rest of "" + continue; + } + + // Stream the character to appropriate handler + char c = newContent.charAt(i); + if (insideThinking) { + handler.onPartialThinking(new PartialThinking(String.valueOf(c))); + } else { + handler.onPartialResponse(String.valueOf(c)); + } + } + + lastProcessedLength = currentText.length(); + } + + /** + * Checks if the text at the given position starts with "<think>". + */ + private boolean isStartOfThinkTag(String text, int position) { + return position + 7 <= text.length() && text.regionMatches(position, "", 0, 7); + } + + /** + * Checks if the text at the given position starts with "</think>". + */ + private boolean isStartOfEndThinkTag(String text, int position) { + return position + 8 <= text.length() && text.regionMatches(position, "", 0, 8); + } + } +} diff --git a/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/GPULlama3Recorder.java b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/GPULlama3Recorder.java new file mode 100644 index 000000000..9b738c269 --- /dev/null +++ b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/GPULlama3Recorder.java @@ -0,0 +1,89 @@ +package io.quarkiverse.langchain4j.gpullama3.runtime; + +import java.util.function.Supplier; + +import org.jboss.logging.Logger; + +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.DisabledChatModel; +import io.quarkiverse.langchain4j.gpullama3.GPULlama3ChatModel; +import io.quarkiverse.langchain4j.gpullama3.runtime.config.LangChain4jGPULlama3FixedRuntimeConfig; +import io.quarkiverse.langchain4j.gpullama3.runtime.config.LangChain4jGPULlama3RuntimeConfig; +import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; +import io.quarkus.runtime.RuntimeValue; +import io.quarkus.runtime.annotations.Recorder; + +@Recorder +public class GPULlama3Recorder { + + private static final Logger LOG = Logger.getLogger(GPULlama3Recorder.class); + + private final RuntimeValue runtimeConfig; + private final RuntimeValue fixedRuntimeConfig; + + public GPULlama3Recorder(RuntimeValue runtimeConfig, + RuntimeValue fixedRuntimeConfig) { + this.runtimeConfig = runtimeConfig; + this.fixedRuntimeConfig = fixedRuntimeConfig; + } + + public Supplier chatModel(String configName) { + var gpuLlama3Config = correspondingConfig(configName); + var gpuLlama3FixedRuntimeConfig = correspondingFixedConfig(configName); + + if (gpuLlama3Config.enableIntegration()) { + LOG.info("Creating GPULlama3ChatModel for config: " + configName); + var chatModelConfig = gpuLlama3Config.chatModel(); + + var builder = GPULlama3ChatModel.builder() + .modelName(gpuLlama3FixedRuntimeConfig.chatModel().modelName()) + .quantization(gpuLlama3FixedRuntimeConfig.chatModel().quantization()) + .onGPU(Boolean.TRUE) + .modelCachePath(fixedRuntimeConfig.getValue().modelsPath()); + + if (chatModelConfig.temperature().isPresent()) { + builder.temperature(chatModelConfig.temperature().getAsDouble()); + } + if (chatModelConfig.topP().isPresent()) { + builder.topP(chatModelConfig.topP().getAsDouble()); + } + if (chatModelConfig.maxTokens().isPresent()) { + builder.maxTokens(chatModelConfig.maxTokens().getAsInt()); + } + if (chatModelConfig.seed().isPresent()) { + builder.seed(chatModelConfig.seed().getAsInt()); + } + + return new Supplier<>() { + @Override + public ChatModel get() { + return builder.build(); + } + }; + } else { + return new Supplier<>() { + @Override + public ChatModel get() { + return new DisabledChatModel(); + } + }; + } + } + + private LangChain4jGPULlama3RuntimeConfig.GPULlama3Config correspondingConfig(String configName) { + return NamedConfigUtil.isDefault(configName) + ? runtimeConfig.getValue().defaultConfig() + : runtimeConfig.getValue().namedConfig().get(configName); + } + + private LangChain4jGPULlama3FixedRuntimeConfig.GPULlama3Config correspondingFixedConfig(String configName) { + return NamedConfigUtil.isDefault(configName) + ? fixedRuntimeConfig.getValue().defaultConfig() + : fixedRuntimeConfig.getValue().namedConfig().get(configName); + } + + private boolean inDebugMode() { + return LOG.isDebugEnabled(); + } + +} diff --git a/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/NameAndQuantization.java b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/NameAndQuantization.java new file mode 100644 index 000000000..a229037da --- /dev/null +++ b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/NameAndQuantization.java @@ -0,0 +1,4 @@ +package io.quarkiverse.langchain4j.gpullama3.runtime; + +public record NameAndQuantization(String name, String quantization) { +} diff --git a/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/ChatModelConfig.java b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/ChatModelConfig.java new file mode 100644 index 000000000..c177b8008 --- /dev/null +++ b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/ChatModelConfig.java @@ -0,0 +1,41 @@ +package io.quarkiverse.langchain4j.gpullama3.runtime.config; + +import java.util.OptionalDouble; +import java.util.OptionalInt; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigGroup; +import io.smallrye.config.WithDefault; + +@ConfigGroup +public interface ChatModelConfig { + + /** + * What sampling temperature to use, between 0.0 and 1.0. + */ + @ConfigDocDefault("0.3") + @WithDefault("${quarkus.langchain4j.temperature}") + OptionalDouble temperature(); + + /** + * What sampling topP to use, between 0.0 and 1.0. + */ + @ConfigDocDefault("0.85") + @WithDefault("${quarkus.langchain4j.top-p}") + OptionalDouble topP(); + + /** + * What seed value to use. + * + * @return + */ + @ConfigDocDefault("1234") + @WithDefault("${quarkus.langchain4j.seed}") + OptionalInt seed(); + + /** + * The maximum number of tokens to generate in the completion. + */ + @ConfigDocDefault("512") + OptionalInt maxTokens(); +} \ No newline at end of file diff --git a/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/ChatModelFixedRuntimeConfig.java b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/ChatModelFixedRuntimeConfig.java new file mode 100644 index 000000000..c7f8359ce --- /dev/null +++ b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/ChatModelFixedRuntimeConfig.java @@ -0,0 +1,21 @@ +package io.quarkiverse.langchain4j.gpullama3.runtime.config; + +import io.quarkiverse.langchain4j.gpullama3.Consts; +import io.quarkus.runtime.annotations.ConfigGroup; +import io.smallrye.config.WithDefault; + +@ConfigGroup +public interface ChatModelFixedRuntimeConfig { + + /** + * Model name to use + */ + @WithDefault(Consts.DEFAULT_CHAT_MODEL_NAME) + String modelName(); + + /** + * Quantization of the model to use + */ + @WithDefault(Consts.DEFAULT_CHAT_MODEL_QUANTIZATION) + String quantization(); +} \ No newline at end of file diff --git a/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/LangChain4jGPULlama3FixedRuntimeConfig.java b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/LangChain4jGPULlama3FixedRuntimeConfig.java new file mode 100644 index 000000000..f6835644c --- /dev/null +++ b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/LangChain4jGPULlama3FixedRuntimeConfig.java @@ -0,0 +1,67 @@ +package io.quarkiverse.langchain4j.gpullama3.runtime.config; + +import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_AND_RUN_TIME_FIXED; + +import java.nio.file.Path; +import java.util.Map; +import java.util.Optional; + +import io.quarkus.runtime.annotations.*; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; + +/** + * Fixed runtime configuration for GPULlama3 extension. + *

+ * This configuration is read at build time and remains fixed for the lifetime of the application. + * It includes settings that cannot be changed after the application is built, such as + * the model file path. These values are baked into the application during the build process. + *

+ * To change these settings, the application must be rebuilt with the new configuration values. + * This ensures optimal performance and allows for build-time validation and optimization. + *

+ * Example configuration: + * + *

+ * quarkus.langchain4j.gpu-llama3.chat-model.model-path=/path/to/model.gguf
+ * 
+ *

+ * Note: These properties must be set in {@code application.properties} at build time + * and cannot be overridden at runtime through environment variables or system properties. + */ +@ConfigRoot(phase = BUILD_AND_RUN_TIME_FIXED) +@ConfigMapping(prefix = "quarkus.langchain4j.gpu-llama3") +public interface LangChain4jGPULlama3FixedRuntimeConfig { + + /** + * Default model config. + */ + @WithParentName + GPULlama3Config defaultConfig(); + + /** + * Named model config. + */ + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); + + /** + * Location on the file-system which serves as a cache for the models + * + */ + @ConfigDocDefault("${user.home}/.langchain4j/models") + Optional modelsPath(); + + @ConfigGroup + interface GPULlama3Config { + + /** + * Chat model related settings + */ + ChatModelFixedRuntimeConfig chatModel(); + } +} \ No newline at end of file diff --git a/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/LangChain4jGPULlama3RuntimeConfig.java b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/LangChain4jGPULlama3RuntimeConfig.java new file mode 100644 index 000000000..f85dc7f24 --- /dev/null +++ b/model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/runtime/config/LangChain4jGPULlama3RuntimeConfig.java @@ -0,0 +1,71 @@ +package io.quarkiverse.langchain4j.gpullama3.runtime.config; + +import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; + +import java.util.Map; +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; + +/** + * Runtime configuration for GPULlama3 extension. + *

+ * This configuration is read at runtime and can be changed without rebuilding the application. + * It includes dynamic settings such as model parameters (temperature, max tokens), + * logging preferences, and integration control. + */ +@ConfigRoot(phase = RUN_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.gpu-llama3") +public interface LangChain4jGPULlama3RuntimeConfig { + + /** + * Default model config. + */ + @WithParentName + GPULlama3Config defaultConfig(); + + /** + * Named model config. + */ + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); + + interface GPULlama3Config { + + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); + + /** + * Whether to enable the integration. Set to {@code false} to disable + * all requests. + */ + @WithDefault("true") + Boolean enableIntegration(); + + /** + * Whether GPULlama3 should log requests + */ + @ConfigDocDefault("false") + @WithDefault("${quarkus.langchain4j.log-requests}") + Optional logRequests(); + + /** + * Whether GPULlama3 client should log responses + */ + @ConfigDocDefault("false") + @WithDefault("${quarkus.langchain4j.log-responses}") + Optional logResponses(); + } +} \ No newline at end of file diff --git a/model-providers/pom.xml b/model-providers/pom.xml index 0b0e65720..429ab7b18 100644 --- a/model-providers/pom.xml +++ b/model-providers/pom.xml @@ -38,6 +38,7 @@ jlama + gpu-llama3 diff --git a/pom.xml b/pom.xml index b62a35367..e538dc932 100644 --- a/pom.xml +++ b/pom.xml @@ -51,6 +51,7 @@ 5.6.0 0.8.4 1.37.1 + 0.2.2 0.9.2