diff --git a/pom.xml b/pom.xml index a079df8c0..c4dc98a36 100644 --- a/pom.xml +++ b/pom.xml @@ -27,7 +27,7 @@ org.springframework.data spring-data-neo4j - 8.0.0-SNAPSHOT + 8.0.0-GH-3002-SNAPSHOT Spring Data Neo4j Next generation Object-Graph-Mapping for Spring Data. diff --git a/src/main/antora/modules/ROOT/nav.adoc b/src/main/antora/modules/ROOT/nav.adoc index dc4006628..30033d420 100644 --- a/src/main/antora/modules/ROOT/nav.adoc +++ b/src/main/antora/modules/ROOT/nav.adoc @@ -26,6 +26,7 @@ ** xref:repositories/sdn-extension.adoc[] ** xref:repositories/query-keywords-reference.adoc[] ** xref:repositories/query-return-types-reference.adoc[] +** xref:repositories/vector-search.adoc[] * xref:repositories/projections.adoc[] ** xref:projections/sdn-projections.adoc[] diff --git a/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc b/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc new file mode 100644 index 000000000..7d5116b86 --- /dev/null +++ b/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc @@ -0,0 +1,30 @@ +[[sdn-vector-search]] += Neo4j Vector Search + +== The `@VectorSearch` annotation +Spring Data Neo4j supports Neo4j's vector search on the repository level by using the `@VectorSearch` annotation. +For this to work, Neo4j needs to have a vector index in place. +How to create a vector index is explained in the https://neo4j.com/docs/cypher-manual/current/indexes/search-performance-indexes/managing-indexes/[Neo4j documentation]. + +NOTE: It's not required to have any (Spring Data) Vector typed property be defined in the domain entities for this to work +because the search operates exclusively on the index. + +The `@VectorSearch` annotation requires two arguments: +The name of the vector index to be used and the number of nearest neighbours. + +For a general vector search over the whole domain, it's possible to use a derived finder method without any property. +[source,java,indent=0,tabsize=4] +---- +include::example$integration/imperative/VectorSearchIT.java[tags=sdn-vector-search.usage;sdn-vector-search.usage.findall] +---- + +The vector index can be combined with any property-based finder method to filter down the results. + +NOTE: For technical reasons, the vector search will always be executed before the property search gets invoked. +E.g. if the property filter looks for a person named "Helge", +but the vector search only yields "Hannes", there won't be a result. + +[source,java,indent=0,tabsize=4] +---- +include::example$integration/imperative/VectorSearchIT.java[tags=sdn-vector-search.usage;sdn-vector-search.usage.findbyproperty] +---- diff --git a/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java b/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java index e069dc85b..3d0ae1079 100644 --- a/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java +++ b/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java @@ -1489,7 +1489,12 @@ private Optional> createFetchSpec() { statement = nodesAndRelationshipsById.toStatement(entityMetaData); } else { - statement = queryFragments.toStatement(); + if (queryFragmentsAndParameters.hasVectorSearchFragment()) { + statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment()); + } + else { + statement = queryFragments.toStatement(); + } } cypherQuery = Neo4jTemplate.this.renderer.render(statement); finalParameters = TemplateSupport.mergeParameters(statement, finalParameters); diff --git a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java index 099e8c340..b7b1a7b25 100644 --- a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java +++ b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java @@ -1408,8 +1408,13 @@ public Mono> toExecutableQuery(PreparedQuery preparedQ return new DefaultReactiveExecutableQuery<>(preparedQuery, fetchSpec); }); } - - Statement statement = queryFragments.toStatement(); + Statement statement = null; + if (queryFragmentsAndParameters.hasVectorSearchFragment()) { + statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment()); + } + else { + statement = queryFragments.toStatement(); + } cypherQuery = this.renderer.render(statement); finalParameters = TemplateSupport.mergeParameters(statement, finalParameters); } diff --git a/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java b/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java index a7a36ce0e..d059ede5f 100644 --- a/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java +++ b/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java @@ -163,6 +163,21 @@ public final class Constants { */ public static final String TO_ID_PARAMETER_NAME = "toId"; + /** + * The name SDN uses for vector search score. + */ + public static final String NAME_OF_SCORE = "__score__"; + + /** + * Vector search vector parameter name. + */ + public static final String VECTOR_SEARCH_VECTOR_PARAMETER = "vectorSearchParam"; + + /** + * Vector search score parameter name. + */ + public static final String VECTOR_SEARCH_SCORE_PARAMETER = "scoreParam"; + private Constants() { } diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java index cb4455891..934e3e117 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java @@ -23,6 +23,7 @@ import java.util.function.LongSupplier; import java.util.function.Supplier; import java.util.function.UnaryOperator; +import java.util.stream.Collectors; import org.jspecify.annotations.Nullable; import org.neo4j.driver.types.MapAccessor; @@ -32,6 +33,8 @@ import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Slice; import org.springframework.data.domain.SliceImpl; import org.springframework.data.geo.GeoPage; @@ -98,11 +101,30 @@ boolean isGeoNearQuery() { return GeoPage.class.isAssignableFrom(returnType); } + boolean isVectorSearchQuery() { + var repositoryMethod = this.queryMethod.getMethod(); + Class returnType = repositoryMethod.getReturnType(); + + for (Class type : Neo4jQueryMethod.VECTOR_SEARCH_RESULTS) { + if (type.isAssignableFrom(returnType)) { + return true; + } + } + + if (Iterable.class.isAssignableFrom(returnType)) { + TypeInformation from = TypeInformation.fromReturnTypeOf(repositoryMethod); + return from.getComponentType() != null && SearchResult.class.equals(from.getComponentType().getType()); + } + + return false; + } + @Override @Nullable public final Object execute(Object[] parameters) { boolean incrementLimit = this.queryMethod.incrementLimit(); boolean geoNearQuery = isGeoNearQuery(); + boolean vectorSearchQuery = isVectorSearchQuery(); Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor( (Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters); @@ -111,7 +133,7 @@ boolean isGeoNearQuery() { ReturnedType returnedType = resultProcessor.getReturnedType(); PreparedQuery preparedQuery = prepareQuery(returnedType.getReturnedType(), PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext), - parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery), + parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery), incrementLimit ? l -> l + 1 : UnaryOperator.identity()); Object rawResult = new Neo4jQueryExecution.DefaultQueryExecution(this.neo4jOperations).execute(preparedQuery, @@ -143,6 +165,9 @@ else if (this.queryMethod.isScrollQuery()) { else if (geoNearQuery) { rawResult = newGeoResults(rawResult); } + else if (this.queryMethod.isSearchQuery()) { + rawResult = createSearchResult((List) rawResult, returnedType.getReturnedType()); + } return resultProcessor.processResult(rawResult, preparingConverter); } @@ -182,6 +207,13 @@ private Slice createSlice(boolean incrementLimit, Neo4jParameterAccessor para } } + private SearchResults createSearchResult(List rawResult, Class returnedType) { + List> searchResults = rawResult.stream() + .map(rawValue -> (SearchResult) rawValue) + .collect(Collectors.toUnmodifiableList()); + return new SearchResults<>(searchResults); + } + protected abstract PreparedQuery prepareQuery(Class returnedType, Collection includedProperties, Neo4jParameterAccessor parameterAccessor, @Nullable Neo4jQueryType queryType, diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java index 5274e4a35..60edae58e 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java @@ -24,8 +24,10 @@ import org.neo4j.driver.types.MapAccessor; import org.neo4j.driver.types.TypeSystem; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import org.springframework.core.convert.converter.Converter; +import org.springframework.data.domain.SearchResult; import org.springframework.data.geo.GeoResult; import org.springframework.data.neo4j.core.PreparedQuery; import org.springframework.data.neo4j.core.PropertyFilterSupport; @@ -89,11 +91,31 @@ boolean isGeoNearQuery() { return false; } + boolean isVectorSearchQuery() { + var repositoryMethod = this.queryMethod.getMethod(); + Class returnType = repositoryMethod.getReturnType(); + + for (Class type : ReactiveNeo4jQueryMethod.VECTOR_SEARCH_RESULTS) { + if (type.isAssignableFrom(returnType)) { + return true; + } + } + + if (Flux.class.isAssignableFrom(returnType) || Mono.class.isAssignableFrom(returnType)) { + TypeInformation from = TypeInformation.fromReturnTypeOf(repositoryMethod); + TypeInformation componentType = from.getComponentType(); + return componentType != null && SearchResult.class.equals(componentType.getType()); + } + + return false; + } + @Override @Nullable public final Object execute(Object[] parameters) { boolean incrementLimit = this.queryMethod.incrementLimit(); boolean geoNearQuery = isGeoNearQuery(); + boolean vectorSearchQuery = isVectorSearchQuery(); Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor( (Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters); ResultProcessor resultProcessor = this.queryMethod.getResultProcessor() @@ -102,7 +124,7 @@ boolean isGeoNearQuery() { ReturnedType returnedType = resultProcessor.getReturnedType(); PreparedQuery preparedQuery = prepareQuery(returnedType.getReturnedType(), PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext), - parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery), + parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery), incrementLimit ? l -> l + 1 : UnaryOperator.identity()); Object rawResult = new Neo4jQueryExecution.ReactiveQueryExecution(this.neo4jOperations).execute(preparedQuery, @@ -126,10 +148,17 @@ parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery), .map(rawResultList -> createWindow(resultProcessor, incrementLimit, parameterAccessor, rawResultList, preparedQuery.getQueryFragmentsAndParameters())); } + else if (this.queryMethod.isSearchQuery()) { + rawResult = createSearchResult((Flux) rawResult, returnedType.getReturnedType()); + } return resultProcessor.processResult(rawResult, preparingConverter); } + private Flux> createSearchResult(Flux rawResult, Class returnedType) { + return rawResult.map(rawValue -> (SearchResult) rawValue); + } + protected abstract PreparedQuery prepareQuery(Class returnedType, Collection includedProperties, Neo4jParameterAccessor parameterAccessor, @Nullable Neo4jQueryType queryType, Supplier> mappingFunction, diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java b/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java index a1785bb76..497b64c9a 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java @@ -45,8 +45,10 @@ import org.springframework.data.domain.OffsetScrollPosition; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; @@ -60,7 +62,6 @@ import org.springframework.data.neo4j.core.mapping.Neo4jPersistentProperty; import org.springframework.data.neo4j.core.mapping.NodeDescription; import org.springframework.data.neo4j.core.mapping.PropertyFilter; -import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.repository.query.parser.AbstractQueryCreator; import org.springframework.data.repository.query.parser.Part; import org.springframework.data.repository.query.parser.PartTree; @@ -119,14 +120,22 @@ final class CypherQueryCreator extends AbstractQueryCreator distanceExpressions = new ArrayList<>(); + private final List additionalReturnExpression = new ArrayList<>(); /** * Can be used to modify the limit of a paged or sliced query. */ private final UnaryOperator limitModifier; - CypherQueryCreator(Neo4jMappingContext mappingContext, QueryMethod queryMethod, Class domainType, + private final Neo4jQueryMethod queryMethod; + + @Nullable + private final Vector vectorSearchParameter; + + @Nullable + private final Score scoreParameter; + + CypherQueryCreator(Neo4jMappingContext mappingContext, Neo4jQueryMethod queryMethod, Class domainType, Neo4jQueryType queryType, PartTree tree, Neo4jParameterAccessor actualParameters, Collection includedProperties, BiFunction, Object> parameterConversion, @@ -148,6 +157,8 @@ final class CypherQueryCreator extends AbstractQueryCreator(this.sortItems); theSort.stream().map(CypherAdapterUtils.sortAdapterFor(this.nodeDescription)).forEach(finalSortItems::add); queryFragments.setReturnBasedOn(this.nodeDescription, this.includedProperties, this.isDistinct, - this.distanceExpressions); + this.additionalReturnExpression); queryFragments.setOrderBy(finalSortItems); } @@ -438,7 +468,7 @@ else if (p2.isPresent() && p2.get().value instanceof Point) { // property to be later retrieved and mapped Neo4jPersistentEntity owner = (Neo4jPersistentEntity) leafProperty.getOwner(); String containerName = getContainerName(path, owner); - this.distanceExpressions + this.additionalReturnExpression .add(distanceFunction.as("__distance_" + containerName + "_" + leafProperty.getPropertyName() + "__")); this.sortItems.add(distanceFunction.ascending()); diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQueryMethod.java b/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQueryMethod.java index 1427a7143..2a8743927 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQueryMethod.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQueryMethod.java @@ -25,6 +25,8 @@ import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResult; import org.springframework.data.geo.GeoResults; @@ -55,12 +57,18 @@ class Neo4jQueryMethod extends QueryMethod { static final List> GEO_NEAR_RESULTS = List.of(GeoResult.class, GeoResults.class, GeoPage.class); + static final List> VECTOR_SEARCH_RESULTS = List.of(SearchResults.class, + SearchResult.class); + /** * Optional query annotation of the method. */ @Nullable private final Query queryAnnotation; + @Nullable + private final VectorSearch vectorSearchAnnotation; + private final String repositoryName; private final boolean cypherBasedProjection; @@ -94,6 +102,7 @@ class Neo4jQueryMethod extends QueryMethod { this.repositoryName = this.method.getDeclaringClass().getName(); this.cypherBasedProjection = cypherBasedProjection; this.queryAnnotation = AnnotatedElementUtils.findMergedAnnotation(this.method, Query.class); + this.vectorSearchAnnotation = AnnotatedElementUtils.findMergedAnnotation(this.method, VectorSearch.class); } String getRepositoryName() { @@ -126,6 +135,14 @@ Optional getQueryAnnotation() { return Optional.ofNullable(this.queryAnnotation); } + boolean hasVectorSearchAnnotation() { + return this.vectorSearchAnnotation != null; + } + + Optional getVectorSearchAnnotation() { + return Optional.ofNullable(this.vectorSearchAnnotation); + } + @Override public Class getReturnedObjectType() { Class returnedObjectType = super.getReturnedObjectType(); @@ -143,7 +160,7 @@ boolean incrementLimit() { boolean asCollectionQuery() { return this.isCollectionLikeQuery() || this.isPageQuery() || this.isSliceQuery() || this.isScrollQuery() - || GeoResults.class.isAssignableFrom(this.method.getReturnType()); + || GeoResults.class.isAssignableFrom(this.method.getReturnType()) || this.isSearchQuery(); } Method getMethod() { diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQuerySupport.java b/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQuerySupport.java index d9c173a72..c51a64a11 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQuerySupport.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQuerySupport.java @@ -44,8 +44,10 @@ import org.springframework.data.domain.KeysetScrollPosition; import org.springframework.data.domain.OffsetScrollPosition; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.ScrollPosition.Direction; +import org.springframework.data.domain.SearchResult; import org.springframework.data.domain.Window; import org.springframework.data.expression.ValueExpressionParser; import org.springframework.data.geo.Box; @@ -135,6 +137,25 @@ else if (distances.size() > 1) { }; } + static BiFunction decorateAsVectorSearchResult( + BiFunction target) { + return (t, r) -> { + Object intermediateResult = target.apply(t, r); + var distances = StreamSupport.stream(r.keys().spliterator(), false) + .filter(k -> k.equals(Constants.NAME_OF_SCORE)) + .toList(); + if (distances.isEmpty()) { + throw new RuntimeException("No score has been returned by the query, cannot create `SearchResult`"); + } + else if (distances.size() > 1) { + throw new RuntimeException( + "More than one score has been returned by the query, cannot create `SearchResult`"); + } + var searchResult = Score.of(r.get(distances.get(0)).asDouble()); + return new SearchResult<>(intermediateResult, searchResult); + }; + } + private static boolean hasValidReturnTypeForDelete(Neo4jQueryMethod queryMethod) { return VALID_RETURN_TYPES_FOR_DELETE .contains(queryMethod.getResultProcessor().getReturnedType().getReturnedType()); @@ -192,7 +213,7 @@ else if (distance.getMetric() == Metrics.MILES) { } protected final Supplier> getMappingFunction( - final ResultProcessor resultProcessor, boolean isGeoNearQuery) { + final ResultProcessor resultProcessor, boolean isGeoNearQuery, boolean isVectorSearchQuery) { return () -> { final ReturnedType returnedTypeMetadata = resultProcessor.getReturnedType(); @@ -213,6 +234,10 @@ else if (returnedTypeMetadata.isProjecting()) { else if (isGeoNearQuery) { mappingFunction = decorateAsGeoResult(this.mappingContext.getRequiredMappingFunctionFor(domainType)); } + else if (isVectorSearchQuery) { + mappingFunction = decorateAsVectorSearchResult( + this.mappingContext.getRequiredMappingFunctionFor(domainType)); + } else { mappingFunction = this.mappingContext.getRequiredMappingFunctionFor(domainType); } diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/PartTreeNeo4jQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/PartTreeNeo4jQuery.java index f50f45634..09bde0684 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/PartTreeNeo4jQuery.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/PartTreeNeo4jQuery.java @@ -57,6 +57,11 @@ private PartTreeNeo4jQuery(Neo4jOperations neo4jOperations, Neo4jMappingContext static RepositoryQuery create(Neo4jOperations neo4jOperations, Neo4jMappingContext mappingContext, Neo4jQueryMethod queryMethod, ProjectionFactory factory) { + if (queryMethod.hasVectorSearchAnnotation() + && queryMethod.getVectorSearchAnnotation().get().numberOfNodes() < 1) { + throw new IllegalArgumentException("Number of nodes in the vector search %s#%s has to be greater than zero." + .formatted(queryMethod.getRepositoryName(), queryMethod.getMethod().getName())); + } return new PartTreeNeo4jQuery(neo4jOperations, mappingContext, queryMethod, new PartTree(queryMethod.getName(), getDomainType(queryMethod)), factory); } diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragments.java b/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragments.java index 621e0ce24..d8bb4f5e9 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragments.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragments.java @@ -28,11 +28,13 @@ import org.neo4j.cypherdsl.core.Condition; import org.neo4j.cypherdsl.core.Cypher; import org.neo4j.cypherdsl.core.Expression; +import org.neo4j.cypherdsl.core.Node; import org.neo4j.cypherdsl.core.PatternElement; import org.neo4j.cypherdsl.core.SortItem; import org.neo4j.cypherdsl.core.Statement; import org.neo4j.cypherdsl.core.StatementBuilder; +import org.springframework.data.neo4j.core.mapping.Constants; import org.springframework.data.neo4j.core.mapping.CypherGenerator; import org.springframework.data.neo4j.core.mapping.Neo4jPersistentEntity; import org.springframework.data.neo4j.core.mapping.Neo4jPersistentProperty; @@ -193,6 +195,60 @@ public Statement toStatement() { return statement; } + public Statement toStatement(VectorSearchFragment vectorSearchFragment) { + + if (this.matchOn.isEmpty()) { + throw new IllegalStateException("No pattern to match on"); + } + var vectorSearch = Cypher.call("db.index.vector.queryNodes") + .withArgs(Cypher.literalOf(vectorSearchFragment.indexName()), + Cypher.literalOf(vectorSearchFragment.numberOfNodes()), + Cypher.parameter(Constants.VECTOR_SEARCH_VECTOR_PARAMETER)) + .yield("node", "score") + .with(Cypher.name("node").as(((Node) this.matchOn.get(0)).getRequiredSymbolicName().getValue()), + Cypher.name("score").as(Constants.NAME_OF_SCORE)); + + StatementBuilder.OngoingReadingWithoutWhere match = null; + if (vectorSearchFragment.hasScore()) { + match = vectorSearch + .where(Cypher.raw(Constants.NAME_OF_SCORE) + .gte(Cypher.parameter(Constants.VECTOR_SEARCH_SCORE_PARAMETER))) + .match(this.matchOn.get(0)); + } + else { + match = vectorSearch.match(this.matchOn.get(0)); + } + + if (this.matchOn.size() > 1) { + for (PatternElement patternElement : this.matchOn.subList(1, this.matchOn.size())) { + match = match.match(patternElement); + } + } + + StatementBuilder.OngoingReadingWithWhere matchWithWhere = match.where(this.condition); + + if (this.deleteExpression != null) { + matchWithWhere = (StatementBuilder.OngoingReadingWithWhere) matchWithWhere + .detachDelete(this.deleteExpression); + } + + StatementBuilder.OngoingReadingAndReturn returnPart = isDistinctReturn() + ? matchWithWhere.returningDistinct(getReturnExpressionsForVectorSearch()) + : matchWithWhere.returning(getReturnExpressionsForVectorSearch()); + + Statement statement = returnPart.orderBy(getOrderBy()).skip(this.skip).limit(this.limit).build(); + + statement.setRenderConstantsAsParameters(false); + return statement; + } + + private Collection getReturnExpressionsForVectorSearch() { + return (this.returnExpressions.isEmpty() && this.returnTuple != null) ? CypherGenerator.INSTANCE + .createReturnStatementForMatch((Neo4jPersistentEntity) this.returnTuple.nodeDescription, + this::includeField, this.returnTuple.additionalExpressions.toArray(Expression[]::new)) + : this.returnExpressions; + } + private Collection getReturnExpressions() { return (this.returnExpressions.isEmpty() && this.returnTuple != null) ? CypherGenerator.INSTANCE .createReturnStatementForMatch((Neo4jPersistentEntity) this.returnTuple.nodeDescription, diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragmentsAndParameters.java b/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragmentsAndParameters.java index dc28dd608..bb20f8b7e 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragmentsAndParameters.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragmentsAndParameters.java @@ -63,6 +63,8 @@ public final class QueryFragmentsAndParameters { private final QueryFragments queryFragments; + private final VectorSearchFragment vectorSearchFragment; + @Nullable private final String cypherQuery; @@ -73,10 +75,21 @@ public final class QueryFragmentsAndParameters { @Nullable private NodeDescription nodeDescription; + public QueryFragmentsAndParameters(@Nullable NodeDescription nodeDescription, QueryFragments queryFragments, + VectorSearchFragment vectorSearchFragment, Map parameters, @Nullable Sort sort) { + this.nodeDescription = nodeDescription; + this.queryFragments = queryFragments; + this.vectorSearchFragment = vectorSearchFragment; + this.parameters = parameters; + this.cypherQuery = null; + this.sort = (sort != null) ? sort : Sort.unsorted(); + } + public QueryFragmentsAndParameters(@Nullable NodeDescription nodeDescription, QueryFragments queryFragments, Map parameters, @Nullable Sort sort) { this.nodeDescription = nodeDescription; this.queryFragments = queryFragments; + this.vectorSearchFragment = null; this.parameters = parameters; this.cypherQuery = null; this.sort = (sort != null) ? sort : Sort.unsorted(); @@ -89,6 +102,7 @@ public QueryFragmentsAndParameters(@NonNull String cypherQuery) { public QueryFragmentsAndParameters(@NonNull String cypherQuery, Map parameters) { this.cypherQuery = cypherQuery; this.queryFragments = new QueryFragments(); + this.vectorSearchFragment = null; this.parameters = parameters; this.sort = Sort.unsorted(); } @@ -384,6 +398,14 @@ public QueryFragments getQueryFragments() { return this.queryFragments; } + public boolean hasVectorSearchFragment() { + return this.vectorSearchFragment != null; + } + + public VectorSearchFragment getVectorSearchFragment() { + return this.vectorSearchFragment; + } + @Nullable public String getCypherQuery() { return this.cypherQuery; } diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveNeo4jQueryMethod.java b/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveNeo4jQueryMethod.java index efaaad052..f0c2297cc 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveNeo4jQueryMethod.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveNeo4jQueryMethod.java @@ -15,11 +15,14 @@ */ package org.springframework.data.neo4j.repository.query; +import java.io.Serializable; import java.lang.reflect.Method; +import java.util.List; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.SearchResult; import org.springframework.data.domain.Slice; import org.springframework.data.neo4j.repository.support.ReactiveCypherdslStatementExecutor; import org.springframework.data.projection.ProjectionFactory; @@ -43,6 +46,8 @@ */ final class ReactiveNeo4jQueryMethod extends Neo4jQueryMethod { + static final List> VECTOR_SEARCH_RESULTS = List.of(SearchResult.class); + @SuppressWarnings("rawtypes") private static final TypeInformation PAGE_TYPE = TypeInformation.of(Page.class); diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/ReactivePartTreeNeo4jQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/ReactivePartTreeNeo4jQuery.java index c074690a3..ce7b0580a 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/ReactivePartTreeNeo4jQuery.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/ReactivePartTreeNeo4jQuery.java @@ -57,6 +57,11 @@ private ReactivePartTreeNeo4jQuery(ReactiveNeo4jOperations neo4jOperations, Neo4 static RepositoryQuery create(ReactiveNeo4jOperations neo4jOperations, Neo4jMappingContext mappingContext, Neo4jQueryMethod queryMethod, ProjectionFactory factory) { + if (queryMethod.hasVectorSearchAnnotation() + && queryMethod.getVectorSearchAnnotation().get().numberOfNodes() < 1) { + throw new IllegalArgumentException("Number of nodes in the vector search %s#%s has to be greater than zero." + .formatted(queryMethod.getRepositoryName(), queryMethod.getMethod().getName())); + } return new ReactivePartTreeNeo4jQuery(neo4jOperations, mappingContext, queryMethod, new PartTree(queryMethod.getName(), getDomainType(queryMethod)), factory); } diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/VectorSearch.java b/src/main/java/org/springframework/data/neo4j/repository/query/VectorSearch.java new file mode 100644 index 000000000..53e95ccea --- /dev/null +++ b/src/main/java/org/springframework/data/neo4j/repository/query/VectorSearch.java @@ -0,0 +1,39 @@ +/* + * Copyright 2011-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.neo4j.repository.query; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Indicates a vector search on a repository. + * + * @author Gerrit Meier + * @since 8.0.0 + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE }) +@Documented +public @interface VectorSearch { + + String indexName(); + + int numberOfNodes(); + +} diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/VectorSearchFragment.java b/src/main/java/org/springframework/data/neo4j/repository/query/VectorSearchFragment.java new file mode 100644 index 000000000..ed445f9ee --- /dev/null +++ b/src/main/java/org/springframework/data/neo4j/repository/query/VectorSearchFragment.java @@ -0,0 +1,33 @@ +/* + * Copyright 2011-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.neo4j.repository.query; + +import org.jspecify.annotations.Nullable; + +/** + * Collected params for vector search. + * + * @author Gerrit Meier + * @param indexName name of the index to use for vector search + * @param numberOfNodes number of nodes to fetch from the index search + * @param score score filter + */ +record VectorSearchFragment(String indexName, int numberOfNodes, @Nullable Double score) { + + boolean hasScore() { + return this.score != null; + } +} diff --git a/src/test/java/org/springframework/data/neo4j/integration/imperative/VectorSearchIT.java b/src/test/java/org/springframework/data/neo4j/integration/imperative/VectorSearchIT.java new file mode 100644 index 000000000..8729ed990 --- /dev/null +++ b/src/test/java/org/springframework/data/neo4j/integration/imperative/VectorSearchIT.java @@ -0,0 +1,204 @@ +/* + * Copyright 2011-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.neo4j.integration.imperative; + +import java.util.List; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.neo4j.driver.Driver; +import org.neo4j.driver.Session; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Vector; +import org.springframework.data.neo4j.core.DatabaseSelectionProvider; +import org.springframework.data.neo4j.core.transaction.Neo4jBookmarkManager; +import org.springframework.data.neo4j.core.transaction.Neo4jTransactionManager; +import org.springframework.data.neo4j.integration.shared.common.EntityWithVector; +import org.springframework.data.neo4j.repository.Neo4jRepository; +import org.springframework.data.neo4j.repository.config.EnableNeo4jRepositories; +import org.springframework.data.neo4j.repository.query.VectorSearch; +import org.springframework.data.neo4j.test.BookmarkCapture; +import org.springframework.data.neo4j.test.Neo4jExtension; +import org.springframework.data.neo4j.test.Neo4jImperativeTestConfiguration; +import org.springframework.data.neo4j.test.Neo4jIntegrationTest; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.EnableTransactionManagement; +import org.springframework.transaction.support.TransactionTemplate; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Gerrit Meier + */ +@Neo4jIntegrationTest +class VectorSearchIT { + + protected static Neo4jExtension.Neo4jConnectionSupport neo4jConnectionSupport; + + @BeforeEach + void setupData(@Autowired BookmarkCapture bookmarkCapture, @Autowired Driver driver) { + try (Session session = driver.session(bookmarkCapture.createSessionConfig())) { + session.run("MATCH (n) detach delete n"); + session.run(""" + CREATE VECTOR INDEX entityIndex IF NOT EXISTS + FOR (m:EntityWithVector) + ON m.myVector + OPTIONS { indexConfig: { + `vector.dimensions`: 3, + `vector.similarity_function`: 'cosine' + }}""").consume(); + session.run( + "CREATE (e:EntityWithVector{name:'dings'}) WITH e CALL db.create.setNodeVectorProperty(e, 'myVector', [0.1, 0.1, 0.1])") + .consume(); + session.run( + "CREATE (e:EntityWithVector{name:'dings2'}) WITH e CALL db.create.setNodeVectorProperty(e, 'myVector', [0.7, 0.0, 0.3])") + .consume(); + session.run("CALL db.awaitIndexes()").consume(); + bookmarkCapture.seedWith(session.lastBookmarks()); + } + } + + @AfterEach + void removeIndex(@Autowired BookmarkCapture bookmarkCapture, @Autowired Driver driver) { + try (Session session = driver.session(bookmarkCapture.createSessionConfig())) { + session.run("DROP INDEX `entityIndex` IF EXISTS"); + } + } + + @Test + void findAllWithVectorIndex(@Autowired VectorSearchRepository repository) { + var result = repository.findBy(Vector.of(new double[] { 0.1d, 0.1d, 0.1d })); + assertThat(result).hasSize(2); + } + + @Test + void findAllAsSearchResultsWithVectorIndex(@Autowired VectorSearchRepository repository) { + var result = repository.findAllBy(Vector.of(new double[] { 0.1d, 0.1d, 0.1d })); + assertThat(result).hasSize(2); + assertThat(result.getContent()).hasSize(2); + assertThat(result.getContent()).allSatisfy(content -> { + assertThat(content.getContent()).isNotNull(); + assertThat(content.getScore().getValue()).isGreaterThanOrEqualTo(0.8d); + }); + } + + @Test + void findSingleAsSearchResultWithVectorIndex(@Autowired VectorSearchRepository repository) { + var result = repository.findBy(Vector.of(new double[] { 0.1d, 0.1d, 0.1d }), Score.of(0.9d)); + assertThat(result).isNotNull(); + assertThat(result.getContent()).isNotNull(); + assertThat(result.getScore().getValue()).isGreaterThanOrEqualTo(0.9d); + } + + @Test + void findSearchResultsOfSearchResults(@Autowired VectorSearchRepository repository) { + var result = repository.findDistinctByName("dings", Vector.of(new double[] { 0.1d, 0.1d, 0.1d })); + assertThat(result).isNotNull(); + } + + @Test + void findByNameWithVectorIndex(@Autowired VectorSearchRepository repository) { + var result = repository.findByName("dings", Vector.of(new double[] { 0.1d, 0.1d, 0.1d })); + assertThat(result).hasSize(1); + } + + @Test + void findByNameWithVectorIndexAndScore(@Autowired VectorSearchRepository repository) { + var result = repository.findByName("dings", Vector.of(new double[] { -0.7d, 0.0d, -0.7d }), Score.of(0.01d)); + assertThat(result).hasSize(1); + } + + @Test + void dontFindByNameWithVectorIndexAndScore(@Autowired VectorSearchRepository repository) { + var result = repository.findByName("dings", Vector.of(new double[] { -0.7d, 0.0d, -0.7d }), Score.of(0.8d)); + assertThat(result).hasSize(0); + } + + // tag::sdn-vector-search.usage[] + interface VectorSearchRepository extends Neo4jRepository { + + // end::sdn-vector-search.usage[] + // tag::sdn-vector-search.usage.findall[] + @VectorSearch(indexName = "entityIndex", numberOfNodes = 2) + List findBy(Vector searchVector); + // end::sdn-vector-search.usage.findall[] + + // tag::sdn-vector-search.usage.findbyproperty[] + @VectorSearch(indexName = "entityIndex", numberOfNodes = 1) + List findByName(String name, Vector searchVector); + // end::sdn-vector-search.usage.findbyproperty[] + + @VectorSearch(indexName = "entityIndex", numberOfNodes = 1) + List findDistinctByName(String name, Vector searchVector); + + @VectorSearch(indexName = "entityIndex", numberOfNodes = 2) + List findByName(String name, Vector searchVector, Score score); + + @VectorSearch(indexName = "entityIndex", numberOfNodes = 2) + SearchResults findAllBy(Vector searchVector); + + @VectorSearch(indexName = "entityIndex", numberOfNodes = 2) + SearchResult findBy(Vector searchVector, Score score); + + // tag::sdn-vector-search.usage[] + + } + // end::sdn-vector-search.usage[] + + @Configuration + @EnableTransactionManagement + @EnableNeo4jRepositories(considerNestedRepositories = true) + static class Config extends Neo4jImperativeTestConfiguration { + + @Bean + @Override + public Driver driver() { + return neo4jConnectionSupport.getDriver(); + } + + @Bean + BookmarkCapture bookmarkCapture() { + return new BookmarkCapture(); + } + + @Override + public PlatformTransactionManager transactionManager(Driver driver, + DatabaseSelectionProvider databaseNameProvider) { + BookmarkCapture bookmarkCapture = bookmarkCapture(); + return new Neo4jTransactionManager(driver, databaseNameProvider, + Neo4jBookmarkManager.create(bookmarkCapture)); + } + + @Bean + TransactionTemplate transactionTemplate(PlatformTransactionManager transactionManager) { + return new TransactionTemplate(transactionManager); + } + + @Override + public boolean isCypher5Compatible() { + return neo4jConnectionSupport.isCypher5SyntaxCompatible(); + } + + } + +} diff --git a/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveVectorSearchIT.java b/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveVectorSearchIT.java new file mode 100644 index 000000000..93433b90e --- /dev/null +++ b/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveVectorSearchIT.java @@ -0,0 +1,185 @@ +/* + * Copyright 2011-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.neo4j.integration.reactive; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.neo4j.driver.Driver; +import org.neo4j.driver.Session; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.Vector; +import org.springframework.data.neo4j.core.ReactiveDatabaseSelectionProvider; +import org.springframework.data.neo4j.core.transaction.Neo4jBookmarkManager; +import org.springframework.data.neo4j.core.transaction.ReactiveNeo4jTransactionManager; +import org.springframework.data.neo4j.integration.shared.common.EntityWithVector; +import org.springframework.data.neo4j.repository.ReactiveNeo4jRepository; +import org.springframework.data.neo4j.repository.config.EnableReactiveNeo4jRepositories; +import org.springframework.data.neo4j.repository.query.VectorSearch; +import org.springframework.data.neo4j.test.BookmarkCapture; +import org.springframework.data.neo4j.test.Neo4jExtension; +import org.springframework.data.neo4j.test.Neo4jIntegrationTest; +import org.springframework.data.neo4j.test.Neo4jReactiveTestConfiguration; +import org.springframework.transaction.ReactiveTransactionManager; +import org.springframework.transaction.annotation.EnableTransactionManagement; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Gerrit Meier + */ +@Neo4jIntegrationTest +class ReactiveVectorSearchIT { + + protected static Neo4jExtension.Neo4jConnectionSupport neo4jConnectionSupport; + + @BeforeEach + void setupData(@Autowired BookmarkCapture bookmarkCapture, @Autowired Driver driver) { + try (Session session = driver.session(bookmarkCapture.createSessionConfig())) { + session.run("MATCH (n) detach delete n"); + session.run(""" + CREATE VECTOR INDEX dingsIndex IF NOT EXISTS + FOR (m:EntityWithVector) + ON m.myVector + OPTIONS { indexConfig: { + `vector.dimensions`: 3, + `vector.similarity_function`: 'cosine' + }}""").consume(); + session.run( + "CREATE (e:EntityWithVector{name:'dings'}) WITH e CALL db.create.setNodeVectorProperty(e, 'myVector', [0.1, 0.1, 0.1])") + .consume(); + session.run( + "CREATE (e:EntityWithVector{name:'dings2'}) WITH e CALL db.create.setNodeVectorProperty(e, 'myVector', [0.7, 0.0, 0.3])") + .consume(); + session.run("CALL db.awaitIndexes()").consume(); + bookmarkCapture.seedWith(session.lastBookmarks()); + } + } + + @AfterEach + void removeIndex(@Autowired BookmarkCapture bookmarkCapture, @Autowired Driver driver) { + try (Session session = driver.session(bookmarkCapture.createSessionConfig())) { + session.run("DROP INDEX `dingsIndex` IF EXISTS"); + } + } + + @Test + void findAllWithVectorIndex(@Autowired VectorSearchRepository repository) { + StepVerifier.create(repository.findBy("dings", Vector.of(new double[] { 0.1d, 0.1d, 0.1d }))) + .expectNextCount(2) + .verifyComplete(); + } + + @Test + void findAllAsSearchResultsWithVectorIndex(@Autowired VectorSearchRepository repository) { + StepVerifier.create(repository.findAllBy(Vector.of(new double[] { 0.1d, 0.1d, 0.1d }))) + .assertNext(result -> assertThat(result.getContent()).isNotNull()) + .assertNext(result -> assertThat(result.getContent()).isNotNull()) + .verifyComplete(); + } + + @Test + void findSingleAsSearchResultWithVectorIndex(@Autowired VectorSearchRepository repository) { + StepVerifier.create(repository.findBy(Vector.of(new double[] { 0.1d, 0.1d, 0.1d }), Score.of(0.9d))) + .assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.getContent()).isNotNull(); + assertThat(result.getScore().getValue()).isGreaterThanOrEqualTo(0.9d); + }) + .verifyComplete(); + } + + @Test + void findByNameWithVectorIndex(@Autowired VectorSearchRepository repository) { + StepVerifier.create(repository.findByName("dings", Vector.of(new double[] { 0.1d, 0.1d, 0.1d }))) + .expectNextCount(1) + .verifyComplete(); + } + + @Test + void findByNameWithVectorIndexAndScore(@Autowired VectorSearchRepository repository) { + StepVerifier + .create(repository.findByName("dings", Vector.of(new double[] { -0.7d, 0.0d, -0.7d }), Score.of(0.01d))) + .expectNextCount(1) + .verifyComplete(); + } + + @Test + void dontFindByNameWithVectorIndexAndScore(@Autowired VectorSearchRepository repository) { + StepVerifier + .create(repository.findByName("dings", Vector.of(new double[] { -0.7d, 0.0d, -0.7d }), Score.of(0.8d))) + .verifyComplete(); + } + + interface VectorSearchRepository extends ReactiveNeo4jRepository { + + @VectorSearch(indexName = "dingsIndex", numberOfNodes = 2) + Flux findBy(String name, Vector searchVector); + + @VectorSearch(indexName = "dingsIndex", numberOfNodes = 1) + Flux findByName(String name, Vector searchVector); + + @VectorSearch(indexName = "dingsIndex", numberOfNodes = 2) + Flux findByName(String name, Vector searchVector, Score score); + + @VectorSearch(indexName = "dingsIndex", numberOfNodes = 2) + Flux> findAllBy(Vector searchVector); + + @VectorSearch(indexName = "dingsIndex", numberOfNodes = 2) + Mono> findBy(Vector searchVector, Score score); + + } + + @Configuration + @EnableTransactionManagement + @EnableReactiveNeo4jRepositories(considerNestedRepositories = true) + static class Config extends Neo4jReactiveTestConfiguration { + + @Bean + @Override + public Driver driver() { + return neo4jConnectionSupport.getDriver(); + } + + @Bean + BookmarkCapture bookmarkCapture() { + return new BookmarkCapture(); + } + + @Override + public ReactiveTransactionManager reactiveTransactionManager(Driver driver, + ReactiveDatabaseSelectionProvider databaseNameProvider) { + BookmarkCapture bookmarkCapture = bookmarkCapture(); + return new ReactiveNeo4jTransactionManager(driver, databaseNameProvider, + Neo4jBookmarkManager.create(bookmarkCapture)); + } + + @Override + public boolean isCypher5Compatible() { + return neo4jConnectionSupport.isCypher5SyntaxCompatible(); + } + + } + +} diff --git a/src/test/java/org/springframework/data/neo4j/integration/shared/common/EntityWithVector.java b/src/test/java/org/springframework/data/neo4j/integration/shared/common/EntityWithVector.java new file mode 100644 index 000000000..a9df684d1 --- /dev/null +++ b/src/test/java/org/springframework/data/neo4j/integration/shared/common/EntityWithVector.java @@ -0,0 +1,50 @@ +/* + * Copyright 2011-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.neo4j.integration.shared.common; + +import org.springframework.data.neo4j.core.schema.GeneratedValue; +import org.springframework.data.neo4j.core.schema.Id; +import org.springframework.data.neo4j.core.schema.Node; + +/** + * @author Gerrit Meier + */ +@Node +public class EntityWithVector { + + @Id + @GeneratedValue + String id; + + String name; + + String getId() { + return this.id; + } + + void setId(String id) { + this.id = id; + } + + String getName() { + return this.name; + } + + void setName(String name) { + this.name = name; + } + +} diff --git a/src/test/java/org/springframework/data/neo4j/repository/query/ReactiveRepositoryQueryTests.java b/src/test/java/org/springframework/data/neo4j/repository/query/ReactiveRepositoryQueryTests.java index 805245dc3..e16aa6f6d 100644 --- a/src/test/java/org/springframework/data/neo4j/repository/query/ReactiveRepositoryQueryTests.java +++ b/src/test/java/org/springframework/data/neo4j/repository/query/ReactiveRepositoryQueryTests.java @@ -19,6 +19,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.UnaryOperator; import java.util.regex.Pattern; @@ -29,6 +30,7 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import org.neo4j.cypherdsl.core.renderer.Configuration; import org.neo4j.driver.Values; import org.neo4j.driver.types.Point; import reactor.core.publisher.Flux; @@ -39,6 +41,7 @@ import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.mapping.MappingException; import org.springframework.data.neo4j.core.PreparedQuery; import org.springframework.data.neo4j.core.ReactiveNeo4jOperations; @@ -48,6 +51,7 @@ import org.springframework.data.neo4j.test.LogbackCapturingExtension; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.repository.core.NamedQueries; import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; import org.springframework.data.repository.query.Param; @@ -87,6 +91,9 @@ final class ReactiveRepositoryQueryTests { @Mock private ProjectionFactory projectionFactory; + @Mock + NamedQueries namedQueries; + private ReactiveRepositoryQueryTests() { } @@ -133,6 +140,41 @@ Flux makeStaticThingsDynamic(@Param("aDynamicLabelPt1") String aDyna Mono findByDontDoThisInRealLiveNamed(@Param("location") org.neo4j.driver.types.Point location, @Param("name") String name, @Param("firstName") String aFirstName); + @VectorSearch(indexName = "testIndex", numberOfNodes = 2) + Flux annotatedVectorSearch(Vector vector); + + @VectorSearch(indexName = "testIndex", numberOfNodes = 0) + Flux illegalAnnotatedVectorSearch(Vector vector); + + } + + @Nested + class ReactiveNeo4jPartTreeTest { + + @Test + void findVectorSearchAnnotation() { + + Neo4jQueryMethod neo4jQueryMethod = reactiveNeo4jQueryMethod("annotatedVectorSearch", Vector.class); + + Optional optionalVectorSearchAnnotation = neo4jQueryMethod.getVectorSearchAnnotation(); + assertThat(optionalVectorSearchAnnotation).isPresent(); + } + + @Test + void failOnZeroNodesVectorSearchAnnotation() { + var lookupStrategy = new ReactiveNeo4jQueryLookupStrategy(ReactiveRepositoryQueryTests.this.neo4jOperations, + ReactiveRepositoryQueryTests.this.neo4jMappingContext, ValueExpressionDelegate.create(), + Configuration.defaultConfig()); + + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> lookupStrategy.resolveQuery( + reactiveNeo4jQueryMethod("illegalAnnotatedVectorSearch", Vector.class).getMethod(), + TEST_REPOSITORY_METADATA, PROJECTION_FACTORY, ReactiveRepositoryQueryTests.this.namedQueries)) + .withMessage("Number of nodes in the vector search " + + "org.springframework.data.neo4j.repository.query.ReactiveRepositoryQueryTests$TestRepository#illegalAnnotatedVectorSearch " + + "has to be greater than zero."); + } + } @Nested diff --git a/src/test/java/org/springframework/data/neo4j/repository/query/RepositoryQueryTests.java b/src/test/java/org/springframework/data/neo4j/repository/query/RepositoryQueryTests.java index c9a28c5e4..5690a3b86 100644 --- a/src/test/java/org/springframework/data/neo4j/repository/query/RepositoryQueryTests.java +++ b/src/test/java/org/springframework/data/neo4j/repository/query/RepositoryQueryTests.java @@ -50,6 +50,7 @@ import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.expression.ValueExpressionParser; import org.springframework.data.mapping.MappingException; import org.springframework.data.neo4j.core.Neo4jOperations; @@ -166,6 +167,12 @@ Optional findByDontDoThisInRealLiveNamed(@Param("location") org.neo4 @Query("MATCH (n:Test) WHERE n.name = $0 OR n.name = $1") List annotatedQueryWithValidTemplate(String name, String anotherName); + @VectorSearch(indexName = "testIndex", numberOfNodes = 2) + List annotatedVectorSearch(Vector vector); + + @VectorSearch(indexName = "testIndex", numberOfNodes = 0) + List illegalAnnotatedVectorSearch(Vector vector); + @Query(CUSTOM_CYPHER_QUERY) List annotatedQueryWithValidTemplate(); @@ -271,6 +278,29 @@ void findQueryAnnotation() { assertThat(optionalQueryAnnotation).isPresent(); } + @Test + void findVectorSearchAnnotation() { + + Neo4jQueryMethod neo4jQueryMethod = neo4jQueryMethod("annotatedVectorSearch", Vector.class); + + Optional optionalVectorSearchAnnotation = neo4jQueryMethod.getVectorSearchAnnotation(); + assertThat(optionalVectorSearchAnnotation).isPresent(); + } + + @Test + void failOnZeroNodesVectorSearchAnnotation() { + final Neo4jQueryLookupStrategy lookupStrategy = new Neo4jQueryLookupStrategy( + RepositoryQueryTests.this.neo4jOperations, RepositoryQueryTests.this.neo4jMappingContext, + ValueExpressionDelegate.create(), Configuration.defaultConfig()); + + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> lookupStrategy.resolveQuery(queryMethod("illegalAnnotatedVectorSearch", Vector.class), + TEST_REPOSITORY_METADATA, PROJECTION_FACTORY, RepositoryQueryTests.this.namedQueries)) + .withMessage("Number of nodes in the vector search " + + "org.springframework.data.neo4j.repository.query.RepositoryQueryTests$TestRepository#illegalAnnotatedVectorSearch " + + "has to be greater than zero."); + } + @Test void streamQueriesShouldBeTreatedAsCollectionQueries() {