Skip to content

Commit bfee596

Browse files
committed
Vector search on repository level
Signed-off-by: Gerrit Meier <[email protected]>
1 parent 7162881 commit bfee596

File tree

15 files changed

+583
-10
lines changed

15 files changed

+583
-10
lines changed

src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,12 @@ private Optional<Neo4jClient.RecordFetchSpec<T>> createFetchSpec() {
14891489
statement = nodesAndRelationshipsById.toStatement(entityMetaData);
14901490
}
14911491
else {
1492-
statement = queryFragments.toStatement();
1492+
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
1493+
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
1494+
}
1495+
else {
1496+
statement = queryFragments.toStatement();
1497+
}
14931498
}
14941499
cypherQuery = Neo4jTemplate.this.renderer.render(statement);
14951500
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);

src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,21 @@ public final class Constants {
163163
*/
164164
public static final String TO_ID_PARAMETER_NAME = "toId";
165165

166+
/**
167+
* The name SDN uses for vector search score.
168+
*/
169+
public static final String NAME_OF_SCORE = "__score__";
170+
171+
/**
172+
* Vector search vector parameter name.
173+
*/
174+
public static final String VECTOR_SEARCH_VECTOR_PARAMETER = "__vector_search__";
175+
176+
/**
177+
* Vector search score parameter name.
178+
*/
179+
public static final String VECTOR_SEARCH_SCORE_PARAMETER = "__score__";
180+
166181
private Constants() {
167182
}
168183

src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.function.LongSupplier;
2424
import java.util.function.Supplier;
2525
import java.util.function.UnaryOperator;
26+
import java.util.stream.Collectors;
2627

2728
import org.jspecify.annotations.Nullable;
2829
import org.neo4j.driver.types.MapAccessor;
@@ -32,6 +33,8 @@
3233
import org.springframework.data.domain.Page;
3334
import org.springframework.data.domain.PageRequest;
3435
import org.springframework.data.domain.Pageable;
36+
import org.springframework.data.domain.SearchResult;
37+
import org.springframework.data.domain.SearchResults;
3538
import org.springframework.data.domain.Slice;
3639
import org.springframework.data.domain.SliceImpl;
3740
import org.springframework.data.geo.GeoPage;
@@ -98,11 +101,30 @@ boolean isGeoNearQuery() {
98101
return GeoPage.class.isAssignableFrom(returnType);
99102
}
100103

104+
boolean isVectorSearchQuery() {
105+
var repositoryMethod = this.queryMethod.getMethod();
106+
Class<?> returnType = repositoryMethod.getReturnType();
107+
108+
for (Class<?> type : Neo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
109+
if (type.isAssignableFrom(returnType)) {
110+
return true;
111+
}
112+
}
113+
114+
if (Iterable.class.isAssignableFrom(returnType)) {
115+
TypeInformation<?> from = TypeInformation.fromReturnTypeOf(repositoryMethod);
116+
return from.getComponentType() != null && SearchResult.class.equals(from.getComponentType().getType());
117+
}
118+
119+
return false;
120+
}
121+
101122
@Override
102123
@Nullable public final Object execute(Object[] parameters) {
103124

104125
boolean incrementLimit = this.queryMethod.incrementLimit();
105126
boolean geoNearQuery = isGeoNearQuery();
127+
boolean vectorSearchQuery = isVectorSearchQuery();
106128
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor(
107129
(Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters);
108130

@@ -111,7 +133,7 @@ boolean isGeoNearQuery() {
111133
ReturnedType returnedType = resultProcessor.getReturnedType();
112134
PreparedQuery<?> preparedQuery = prepareQuery(returnedType.getReturnedType(),
113135
PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext),
114-
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
136+
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery),
115137
incrementLimit ? l -> l + 1 : UnaryOperator.identity());
116138

117139
Object rawResult = new Neo4jQueryExecution.DefaultQueryExecution(this.neo4jOperations).execute(preparedQuery,
@@ -143,6 +165,9 @@ else if (this.queryMethod.isScrollQuery()) {
143165
else if (geoNearQuery) {
144166
rawResult = newGeoResults(rawResult);
145167
}
168+
else if (this.queryMethod.isSearchQuery()) {
169+
rawResult = createSearchResult((List<?>) rawResult, returnedType.getReturnedType());
170+
}
146171

147172
return resultProcessor.processResult(rawResult, preparingConverter);
148173
}
@@ -182,6 +207,13 @@ private Slice<?> createSlice(boolean incrementLimit, Neo4jParameterAccessor para
182207
}
183208
}
184209

210+
private <T> SearchResults<?> createSearchResult(List<?> rawResult, Class<T> returnedType) {
211+
List<SearchResult<T>> searchResults = rawResult.stream()
212+
.map(rawValue -> (SearchResult<T>) rawValue)
213+
.collect(Collectors.toUnmodifiableList());
214+
return new SearchResults<>(searchResults);
215+
}
216+
185217
protected abstract <T> PreparedQuery<T> prepareQuery(Class<T> returnedType,
186218
Collection<PropertyFilter.ProjectedPath> includedProperties, Neo4jParameterAccessor parameterAccessor,
187219
@Nullable Neo4jQueryType queryType,

src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import reactor.core.publisher.Flux;
2727

2828
import org.springframework.core.convert.converter.Converter;
29+
import org.springframework.data.domain.SearchResult;
2930
import org.springframework.data.geo.GeoResult;
3031
import org.springframework.data.neo4j.core.PreparedQuery;
3132
import org.springframework.data.neo4j.core.PropertyFilterSupport;
@@ -89,11 +90,31 @@ boolean isGeoNearQuery() {
8990
return false;
9091
}
9192

93+
boolean isVectorSearchQuery() {
94+
var repositoryMethod = this.queryMethod.getMethod();
95+
Class<?> returnType = repositoryMethod.getReturnType();
96+
97+
for (Class<?> type : Neo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
98+
if (type.isAssignableFrom(returnType)) {
99+
return true;
100+
}
101+
}
102+
103+
if (Flux.class.isAssignableFrom(returnType)) {
104+
TypeInformation<?> from = TypeInformation.fromReturnTypeOf(repositoryMethod);
105+
TypeInformation<?> componentType = from.getComponentType();
106+
return componentType != null && SearchResult.class.equals(componentType.getType());
107+
}
108+
109+
return false;
110+
}
111+
92112
@Override
93113
@Nullable public final Object execute(Object[] parameters) {
94114

95115
boolean incrementLimit = this.queryMethod.incrementLimit();
96116
boolean geoNearQuery = isGeoNearQuery();
117+
boolean vectorSearchQuery = isVectorSearchQuery();
97118
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor(
98119
(Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters);
99120
ResultProcessor resultProcessor = this.queryMethod.getResultProcessor()
@@ -102,7 +123,7 @@ boolean isGeoNearQuery() {
102123
ReturnedType returnedType = resultProcessor.getReturnedType();
103124
PreparedQuery<?> preparedQuery = prepareQuery(returnedType.getReturnedType(),
104125
PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext),
105-
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
126+
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery),
106127
incrementLimit ? l -> l + 1 : UnaryOperator.identity());
107128

108129
Object rawResult = new Neo4jQueryExecution.ReactiveQueryExecution(this.neo4jOperations).execute(preparedQuery,

src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@
4545
import org.springframework.data.domain.OffsetScrollPosition;
4646
import org.springframework.data.domain.Pageable;
4747
import org.springframework.data.domain.Range;
48+
import org.springframework.data.domain.Score;
4849
import org.springframework.data.domain.ScrollPosition;
4950
import org.springframework.data.domain.Sort;
51+
import org.springframework.data.domain.Vector;
5052
import org.springframework.data.geo.Box;
5153
import org.springframework.data.geo.Circle;
5254
import org.springframework.data.geo.Distance;
@@ -60,7 +62,6 @@
6062
import org.springframework.data.neo4j.core.mapping.Neo4jPersistentProperty;
6163
import org.springframework.data.neo4j.core.mapping.NodeDescription;
6264
import org.springframework.data.neo4j.core.mapping.PropertyFilter;
63-
import org.springframework.data.repository.query.QueryMethod;
6465
import org.springframework.data.repository.query.parser.AbstractQueryCreator;
6566
import org.springframework.data.repository.query.parser.Part;
6667
import org.springframework.data.repository.query.parser.PartTree;
@@ -119,14 +120,22 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar
119120

120121
private final boolean keysetRequiresSort;
121122

122-
private final List<Expression> distanceExpressions = new ArrayList<>();
123+
private final List<Expression> additionalReturnExpression = new ArrayList<>();
123124

124125
/**
125126
* Can be used to modify the limit of a paged or sliced query.
126127
*/
127128
private final UnaryOperator<Integer> limitModifier;
128129

129-
CypherQueryCreator(Neo4jMappingContext mappingContext, QueryMethod queryMethod, Class<?> domainType,
130+
private final Neo4jQueryMethod queryMethod;
131+
132+
@Nullable
133+
private final Vector vectorSearchParameter;
134+
135+
@Nullable
136+
private final Score scoreParameter;
137+
138+
CypherQueryCreator(Neo4jMappingContext mappingContext, Neo4jQueryMethod queryMethod, Class<?> domainType,
130139
Neo4jQueryType queryType, PartTree tree, Neo4jParameterAccessor actualParameters,
131140
Collection<PropertyFilter.ProjectedPath> includedProperties,
132141
BiFunction<Object, Neo4jPersistentPropertyConverter<?>, Object> parameterConversion,
@@ -148,6 +157,8 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar
148157

149158
this.pagingParameter = actualParameters.getPageable();
150159
this.scrollPosition = actualParameters.getScrollPosition();
160+
this.vectorSearchParameter = actualParameters.getVector();
161+
this.scoreParameter = actualParameters.getScore();
151162
this.limitModifier = limitModifier;
152163

153164
AtomicInteger symbolicNameIndex = new AtomicInteger();
@@ -160,6 +171,7 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar
160171

161172
this.keysetRequiresSort = queryMethod.isScrollQuery()
162173
&& actualParameters.getScrollPosition() instanceof KeysetScrollPosition;
174+
this.queryMethod = queryMethod;
163175
}
164176

165177
@Override
@@ -196,6 +208,21 @@ protected QueryFragmentsAndParameters complete(@Nullable Condition condition, So
196208
if (this.keysetRequiresSort && theSort.isUnsorted()) {
197209
throw new UnsupportedOperationException("Unsorted keyset based scrolling is not supported.");
198210
}
211+
if (this.queryMethod.hasVectorSearchAnnotation()) {
212+
var vectorSearchAnnotation = this.queryMethod.getVectorSearchAnnotation().orElseThrow();
213+
var indexName = vectorSearchAnnotation.indexName();
214+
var numberOfNodes = vectorSearchAnnotation.numberOfNodes();
215+
convertedParameters.put(Constants.VECTOR_SEARCH_VECTOR_PARAMETER,
216+
this.vectorSearchParameter.toDoubleArray());
217+
if (this.scoreParameter != null) {
218+
convertedParameters.put(Constants.VECTOR_SEARCH_SCORE_PARAMETER, this.scoreParameter.getValue());
219+
}
220+
var vectorSearchFragment = new VectorSearchFragment(indexName, numberOfNodes,
221+
(this.scoreParameter != null) ? this.scoreParameter.getValue() : null);
222+
var queryFragmentsAndParameters = new QueryFragmentsAndParameters(this.nodeDescription, queryFragments,
223+
vectorSearchFragment, convertedParameters, theSort);
224+
return queryFragmentsAndParameters;
225+
}
199226
return new QueryFragmentsAndParameters(this.nodeDescription, queryFragments, convertedParameters, theSort);
200227
}
201228

@@ -274,11 +301,14 @@ else if (this.scrollPosition instanceof OffsetScrollPosition offsetScrollPositio
274301
? this.maxResults.intValue() : this.pagingParameter.getPageSize()));
275302
}
276303

304+
if (this.queryMethod.hasVectorSearchAnnotation()) {
305+
this.additionalReturnExpression.add(Cypher.name(Constants.NAME_OF_SCORE));
306+
}
277307
var finalSortItems = new ArrayList<>(this.sortItems);
278308
theSort.stream().map(CypherAdapterUtils.sortAdapterFor(this.nodeDescription)).forEach(finalSortItems::add);
279309

280310
queryFragments.setReturnBasedOn(this.nodeDescription, this.includedProperties, this.isDistinct,
281-
this.distanceExpressions);
311+
this.additionalReturnExpression);
282312
queryFragments.setOrderBy(finalSortItems);
283313
}
284314

@@ -438,7 +468,7 @@ else if (p2.isPresent() && p2.get().value instanceof Point) {
438468
// property to be later retrieved and mapped
439469
Neo4jPersistentEntity<?> owner = (Neo4jPersistentEntity<?>) leafProperty.getOwner();
440470
String containerName = getContainerName(path, owner);
441-
this.distanceExpressions
471+
this.additionalReturnExpression
442472
.add(distanceFunction.as("__distance_" + containerName + "_" + leafProperty.getPropertyName() + "__"));
443473

444474
this.sortItems.add(distanceFunction.ascending());

src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQueryMethod.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import org.springframework.core.MethodParameter;
2727
import org.springframework.core.annotation.AnnotatedElementUtils;
28+
import org.springframework.data.domain.SearchResult;
29+
import org.springframework.data.domain.SearchResults;
2830
import org.springframework.data.geo.GeoPage;
2931
import org.springframework.data.geo.GeoResult;
3032
import org.springframework.data.geo.GeoResults;
@@ -55,12 +57,18 @@ class Neo4jQueryMethod extends QueryMethod {
5557
static final List<Class<? extends Serializable>> GEO_NEAR_RESULTS = List.of(GeoResult.class, GeoResults.class,
5658
GeoPage.class);
5759

60+
static final List<Class<? extends Serializable>> VECTOR_SEARCH_RESULTS = List.of(SearchResults.class,
61+
SearchResult.class);
62+
5863
/**
5964
* Optional query annotation of the method.
6065
*/
6166
@Nullable
6267
private final Query queryAnnotation;
6368

69+
@Nullable
70+
private final VectorSearch vectorSearchAnnotation;
71+
6472
private final String repositoryName;
6573

6674
private final boolean cypherBasedProjection;
@@ -94,6 +102,7 @@ class Neo4jQueryMethod extends QueryMethod {
94102
this.repositoryName = this.method.getDeclaringClass().getName();
95103
this.cypherBasedProjection = cypherBasedProjection;
96104
this.queryAnnotation = AnnotatedElementUtils.findMergedAnnotation(this.method, Query.class);
105+
this.vectorSearchAnnotation = AnnotatedElementUtils.findMergedAnnotation(this.method, VectorSearch.class);
97106
}
98107

99108
String getRepositoryName() {
@@ -126,6 +135,14 @@ Optional<Query> getQueryAnnotation() {
126135
return Optional.ofNullable(this.queryAnnotation);
127136
}
128137

138+
boolean hasVectorSearchAnnotation() {
139+
return this.vectorSearchAnnotation != null;
140+
}
141+
142+
Optional<VectorSearch> getVectorSearchAnnotation() {
143+
return Optional.ofNullable(this.vectorSearchAnnotation);
144+
}
145+
129146
@Override
130147
public Class<?> getReturnedObjectType() {
131148
Class<?> returnedObjectType = super.getReturnedObjectType();
@@ -143,7 +160,7 @@ boolean incrementLimit() {
143160

144161
boolean asCollectionQuery() {
145162
return this.isCollectionLikeQuery() || this.isPageQuery() || this.isSliceQuery() || this.isScrollQuery()
146-
|| GeoResults.class.isAssignableFrom(this.method.getReturnType());
163+
|| GeoResults.class.isAssignableFrom(this.method.getReturnType()) || this.isSearchQuery();
147164
}
148165

149166
Method getMethod() {

src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQuerySupport.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@
4444
import org.springframework.data.domain.KeysetScrollPosition;
4545
import org.springframework.data.domain.OffsetScrollPosition;
4646
import org.springframework.data.domain.Range;
47+
import org.springframework.data.domain.Score;
4748
import org.springframework.data.domain.ScrollPosition;
4849
import org.springframework.data.domain.ScrollPosition.Direction;
50+
import org.springframework.data.domain.SearchResult;
4951
import org.springframework.data.domain.Window;
5052
import org.springframework.data.expression.ValueExpressionParser;
5153
import org.springframework.data.geo.Box;
@@ -135,6 +137,25 @@ else if (distances.size() > 1) {
135137
};
136138
}
137139

140+
static BiFunction<TypeSystem, MapAccessor, ?> decorateAsVectorSearchResult(
141+
BiFunction<TypeSystem, MapAccessor, ?> target) {
142+
return (t, r) -> {
143+
Object intermediateResult = target.apply(t, r);
144+
var distances = StreamSupport.stream(r.keys().spliterator(), false)
145+
.filter(k -> k.equals(Constants.NAME_OF_SCORE))
146+
.toList();
147+
if (distances.isEmpty()) {
148+
throw new RuntimeException("No score has been returned by the query, cannot create `SearchResult`");
149+
}
150+
else if (distances.size() > 1) {
151+
throw new RuntimeException(
152+
"More than one score has been returned by the query, cannot create `SearchResult`");
153+
}
154+
var searchResult = Score.of(r.get(distances.get(0)).asDouble());
155+
return new SearchResult<>(intermediateResult, searchResult);
156+
};
157+
}
158+
138159
private static boolean hasValidReturnTypeForDelete(Neo4jQueryMethod queryMethod) {
139160
return VALID_RETURN_TYPES_FOR_DELETE
140161
.contains(queryMethod.getResultProcessor().getReturnedType().getReturnedType());
@@ -192,7 +213,7 @@ else if (distance.getMetric() == Metrics.MILES) {
192213
}
193214

194215
protected final Supplier<BiFunction<TypeSystem, MapAccessor, ?>> getMappingFunction(
195-
final ResultProcessor resultProcessor, boolean isGeoNearQuery) {
216+
final ResultProcessor resultProcessor, boolean isGeoNearQuery, boolean isVectorSearchQuery) {
196217

197218
return () -> {
198219
final ReturnedType returnedTypeMetadata = resultProcessor.getReturnedType();
@@ -213,6 +234,10 @@ else if (returnedTypeMetadata.isProjecting()) {
213234
else if (isGeoNearQuery) {
214235
mappingFunction = decorateAsGeoResult(this.mappingContext.getRequiredMappingFunctionFor(domainType));
215236
}
237+
else if (isVectorSearchQuery) {
238+
mappingFunction = decorateAsVectorSearchResult(
239+
this.mappingContext.getRequiredMappingFunctionFor(domainType));
240+
}
216241
else {
217242
mappingFunction = this.mappingContext.getRequiredMappingFunctionFor(domainType);
218243
}

0 commit comments

Comments
 (0)