Skip to content

Commit 090d52a

Browse files
authored
Merge pull request #1796 from quarkiverse/agentic-context-cdi
Make sure that managed QuarkusAiServiceContext is always used
2 parents 2997d4e + 8e42f06 commit 090d52a

File tree

9 files changed

+495
-3184
lines changed

9 files changed

+495
-3184
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,10 @@
115115
import io.quarkiverse.langchain4j.spi.PromptTemplateFactoryContentFilterProvider;
116116
import io.quarkus.arc.Arc;
117117
import io.quarkus.arc.ArcContainer;
118+
import io.quarkus.arc.DefaultBean;
118119
import io.quarkus.arc.InstanceHandle;
119120
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
121+
import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem;
120122
import io.quarkus.arc.deployment.CustomScopeAnnotationsBuildItem;
121123
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
122124
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
@@ -324,6 +326,7 @@ private static Set<String> gatherToolNames(ClassInfo toolClass) {
324326
@BuildStep
325327
public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
326328
CustomScopeAnnotationsBuildItem customScopes,
329+
List<AnnotationsImpliesAiServiceBuildItem> annotationsImpliesAiServiceItems,
327330
BuildProducer<RequestChatModelBeanBuildItem> requestChatModelBeanProducer,
328331
BuildProducer<RequestModerationModelBeanBuildItem> requestModerationModelBeanProducer,
329332
BuildProducer<RequestImageModelBeanBuildItem> requestImageModelBeanProducer,
@@ -338,11 +341,29 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
338341
Set<String> imageModelNames = new HashSet<>();
339342
List<ToolProviderInfo> toolProviderInfos = new ArrayList<>();
340343
ClassOutput generatedClassOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
341-
for (AnnotationInstance instance : index.getAnnotations(LangChain4jDotNames.REGISTER_AI_SERVICES)) {
344+
345+
Set<DotName> annotationsThatImplyAiService = annotationsImpliesAiServiceItems.stream().flatMap(
346+
bi -> bi.getAnnotationNames().stream())
347+
.collect(Collectors.toSet());
348+
349+
Collection<AnnotationInstance> registerAiServicesInstances = new ArrayList<>(
350+
index.getAnnotations(REGISTER_AI_SERVICES));
351+
352+
Set<AnnotationInstance> impliedRegisterAiServiceInstance = determinedImpliedRegisterAiService(
353+
annotationsThatImplyAiService, index);
354+
Set<DotName> impliedRegisterAiServiceTarget = impliedRegisterAiServiceInstance.stream()
355+
.map(ai -> ai.target().asClass().name()).collect(Collectors.toSet());
356+
registerAiServicesInstances.addAll(impliedRegisterAiServiceInstance);
357+
Set<DotName> alreadyHandled = new HashSet<>();
358+
for (AnnotationInstance instance : registerAiServicesInstances) {
342359
if (instance.target().kind() != AnnotationTarget.Kind.CLASS) {
343360
continue; // should never happen
344361
}
345362
ClassInfo declarativeAiServiceClassInfo = instance.target().asClass();
363+
if (alreadyHandled.contains(declarativeAiServiceClassInfo.name())) {
364+
continue;
365+
}
366+
alreadyHandled.add(declarativeAiServiceClassInfo.name());
346367

347368
DotName chatLanguageModelSupplierClassDotName = getSupplierDotName(instance.value("chatLanguageModelSupplier"),
348369
LangChain4jDotNames.BEAN_CHAT_MODEL_SUPPLIER,
@@ -462,7 +483,9 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
462483
toolHallucinationStrategy(instance),
463484
classInputGuardrails(declarativeAiServiceClassInfo, index),
464485
classOutputGuardrails(declarativeAiServiceClassInfo, index),
465-
maxSequentialToolInvocations));
486+
maxSequentialToolInvocations,
487+
// we need to make these @DefaultBean because there could be other CDI beans of the same type that need to take precedence
488+
impliedRegisterAiServiceTarget.contains(declarativeAiServiceClassInfo.name())));
466489

467490
}
468491
toolProviderProducer.produce(new ToolProviderMetaBuildItem(toolProviderInfos));
@@ -480,6 +503,35 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
480503
}
481504
}
482505

506+
private static Set<AnnotationInstance> determinedImpliedRegisterAiService(Set<DotName> annotationsThatImplyAiService,
507+
IndexView index) {
508+
Set<AnnotationInstance> impliedDefaultRegisterAiService = new HashSet<>();
509+
for (DotName ann : annotationsThatImplyAiService) {
510+
index.getAnnotations(ann).forEach(instance -> {
511+
ClassInfo ci;
512+
switch (instance.target().kind()) {
513+
case METHOD -> {
514+
ci = instance.target().asMethod().declaringClass();
515+
}
516+
case CLASS -> {
517+
ci = instance.target().asClass();
518+
}
519+
case FIELD -> {
520+
ci = instance.target().asField().declaringClass();
521+
}
522+
default -> {
523+
ci = null;
524+
}
525+
}
526+
if (ci == null) {
527+
return;
528+
}
529+
impliedDefaultRegisterAiService.add(AnnotationInstance.builder(REGISTER_AI_SERVICES).buildWithTarget(ci));
530+
});
531+
}
532+
return impliedDefaultRegisterAiService;
533+
}
534+
483535
private static String chatModelName(AnnotationInstance instance, DotName chatLanguageModelSupplierClassDotName,
484536
DotName streamingChatLanguageModelSupplierClassDotName, Set<String> chatModelNames) {
485537
String chatModelName = NamedConfigUtil.DEFAULT_NAME;
@@ -793,6 +845,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
793845
String moderationModelName = bi.getModerationModelName();
794846
SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
795847
.configure(QuarkusAiServiceContext.class)
848+
.unremovable()
796849
.forceApplicationClass()
797850
.createWith(recorder.createDeclarativeAiService(
798851
new DeclarativeAiServiceCreateInfo(
@@ -1166,6 +1219,11 @@ AnnotationsImpliesAiServiceBuildItem implyAiService() {
11661219
LangChain4jDotNames.MODERATE));
11671220
}
11681221

1222+
@BuildStep
1223+
public void annotationTransformations(BuildProducer<AnnotationsTransformerBuildItem> producer) {
1224+
1225+
}
1226+
11691227
@BuildStep
11701228
@Record(ExecutionTime.STATIC_INIT)
11711229
public void handleAiServices(
@@ -1319,6 +1377,9 @@ public void handleAiServices(
13191377
classCreator.addAnnotation(
13201378
AnnotationInstance.builder(NAMED).add("value", matchingBI.getBeanName().get()).build());
13211379
}
1380+
if (matchingBI.isMakeDefaultBean()) {
1381+
classCreator.addAnnotation(DefaultBean.class);
1382+
}
13221383
}
13231384

13241385
FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
3434
private final DeclarativeAiServiceInputGuardrails inputGuardrails;
3535
private final DeclarativeAiServiceOutputGuardrails outputGuardrails;
3636
private final Integer maxSequentialToolInvocations;
37+
private final boolean makeDefaultBean;
3738

3839
public DeclarativeAiServiceBuildItem(
3940
ClassInfo serviceClassInfo,
@@ -55,7 +56,7 @@ public DeclarativeAiServiceBuildItem(
5556
DotName toolHallucinationStrategyClassDotName,
5657
DeclarativeAiServiceInputGuardrails inputGuardrails,
5758
DeclarativeAiServiceOutputGuardrails outputGuardrails,
58-
Integer maxSequentialToolInvocations) {
59+
Integer maxSequentialToolInvocations, boolean makeDefaultBean) {
5960
this.serviceClassInfo = serviceClassInfo;
6061
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
6162
this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName;
@@ -76,6 +77,7 @@ public DeclarativeAiServiceBuildItem(
7677
this.inputGuardrails = inputGuardrails;
7778
this.outputGuardrails = outputGuardrails;
7879
this.maxSequentialToolInvocations = maxSequentialToolInvocations;
80+
this.makeDefaultBean = makeDefaultBean;
7981
}
8082

8183
public ClassInfo getServiceClassInfo() {
@@ -154,6 +156,10 @@ public DeclarativeAiServiceOutputGuardrails getOutputGuardrails() {
154156
return outputGuardrails;
155157
}
156158

159+
public boolean isMakeDefaultBean() {
160+
return makeDefaultBean;
161+
}
162+
157163
public record DeclarativeAiServiceInputGuardrails(List<ClassInfo> inputGuardrailClassInfos) {
158164
public List<String> asClassNames() {
159165
return this.inputGuardrailClassInfos.stream()

core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusAiServiceContextFactory.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,20 @@
33
import dev.langchain4j.service.AiServiceContext;
44
import dev.langchain4j.spi.services.AiServiceContextFactory;
55
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
6+
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContextQualifier;
7+
import io.quarkus.arc.Arc;
8+
import io.quarkus.arc.InstanceHandle;
69

710
public class QuarkusAiServiceContextFactory implements AiServiceContextFactory {
811

912
@Override
1013
public AiServiceContext create(Class<?> aiServiceClass) {
11-
return new QuarkusAiServiceContext(aiServiceClass);
14+
InstanceHandle<QuarkusAiServiceContext> instance = Arc.container().instance(QuarkusAiServiceContext.class,
15+
QuarkusAiServiceContextQualifier.Literal.of(
16+
aiServiceClass.getName()));
17+
if (instance.isAvailable()) {
18+
return instance.get();
19+
}
20+
return new QuarkusAiServiceContext(aiServiceClass); // just create a default context
1221
}
1322
}

0 commit comments

Comments
 (0)