Skip to content

Commit 79463fd

Browse files
authored
Merge pull request #1728 from edeandrea/incorporate-upstream-guardrails
Refactor upstream guardrails
2 parents 32b114b + cbb10d7 commit 79463fd

File tree

130 files changed

+8381
-568
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

130 files changed

+8381
-568
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 214 additions & 23 deletions
Large diffs are not rendered by default.

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
3131
private final String moderationModelName;
3232
private final String imageModelName;
3333
private final Optional<String> beanName;
34+
private final DeclarativeAiServiceInputGuardrails inputGuardrails;
35+
private final DeclarativeAiServiceOutputGuardrails outputGuardrails;
3436
private final Integer maxSequentialToolInvocations;
3537

3638
public DeclarativeAiServiceBuildItem(
@@ -51,6 +53,8 @@ public DeclarativeAiServiceBuildItem(
5153
DotName toolProviderClassDotName,
5254
Optional<String> beanName,
5355
DotName toolHallucinationStrategyClassDotName,
56+
DeclarativeAiServiceInputGuardrails inputGuardrails,
57+
DeclarativeAiServiceOutputGuardrails outputGuardrails,
5458
Integer maxSequentialToolInvocations) {
5559
this.serviceClassInfo = serviceClassInfo;
5660
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
@@ -69,6 +73,8 @@ public DeclarativeAiServiceBuildItem(
6973
this.toolProviderClassDotName = toolProviderClassDotName;
7074
this.beanName = beanName;
7175
this.toolHallucinationStrategyClassDotName = toolHallucinationStrategyClassDotName;
76+
this.inputGuardrails = inputGuardrails;
77+
this.outputGuardrails = outputGuardrails;
7278
this.maxSequentialToolInvocations = maxSequentialToolInvocations;
7379
}
7480

@@ -140,6 +146,35 @@ public DotName getToolHallucinationStrategyClassDotName() {
140146
return toolHallucinationStrategyClassDotName;
141147
}
142148

149+
public DeclarativeAiServiceInputGuardrails getInputGuardrails() {
150+
return inputGuardrails;
151+
}
152+
153+
public DeclarativeAiServiceOutputGuardrails getOutputGuardrails() {
154+
return outputGuardrails;
155+
}
156+
157+
public record DeclarativeAiServiceInputGuardrails(List<ClassInfo> inputGuardrailClassInfos) {
158+
public List<String> asClassNames() {
159+
return this.inputGuardrailClassInfos.stream()
160+
.map(classInfo -> classInfo.name().toString())
161+
.toList();
162+
}
163+
}
164+
165+
public record DeclarativeAiServiceOutputGuardrails(List<ClassInfo> outputGuardrailClassInfos, int maxRetries,
166+
int actualMaxRetries) {
167+
public DeclarativeAiServiceOutputGuardrails(List<ClassInfo> outputGuardrailClassInfos, int maxRetries) {
168+
this(outputGuardrailClassInfos, maxRetries, maxRetries);
169+
}
170+
171+
public List<String> asClassNames() {
172+
return this.outputGuardrailClassInfos.stream()
173+
.map(classInfo -> classInfo.name().toString())
174+
.toList();
175+
}
176+
}
177+
143178
public Integer getMaxSequentialToolInvocations() {
144179
return maxSequentialToolInvocations;
145180
}

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/GuardrailObservabilityProcessorSupport.java

Lines changed: 95 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,64 @@
1010
import org.jboss.logging.Logger;
1111

1212
import dev.langchain4j.data.message.UserMessage;
13-
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
14-
import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams;
15-
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
16-
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
17-
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
18-
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
13+
import dev.langchain4j.guardrail.InputGuardrail;
14+
import dev.langchain4j.guardrail.InputGuardrailRequest;
15+
import dev.langchain4j.guardrail.InputGuardrailResult;
16+
import dev.langchain4j.guardrail.OutputGuardrail;
17+
import dev.langchain4j.guardrail.OutputGuardrailRequest;
18+
import dev.langchain4j.guardrail.OutputGuardrailResult;
1919

2020
final class GuardrailObservabilityProcessorSupport {
2121
private static final Logger LOG = Logger.getLogger(GuardrailObservabilityProcessorSupport.class);
22-
private static final DotName INPUT_GUARDRAIL_PARAMS = DotName.createSimple(InputGuardrailParams.class);
22+
23+
/**
24+
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
25+
*/
26+
@Deprecated(forRemoval = true)
27+
private static final DotName QUARKUS_INPUT_GUARDRAIL_PARAMS = DotName
28+
.createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrailParams.class);
29+
30+
/**
31+
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
32+
*/
33+
@Deprecated(forRemoval = true)
34+
private static final DotName QUARKUS_INPUT_GUARDRAIL_RESULT = DotName
35+
.createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrailResult.class);
36+
37+
/**
38+
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
39+
*/
40+
@Deprecated(forRemoval = true)
41+
private static final DotName QUARKUS_OUTPUT_GUARDRAIL_PARAMS = DotName
42+
.createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams.class);
43+
44+
/**
45+
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
46+
*/
47+
@Deprecated(forRemoval = true)
48+
private static final DotName QUARKUS_OUTPUT_GUARDRAIL_RESULT = DotName
49+
.createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult.class);
50+
51+
/**
52+
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
53+
*/
54+
@Deprecated(forRemoval = true)
55+
private static final DotName QUARKUS_INPUT_GUARDRAIL = DotName
56+
.createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrail.class);
57+
58+
/**
59+
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
60+
*/
61+
@Deprecated(forRemoval = true)
62+
private static final DotName QUARKUS_OUTPUT_GUARDRAIL = DotName
63+
.createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrail.class);
64+
private static final DotName INPUT_GUARDRAIL_REQUEST = DotName.createSimple(InputGuardrailRequest.class);
2365
private static final DotName INPUT_GUARDRAIL_RESULT = DotName.createSimple(InputGuardrailResult.class);
24-
private static final DotName OUTPUT_GUARDRAIL_PARAMS = DotName.createSimple(OutputGuardrailParams.class);
66+
private static final DotName OUTPUT_GUARDRAIL_REQUEST = DotName.createSimple(OutputGuardrailRequest.class);
2567
private static final DotName OUTPUT_GUARDRAIL_RESULT = DotName.createSimple(OutputGuardrailResult.class);
2668
private static final DotName INPUT_GUARDRAIL = DotName.createSimple(InputGuardrail.class);
2769
private static final DotName OUTPUT_GUARDRAIL = DotName.createSimple(OutputGuardrail.class);
70+
2871
static final DotName MICROMETER_TIMED = DotName.createSimple("io.micrometer.core.annotation.Timed");
2972
static final DotName MICROMETER_COUNTED = DotName.createSimple("io.micrometer.core.annotation.Counted");
3073
static final DotName WITH_SPAN = DotName.createSimple("io.opentelemetry.instrumentation.annotations.WithSpan");
@@ -37,11 +80,17 @@ enum TransformType {
3780
}
3881

3982
enum GuardrailType {
83+
QUARKUS_INPUT,
84+
QUARKUS_OUTPUT,
4085
INPUT,
4186
OUTPUT;
4287

4388
static Optional<GuardrailType> from(IndexView indexView, ClassInfo classToCheck) {
44-
if (indexView.getAllKnownImplementors(INPUT_GUARDRAIL).contains(classToCheck)) {
89+
if (indexView.getAllKnownImplementors(QUARKUS_INPUT_GUARDRAIL).contains(classToCheck)) {
90+
return Optional.of(QUARKUS_INPUT);
91+
} else if (indexView.getAllKnownImplementors(QUARKUS_OUTPUT_GUARDRAIL).contains(classToCheck)) {
92+
return Optional.of(QUARKUS_OUTPUT);
93+
} else if (indexView.getAllKnownImplementors(INPUT_GUARDRAIL).contains(classToCheck)) {
4594
return Optional.of(INPUT);
4695
} else if (indexView.getAllKnownImplementors(OUTPUT_GUARDRAIL).contains(classToCheck)) {
4796
return Optional.of(OUTPUT);
@@ -102,16 +151,18 @@ private static boolean shouldTransformGuardrailValidateMethod(MethodInfo methodI
102151
}
103152

104153
var isOtherValidateMethodVariant = switch (guardrailType) {
105-
case INPUT -> isInputGuardrailValidateMethodWithUserMessage(methodInfo);
106-
case OUTPUT -> isOutputGuardrailValidateMethodWithAiMessage(methodInfo);
154+
case QUARKUS_INPUT, INPUT -> isInputGuardrailValidateMethodWithUserMessage(methodInfo);
155+
case QUARKUS_OUTPUT, OUTPUT -> isOutputGuardrailValidateMethodWithAiMessage(methodInfo);
107156
};
108157

109158
if (isOtherValidateMethodVariant && !doesMethodAlreadyHaveTransformationAnnotation(methodInfo, transformType)) {
110159
// If this is the other method variant, we need to ensure that the
111160
// variant with the params isn't also present on the method's declaring class
112161
var paramType = switch (guardrailType) {
113-
case INPUT -> Type.parse(INPUT_GUARDRAIL_PARAMS.toString());
114-
case OUTPUT -> Type.parse(OUTPUT_GUARDRAIL_PARAMS.toString());
162+
case QUARKUS_INPUT -> Type.parse(QUARKUS_INPUT_GUARDRAIL_PARAMS.toString());
163+
case QUARKUS_OUTPUT -> Type.parse(QUARKUS_OUTPUT_GUARDRAIL_PARAMS.toString());
164+
case INPUT -> Type.parse(INPUT_GUARDRAIL_REQUEST.toString());
165+
case OUTPUT -> Type.parse(OUTPUT_GUARDRAIL_REQUEST.toString());
115166
};
116167

117168
var otherValidateMethod = methodDeclaringClass.method("validate", paramType);
@@ -129,9 +180,16 @@ private static boolean shouldTransformGuardrailValidateMethod(MethodInfo methodI
129180
* Checks the method meets <strong>ALL</strong> the following conditions:
130181
* <ul>
131182
* <li>The method's name is {@link #VALIDATE_METHOD_NAME}</li>
132-
* <li><strong>IF</strong> the method's single parameter's type is {@link InputGuardrailParams} then the return type must be
183+
* <li><strong>IF</strong> the method's single parameter's type is
184+
* {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailParams} then the return type must be
185+
* {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailResult}</li>
186+
* <li><strong>IF</strong> the method's single parameter's type is
187+
* {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams} then the return type must
188+
* be {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult}</li>
189+
* <li><strong>IF</strong> the method's single parameter's type is {@link InputGuardrailRequest} then the return type must
190+
* be
133191
* {@link InputGuardrailResult}</li>
134-
* <li><strong>IF</strong> the method's single parameter's type is {@link OutputGuardrailParams} then the return type must
192+
* <li><strong>IF</strong> the method's single parameter's type is {@link OutputGuardrailRequest} then the return type must
135193
* be {@link OutputGuardrailResult}</li>
136194
* </ul>
137195
*/
@@ -143,7 +201,8 @@ static boolean isGuardrailValidateMethodWithParams(MethodInfo methodInfo) {
143201
* Checks the method meets <strong>ALL</strong> the following conditions:
144202
* <ul>
145203
* <li>The method's name is {@link #VALIDATE_METHOD_NAME}</li>
146-
* <li>The method's return type is {@link InputGuardrailResult}</li>
204+
* <li>The method's return type is {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailResult} or
205+
* {@link InputGuardrailResult}</li>
147206
* <li>The method's single parameter's type is {@link dev.langchain4j.data.message.UserMessage}</li>
148207
* </ul>
149208
*/
@@ -156,7 +215,8 @@ private static boolean isInputGuardrailValidateMethodWithUserMessage(MethodInfo
156215
* Checks the method meets <strong>ALL</strong> the following conditions:
157216
* <ul>
158217
* <li>The method's name is {@link #VALIDATE_METHOD_NAME}</li>
159-
* <li>The method's return type is {@link OutputGuardrailResult}</li>
218+
* <li>The method's return type is {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult} or
219+
* {@link OutputGuardrailResult}</li>
160220
* <li>The method's single parameter's type is {@link dev.langchain4j.data.message.AiMessage}</li>
161221
* </ul>
162222
*/
@@ -168,9 +228,16 @@ private static boolean isOutputGuardrailValidateMethodWithAiMessage(MethodInfo m
168228
/**
169229
* Checks the method meets <strong>ALL</strong> the following conditions:
170230
* <ul>
171-
* <li><strong>IF</strong> the method's single parameter's type is {@link InputGuardrailParams} then the return type must be
231+
* <li><strong>IF</strong> the method's single parameter's type is
232+
* {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailParams} then the return type must be
233+
* {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailResult}</li>
234+
* <li><strong>IF</strong> the method's single parameter's type is
235+
* {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams} then the return type must
236+
* be {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult}</li>
237+
* <li><strong>IF</strong> the method's single parameter's type is {@link InputGuardrailRequest} then the return type must
238+
* be
172239
* {@link InputGuardrailResult}</li>
173-
* <li><strong>IF</strong> the method's single parameter's type is {@link OutputGuardrailParams} then the return type must
240+
* <li><strong>IF</strong> the method's single parameter's type is {@link OutputGuardrailRequest} then the return type must
174241
* be {@link OutputGuardrailResult}</li>
175242
* </ul>
176243
*/
@@ -187,8 +254,13 @@ private static boolean doesValidateMethodWithParamsHaveCorrectSignature(MethodIn
187254
// Also check the return type
188255
var returnType = methodInfo.returnType().name();
189256

190-
return (INPUT_GUARDRAIL_PARAMS.equals(paramTypeName) && INPUT_GUARDRAIL_RESULT.equals(returnType)) ||
191-
(OUTPUT_GUARDRAIL_PARAMS.equals(paramTypeName) && OUTPUT_GUARDRAIL_RESULT.equals(returnType));
257+
return (QUARKUS_INPUT_GUARDRAIL_PARAMS.equals(paramTypeName) && QUARKUS_INPUT_GUARDRAIL_RESULT.equals(returnType))
258+
||
259+
(QUARKUS_OUTPUT_GUARDRAIL_PARAMS.equals(paramTypeName)
260+
&& QUARKUS_OUTPUT_GUARDRAIL_RESULT.equals(returnType))
261+
||
262+
(INPUT_GUARDRAIL_REQUEST.equals(paramTypeName) && INPUT_GUARDRAIL_RESULT.equals(returnType)) ||
263+
(OUTPUT_GUARDRAIL_REQUEST.equals(paramTypeName) && OUTPUT_GUARDRAIL_RESULT.equals(returnType));
192264
}
193265

194266
return false;
@@ -207,7 +279,8 @@ private static boolean doesValidateMethodWithoutParamsHaveCorrectSignature(Metho
207279
var returnType = methodInfo.returnType().name();
208280

209281
return paramType.equals(paramTypeName) &&
210-
(INPUT_GUARDRAIL_RESULT.equals(returnType) || OUTPUT_GUARDRAIL_RESULT.equals(returnType));
282+
(QUARKUS_INPUT_GUARDRAIL_RESULT.equals(returnType) || QUARKUS_OUTPUT_GUARDRAIL_RESULT.equals(returnType) ||
283+
INPUT_GUARDRAIL_RESULT.equals(returnType) || OUTPUT_GUARDRAIL_RESULT.equals(returnType));
211284
}
212285

213286
return false;

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import dev.langchain4j.service.TokenStream;
2727
import dev.langchain4j.service.UserMessage;
2828
import dev.langchain4j.service.UserName;
29+
import dev.langchain4j.service.guardrail.InputGuardrails;
30+
import dev.langchain4j.service.guardrail.OutputGuardrails;
2931
import dev.langchain4j.service.tool.ToolProvider;
3032
import dev.langchain4j.web.search.WebSearchEngine;
3133
import dev.langchain4j.web.search.WebSearchTool;
@@ -36,8 +38,6 @@
3638
import io.quarkiverse.langchain4j.PdfUrl;
3739
import io.quarkiverse.langchain4j.RegisterAiService;
3840
import io.quarkiverse.langchain4j.SeedMemory;
39-
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
40-
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
4141
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContextQualifier;
4242

4343
public class LangChain4jDotNames {
@@ -49,6 +49,18 @@ public class LangChain4jDotNames {
4949
public static final DotName IMAGE_MODEL = DotName.createSimple(ImageModel.class);
5050
public static final DotName CHAT_MESSAGE = DotName.createSimple(ChatMessage.class);
5151
public static final DotName TOKEN_STREAM = DotName.createSimple(TokenStream.class);
52+
/**
53+
* @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed
54+
*/
55+
@Deprecated(forRemoval = true)
56+
public static final DotName QUARKUS_OUTPUT_GUARDRAILS = DotName
57+
.createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrails.class);
58+
/**
59+
* @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed
60+
*/
61+
@Deprecated(forRemoval = true)
62+
public static final DotName QUARKUS_INPUT_GUARDRAILS = DotName
63+
.createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrails.class);
5264
public static final DotName OUTPUT_GUARDRAILS = DotName.createSimple(OutputGuardrails.class);
5365
public static final DotName INPUT_GUARDRAILS = DotName.createSimple(InputGuardrails.class);
5466
static final DotName AI_SERVICES = DotName.createSimple(AiServices.class);

0 commit comments

Comments
 (0)