Skip to content

Vector search on repository level #3037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-neo4j</artifactId>
<version>8.0.0-SNAPSHOT</version>
<version>8.0.0-GH-3002-SNAPSHOT</version>

<name>Spring Data Neo4j</name>
<description>Next generation Object-Graph-Mapping for Spring Data.</description>
Expand Down
1 change: 1 addition & 0 deletions src/main/antora/modules/ROOT/nav.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
** xref:repositories/sdn-extension.adoc[]
** xref:repositories/query-keywords-reference.adoc[]
** xref:repositories/query-return-types-reference.adoc[]
** xref:repositories/vector-search.adoc[]

* xref:repositories/projections.adoc[]
** xref:projections/sdn-projections.adoc[]
Expand Down
30 changes: 30 additions & 0 deletions src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[[sdn-vector-search]]
= Neo4j Vector Search

== The `@VectorSearch` annotation
Spring Data Neo4j supports Neo4j's vector search on the repository level by using the `@VectorSearch` annotation.
For this to work, Neo4j needs to have a vector index in place.
How to create a vector index is explained in the https://neo4j.com/docs/cypher-manual/current/indexes/search-performance-indexes/managing-indexes/[Neo4j documentation].

NOTE: It's not required to have any (Spring Data) Vector typed property be defined in the domain entities for this to work
because the search operates exclusively on the index.

The `@VectorSearch` annotation requires two arguments:
The name of the vector index to be used and the number of nearest neighbours.

For a general vector search over the whole domain, it's possible to use a derived finder method without any property.
[source,java,indent=0,tabsize=4]
----
include::example$integration/imperative/VectorSearchIT.java[tags=sdn-vector-search.usage;sdn-vector-search.usage.findall]
----

The vector index can be combined with any property-based finder method to filter down the results.

NOTE: For technical reasons, the vector search will always be executed before the property search gets invoked.
E.g. if the property filter looks for a person named "Helge",
but the vector search only yields "Hannes", there won't be a result.

[source,java,indent=0,tabsize=4]
----
include::example$integration/imperative/VectorSearchIT.java[tags=sdn-vector-search.usage;sdn-vector-search.usage.findbyproperty]
----
Original file line number Diff line number Diff line change
Expand Up @@ -1489,7 +1489,12 @@ private Optional<Neo4jClient.RecordFetchSpec<T>> createFetchSpec() {
statement = nodesAndRelationshipsById.toStatement(entityMetaData);
}
else {
statement = queryFragments.toStatement();
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
}
else {
statement = queryFragments.toStatement();
}
}
cypherQuery = Neo4jTemplate.this.renderer.render(statement);
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1408,8 +1408,13 @@ public <T> Mono<ExecutableQuery<T>> toExecutableQuery(PreparedQuery<T> preparedQ
return new DefaultReactiveExecutableQuery<>(preparedQuery, fetchSpec);
});
}

Statement statement = queryFragments.toStatement();
Statement statement = null;
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
}
else {
statement = queryFragments.toStatement();
}
cypherQuery = this.renderer.render(statement);
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,21 @@ public final class Constants {
*/
public static final String TO_ID_PARAMETER_NAME = "toId";

/**
* The name SDN uses for vector search score.
*/
public static final String NAME_OF_SCORE = "__score__";

/**
* Vector search vector parameter name.
*/
public static final String VECTOR_SEARCH_VECTOR_PARAMETER = "vectorSearchParam";

/**
* Vector search score parameter name.
*/
public static final String VECTOR_SEARCH_SCORE_PARAMETER = "scoreParam";

private Constants() {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.function.LongSupplier;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;

import org.jspecify.annotations.Nullable;
import org.neo4j.driver.types.MapAccessor;
Expand All @@ -32,6 +33,8 @@
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.SearchResult;
import org.springframework.data.domain.SearchResults;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.SliceImpl;
import org.springframework.data.geo.GeoPage;
Expand Down Expand Up @@ -98,11 +101,30 @@ boolean isGeoNearQuery() {
return GeoPage.class.isAssignableFrom(returnType);
}

boolean isVectorSearchQuery() {
var repositoryMethod = this.queryMethod.getMethod();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could QueryMethod.isSearchQuery() be useful here?

Would it also make sense to check for the presence of the @VectorSearch annotation?

Class<?> returnType = repositoryMethod.getReturnType();

for (Class<?> type : Neo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
if (type.isAssignableFrom(returnType)) {
return true;
}
}

if (Iterable.class.isAssignableFrom(returnType)) {
TypeInformation<?> from = TypeInformation.fromReturnTypeOf(repositoryMethod);
return from.getComponentType() != null && SearchResult.class.equals(from.getComponentType().getType());
}

return false;
}

@Override
@Nullable public final Object execute(Object[] parameters) {

boolean incrementLimit = this.queryMethod.incrementLimit();
boolean geoNearQuery = isGeoNearQuery();
boolean vectorSearchQuery = isVectorSearchQuery();
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor(
(Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters);

Expand All @@ -111,7 +133,7 @@ boolean isGeoNearQuery() {
ReturnedType returnedType = resultProcessor.getReturnedType();
PreparedQuery<?> preparedQuery = prepareQuery(returnedType.getReturnedType(),
PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext),
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery),
incrementLimit ? l -> l + 1 : UnaryOperator.identity());

Object rawResult = new Neo4jQueryExecution.DefaultQueryExecution(this.neo4jOperations).execute(preparedQuery,
Expand Down Expand Up @@ -143,6 +165,9 @@ else if (this.queryMethod.isScrollQuery()) {
else if (geoNearQuery) {
rawResult = newGeoResults(rawResult);
}
else if (this.queryMethod.isSearchQuery()) {
rawResult = createSearchResult((List<?>) rawResult, returnedType.getReturnedType());
}

return resultProcessor.processResult(rawResult, preparingConverter);
}
Expand Down Expand Up @@ -182,6 +207,13 @@ private Slice<?> createSlice(boolean incrementLimit, Neo4jParameterAccessor para
}
}

private <T> SearchResults<?> createSearchResult(List<?> rawResult, Class<T> returnedType) {
List<SearchResult<T>> searchResults = rawResult.stream()
.map(rawValue -> (SearchResult<T>) rawValue)
.collect(Collectors.toUnmodifiableList());
return new SearchResults<>(searchResults);
}

protected abstract <T> PreparedQuery<T> prepareQuery(Class<T> returnedType,
Collection<PropertyFilter.ProjectedPath> includedProperties, Neo4jParameterAccessor parameterAccessor,
@Nullable Neo4jQueryType queryType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import org.neo4j.driver.types.MapAccessor;
import org.neo4j.driver.types.TypeSystem;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.core.convert.converter.Converter;
import org.springframework.data.domain.SearchResult;
import org.springframework.data.geo.GeoResult;
import org.springframework.data.neo4j.core.PreparedQuery;
import org.springframework.data.neo4j.core.PropertyFilterSupport;
Expand Down Expand Up @@ -89,11 +91,31 @@ boolean isGeoNearQuery() {
return false;
}

boolean isVectorSearchQuery() {
var repositoryMethod = this.queryMethod.getMethod();
Class<?> returnType = repositoryMethod.getReturnType();

for (Class<?> type : ReactiveNeo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
if (type.isAssignableFrom(returnType)) {
return true;
}
}

if (Flux.class.isAssignableFrom(returnType) || Mono.class.isAssignableFrom(returnType)) {
TypeInformation<?> from = TypeInformation.fromReturnTypeOf(repositoryMethod);
TypeInformation<?> componentType = from.getComponentType();
return componentType != null && SearchResult.class.equals(componentType.getType());
}

return false;
}

@Override
@Nullable public final Object execute(Object[] parameters) {

boolean incrementLimit = this.queryMethod.incrementLimit();
boolean geoNearQuery = isGeoNearQuery();
boolean vectorSearchQuery = isVectorSearchQuery();
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor(
(Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters);
ResultProcessor resultProcessor = this.queryMethod.getResultProcessor()
Expand All @@ -102,7 +124,7 @@ boolean isGeoNearQuery() {
ReturnedType returnedType = resultProcessor.getReturnedType();
PreparedQuery<?> preparedQuery = prepareQuery(returnedType.getReturnedType(),
PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext),
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery),
incrementLimit ? l -> l + 1 : UnaryOperator.identity());

Object rawResult = new Neo4jQueryExecution.ReactiveQueryExecution(this.neo4jOperations).execute(preparedQuery,
Expand All @@ -126,10 +148,17 @@ parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
.map(rawResultList -> createWindow(resultProcessor, incrementLimit, parameterAccessor, rawResultList,
preparedQuery.getQueryFragmentsAndParameters()));
}
else if (this.queryMethod.isSearchQuery()) {
rawResult = createSearchResult((Flux<?>) rawResult, returnedType.getReturnedType());
}

return resultProcessor.processResult(rawResult, preparingConverter);
}

private <T> Flux<SearchResult<?>> createSearchResult(Flux<?> rawResult, Class<T> returnedType) {
return rawResult.map(rawValue -> (SearchResult<T>) rawValue);
}

protected abstract <T extends Object> PreparedQuery<T> prepareQuery(Class<T> returnedType,
Collection<PropertyFilter.ProjectedPath> includedProperties, Neo4jParameterAccessor parameterAccessor,
@Nullable Neo4jQueryType queryType, Supplier<BiFunction<TypeSystem, MapAccessor, ?>> mappingFunction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@
import org.springframework.data.domain.OffsetScrollPosition;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.Score;
import org.springframework.data.domain.ScrollPosition;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Vector;
import org.springframework.data.geo.Box;
import org.springframework.data.geo.Circle;
import org.springframework.data.geo.Distance;
Expand All @@ -60,7 +62,6 @@
import org.springframework.data.neo4j.core.mapping.Neo4jPersistentProperty;
import org.springframework.data.neo4j.core.mapping.NodeDescription;
import org.springframework.data.neo4j.core.mapping.PropertyFilter;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.repository.query.parser.AbstractQueryCreator;
import org.springframework.data.repository.query.parser.Part;
import org.springframework.data.repository.query.parser.PartTree;
Expand Down Expand Up @@ -119,14 +120,22 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar

private final boolean keysetRequiresSort;

private final List<Expression> distanceExpressions = new ArrayList<>();
private final List<Expression> additionalReturnExpression = new ArrayList<>();

/**
* Can be used to modify the limit of a paged or sliced query.
*/
private final UnaryOperator<Integer> limitModifier;

CypherQueryCreator(Neo4jMappingContext mappingContext, QueryMethod queryMethod, Class<?> domainType,
private final Neo4jQueryMethod queryMethod;

@Nullable
private final Vector vectorSearchParameter;

@Nullable
private final Score scoreParameter;

CypherQueryCreator(Neo4jMappingContext mappingContext, Neo4jQueryMethod queryMethod, Class<?> domainType,
Neo4jQueryType queryType, PartTree tree, Neo4jParameterAccessor actualParameters,
Collection<PropertyFilter.ProjectedPath> includedProperties,
BiFunction<Object, Neo4jPersistentPropertyConverter<?>, Object> parameterConversion,
Expand All @@ -148,6 +157,8 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar

this.pagingParameter = actualParameters.getPageable();
this.scrollPosition = actualParameters.getScrollPosition();
this.vectorSearchParameter = actualParameters.getVector();
this.scoreParameter = actualParameters.getScore();
this.limitModifier = limitModifier;

AtomicInteger symbolicNameIndex = new AtomicInteger();
Expand All @@ -160,6 +171,7 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar

this.keysetRequiresSort = queryMethod.isScrollQuery()
&& actualParameters.getScrollPosition() instanceof KeysetScrollPosition;
this.queryMethod = queryMethod;
}

@Override
Expand Down Expand Up @@ -196,6 +208,21 @@ protected QueryFragmentsAndParameters complete(@Nullable Condition condition, So
if (this.keysetRequiresSort && theSort.isUnsorted()) {
throw new UnsupportedOperationException("Unsorted keyset based scrolling is not supported.");
}
if (this.queryMethod.hasVectorSearchAnnotation()) {
var vectorSearchAnnotation = this.queryMethod.getVectorSearchAnnotation().orElseThrow();
var indexName = vectorSearchAnnotation.indexName();
var numberOfNodes = vectorSearchAnnotation.numberOfNodes();
convertedParameters.put(Constants.VECTOR_SEARCH_VECTOR_PARAMETER,
this.vectorSearchParameter.toDoubleArray());
if (this.scoreParameter != null) {
convertedParameters.put(Constants.VECTOR_SEARCH_SCORE_PARAMETER, this.scoreParameter.getValue());
}
var vectorSearchFragment = new VectorSearchFragment(indexName, numberOfNodes,
(this.scoreParameter != null) ? this.scoreParameter.getValue() : null);
var queryFragmentsAndParameters = new QueryFragmentsAndParameters(this.nodeDescription, queryFragments,
vectorSearchFragment, convertedParameters, theSort);
return queryFragmentsAndParameters;
}
return new QueryFragmentsAndParameters(this.nodeDescription, queryFragments, convertedParameters, theSort);
}

Expand Down Expand Up @@ -274,11 +301,14 @@ else if (this.scrollPosition instanceof OffsetScrollPosition offsetScrollPositio
? this.maxResults.intValue() : this.pagingParameter.getPageSize()));
}

if (this.queryMethod.hasVectorSearchAnnotation()) {
this.additionalReturnExpression.add(Cypher.name(Constants.NAME_OF_SCORE));
}
var finalSortItems = new ArrayList<>(this.sortItems);
theSort.stream().map(CypherAdapterUtils.sortAdapterFor(this.nodeDescription)).forEach(finalSortItems::add);

queryFragments.setReturnBasedOn(this.nodeDescription, this.includedProperties, this.isDistinct,
this.distanceExpressions);
this.additionalReturnExpression);
queryFragments.setOrderBy(finalSortItems);
}

Expand Down Expand Up @@ -438,7 +468,7 @@ else if (p2.isPresent() && p2.get().value instanceof Point) {
// property to be later retrieved and mapped
Neo4jPersistentEntity<?> owner = (Neo4jPersistentEntity<?>) leafProperty.getOwner();
String containerName = getContainerName(path, owner);
this.distanceExpressions
this.additionalReturnExpression
.add(distanceFunction.as("__distance_" + containerName + "_" + leafProperty.getPropertyName() + "__"));

this.sortItems.add(distanceFunction.ascending());
Expand Down
Loading