Skip to content

Commit 8e42f06

Browse files
committed
Make sure that managed QuarkusAiServiceContext is always used
1 parent 49c042d commit 8e42f06

File tree

5 files changed

+89
-38
lines changed

5 files changed

+89
-38
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,10 @@
115115
import io.quarkiverse.langchain4j.spi.PromptTemplateFactoryContentFilterProvider;
116116
import io.quarkus.arc.Arc;
117117
import io.quarkus.arc.ArcContainer;
118+
import io.quarkus.arc.DefaultBean;
118119
import io.quarkus.arc.InstanceHandle;
119120
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
121+
import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem;
120122
import io.quarkus.arc.deployment.CustomScopeAnnotationsBuildItem;
121123
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
122124
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
@@ -324,6 +326,7 @@ private static Set<String> gatherToolNames(ClassInfo toolClass) {
324326
@BuildStep
325327
public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
326328
CustomScopeAnnotationsBuildItem customScopes,
329+
List<AnnotationsImpliesAiServiceBuildItem> annotationsImpliesAiServiceItems,
327330
BuildProducer<RequestChatModelBeanBuildItem> requestChatModelBeanProducer,
328331
BuildProducer<RequestModerationModelBeanBuildItem> requestModerationModelBeanProducer,
329332
BuildProducer<RequestImageModelBeanBuildItem> requestImageModelBeanProducer,
@@ -338,11 +341,29 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
338341
Set<String> imageModelNames = new HashSet<>();
339342
List<ToolProviderInfo> toolProviderInfos = new ArrayList<>();
340343
ClassOutput generatedClassOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
341-
for (AnnotationInstance instance : index.getAnnotations(LangChain4jDotNames.REGISTER_AI_SERVICES)) {
344+
345+
Set<DotName> annotationsThatImplyAiService = annotationsImpliesAiServiceItems.stream().flatMap(
346+
bi -> bi.getAnnotationNames().stream())
347+
.collect(Collectors.toSet());
348+
349+
Collection<AnnotationInstance> registerAiServicesInstances = new ArrayList<>(
350+
index.getAnnotations(REGISTER_AI_SERVICES));
351+
352+
Set<AnnotationInstance> impliedRegisterAiServiceInstance = determinedImpliedRegisterAiService(
353+
annotationsThatImplyAiService, index);
354+
Set<DotName> impliedRegisterAiServiceTarget = impliedRegisterAiServiceInstance.stream()
355+
.map(ai -> ai.target().asClass().name()).collect(Collectors.toSet());
356+
registerAiServicesInstances.addAll(impliedRegisterAiServiceInstance);
357+
Set<DotName> alreadyHandled = new HashSet<>();
358+
for (AnnotationInstance instance : registerAiServicesInstances) {
342359
if (instance.target().kind() != AnnotationTarget.Kind.CLASS) {
343360
continue; // should never happen
344361
}
345362
ClassInfo declarativeAiServiceClassInfo = instance.target().asClass();
363+
if (alreadyHandled.contains(declarativeAiServiceClassInfo.name())) {
364+
continue;
365+
}
366+
alreadyHandled.add(declarativeAiServiceClassInfo.name());
346367

347368
DotName chatLanguageModelSupplierClassDotName = getSupplierDotName(instance.value("chatLanguageModelSupplier"),
348369
LangChain4jDotNames.BEAN_CHAT_MODEL_SUPPLIER,
@@ -462,7 +483,9 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
462483
toolHallucinationStrategy(instance),
463484
classInputGuardrails(declarativeAiServiceClassInfo, index),
464485
classOutputGuardrails(declarativeAiServiceClassInfo, index),
465-
maxSequentialToolInvocations));
486+
maxSequentialToolInvocations,
487+
// we need to make these @DefaultBean because there could be other CDI beans of the same type that need to take precedence
488+
impliedRegisterAiServiceTarget.contains(declarativeAiServiceClassInfo.name())));
466489

467490
}
468491
toolProviderProducer.produce(new ToolProviderMetaBuildItem(toolProviderInfos));
@@ -480,6 +503,35 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
480503
}
481504
}
482505

506+
private static Set<AnnotationInstance> determinedImpliedRegisterAiService(Set<DotName> annotationsThatImplyAiService,
507+
IndexView index) {
508+
Set<AnnotationInstance> impliedDefaultRegisterAiService = new HashSet<>();
509+
for (DotName ann : annotationsThatImplyAiService) {
510+
index.getAnnotations(ann).forEach(instance -> {
511+
ClassInfo ci;
512+
switch (instance.target().kind()) {
513+
case METHOD -> {
514+
ci = instance.target().asMethod().declaringClass();
515+
}
516+
case CLASS -> {
517+
ci = instance.target().asClass();
518+
}
519+
case FIELD -> {
520+
ci = instance.target().asField().declaringClass();
521+
}
522+
default -> {
523+
ci = null;
524+
}
525+
}
526+
if (ci == null) {
527+
return;
528+
}
529+
impliedDefaultRegisterAiService.add(AnnotationInstance.builder(REGISTER_AI_SERVICES).buildWithTarget(ci));
530+
});
531+
}
532+
return impliedDefaultRegisterAiService;
533+
}
534+
483535
private static String chatModelName(AnnotationInstance instance, DotName chatLanguageModelSupplierClassDotName,
484536
DotName streamingChatLanguageModelSupplierClassDotName, Set<String> chatModelNames) {
485537
String chatModelName = NamedConfigUtil.DEFAULT_NAME;
@@ -793,6 +845,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
793845
String moderationModelName = bi.getModerationModelName();
794846
SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
795847
.configure(QuarkusAiServiceContext.class)
848+
.unremovable()
796849
.forceApplicationClass()
797850
.createWith(recorder.createDeclarativeAiService(
798851
new DeclarativeAiServiceCreateInfo(
@@ -1166,6 +1219,11 @@ AnnotationsImpliesAiServiceBuildItem implyAiService() {
11661219
LangChain4jDotNames.MODERATE));
11671220
}
11681221

1222+
@BuildStep
1223+
public void annotationTransformations(BuildProducer<AnnotationsTransformerBuildItem> producer) {
1224+
1225+
}
1226+
11691227
@BuildStep
11701228
@Record(ExecutionTime.STATIC_INIT)
11711229
public void handleAiServices(
@@ -1319,6 +1377,9 @@ public void handleAiServices(
13191377
classCreator.addAnnotation(
13201378
AnnotationInstance.builder(NAMED).add("value", matchingBI.getBeanName().get()).build());
13211379
}
1380+
if (matchingBI.isMakeDefaultBean()) {
1381+
classCreator.addAnnotation(DefaultBean.class);
1382+
}
13221383
}
13231384

13241385
FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
3434
private final DeclarativeAiServiceInputGuardrails inputGuardrails;
3535
private final DeclarativeAiServiceOutputGuardrails outputGuardrails;
3636
private final Integer maxSequentialToolInvocations;
37+
private final boolean makeDefaultBean;
3738

3839
public DeclarativeAiServiceBuildItem(
3940
ClassInfo serviceClassInfo,
@@ -55,7 +56,7 @@ public DeclarativeAiServiceBuildItem(
5556
DotName toolHallucinationStrategyClassDotName,
5657
DeclarativeAiServiceInputGuardrails inputGuardrails,
5758
DeclarativeAiServiceOutputGuardrails outputGuardrails,
58-
Integer maxSequentialToolInvocations) {
59+
Integer maxSequentialToolInvocations, boolean makeDefaultBean) {
5960
this.serviceClassInfo = serviceClassInfo;
6061
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
6162
this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName;
@@ -76,6 +77,7 @@ public DeclarativeAiServiceBuildItem(
7677
this.inputGuardrails = inputGuardrails;
7778
this.outputGuardrails = outputGuardrails;
7879
this.maxSequentialToolInvocations = maxSequentialToolInvocations;
80+
this.makeDefaultBean = makeDefaultBean;
7981
}
8082

8183
public ClassInfo getServiceClassInfo() {
@@ -154,6 +156,10 @@ public DeclarativeAiServiceOutputGuardrails getOutputGuardrails() {
154156
return outputGuardrails;
155157
}
156158

159+
public boolean isMakeDefaultBean() {
160+
return makeDefaultBean;
161+
}
162+
157163
public record DeclarativeAiServiceInputGuardrails(List<ClassInfo> inputGuardrailClassInfos) {
158164
public List<String> asClassNames() {
159165
return this.inputGuardrailClassInfos.stream()

core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusAiServiceContextFactory.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,20 @@
33
import dev.langchain4j.service.AiServiceContext;
44
import dev.langchain4j.spi.services.AiServiceContextFactory;
55
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
6+
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContextQualifier;
7+
import io.quarkus.arc.Arc;
8+
import io.quarkus.arc.InstanceHandle;
69

710
public class QuarkusAiServiceContextFactory implements AiServiceContextFactory {
811

912
@Override
1013
public AiServiceContext create(Class<?> aiServiceClass) {
11-
return new QuarkusAiServiceContext(aiServiceClass);
14+
InstanceHandle<QuarkusAiServiceContext> instance = Arc.container().instance(QuarkusAiServiceContext.class,
15+
QuarkusAiServiceContextQualifier.Literal.of(
16+
aiServiceClass.getName()));
17+
if (instance.isAvailable()) {
18+
return instance.get();
19+
}
20+
return new QuarkusAiServiceContext(aiServiceClass); // just create a default context
1221
}
1322
}

model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AiServicesTest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import io.opentelemetry.instrumentation.annotations.SpanAttribute;
5959
import io.quarkiverse.langchain4j.openai.testing.internal.OpenAiBaseTest;
6060
import io.quarkiverse.langchain4j.runtime.LangChain4jUtil;
61+
import io.quarkus.arc.Arc;
6162
import io.quarkus.test.QuarkusUnitTest;
6263

6364
public class AiServicesTest extends OpenAiBaseTest {
@@ -68,7 +69,8 @@ public class AiServicesTest extends OpenAiBaseTest {
6869
() -> ShrinkWrap.create(JavaArchive.class)
6970
.addAsResource("messages/recipe-user.txt")
7071
.addAsResource("messages/translate-user.txt")
71-
.addAsResource("messages/translate-system"));
72+
.addAsResource("messages/translate-system"))
73+
.overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "whatever");
7274

7375
private OpenAiChatModel createChatModel() {
7476
return OpenAiChatModel.builder().baseUrl(resolvedWiremockUrl("/v1"))
@@ -93,6 +95,11 @@ private static MessageWindowChatMemory createChatMemory() {
9395
void setup() {
9496
resetRequests();
9597
resetMappings();
98+
clearDefaultMemory();
99+
}
100+
101+
private void clearDefaultMemory() {
102+
Arc.container().instance(ChatMemoryStore.class).get().deleteMessages("default");
96103
}
97104

98105
interface Assistant {
@@ -312,6 +319,7 @@ void test_create_recipe_from_list_of_ingredients() throws IOException {
312319

313320
resetRequests();
314321
resetMappings();
322+
clearDefaultMemory();
315323

316324
setChatCompletionMessageContent(value);
317325
Chef chef = AiServices.create(Chef.class, createChatModel());

model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/NoModerationProvidedTest.java

Lines changed: 0 additions & 33 deletions
This file was deleted.

0 commit comments

Comments
 (0)