diff --git a/pom.xml b/pom.xml
index a079df8c0..c4dc98a36 100644
--- a/pom.xml
+++ b/pom.xml
@@ -27,7 +27,7 @@
org.springframework.data
spring-data-neo4j
- 8.0.0-SNAPSHOT
+ 8.0.0-GH-3002-SNAPSHOT
Spring Data Neo4j
Next generation Object-Graph-Mapping for Spring Data.
diff --git a/src/main/antora/modules/ROOT/nav.adoc b/src/main/antora/modules/ROOT/nav.adoc
index dc4006628..30033d420 100644
--- a/src/main/antora/modules/ROOT/nav.adoc
+++ b/src/main/antora/modules/ROOT/nav.adoc
@@ -26,6 +26,7 @@
** xref:repositories/sdn-extension.adoc[]
** xref:repositories/query-keywords-reference.adoc[]
** xref:repositories/query-return-types-reference.adoc[]
+** xref:repositories/vector-search.adoc[]
* xref:repositories/projections.adoc[]
** xref:projections/sdn-projections.adoc[]
diff --git a/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc b/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc
new file mode 100644
index 000000000..7d5116b86
--- /dev/null
+++ b/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc
@@ -0,0 +1,30 @@
+[[sdn-vector-search]]
+= Neo4j Vector Search
+
+== The `@VectorSearch` annotation
+Spring Data Neo4j supports Neo4j's vector search on the repository level by using the `@VectorSearch` annotation.
+For this to work, Neo4j needs to have a vector index in place.
+How to create a vector index is explained in the https://neo4j.com/docs/cypher-manual/current/indexes/search-performance-indexes/managing-indexes/[Neo4j documentation].
+
+NOTE: It's not required to have any (Spring Data) Vector typed property be defined in the domain entities for this to work
+because the search operates exclusively on the index.
+
+The `@VectorSearch` annotation requires two arguments:
+The name of the vector index to be used and the number of nearest neighbours.
+
+For a general vector search over the whole domain, it's possible to use a derived finder method without any property.
+[source,java,indent=0,tabsize=4]
+----
+include::example$integration/imperative/VectorSearchIT.java[tags=sdn-vector-search.usage;sdn-vector-search.usage.findall]
+----
+
+The vector index can be combined with any property-based finder method to filter down the results.
+
+NOTE: For technical reasons, the vector search will always be executed before the property search gets invoked.
+E.g. if the property filter looks for a person named "Helge",
+but the vector search only yields "Hannes", there won't be a result.
+
+[source,java,indent=0,tabsize=4]
+----
+include::example$integration/imperative/VectorSearchIT.java[tags=sdn-vector-search.usage;sdn-vector-search.usage.findbyproperty]
+----
diff --git a/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java b/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java
index e069dc85b..3d0ae1079 100644
--- a/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java
+++ b/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java
@@ -1489,7 +1489,12 @@ private Optional> createFetchSpec() {
statement = nodesAndRelationshipsById.toStatement(entityMetaData);
}
else {
- statement = queryFragments.toStatement();
+ if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
+ statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
+ }
+ else {
+ statement = queryFragments.toStatement();
+ }
}
cypherQuery = Neo4jTemplate.this.renderer.render(statement);
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);
diff --git a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java
index 099e8c340..b7b1a7b25 100644
--- a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java
+++ b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java
@@ -1408,8 +1408,13 @@ public Mono> toExecutableQuery(PreparedQuery preparedQ
return new DefaultReactiveExecutableQuery<>(preparedQuery, fetchSpec);
});
}
-
- Statement statement = queryFragments.toStatement();
+ Statement statement = null;
+ if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
+ statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
+ }
+ else {
+ statement = queryFragments.toStatement();
+ }
cypherQuery = this.renderer.render(statement);
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);
}
diff --git a/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java b/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java
index a7a36ce0e..d059ede5f 100644
--- a/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java
+++ b/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java
@@ -163,6 +163,21 @@ public final class Constants {
*/
public static final String TO_ID_PARAMETER_NAME = "toId";
+ /**
+ * The name SDN uses for vector search score.
+ */
+ public static final String NAME_OF_SCORE = "__score__";
+
+ /**
+ * Vector search vector parameter name.
+ */
+ public static final String VECTOR_SEARCH_VECTOR_PARAMETER = "vectorSearchParam";
+
+ /**
+ * Vector search score parameter name.
+ */
+ public static final String VECTOR_SEARCH_SCORE_PARAMETER = "scoreParam";
+
private Constants() {
}
diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java
index cb4455891..934e3e117 100644
--- a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java
+++ b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java
@@ -23,6 +23,7 @@
import java.util.function.LongSupplier;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;
+import java.util.stream.Collectors;
import org.jspecify.annotations.Nullable;
import org.neo4j.driver.types.MapAccessor;
@@ -32,6 +33,8 @@
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
+import org.springframework.data.domain.SearchResult;
+import org.springframework.data.domain.SearchResults;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.SliceImpl;
import org.springframework.data.geo.GeoPage;
@@ -98,11 +101,30 @@ boolean isGeoNearQuery() {
return GeoPage.class.isAssignableFrom(returnType);
}
+ boolean isVectorSearchQuery() {
+ var repositoryMethod = this.queryMethod.getMethod();
+ Class> returnType = repositoryMethod.getReturnType();
+
+ for (Class> type : Neo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
+ if (type.isAssignableFrom(returnType)) {
+ return true;
+ }
+ }
+
+ if (Iterable.class.isAssignableFrom(returnType)) {
+ TypeInformation> from = TypeInformation.fromReturnTypeOf(repositoryMethod);
+ return from.getComponentType() != null && SearchResult.class.equals(from.getComponentType().getType());
+ }
+
+ return false;
+ }
+
@Override
@Nullable public final Object execute(Object[] parameters) {
boolean incrementLimit = this.queryMethod.incrementLimit();
boolean geoNearQuery = isGeoNearQuery();
+ boolean vectorSearchQuery = isVectorSearchQuery();
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor(
(Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters);
@@ -111,7 +133,7 @@ boolean isGeoNearQuery() {
ReturnedType returnedType = resultProcessor.getReturnedType();
PreparedQuery> preparedQuery = prepareQuery(returnedType.getReturnedType(),
PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext),
- parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
+ parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery),
incrementLimit ? l -> l + 1 : UnaryOperator.identity());
Object rawResult = new Neo4jQueryExecution.DefaultQueryExecution(this.neo4jOperations).execute(preparedQuery,
@@ -143,6 +165,9 @@ else if (this.queryMethod.isScrollQuery()) {
else if (geoNearQuery) {
rawResult = newGeoResults(rawResult);
}
+ else if (this.queryMethod.isSearchQuery()) {
+ rawResult = createSearchResult((List>) rawResult, returnedType.getReturnedType());
+ }
return resultProcessor.processResult(rawResult, preparingConverter);
}
@@ -182,6 +207,13 @@ private Slice> createSlice(boolean incrementLimit, Neo4jParameterAccessor para
}
}
+ private SearchResults> createSearchResult(List> rawResult, Class returnedType) {
+ List> searchResults = rawResult.stream()
+ .map(rawValue -> (SearchResult) rawValue)
+ .collect(Collectors.toUnmodifiableList());
+ return new SearchResults<>(searchResults);
+ }
+
protected abstract PreparedQuery prepareQuery(Class returnedType,
Collection includedProperties, Neo4jParameterAccessor parameterAccessor,
@Nullable Neo4jQueryType queryType,
diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java
index 5274e4a35..60edae58e 100644
--- a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java
+++ b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java
@@ -24,8 +24,10 @@
import org.neo4j.driver.types.MapAccessor;
import org.neo4j.driver.types.TypeSystem;
import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter;
+import org.springframework.data.domain.SearchResult;
import org.springframework.data.geo.GeoResult;
import org.springframework.data.neo4j.core.PreparedQuery;
import org.springframework.data.neo4j.core.PropertyFilterSupport;
@@ -89,11 +91,31 @@ boolean isGeoNearQuery() {
return false;
}
+ boolean isVectorSearchQuery() {
+ var repositoryMethod = this.queryMethod.getMethod();
+ Class> returnType = repositoryMethod.getReturnType();
+
+ for (Class> type : ReactiveNeo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
+ if (type.isAssignableFrom(returnType)) {
+ return true;
+ }
+ }
+
+ if (Flux.class.isAssignableFrom(returnType) || Mono.class.isAssignableFrom(returnType)) {
+ TypeInformation> from = TypeInformation.fromReturnTypeOf(repositoryMethod);
+ TypeInformation> componentType = from.getComponentType();
+ return componentType != null && SearchResult.class.equals(componentType.getType());
+ }
+
+ return false;
+ }
+
@Override
@Nullable public final Object execute(Object[] parameters) {
boolean incrementLimit = this.queryMethod.incrementLimit();
boolean geoNearQuery = isGeoNearQuery();
+ boolean vectorSearchQuery = isVectorSearchQuery();
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor(
(Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters);
ResultProcessor resultProcessor = this.queryMethod.getResultProcessor()
@@ -102,7 +124,7 @@ boolean isGeoNearQuery() {
ReturnedType returnedType = resultProcessor.getReturnedType();
PreparedQuery> preparedQuery = prepareQuery(returnedType.getReturnedType(),
PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext),
- parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
+ parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery),
incrementLimit ? l -> l + 1 : UnaryOperator.identity());
Object rawResult = new Neo4jQueryExecution.ReactiveQueryExecution(this.neo4jOperations).execute(preparedQuery,
@@ -126,10 +148,17 @@ parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
.map(rawResultList -> createWindow(resultProcessor, incrementLimit, parameterAccessor, rawResultList,
preparedQuery.getQueryFragmentsAndParameters()));
}
+ else if (this.queryMethod.isSearchQuery()) {
+ rawResult = createSearchResult((Flux>) rawResult, returnedType.getReturnedType());
+ }
return resultProcessor.processResult(rawResult, preparingConverter);
}
+ private Flux> createSearchResult(Flux> rawResult, Class returnedType) {
+ return rawResult.map(rawValue -> (SearchResult) rawValue);
+ }
+
protected abstract PreparedQuery prepareQuery(Class returnedType,
Collection includedProperties, Neo4jParameterAccessor parameterAccessor,
@Nullable Neo4jQueryType queryType, Supplier> mappingFunction,
diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java b/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java
index a1785bb76..497b64c9a 100644
--- a/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java
+++ b/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java
@@ -45,8 +45,10 @@
import org.springframework.data.domain.OffsetScrollPosition;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Range;
+import org.springframework.data.domain.Score;
import org.springframework.data.domain.ScrollPosition;
import org.springframework.data.domain.Sort;
+import org.springframework.data.domain.Vector;
import org.springframework.data.geo.Box;
import org.springframework.data.geo.Circle;
import org.springframework.data.geo.Distance;
@@ -60,7 +62,6 @@
import org.springframework.data.neo4j.core.mapping.Neo4jPersistentProperty;
import org.springframework.data.neo4j.core.mapping.NodeDescription;
import org.springframework.data.neo4j.core.mapping.PropertyFilter;
-import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.repository.query.parser.AbstractQueryCreator;
import org.springframework.data.repository.query.parser.Part;
import org.springframework.data.repository.query.parser.PartTree;
@@ -119,14 +120,22 @@ final class CypherQueryCreator extends AbstractQueryCreator distanceExpressions = new ArrayList<>();
+ private final List additionalReturnExpression = new ArrayList<>();
/**
* Can be used to modify the limit of a paged or sliced query.
*/
private final UnaryOperator limitModifier;
- CypherQueryCreator(Neo4jMappingContext mappingContext, QueryMethod queryMethod, Class> domainType,
+ private final Neo4jQueryMethod queryMethod;
+
+ @Nullable
+ private final Vector vectorSearchParameter;
+
+ @Nullable
+ private final Score scoreParameter;
+
+ CypherQueryCreator(Neo4jMappingContext mappingContext, Neo4jQueryMethod queryMethod, Class> domainType,
Neo4jQueryType queryType, PartTree tree, Neo4jParameterAccessor actualParameters,
Collection includedProperties,
BiFunction