Skip to content

Commit 91eb9d5

Browse files
orionpapadakisgeoand
authored andcommitted
Introduce ModelRegistry for GPULlama3 for automatic model management.
1 parent b9f3b76 commit 91eb9d5

File tree

12 files changed

+568
-93
lines changed

12 files changed

+568
-93
lines changed
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
1+
quarkus.langchain4j.gpu-llama3.include-models-in-artifact=false
2+
13
# Configure GPULlama3
2-
quarkus.langchain4j.gpu-llama3.chat-model.model-path=/Users/orion/LLMModels/beehive-llama-3.2-1b-instruct-fp16.gguf
34
quarkus.langchain4j.gpu-llama3.enable-integration=true
5+
quarkus.langchain4j.gpu-llama3.chat-model.model-name=beehive-lab/Llama-3.2-1B-Instruct-GGUF
6+
quarkus.langchain4j.gpu-llama3.chat-model.quantization=FP16
47
quarkus.langchain4j.gpu-llama3.chat-model.temperature=0.7
5-
quarkus.langchain4j.gpu-llama3.chat-model.max-tokens=100
8+
quarkus.langchain4j.gpu-llama3.chat-model.max-tokens=513
9+
10+
# other supported models:
11+
#model-name=ggml-org/Qwen3-0.6B-GGUF
12+
#quantization=f16

model-providers/gpu-llama3/deployment/src/main/java/io/quarkiverse/langchain4j/gpullama3/deployment/GPULlama3Processor.java

Lines changed: 159 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,44 @@
22

33
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL;
44

5+
import java.io.IOException;
6+
import java.math.BigDecimal;
7+
import java.math.RoundingMode;
8+
import java.nio.file.Files;
9+
import java.nio.file.Path;
10+
import java.util.ArrayList;
511
import java.util.List;
12+
import java.util.Optional;
13+
import java.util.concurrent.TimeUnit;
14+
import java.util.concurrent.atomic.AtomicReference;
615

716
import jakarta.enterprise.context.ApplicationScoped;
817

18+
import org.jboss.logging.Logger;
19+
920
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
1021
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
22+
import io.quarkiverse.langchain4j.gpullama3.GPULlama3ModelRegistry;
1123
import io.quarkiverse.langchain4j.gpullama3.runtime.GPULlama3Recorder;
24+
import io.quarkiverse.langchain4j.gpullama3.runtime.NameAndQuantization;
25+
import io.quarkiverse.langchain4j.gpullama3.runtime.config.ChatModelFixedRuntimeConfig;
26+
import io.quarkiverse.langchain4j.gpullama3.runtime.config.LangChain4jGPULlama3FixedRuntimeConfig;
27+
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
1228
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
13-
import io.quarkus.deployment.annotations.BuildProducer;
14-
import io.quarkus.deployment.annotations.BuildStep;
15-
import io.quarkus.deployment.annotations.ExecutionTime;
29+
import io.quarkus.builder.item.MultiBuildItem;
30+
import io.quarkus.deployment.annotations.*;
1631
import io.quarkus.deployment.annotations.Record;
1732
import io.quarkus.deployment.builditem.FeatureBuildItem;
33+
import io.quarkus.deployment.builditem.LaunchModeBuildItem;
34+
import io.quarkus.deployment.builditem.ServiceStartBuildItem;
35+
import io.quarkus.deployment.console.ConsoleInstalledBuildItem;
36+
import io.quarkus.deployment.console.StartupLogCompressor;
37+
import io.quarkus.deployment.logging.LoggingSetupBuildItem;
1838

1939
public class GPULlama3Processor {
2040

41+
private final static Logger LOG = Logger.getLogger(GPULlama3Processor.class);
42+
2143
private static final String PROVIDER = "gpu-llama3";
2244
private static final String FEATURE = "langchain4j-gpu-llama3";
2345

@@ -55,4 +77,138 @@ void generateBeans(GPULlama3Recorder recorder,
5577
}
5678
}
5779
}
80+
81+
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
82+
@Produce(ServiceStartBuildItem.class)
83+
@BuildStep
84+
void downloadModels(List<SelectedChatModelProviderBuildItem> selectedChatModels,
85+
LoggingSetupBuildItem loggingSetupBuildItem,
86+
Optional<ConsoleInstalledBuildItem> consoleInstalledBuildItem,
87+
LaunchModeBuildItem launchMode,
88+
LangChain4jGPULlama3BuildTimeConfig buildTimeConfig,
89+
LangChain4jGPULlama3FixedRuntimeConfig fixedRuntimeConfig,
90+
BuildProducer<ModelDownloadedBuildItem> modelDownloadedProducer) {
91+
if (!buildTimeConfig.includeModelsInArtifact()) {
92+
return;
93+
}
94+
GPULlama3ModelRegistry registry = GPULlama3ModelRegistry.getOrCreate(fixedRuntimeConfig.modelsPath());
95+
96+
BigDecimal ONE_HUNDRED = new BigDecimal("100");
97+
98+
if (buildTimeConfig.chatModel().enabled().orElse(true)) {
99+
List<NameAndQuantization> modelsNeeded = new ArrayList<>();
100+
for (var selected : selectedChatModels) {
101+
if (PROVIDER.equals(selected.getProvider())) {
102+
String configName = selected.getConfigName();
103+
104+
ChatModelFixedRuntimeConfig matchingConfig = NamedConfigUtil.isDefault(configName)
105+
? fixedRuntimeConfig.defaultConfig().chatModel()
106+
: fixedRuntimeConfig.namedConfig().get(configName).chatModel();
107+
modelsNeeded.add(new NameAndQuantization(matchingConfig.modelName(), matchingConfig.quantization()));
108+
}
109+
}
110+
111+
if (!modelsNeeded.isEmpty()) {
112+
StartupLogCompressor compressor = new StartupLogCompressor(
113+
(launchMode.isTest() ? "(test) " : "") + "GPULlama3.java model pull:",
114+
consoleInstalledBuildItem,
115+
loggingSetupBuildItem);
116+
117+
for (var model : modelsNeeded) {
118+
GPULlama3ModelRegistry.ModelInfo modelInfo = GPULlama3ModelRegistry.ModelInfo.from(model.name());
119+
Path pathOfModelDirOnDisk = registry.constructModelDirectoryPath(modelInfo);
120+
// Check if the model is already downloaded
121+
// this is done automatically by download model, but we want to provide a good progress experience, so we do it again here
122+
if (Files.exists(pathOfModelDirOnDisk.resolve(GPULlama3ModelRegistry.FINISHED_MARKER))) {
123+
LOG.debug("Model " + model.name() + "already exists in " + pathOfModelDirOnDisk);
124+
} else {
125+
// we pull one model at a time and provide progress updates to the user via logging
126+
LOG.info("Pulling model " + model.name());
127+
128+
AtomicReference<Long> LAST_UPDATE_REF = new AtomicReference<>();
129+
130+
try {
131+
registry.downloadModel(model.name(), model.quantization(), Optional.empty(),
132+
Optional.of(new GPULlama3ModelRegistry.ProgressReporter() {
133+
@Override
134+
public void update(String filename, long sizeDownloaded, long totalSize) {
135+
// Jlama downloads a bunch of files for each mode of which only the
136+
// weights file is large
137+
// and makes sense to report progress on
138+
if (totalSize < 100_000) {
139+
return;
140+
}
141+
142+
if (!logUpdate(LAST_UPDATE_REF.get())) {
143+
return;
144+
}
145+
146+
LAST_UPDATE_REF.set(System.nanoTime());
147+
148+
BigDecimal percentage = new BigDecimal(sizeDownloaded)
149+
.divide(new BigDecimal(totalSize),
150+
4,
151+
RoundingMode.HALF_DOWN)
152+
.multiply(ONE_HUNDRED);
153+
BigDecimal progress = percentage.setScale(2, RoundingMode.HALF_DOWN);
154+
if (progress.compareTo(ONE_HUNDRED) >= 0) {
155+
// avoid showing 100% for too long
156+
LOG.infof("Verifying and cleaning up\n", progress);
157+
} else {
158+
LOG.infof("%s - Progress: %s%%\n", model.name(), progress);
159+
}
160+
}
161+
162+
/**
163+
* @param lastUpdate The last update time in nanoseconds
164+
* Determines whether we should log an update.
165+
* This is done in order to not overwhelm the console with updates which might
166+
* make
167+
* canceling the download difficult. See
168+
* <a href=
169+
* "https://github.com/quarkiverse/quarkus-langchain4j/issues/1044">this</a>
170+
*/
171+
private boolean logUpdate(Long lastUpdate) {
172+
if (lastUpdate == null) {
173+
return true;
174+
} else {
175+
return TimeUnit.NANOSECONDS.toMillis(System.nanoTime())
176+
- TimeUnit.NANOSECONDS.toMillis(lastUpdate) > 1_000;
177+
}
178+
}
179+
}));
180+
} catch (IOException e) {
181+
compressor.closeAndDumpCaptured();
182+
} catch (InterruptedException e) {
183+
throw new RuntimeException(e);
184+
}
185+
}
186+
187+
modelDownloadedProducer.produce(new ModelDownloadedBuildItem(model, pathOfModelDirOnDisk));
188+
}
189+
190+
compressor.close();
191+
}
192+
}
193+
194+
}
195+
196+
public static final class ModelDownloadedBuildItem extends MultiBuildItem {
197+
198+
private final NameAndQuantization model;
199+
private final Path directory;
200+
201+
public ModelDownloadedBuildItem(NameAndQuantization model, Path directory) {
202+
this.model = model;
203+
this.directory = directory;
204+
}
205+
206+
public NameAndQuantization getModel() {
207+
return model;
208+
}
209+
210+
public Path getDirectory() {
211+
return directory;
212+
}
213+
}
58214
}

model-providers/gpu-llama3/deployment/src/main/java/io/quarkiverse/langchain4j/gpullama3/deployment/LangChain4jGPULlama3BuildTimeConfig.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,19 @@
44

55
import io.quarkus.runtime.annotations.ConfigRoot;
66
import io.smallrye.config.ConfigMapping;
7+
import io.smallrye.config.WithDefault;
78

89
@ConfigRoot(phase = BUILD_TIME)
910
@ConfigMapping(prefix = "quarkus.langchain4j.gpu-llama3")
1011
public interface LangChain4jGPULlama3BuildTimeConfig {
1112

13+
/**
14+
* Determines whether the necessary GPULlama3 models are downloaded and included in the jar at build time.
15+
* Currently, this option is only valid for {@code fast-jar} deployments.
16+
*/
17+
@WithDefault("true")
18+
boolean includeModelsInArtifact();
19+
1220
/**
1321
* Chat model related settings
1422
*/
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package io.quarkiverse.langchain4j.gpullama3;
2+
3+
public final class Consts {
4+
5+
private Consts() {
6+
}
7+
8+
/**
9+
* working links:
10+
* https://huggingface.co/beehive-lab/Llama-3.2-1B-Instruct-GGUF/blob/main/Llama-3.2-1B-Instruct-FP16.gguf
11+
* https://huggingface.co/ggml-org/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-f16.gguf
12+
*/
13+
14+
public static final String DEFAULT_CHAT_MODEL_NAME = "beehive-lab/Llama-3.2-1B-Instruct-GGUF";
15+
public static final String DEFAULT_CHAT_MODEL_QUANTIZATION = "FP16";
16+
17+
}

model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3ChatModel.java

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package io.quarkiverse.langchain4j.gpullama3;
22

33
import static dev.langchain4j.internal.Utils.getOrDefault;
4-
import static java.util.Objects.requireNonNull;
54

5+
import java.io.IOException;
6+
import java.io.UncheckedIOException;
67
import java.nio.file.Path;
8+
import java.util.Optional;
79

810
import dev.langchain4j.data.message.AiMessage;
911
import dev.langchain4j.internal.ChatRequestValidationUtils;
@@ -16,13 +18,22 @@ public class GPULlama3ChatModel extends GPULlama3BaseModel implements ChatModel
1618

1719
// @formatter:off
1820
private GPULlama3ChatModel(Builder builder) {
19-
init(
20-
requireNonNull(builder.modelPath, "modelPath is required and must be specified"),
21-
getOrDefault(builder.temperature, 0.1),
22-
getOrDefault(builder.topP, 1.0),
23-
getOrDefault(builder.seed, 12345),
24-
getOrDefault(builder.maxTokens, 512),
25-
getOrDefault(builder.onGPU, Boolean.TRUE));
21+
GPULlama3ModelRegistry gpuLlama3ModelRegistry = GPULlama3ModelRegistry.getOrCreate(builder.modelCachePath);
22+
try {
23+
Path modelPath = gpuLlama3ModelRegistry.downloadModel(builder.modelName, builder.quantization,
24+
Optional.empty(), Optional.empty());
25+
init(
26+
modelPath,
27+
getOrDefault(builder.temperature, 0.1),
28+
getOrDefault(builder.topP, 1.0),
29+
getOrDefault(builder.seed, 12345),
30+
getOrDefault(builder.maxTokens, 512),
31+
getOrDefault(builder.onGPU, Boolean.TRUE));
32+
} catch (IOException e) {
33+
throw new UncheckedIOException(e);
34+
} catch (InterruptedException e) {
35+
throw new RuntimeException(e);
36+
}
2637
}
2738
// @formatter:on
2839

@@ -58,7 +69,9 @@ public ChatResponse doChat(ChatRequest chatRequest) {
5869

5970
public static class Builder {
6071

61-
protected Path modelPath;
72+
private Optional<Path> modelCachePath;
73+
private String modelName = Consts.DEFAULT_CHAT_MODEL_NAME;
74+
private String quantization = Consts.DEFAULT_CHAT_MODEL_QUANTIZATION;
6275
protected Double temperature;
6376
protected Double topP;
6477
protected Integer seed;
@@ -69,8 +82,18 @@ public Builder() {
6982
// This is public so it can be extended
7083
}
7184

72-
public Builder modelPath(Path modelPath) {
73-
this.modelPath = modelPath;
85+
public Builder modelCachePath(Optional<Path> modelCachePath) {
86+
this.modelCachePath = modelCachePath;
87+
return this;
88+
}
89+
90+
public Builder modelName(String modelName) {
91+
this.modelName = modelName;
92+
return this;
93+
}
94+
95+
public Builder quantization(String quantization) {
96+
this.quantization = quantization;
7497
return this;
7598
}
7699

0 commit comments

Comments
 (0)