Skip to content

Commit d160109

Browse files
ibriqgeoand
authored andcommitted
Add GPULlama3.java as model provider to run on GPUs
1 parent 66bd02a commit d160109

File tree

21 files changed

+1206
-0
lines changed

21 files changed

+1206
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
3+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
4+
<modelVersion>4.0.0</modelVersion>
5+
<parent>
6+
<groupId>io.quarkiverse.langchain4j</groupId>
7+
<artifactId>quarkus-langchain4j-integration-tests-parent</artifactId>
8+
<version>999-SNAPSHOT</version>
9+
</parent>
10+
<artifactId>quarkus-langchain4j-integration-test-gpu-llama3</artifactId>
11+
<name>Quarkus LangChain4j - Integration Tests - GPULlama3</name>
12+
<properties>
13+
<skipITs>true</skipITs>
14+
<maven.compiler.release>21</maven.compiler.release>
15+
<quarkus.version>3.18.0</quarkus.version>
16+
<!-- TornadoVM argfile path -->
17+
<tornado.argfile>${env.TORNADO_SDK}/../../../tornado-argfile</tornado.argfile>
18+
</properties>
19+
<dependencies>
20+
<dependency>
21+
<groupId>io.quarkus</groupId>
22+
<artifactId>quarkus-rest-jackson</artifactId>
23+
</dependency>
24+
<dependency>
25+
<groupId>io.quarkiverse.langchain4j</groupId>
26+
<artifactId>quarkus-langchain4j-gpu-llama3</artifactId>
27+
<version>999-SNAPSHOT</version>
28+
</dependency>
29+
<dependency>
30+
<groupId>io.quarkus</groupId>
31+
<artifactId>quarkus-junit5</artifactId>
32+
<scope>test</scope>
33+
</dependency>
34+
<dependency>
35+
<groupId>io.rest-assured</groupId>
36+
<artifactId>rest-assured</artifactId>
37+
<scope>test</scope>
38+
</dependency>
39+
<dependency>
40+
<groupId>org.assertj</groupId>
41+
<artifactId>assertj-core</artifactId>
42+
<scope>test</scope>
43+
</dependency>
44+
<dependency>
45+
<groupId>io.quarkus</groupId>
46+
<artifactId>quarkus-devtools-testing</artifactId>
47+
<scope>test</scope>
48+
</dependency>
49+
50+
<!-- Make sure the deployment artifact is built before executing this module -->
51+
<dependency>
52+
<groupId>io.quarkiverse.langchain4j</groupId>
53+
<artifactId>quarkus-langchain4j-gpu-llama3-deployment</artifactId>
54+
<version>999-SNAPSHOT</version>
55+
<type>pom</type>
56+
<scope>test</scope>
57+
<exclusions>
58+
<exclusion>
59+
<groupId>*</groupId>
60+
<artifactId>*</artifactId>
61+
</exclusion>
62+
</exclusions>
63+
</dependency>
64+
</dependencies>
65+
<build>
66+
<plugins>
67+
<plugin>
68+
<groupId>io.quarkus</groupId>
69+
<artifactId>quarkus-maven-plugin</artifactId>
70+
<executions>
71+
<execution>
72+
<goals>
73+
<goal>build</goal>
74+
</goals>
75+
</execution>
76+
</executions>
77+
<configuration>
78+
<!-- Pass tornado-argfile to dev mode -->
79+
<jvmArgs>@${tornado.argfile}</jvmArgs>
80+
</configuration>
81+
</plugin>
82+
83+
<plugin>
84+
<artifactId>maven-failsafe-plugin</artifactId>
85+
<executions>
86+
<execution>
87+
<goals>
88+
<goal>integration-test</goal>
89+
<goal>verify</goal>
90+
</goals>
91+
<configuration>
92+
<argLine>@${tornado.argfile}</argLine>
93+
<systemPropertyVariables>
94+
<native.image.path>${project.build.directory}/${project.build.finalName}-runner</native.image.path>
95+
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
96+
<maven.home>${maven.home}</maven.home>
97+
</systemPropertyVariables>
98+
</configuration>
99+
</execution>
100+
</executions>
101+
</plugin>
102+
</plugins>
103+
</build>
104+
</project>
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package org.acme.example.gpullama3.chat;
2+
3+
import jakarta.ws.rs.GET;
4+
import jakarta.ws.rs.Path;
5+
6+
import dev.langchain4j.model.chat.ChatModel;
7+
8+
@Path("chat")
9+
public class ChatLanguageModelResource {
10+
11+
private final ChatModel chatModel;
12+
13+
public ChatLanguageModelResource(ChatModel chatModel) {
14+
this.chatModel = chatModel;
15+
}
16+
17+
@GET
18+
@Path("blocking")
19+
public String blocking() {
20+
return chatModel.chat("When was the nobel prize for economics first awarded?");
21+
}
22+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Configure GPULlama3
2+
quarkus.langchain4j.gpu-llama3.chat-model.model-path=/Users/orion/LLMModels/beehive-llama-3.2-1b-instruct-fp16.gguf
3+
quarkus.langchain4j.gpu-llama3.enable-integration=true
4+
quarkus.langchain4j.gpu-llama3.chat-model.temperature=0.7
5+
quarkus.langchain4j.gpu-llama3.chat-model.max-tokens=100

integration-tests/pom.xml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
</activation>
4646
<modules>
4747
<module>jlama</module>
48+
<module>gpu-llama3</module>
4849
</modules>
4950
</profile>
5051
<profile>
@@ -107,6 +108,11 @@
107108
<artifactId>quarkus-langchain4j-easy-rag</artifactId>
108109
<version>999-SNAPSHOT</version>
109110
</dependency>
111+
<dependency>
112+
<groupId>io.quarkiverse.langchain4j</groupId>
113+
<artifactId>quarkus-langchain4j-gpullama3</artifactId>
114+
<version>${quarkus-langchain4j.version}</version>
115+
</dependency>
110116
<dependency>
111117
<groupId>io.quarkiverse.langchain4j</groupId>
112118
<artifactId>quarkus-langchain4j-hugging-face</artifactId>
@@ -122,6 +128,11 @@
122128
<artifactId>quarkus-langchain4j-llama3-java</artifactId>
123129
<version>999-SNAPSHOT</version>
124130
</dependency>
131+
<dependency>
132+
<groupId>io.quarkiverse.langchain4j</groupId>
133+
<artifactId>quarkus-langchain4j-gpu-llama3</artifactId>
134+
<version>999-SNAPSHOT</version>
135+
</dependency>
125136
<dependency>
126137
<groupId>io.quarkiverse.langchain4j</groupId>
127138
<artifactId>quarkus-langchain4j-mcp</artifactId>
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
### How to run the integrated tests:
2+
3+
#### 1) Install TornadoVM:
4+
5+
```bash
6+
cd ~
7+
git clone [email protected]:beehive-lab/TornadoVM.git
8+
cd ~/TornadoVM
9+
./bin/tornadovm-installer --jdk jdk21 --backend opencl
10+
source setvars.sh
11+
```
12+
13+
Note that the above steps:
14+
- Set `TORNADOVM_SDK` environment variable to the path of the TornadoVM SDK.
15+
- Create the `tornado-argfile` under `~/TornadoVM` which contains all the required JVM arguments to enable TornadoVM.
16+
- The argfile is automatically used in Quarkus dev mode; however, in production mode, you need to manually pass the argfile to the JVM (see step 3).
17+
18+
#### 2) Build Quarkus-langchain4j:
19+
20+
```bash
21+
cd ~
22+
git clone [email protected]:mikepapadim/quarkus-langchain4j.git
23+
cd ~/quarkus-langchain4j
24+
git checkout gpu-llama3-integration
25+
mvn clean install -DskipTests
26+
```
27+
28+
#### 3) Run the integrated tests:
29+
30+
##### 3.1 Deploy the Quarkus app:
31+
32+
```bash
33+
cd ~/quarkus-langchain4j/integration-tests/gpullama3
34+
```
35+
- For *dev* mode, run:
36+
```
37+
mvn quarkus:dev
38+
```
39+
40+
- For *production* mode, run:
41+
```bash
42+
java @~/TornadoVM/tornado-argfile -jar target/quarkus-app/quarkus-run.jar
43+
```
44+
##### 3.2 Send requests to the Quarkus app:
45+
46+
when quarkus is running, open a new terminal and run:
47+
48+
```bash
49+
curl http://localhost:8080/chat/blocking
50+
```
51+
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
7+
<parent>
8+
<groupId>io.quarkiverse.langchain4j</groupId>
9+
<artifactId>quarkus-langchain4j-gpu-llama3-parent</artifactId>
10+
<version>999-SNAPSHOT</version>
11+
</parent>
12+
13+
<artifactId>quarkus-langchain4j-gpu-llama3-deployment</artifactId>
14+
<name>Quarkus LangChain4j - GPULlama3 - Deployment</name>
15+
16+
<dependencies>
17+
<dependency>
18+
<groupId>io.quarkiverse.langchain4j</groupId>
19+
<artifactId>quarkus-langchain4j-gpu-llama3</artifactId>
20+
<version>${project.version}</version>
21+
</dependency>
22+
23+
<!-- what is this -->
24+
<dependency>
25+
<groupId>io.quarkus</groupId>
26+
<artifactId>quarkus-arc-deployment</artifactId>
27+
<scope>provided</scope>
28+
</dependency>
29+
<dependency>
30+
<groupId>io.quarkiverse.langchain4j</groupId>
31+
<artifactId>quarkus-langchain4j-core-deployment</artifactId>
32+
<version>${project.version}</version>
33+
</dependency>
34+
35+
<dependency>
36+
<groupId>io.quarkus</groupId>
37+
<artifactId>quarkus-junit5-internal</artifactId>
38+
<scope>test</scope>
39+
</dependency>
40+
<dependency>
41+
<groupId>org.assertj</groupId>
42+
<artifactId>assertj-core</artifactId>
43+
<scope>test</scope>
44+
</dependency>
45+
</dependencies>
46+
47+
<build>
48+
<plugins>
49+
<plugin>
50+
<artifactId>maven-compiler-plugin</artifactId>
51+
<configuration>
52+
<annotationProcessorPaths>
53+
<path>
54+
<groupId>io.quarkus</groupId>
55+
<artifactId>quarkus-extension-processor</artifactId>
56+
</path>
57+
</annotationProcessorPaths>
58+
</configuration>
59+
</plugin>
60+
</plugins>
61+
</build>
62+
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.gpullama3.deployment;
2+
3+
import java.util.Optional;
4+
5+
import io.quarkus.runtime.annotations.ConfigDocDefault;
6+
import io.quarkus.runtime.annotations.ConfigGroup;
7+
8+
@ConfigGroup
9+
public interface ChatModelBuildConfig {
10+
11+
/**
12+
* Whether the model should be enabled
13+
*/
14+
@ConfigDocDefault("true")
15+
Optional<Boolean> enabled();
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package io.quarkiverse.langchain4j.gpullama3.deployment;
2+
3+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL;
4+
5+
import java.util.List;
6+
7+
import jakarta.enterprise.context.ApplicationScoped;
8+
9+
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
10+
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
11+
import io.quarkiverse.langchain4j.gpullama3.runtime.GPULlama3Recorder;
12+
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
13+
import io.quarkus.deployment.annotations.BuildProducer;
14+
import io.quarkus.deployment.annotations.BuildStep;
15+
import io.quarkus.deployment.annotations.ExecutionTime;
16+
import io.quarkus.deployment.annotations.Record;
17+
import io.quarkus.deployment.builditem.FeatureBuildItem;
18+
19+
public class GPULlama3Processor {
20+
21+
private static final String PROVIDER = "gpu-llama3";
22+
private static final String FEATURE = "langchain4j-gpu-llama3";
23+
24+
@BuildStep
25+
FeatureBuildItem feature() {
26+
return new FeatureBuildItem(FEATURE);
27+
}
28+
29+
@BuildStep
30+
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
31+
LangChain4jGPULlama3BuildTimeConfig config) {
32+
if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) {
33+
chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER));
34+
}
35+
}
36+
37+
@BuildStep
38+
@Record(ExecutionTime.RUNTIME_INIT)
39+
void generateBeans(GPULlama3Recorder recorder,
40+
List<SelectedChatModelProviderBuildItem> selectedChatModels,
41+
BuildProducer<SyntheticBeanBuildItem> beanProducer) {
42+
43+
for (var selected : selectedChatModels) {
44+
if (PROVIDER.equals(selected.getProvider())) {
45+
String configName = selected.getConfigName();
46+
47+
var builder = SyntheticBeanBuildItem
48+
.configure(CHAT_MODEL)
49+
.setRuntimeInit()
50+
.defaultBean()
51+
.scope(ApplicationScoped.class)
52+
.supplier(recorder.chatModel(configName));
53+
54+
beanProducer.produce(builder.done());
55+
}
56+
}
57+
}
58+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.gpullama3.deployment;
2+
3+
import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;
4+
5+
import io.quarkus.runtime.annotations.ConfigRoot;
6+
import io.smallrye.config.ConfigMapping;
7+
8+
@ConfigRoot(phase = BUILD_TIME)
9+
@ConfigMapping(prefix = "quarkus.langchain4j.gpu-llama3")
10+
public interface LangChain4jGPULlama3BuildTimeConfig {
11+
12+
/**
13+
* Chat model related settings
14+
*/
15+
ChatModelBuildConfig chatModel();
16+
}

0 commit comments

Comments
 (0)