|  | 
| 1 | 1 | package io.quarkiverse.langchain4j.runtime; | 
| 2 | 2 | 
 | 
| 3 |  | -import static dev.langchain4j.service.TypeUtils.getRawClass; | 
|  | 3 | +import java.lang.reflect.*; | 
|  | 4 | +import java.util.*; | 
|  | 5 | +import java.util.regex.Matcher; | 
|  | 6 | +import java.util.regex.Pattern; | 
| 4 | 7 | 
 | 
| 5 |  | -import java.lang.reflect.Type; | 
|  | 8 | +import com.fasterxml.jackson.databind.ObjectMapper; | 
| 6 | 9 | 
 | 
|  | 10 | +import dev.langchain4j.data.message.AiMessage; | 
|  | 11 | +import dev.langchain4j.model.output.Response; | 
|  | 12 | +import dev.langchain4j.model.output.structured.Description; | 
|  | 13 | +import dev.langchain4j.service.Result; | 
|  | 14 | +import dev.langchain4j.service.TokenStream; | 
|  | 15 | +import dev.langchain4j.service.TypeUtils; | 
|  | 16 | +//import dev.langchain4j.service.output.OutputParser; | 
| 7 | 17 | import dev.langchain4j.service.output.ServiceOutputParser; | 
|  | 18 | +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; | 
| 8 | 19 | import io.smallrye.mutiny.Multi; | 
| 9 | 20 | 
 | 
| 10 | 21 | public class QuarkusServiceOutputParser extends ServiceOutputParser { | 
|  | 22 | +    private static final Pattern JSON_BLOCK_PATTERN = Pattern.compile("(?s)\\{.*\\}|\\[.*\\]"); | 
| 11 | 23 | 
 | 
| 12 | 24 |     @Override | 
| 13 | 25 |     public String outputFormatInstructions(Type returnType) { | 
| 14 | 26 |         Class<?> rawClass = getRawClass(returnType); | 
| 15 |  | -        if (Multi.class.equals(rawClass)) { | 
| 16 |  | -            // when Multi is used as the return type, Multi<String> is the only supported type, thus we don't need want any formatting instructions | 
| 17 |  | -            return ""; | 
|  | 27 | + | 
|  | 28 | +        if (rawClass != String.class && rawClass != AiMessage.class && rawClass != TokenStream.class | 
|  | 29 | +                && rawClass != Response.class && !Multi.class.equals(rawClass)) { | 
|  | 30 | +            try { | 
|  | 31 | +                var schema = this.toJsonSchema(returnType); | 
|  | 32 | +                return "You must answer strictly with json according to the following json schema format: " + schema; | 
|  | 33 | +            } catch (Exception e) { | 
|  | 34 | +                return ""; | 
|  | 35 | +            } | 
|  | 36 | +        } | 
|  | 37 | + | 
|  | 38 | +        return ""; | 
|  | 39 | +    } | 
|  | 40 | + | 
|  | 41 | +    public Object parse(Response<AiMessage> response, Type returnType) { | 
|  | 42 | +        QuarkusJsonCodecFactory factory = new QuarkusJsonCodecFactory(); | 
|  | 43 | +        var codec = factory.create(); | 
|  | 44 | + | 
|  | 45 | +        if (TypeUtils.typeHasRawClass(returnType, Result.class)) { | 
|  | 46 | +            returnType = TypeUtils.resolveFirstGenericParameterClass(returnType); | 
|  | 47 | +        } | 
|  | 48 | + | 
|  | 49 | +        Class<?> rawReturnClass = TypeUtils.getRawClass(returnType); | 
|  | 50 | + | 
|  | 51 | +        if (rawReturnClass == Response.class) { | 
|  | 52 | +            return response; | 
|  | 53 | +        } else { | 
|  | 54 | +            AiMessage aiMessage = response.content(); | 
|  | 55 | +            if (rawReturnClass == AiMessage.class) { | 
|  | 56 | +                return aiMessage; | 
|  | 57 | +            } else { | 
|  | 58 | +                String text = aiMessage.text(); | 
|  | 59 | +                if (rawReturnClass == String.class) { | 
|  | 60 | +                    return text; | 
|  | 61 | +                } else { | 
|  | 62 | +                    try { | 
|  | 63 | +                        return codec.fromJson(text, returnType); | 
|  | 64 | +                    } catch (Exception var10) { | 
|  | 65 | +                        String jsonBlock = this.extractJsonBlock(text); | 
|  | 66 | +                        return codec.fromJson(jsonBlock, returnType); | 
|  | 67 | +                    } | 
|  | 68 | +                } | 
|  | 69 | +            } | 
|  | 70 | +        } | 
|  | 71 | +    } | 
|  | 72 | + | 
|  | 73 | +    private String extractJsonBlock(String text) { | 
|  | 74 | +        Matcher matcher = JSON_BLOCK_PATTERN.matcher(text); | 
|  | 75 | +        return matcher.find() ? matcher.group() : text; | 
|  | 76 | +    } | 
|  | 77 | + | 
|  | 78 | +    public String toJsonSchema(Type type) throws Exception { | 
|  | 79 | +        Map<String, Object> schema = new HashMap<>(); | 
|  | 80 | +        Class<?> rawClass = getRawClass(type); | 
|  | 81 | + | 
|  | 82 | +        if (type instanceof WildcardType wildcardType) { | 
|  | 83 | +            Type boundType = wildcardType.getUpperBounds().length > 0 ? wildcardType.getUpperBounds()[0] | 
|  | 84 | +                    : wildcardType.getLowerBounds()[0]; | 
|  | 85 | +            return toJsonSchema(boundType); | 
|  | 86 | +        } | 
|  | 87 | + | 
|  | 88 | +        if (rawClass == String.class || rawClass == Character.class) { | 
|  | 89 | +            schema.put("type", "string"); | 
|  | 90 | +        } else if (rawClass == Boolean.class || rawClass == boolean.class) { | 
|  | 91 | +            schema.put("type", "boolean"); | 
|  | 92 | +        } else if (Number.class.isAssignableFrom(rawClass) || rawClass.isPrimitive()) { | 
|  | 93 | +            schema.put("type", (rawClass == double.class || rawClass == float.class) ? "number" : "integer"); | 
|  | 94 | +        } else if (Collection.class.isAssignableFrom(rawClass) || rawClass.isArray()) { | 
|  | 95 | +            schema.put("type", "array"); | 
|  | 96 | + | 
|  | 97 | +            Type elementType = getElementType(type); | 
|  | 98 | +            Map<String, Object> itemsSchema = toJsonSchemaMap(elementType); | 
|  | 99 | +            schema.put("items", itemsSchema); | 
|  | 100 | +        } else if (rawClass.isEnum()) { | 
|  | 101 | +            schema.put("type", "string"); | 
|  | 102 | +            schema.put("enum", getEnumConstants(rawClass)); | 
|  | 103 | +        } else { | 
|  | 104 | +            schema.put("type", "object"); | 
|  | 105 | +            Map<String, Object> properties = new HashMap<>(); | 
|  | 106 | + | 
|  | 107 | +            for (Field field : rawClass.getDeclaredFields()) { | 
|  | 108 | +                field.setAccessible(true); | 
|  | 109 | +                Map<String, Object> fieldSchema = toJsonSchemaMap(field.getGenericType()); | 
|  | 110 | +                properties.put(field.getName(), fieldSchema); | 
|  | 111 | +                if (field.isAnnotationPresent(Description.class)) { | 
|  | 112 | +                    Description description = field.getAnnotation(Description.class); | 
|  | 113 | +                    fieldSchema.put("description", description.value()); | 
|  | 114 | +                } | 
|  | 115 | +            } | 
|  | 116 | +            schema.put("properties", properties); | 
|  | 117 | +        } | 
|  | 118 | + | 
|  | 119 | +        ObjectMapper mapper = new ObjectMapper(); | 
|  | 120 | +        return mapper.writeValueAsString(schema); // Convert the schema map to a JSON string | 
|  | 121 | +    } | 
|  | 122 | + | 
|  | 123 | +    private Class<?> getRawClass(Type type) { | 
|  | 124 | +        if (type instanceof Class<?>) { | 
|  | 125 | +            return (Class<?>) type; | 
|  | 126 | +        } else if (type instanceof ParameterizedType) { | 
|  | 127 | +            return (Class<?>) ((ParameterizedType) type).getRawType(); | 
|  | 128 | +        } else if (type instanceof GenericArrayType) { | 
|  | 129 | +            Type componentType = ((GenericArrayType) type).getGenericComponentType(); | 
|  | 130 | +            return Array.newInstance(getRawClass(componentType), 0).getClass(); | 
|  | 131 | +        } else if (type instanceof WildcardType) { | 
|  | 132 | +            Type boundType = ((WildcardType) type).getUpperBounds().length > 0 ? ((WildcardType) type).getUpperBounds()[0] | 
|  | 133 | +                    : ((WildcardType) type).getLowerBounds()[0]; | 
|  | 134 | +            return getRawClass(boundType); | 
|  | 135 | +        } | 
|  | 136 | +        throw new IllegalArgumentException("Unsupported type: " + type); | 
|  | 137 | +    } | 
|  | 138 | + | 
|  | 139 | +    private Type getElementType(Type type) { | 
|  | 140 | +        if (type instanceof ParameterizedType) { | 
|  | 141 | +            return ((ParameterizedType) type).getActualTypeArguments()[0]; | 
|  | 142 | +        } else if (type instanceof GenericArrayType) { | 
|  | 143 | +            return ((GenericArrayType) type).getGenericComponentType(); | 
|  | 144 | +        } else if (type instanceof Class<?> && ((Class<?>) type).isArray()) { | 
|  | 145 | +            return ((Class<?>) type).getComponentType(); | 
|  | 146 | +        } | 
|  | 147 | +        return Object.class; // Fallback for cases where element type cannot be determined | 
|  | 148 | +    } | 
|  | 149 | + | 
|  | 150 | +    private Map<String, Object> toJsonSchemaMap(Type type) throws Exception { | 
|  | 151 | +        String jsonSchema = toJsonSchema(type); | 
|  | 152 | +        ObjectMapper mapper = new ObjectMapper(); | 
|  | 153 | +        return mapper.readValue(jsonSchema, Map.class); | 
|  | 154 | +    } | 
|  | 155 | + | 
|  | 156 | +    private List<String> getEnumConstants(Class<?> enumClass) { | 
|  | 157 | +        List<String> constants = new ArrayList<>(); | 
|  | 158 | +        for (Object constant : enumClass.getEnumConstants()) { | 
|  | 159 | +            constants.add(constant.toString()); | 
| 18 | 160 |         } | 
| 19 |  | -        return super.outputFormatInstructions(returnType); | 
|  | 161 | +        return constants; | 
| 20 | 162 |     } | 
| 21 | 163 | } | 
0 commit comments