Skip to content

Commit 0d4a824

Browse files
authored
Merge pull request #1793 from mikepapadim/gpu-llama3-integration
Add GPULlama3.java as model provider to run on GPUs
2 parents c778731 + e712762 commit 0d4a824

File tree

26 files changed

+1710
-4
lines changed

26 files changed

+1710
-4
lines changed

.github/workflows/build-pull-request.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,15 @@ jobs:
6666
| jq -R -s -c 'split("\n")[:-1]')
6767
6868
# Integration tests (without the in-process embedding models)
69-
# Remove JLama and Llama3 from the list
69+
# Remove JLama, Llama3 and GPU Llama3 from the list
7070
cd integration-tests
7171
IT_MODULES=$( \
7272
find . -mindepth 2 -maxdepth 2 -type f -name 'pom.xml' -exec dirname {} \; \
7373
| sed 's|^\./||' \
7474
| sort -u \
7575
| grep -v jlama \
7676
| grep -v llama3-java \
77+
| grep -v gpu-llama3 \
7778
| grep -v in-process-embedding-models \
7879
| jq -R -s -c 'split("\n")[:-1]')
7980
@@ -143,6 +144,13 @@ jobs:
143144
run: |
144145
./mvnw -B clean install -DskipTests -Dno-format -ntp -f model-providers/jlama/pom.xml
145146
147+
# Build Jlama if JDK >= 21
148+
# It's not build by default as it requires Java 21+
149+
- name: Build GPU Llama3 extension
150+
if: ${{ matrix.java >= 21 }}
151+
run: |
152+
./mvnw -B clean install -DskipTests -Dno-format -ntp -f model-providers/gpu-llama3/pom.xml
153+
146154
# Build Llama3.java if JDK >= 22. See https://x.com/tjake/status/1849141171475399083?t=EpgVJCPLC17fCXio0FvnhA&s=19 for the reason
147155
- name: Build Llama3-java extension
148156
if: ${{ matrix.java >= 22 }}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
:project-version: 1.3.1
2-
:langchain4j-version: 1.6.0
3-
:langchain4j-embeddings-version: 1.6.0-beta12
2+
:langchain4j-version: 1.8.0
3+
:langchain4j-embeddings-version: 1.8.0-beta15
44
:examples-dir: ./../examples/
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
### How to run the integrated tests:
2+
3+
#### 1) Install TornadoVM:
4+
5+
```bash
6+
cd ~
7+
git clone [email protected]:beehive-lab/TornadoVM.git
8+
cd ~/TornadoVM
9+
./bin/tornadovm-installer --jdk jdk21 --backend opencl
10+
source setvars.sh
11+
```
12+
13+
Note that the above steps:
14+
- Set `TORNADOVM_SDK` environment variable to the path of the TornadoVM SDK.
15+
- Create the `tornado-argfile` under `~/TornadoVM` which contains all the required JVM arguments to enable TornadoVM.
16+
- The argfile is automatically used in Quarkus dev mode; however, in production mode, you need to manually pass the argfile to the JVM (see step 3).
17+
18+
#### 2) Build Quarkus-langchain4j:
19+
20+
```bash
21+
cd ~
22+
git clone [email protected]:mikepapadim/quarkus-langchain4j.git
23+
cd ~/quarkus-langchain4j
24+
git checkout gpu-llama3-integration
25+
mvn clean install -DskipTests
26+
```
27+
28+
#### 3) Run the integrated tests:
29+
30+
##### 3.1 Deploy the Quarkus app:
31+
32+
```bash
33+
cd ~/quarkus-langchain4j/integration-tests/gpullama3
34+
```
35+
- For *dev* mode, run:
36+
```
37+
mvn quarkus:dev
38+
```
39+
40+
- For *production* mode, run:
41+
```bash
42+
java @~/TornadoVM/tornado-argfile -jar target/quarkus-app/quarkus-run.jar
43+
```
44+
##### 3.2 Send requests to the Quarkus app:
45+
46+
when quarkus is running, open a new terminal and run:
47+
48+
```bash
49+
curl http://localhost:8080/chat/blocking
50+
```
51+
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
3+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
4+
<modelVersion>4.0.0</modelVersion>
5+
<parent>
6+
<groupId>io.quarkiverse.langchain4j</groupId>
7+
<artifactId>quarkus-langchain4j-integration-tests-parent</artifactId>
8+
<version>999-SNAPSHOT</version>
9+
</parent>
10+
<artifactId>quarkus-langchain4j-integration-test-gpu-llama3</artifactId>
11+
<name>Quarkus LangChain4j - Integration Tests - GPULlama3</name>
12+
<properties>
13+
<skipITs>true</skipITs>
14+
<maven.compiler.release>21</maven.compiler.release>
15+
<quarkus.version>3.18.0</quarkus.version>
16+
<!-- TornadoVM argfile path -->
17+
<tornado.argfile>${env.TORNADO_SDK}/../../../tornado-argfile</tornado.argfile>
18+
</properties>
19+
<dependencies>
20+
<dependency>
21+
<groupId>io.quarkus</groupId>
22+
<artifactId>quarkus-rest-jackson</artifactId>
23+
</dependency>
24+
<dependency>
25+
<groupId>io.quarkiverse.langchain4j</groupId>
26+
<artifactId>quarkus-langchain4j-gpu-llama3</artifactId>
27+
<version>999-SNAPSHOT</version>
28+
</dependency>
29+
<dependency>
30+
<groupId>io.quarkus</groupId>
31+
<artifactId>quarkus-junit5</artifactId>
32+
<scope>test</scope>
33+
</dependency>
34+
<dependency>
35+
<groupId>io.rest-assured</groupId>
36+
<artifactId>rest-assured</artifactId>
37+
<scope>test</scope>
38+
</dependency>
39+
<dependency>
40+
<groupId>org.assertj</groupId>
41+
<artifactId>assertj-core</artifactId>
42+
<scope>test</scope>
43+
</dependency>
44+
<dependency>
45+
<groupId>io.quarkus</groupId>
46+
<artifactId>quarkus-devtools-testing</artifactId>
47+
<scope>test</scope>
48+
</dependency>
49+
50+
<!-- Make sure the deployment artifact is built before executing this module -->
51+
<dependency>
52+
<groupId>io.quarkiverse.langchain4j</groupId>
53+
<artifactId>quarkus-langchain4j-gpu-llama3-deployment</artifactId>
54+
<version>999-SNAPSHOT</version>
55+
<type>pom</type>
56+
<scope>test</scope>
57+
<exclusions>
58+
<exclusion>
59+
<groupId>*</groupId>
60+
<artifactId>*</artifactId>
61+
</exclusion>
62+
</exclusions>
63+
</dependency>
64+
</dependencies>
65+
<build>
66+
<plugins>
67+
<plugin>
68+
<groupId>io.quarkus</groupId>
69+
<artifactId>quarkus-maven-plugin</artifactId>
70+
<executions>
71+
<execution>
72+
<goals>
73+
<goal>build</goal>
74+
</goals>
75+
</execution>
76+
</executions>
77+
<configuration>
78+
<!-- Pass tornado-argfile to dev mode -->
79+
<jvmArgs>@${tornado.argfile}</jvmArgs>
80+
</configuration>
81+
</plugin>
82+
83+
<plugin>
84+
<artifactId>maven-failsafe-plugin</artifactId>
85+
<executions>
86+
<execution>
87+
<goals>
88+
<goal>integration-test</goal>
89+
<goal>verify</goal>
90+
</goals>
91+
<configuration>
92+
<argLine>@${tornado.argfile}</argLine>
93+
<systemPropertyVariables>
94+
<native.image.path>${project.build.directory}/${project.build.finalName}-runner</native.image.path>
95+
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
96+
<maven.home>${maven.home}</maven.home>
97+
</systemPropertyVariables>
98+
</configuration>
99+
</execution>
100+
</executions>
101+
</plugin>
102+
</plugins>
103+
</build>
104+
</project>
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package org.acme.example.gpullama3.chat;
2+
3+
import jakarta.ws.rs.GET;
4+
import jakarta.ws.rs.Path;
5+
6+
import dev.langchain4j.model.chat.ChatModel;
7+
8+
@Path("chat")
9+
public class ChatLanguageModelResource {
10+
11+
private final ChatModel chatModel;
12+
13+
public ChatLanguageModelResource(ChatModel chatModel) {
14+
this.chatModel = chatModel;
15+
}
16+
17+
@GET
18+
@Path("blocking")
19+
public String blocking() {
20+
return chatModel.chat("When was the nobel prize for economics first awarded?");
21+
}
22+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
quarkus.langchain4j.gpu-llama3.include-models-in-artifact=false
2+
3+
# Configure GPULlama3
4+
quarkus.langchain4j.gpu-llama3.enable-integration=true
5+
quarkus.langchain4j.gpu-llama3.chat-model.model-name=beehive-lab/Llama-3.2-1B-Instruct-GGUF
6+
quarkus.langchain4j.gpu-llama3.chat-model.quantization=FP16
7+
quarkus.langchain4j.gpu-llama3.chat-model.temperature=0.7
8+
quarkus.langchain4j.gpu-llama3.chat-model.max-tokens=513
9+
10+
# other supported models:
11+
#model-name=ggml-org/Qwen3-0.6B-GGUF
12+
#quantization=f16

integration-tests/pom.xml

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,18 @@
5656
<module>llama3-java</module>
5757
</modules>
5858
</profile>
59-
<profile>
59+
<profile>
60+
<id>TornadoVM</id>
61+
<activation>
62+
<property>
63+
<name>tornado</name>
64+
</property>
65+
</activation>
66+
<modules>
67+
<module>gpu-llama3</module>
68+
</modules>
69+
</profile>
70+
<profile>
6071
<id>default-project-deps</id>
6172
<activation>
6273
<property>
@@ -107,6 +118,11 @@
107118
<artifactId>quarkus-langchain4j-easy-rag</artifactId>
108119
<version>999-SNAPSHOT</version>
109120
</dependency>
121+
<dependency>
122+
<groupId>io.quarkiverse.langchain4j</groupId>
123+
<artifactId>quarkus-langchain4j-gpullama3</artifactId>
124+
<version>${quarkus-langchain4j.version}</version>
125+
</dependency>
110126
<dependency>
111127
<groupId>io.quarkiverse.langchain4j</groupId>
112128
<artifactId>quarkus-langchain4j-hugging-face</artifactId>
@@ -122,6 +138,11 @@
122138
<artifactId>quarkus-langchain4j-llama3-java</artifactId>
123139
<version>999-SNAPSHOT</version>
124140
</dependency>
141+
<dependency>
142+
<groupId>io.quarkiverse.langchain4j</groupId>
143+
<artifactId>quarkus-langchain4j-gpu-llama3</artifactId>
144+
<version>999-SNAPSHOT</version>
145+
</dependency>
125146
<dependency>
126147
<groupId>io.quarkiverse.langchain4j</groupId>
127148
<artifactId>quarkus-langchain4j-mcp</artifactId>
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
7+
<parent>
8+
<groupId>io.quarkiverse.langchain4j</groupId>
9+
<artifactId>quarkus-langchain4j-gpu-llama3-parent</artifactId>
10+
<version>999-SNAPSHOT</version>
11+
</parent>
12+
13+
<artifactId>quarkus-langchain4j-gpu-llama3-deployment</artifactId>
14+
<name>Quarkus LangChain4j - GPULlama3 - Deployment</name>
15+
16+
<dependencies>
17+
<dependency>
18+
<groupId>io.quarkiverse.langchain4j</groupId>
19+
<artifactId>quarkus-langchain4j-gpu-llama3</artifactId>
20+
<version>${project.version}</version>
21+
</dependency>
22+
23+
<!-- what is this -->
24+
<dependency>
25+
<groupId>io.quarkus</groupId>
26+
<artifactId>quarkus-arc-deployment</artifactId>
27+
<scope>provided</scope>
28+
</dependency>
29+
<dependency>
30+
<groupId>io.quarkiverse.langchain4j</groupId>
31+
<artifactId>quarkus-langchain4j-core-deployment</artifactId>
32+
<version>${project.version}</version>
33+
</dependency>
34+
35+
<dependency>
36+
<groupId>io.quarkus</groupId>
37+
<artifactId>quarkus-junit5-internal</artifactId>
38+
<scope>test</scope>
39+
</dependency>
40+
<dependency>
41+
<groupId>org.assertj</groupId>
42+
<artifactId>assertj-core</artifactId>
43+
<scope>test</scope>
44+
</dependency>
45+
</dependencies>
46+
47+
<build>
48+
<plugins>
49+
<plugin>
50+
<artifactId>maven-compiler-plugin</artifactId>
51+
<configuration>
52+
<annotationProcessorPaths>
53+
<path>
54+
<groupId>io.quarkus</groupId>
55+
<artifactId>quarkus-extension-processor</artifactId>
56+
</path>
57+
</annotationProcessorPaths>
58+
</configuration>
59+
</plugin>
60+
</plugins>
61+
</build>
62+
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.gpullama3.deployment;
2+
3+
import java.util.Optional;
4+
5+
import io.quarkus.runtime.annotations.ConfigDocDefault;
6+
import io.quarkus.runtime.annotations.ConfigGroup;
7+
8+
@ConfigGroup
9+
public interface ChatModelBuildConfig {
10+
11+
/**
12+
* Whether the model should be enabled
13+
*/
14+
@ConfigDocDefault("true")
15+
Optional<Boolean> enabled();
16+
}

0 commit comments

Comments
 (0)