Skip to content

Commit d1b7692

Browse files
committed
GH-3002 - Introduce vector search on repository level
Signed-off-by: Gerrit Meier <[email protected]> Closes #3002
1 parent aed6ab6 commit d1b7692

22 files changed

+877
-12
lines changed

src/main/antora/modules/ROOT/nav.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
** xref:repositories/sdn-extension.adoc[]
2727
** xref:repositories/query-keywords-reference.adoc[]
2828
** xref:repositories/query-return-types-reference.adoc[]
29+
** xref:repositories/vector-search.adoc[]
2930
3031
* xref:repositories/projections.adoc[]
3132
** xref:projections/sdn-projections.adoc[]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
[[sdn-vector-search]]
2+
= Neo4j Vector Search
3+
4+
== The `@VectorSearch` annotation
5+
Spring Data Neo4j supports Neo4j's vector search on the repository level by using the `@VectorSearch` annotation.
6+
For this to work, Neo4j needs to have a vector index in place.
7+
How to create a vector index is explained in the https://neo4j.com/docs/cypher-manual/current/indexes/search-performance-indexes/managing-indexes/[Neo4j documentation].
8+
9+
NOTE: It's not required to have any (Spring Data) Vector typed property be defined in the domain entities for this to work
10+
because the search operates exclusively on the index.
11+
12+
The `@VectorSearch` annotation requires two arguments:
13+
The name of the vector index to be used and the number of nearest neighbours.
14+
15+
For a general vector search over the whole domain, it's possible to use a derived finder method without any property.
16+
[source,java,indent=0,tabsize=4]
17+
----
18+
include::example$integration/imperative/VectorSearchIT.java[tags=sdn-vector-search.usage;sdn-vector-search.usage.findall]
19+
----
20+
21+
The vector index can be combined with any property-based finder method to filter down the results.
22+
23+
NOTE: For technical reasons, the vector search will always be executed before the property search gets invoked.
24+
E.g. if the property filter looks for a person named "Helge",
25+
but the vector search only yields "Hannes", there won't be a result.
26+
27+
[source,java,indent=0,tabsize=4]
28+
----
29+
include::example$integration/imperative/VectorSearchIT.java[tags=sdn-vector-search.usage;sdn-vector-search.usage.findbyproperty]
30+
----

src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,12 @@ private Optional<Neo4jClient.RecordFetchSpec<T>> createFetchSpec() {
14891489
statement = nodesAndRelationshipsById.toStatement(entityMetaData);
14901490
}
14911491
else {
1492-
statement = queryFragments.toStatement();
1492+
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
1493+
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
1494+
}
1495+
else {
1496+
statement = queryFragments.toStatement();
1497+
}
14931498
}
14941499
cypherQuery = Neo4jTemplate.this.renderer.render(statement);
14951500
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);

src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,8 +1408,13 @@ public <T> Mono<ExecutableQuery<T>> toExecutableQuery(PreparedQuery<T> preparedQ
14081408
return new DefaultReactiveExecutableQuery<>(preparedQuery, fetchSpec);
14091409
});
14101410
}
1411-
1412-
Statement statement = queryFragments.toStatement();
1411+
Statement statement = null;
1412+
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
1413+
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
1414+
}
1415+
else {
1416+
statement = queryFragments.toStatement();
1417+
}
14131418
cypherQuery = this.renderer.render(statement);
14141419
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);
14151420
}

src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,21 @@ public final class Constants {
163163
*/
164164
public static final String TO_ID_PARAMETER_NAME = "toId";
165165

166+
/**
167+
* The name SDN uses for vector search score.
168+
*/
169+
public static final String NAME_OF_SCORE = "__score__";
170+
171+
/**
172+
* Vector search vector parameter name.
173+
*/
174+
public static final String VECTOR_SEARCH_VECTOR_PARAMETER = "vectorSearchParam";
175+
176+
/**
177+
* Vector search score parameter name.
178+
*/
179+
public static final String VECTOR_SEARCH_SCORE_PARAMETER = "scoreParam";
180+
166181
private Constants() {
167182
}
168183

src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.function.LongSupplier;
2424
import java.util.function.Supplier;
2525
import java.util.function.UnaryOperator;
26+
import java.util.stream.Collectors;
2627

2728
import org.jspecify.annotations.Nullable;
2829
import org.neo4j.driver.types.MapAccessor;
@@ -32,6 +33,8 @@
3233
import org.springframework.data.domain.Page;
3334
import org.springframework.data.domain.PageRequest;
3435
import org.springframework.data.domain.Pageable;
36+
import org.springframework.data.domain.SearchResult;
37+
import org.springframework.data.domain.SearchResults;
3538
import org.springframework.data.domain.Slice;
3639
import org.springframework.data.domain.SliceImpl;
3740
import org.springframework.data.geo.GeoPage;
@@ -98,11 +101,30 @@ boolean isGeoNearQuery() {
98101
return GeoPage.class.isAssignableFrom(returnType);
99102
}
100103

104+
boolean isVectorSearchQuery() {
105+
var repositoryMethod = this.queryMethod.getMethod();
106+
Class<?> returnType = repositoryMethod.getReturnType();
107+
108+
for (Class<?> type : Neo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
109+
if (type.isAssignableFrom(returnType)) {
110+
return true;
111+
}
112+
}
113+
114+
if (Iterable.class.isAssignableFrom(returnType)) {
115+
TypeInformation<?> from = TypeInformation.fromReturnTypeOf(repositoryMethod);
116+
return from.getComponentType() != null && SearchResult.class.equals(from.getComponentType().getType());
117+
}
118+
119+
return false;
120+
}
121+
101122
@Override
102123
@Nullable public final Object execute(Object[] parameters) {
103124

104125
boolean incrementLimit = this.queryMethod.incrementLimit();
105126
boolean geoNearQuery = isGeoNearQuery();
127+
boolean vectorSearchQuery = isVectorSearchQuery();
106128
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor(
107129
(Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters);
108130

@@ -111,7 +133,7 @@ boolean isGeoNearQuery() {
111133
ReturnedType returnedType = resultProcessor.getReturnedType();
112134
PreparedQuery<?> preparedQuery = prepareQuery(returnedType.getReturnedType(),
113135
PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext),
114-
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
136+
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery),
115137
incrementLimit ? l -> l + 1 : UnaryOperator.identity());
116138

117139
Object rawResult = new Neo4jQueryExecution.DefaultQueryExecution(this.neo4jOperations).execute(preparedQuery,
@@ -143,6 +165,9 @@ else if (this.queryMethod.isScrollQuery()) {
143165
else if (geoNearQuery) {
144166
rawResult = newGeoResults(rawResult);
145167
}
168+
else if (this.queryMethod.isSearchQuery()) {
169+
rawResult = createSearchResult((List<?>) rawResult, returnedType.getReturnedType());
170+
}
146171

147172
return resultProcessor.processResult(rawResult, preparingConverter);
148173
}
@@ -182,6 +207,13 @@ private Slice<?> createSlice(boolean incrementLimit, Neo4jParameterAccessor para
182207
}
183208
}
184209

210+
private <T> SearchResults<?> createSearchResult(List<?> rawResult, Class<T> returnedType) {
211+
List<SearchResult<T>> searchResults = rawResult.stream()
212+
.map(rawValue -> (SearchResult<T>) rawValue)
213+
.collect(Collectors.toUnmodifiableList());
214+
return new SearchResults<>(searchResults);
215+
}
216+
185217
protected abstract <T> PreparedQuery<T> prepareQuery(Class<T> returnedType,
186218
Collection<PropertyFilter.ProjectedPath> includedProperties, Neo4jParameterAccessor parameterAccessor,
187219
@Nullable Neo4jQueryType queryType,

src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import org.neo4j.driver.types.MapAccessor;
2525
import org.neo4j.driver.types.TypeSystem;
2626
import reactor.core.publisher.Flux;
27+
import reactor.core.publisher.Mono;
2728

2829
import org.springframework.core.convert.converter.Converter;
30+
import org.springframework.data.domain.SearchResult;
2931
import org.springframework.data.geo.GeoResult;
3032
import org.springframework.data.neo4j.core.PreparedQuery;
3133
import org.springframework.data.neo4j.core.PropertyFilterSupport;
@@ -89,11 +91,31 @@ boolean isGeoNearQuery() {
8991
return false;
9092
}
9193

94+
boolean isVectorSearchQuery() {
95+
var repositoryMethod = this.queryMethod.getMethod();
96+
Class<?> returnType = repositoryMethod.getReturnType();
97+
98+
for (Class<?> type : ReactiveNeo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
99+
if (type.isAssignableFrom(returnType)) {
100+
return true;
101+
}
102+
}
103+
104+
if (Flux.class.isAssignableFrom(returnType) || Mono.class.isAssignableFrom(returnType)) {
105+
TypeInformation<?> from = TypeInformation.fromReturnTypeOf(repositoryMethod);
106+
TypeInformation<?> componentType = from.getComponentType();
107+
return componentType != null && SearchResult.class.equals(componentType.getType());
108+
}
109+
110+
return false;
111+
}
112+
92113
@Override
93114
@Nullable public final Object execute(Object[] parameters) {
94115

95116
boolean incrementLimit = this.queryMethod.incrementLimit();
96117
boolean geoNearQuery = isGeoNearQuery();
118+
boolean vectorSearchQuery = isVectorSearchQuery();
97119
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor(
98120
(Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters);
99121
ResultProcessor resultProcessor = this.queryMethod.getResultProcessor()
@@ -102,7 +124,7 @@ boolean isGeoNearQuery() {
102124
ReturnedType returnedType = resultProcessor.getReturnedType();
103125
PreparedQuery<?> preparedQuery = prepareQuery(returnedType.getReturnedType(),
104126
PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext),
105-
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
127+
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery),
106128
incrementLimit ? l -> l + 1 : UnaryOperator.identity());
107129

108130
Object rawResult = new Neo4jQueryExecution.ReactiveQueryExecution(this.neo4jOperations).execute(preparedQuery,
@@ -126,10 +148,17 @@ parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
126148
.map(rawResultList -> createWindow(resultProcessor, incrementLimit, parameterAccessor, rawResultList,
127149
preparedQuery.getQueryFragmentsAndParameters()));
128150
}
151+
else if (this.queryMethod.isSearchQuery()) {
152+
rawResult = createSearchResult((Flux<?>) rawResult, returnedType.getReturnedType());
153+
}
129154

130155
return resultProcessor.processResult(rawResult, preparingConverter);
131156
}
132157

158+
private <T> Flux<SearchResult<?>> createSearchResult(Flux<?> rawResult, Class<T> returnedType) {
159+
return rawResult.map(rawValue -> (SearchResult<T>) rawValue);
160+
}
161+
133162
protected abstract <T extends Object> PreparedQuery<T> prepareQuery(Class<T> returnedType,
134163
Collection<PropertyFilter.ProjectedPath> includedProperties, Neo4jParameterAccessor parameterAccessor,
135164
@Nullable Neo4jQueryType queryType, Supplier<BiFunction<TypeSystem, MapAccessor, ?>> mappingFunction,

src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@
4545
import org.springframework.data.domain.OffsetScrollPosition;
4646
import org.springframework.data.domain.Pageable;
4747
import org.springframework.data.domain.Range;
48+
import org.springframework.data.domain.Score;
4849
import org.springframework.data.domain.ScrollPosition;
4950
import org.springframework.data.domain.Sort;
51+
import org.springframework.data.domain.Vector;
5052
import org.springframework.data.geo.Box;
5153
import org.springframework.data.geo.Circle;
5254
import org.springframework.data.geo.Distance;
@@ -60,7 +62,6 @@
6062
import org.springframework.data.neo4j.core.mapping.Neo4jPersistentProperty;
6163
import org.springframework.data.neo4j.core.mapping.NodeDescription;
6264
import org.springframework.data.neo4j.core.mapping.PropertyFilter;
63-
import org.springframework.data.repository.query.QueryMethod;
6465
import org.springframework.data.repository.query.parser.AbstractQueryCreator;
6566
import org.springframework.data.repository.query.parser.Part;
6667
import org.springframework.data.repository.query.parser.PartTree;
@@ -119,14 +120,22 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar
119120

120121
private final boolean keysetRequiresSort;
121122

122-
private final List<Expression> distanceExpressions = new ArrayList<>();
123+
private final List<Expression> additionalReturnExpression = new ArrayList<>();
123124

124125
/**
125126
* Can be used to modify the limit of a paged or sliced query.
126127
*/
127128
private final UnaryOperator<Integer> limitModifier;
128129

129-
CypherQueryCreator(Neo4jMappingContext mappingContext, QueryMethod queryMethod, Class<?> domainType,
130+
private final Neo4jQueryMethod queryMethod;
131+
132+
@Nullable
133+
private final Vector vectorSearchParameter;
134+
135+
@Nullable
136+
private final Score scoreParameter;
137+
138+
CypherQueryCreator(Neo4jMappingContext mappingContext, Neo4jQueryMethod queryMethod, Class<?> domainType,
130139
Neo4jQueryType queryType, PartTree tree, Neo4jParameterAccessor actualParameters,
131140
Collection<PropertyFilter.ProjectedPath> includedProperties,
132141
BiFunction<Object, Neo4jPersistentPropertyConverter<?>, Object> parameterConversion,
@@ -148,6 +157,8 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar
148157

149158
this.pagingParameter = actualParameters.getPageable();
150159
this.scrollPosition = actualParameters.getScrollPosition();
160+
this.vectorSearchParameter = actualParameters.getVector();
161+
this.scoreParameter = actualParameters.getScore();
151162
this.limitModifier = limitModifier;
152163

153164
AtomicInteger symbolicNameIndex = new AtomicInteger();
@@ -160,6 +171,7 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar
160171

161172
this.keysetRequiresSort = queryMethod.isScrollQuery()
162173
&& actualParameters.getScrollPosition() instanceof KeysetScrollPosition;
174+
this.queryMethod = queryMethod;
163175
}
164176

165177
@Override
@@ -196,6 +208,21 @@ protected QueryFragmentsAndParameters complete(@Nullable Condition condition, So
196208
if (this.keysetRequiresSort && theSort.isUnsorted()) {
197209
throw new UnsupportedOperationException("Unsorted keyset based scrolling is not supported.");
198210
}
211+
if (this.queryMethod.hasVectorSearchAnnotation()) {
212+
var vectorSearchAnnotation = this.queryMethod.getVectorSearchAnnotation().orElseThrow();
213+
var indexName = vectorSearchAnnotation.indexName();
214+
var numberOfNodes = vectorSearchAnnotation.numberOfNodes();
215+
convertedParameters.put(Constants.VECTOR_SEARCH_VECTOR_PARAMETER,
216+
this.vectorSearchParameter.toDoubleArray());
217+
if (this.scoreParameter != null) {
218+
convertedParameters.put(Constants.VECTOR_SEARCH_SCORE_PARAMETER, this.scoreParameter.getValue());
219+
}
220+
var vectorSearchFragment = new VectorSearchFragment(indexName, numberOfNodes,
221+
(this.scoreParameter != null) ? this.scoreParameter.getValue() : null);
222+
var queryFragmentsAndParameters = new QueryFragmentsAndParameters(this.nodeDescription, queryFragments,
223+
vectorSearchFragment, convertedParameters, theSort);
224+
return queryFragmentsAndParameters;
225+
}
199226
return new QueryFragmentsAndParameters(this.nodeDescription, queryFragments, convertedParameters, theSort);
200227
}
201228

@@ -274,11 +301,14 @@ else if (this.scrollPosition instanceof OffsetScrollPosition offsetScrollPositio
274301
? this.maxResults.intValue() : this.pagingParameter.getPageSize()));
275302
}
276303

304+
if (this.queryMethod.hasVectorSearchAnnotation()) {
305+
this.additionalReturnExpression.add(Cypher.name(Constants.NAME_OF_SCORE));
306+
}
277307
var finalSortItems = new ArrayList<>(this.sortItems);
278308
theSort.stream().map(CypherAdapterUtils.sortAdapterFor(this.nodeDescription)).forEach(finalSortItems::add);
279309

280310
queryFragments.setReturnBasedOn(this.nodeDescription, this.includedProperties, this.isDistinct,
281-
this.distanceExpressions);
311+
this.additionalReturnExpression);
282312
queryFragments.setOrderBy(finalSortItems);
283313
}
284314

@@ -438,7 +468,7 @@ else if (p2.isPresent() && p2.get().value instanceof Point) {
438468
// property to be later retrieved and mapped
439469
Neo4jPersistentEntity<?> owner = (Neo4jPersistentEntity<?>) leafProperty.getOwner();
440470
String containerName = getContainerName(path, owner);
441-
this.distanceExpressions
471+
this.additionalReturnExpression
442472
.add(distanceFunction.as("__distance_" + containerName + "_" + leafProperty.getPropertyName() + "__"));
443473

444474
this.sortItems.add(distanceFunction.ascending());

0 commit comments

Comments
 (0)