Skip to content

Commit e4b7e1a

Browse files
committed
Reactive part
Signed-off-by: Gerrit Meier <[email protected]>
1 parent bfee596 commit e4b7e1a

File tree

6 files changed

+212
-5
lines changed

6 files changed

+212
-5
lines changed

src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,8 +1408,13 @@ public <T> Mono<ExecutableQuery<T>> toExecutableQuery(PreparedQuery<T> preparedQ
14081408
return new DefaultReactiveExecutableQuery<>(preparedQuery, fetchSpec);
14091409
});
14101410
}
1411-
1412-
Statement statement = queryFragments.toStatement();
1411+
Statement statement = null;
1412+
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
1413+
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
1414+
}
1415+
else {
1416+
statement = queryFragments.toStatement();
1417+
}
14131418
cypherQuery = this.renderer.render(statement);
14141419
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);
14151420
}

src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.neo4j.driver.types.MapAccessor;
2525
import org.neo4j.driver.types.TypeSystem;
2626
import reactor.core.publisher.Flux;
27+
import reactor.core.publisher.Mono;
2728

2829
import org.springframework.core.convert.converter.Converter;
2930
import org.springframework.data.domain.SearchResult;
@@ -94,13 +95,13 @@ boolean isVectorSearchQuery() {
9495
var repositoryMethod = this.queryMethod.getMethod();
9596
Class<?> returnType = repositoryMethod.getReturnType();
9697

97-
for (Class<?> type : Neo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
98+
for (Class<?> type : ReactiveNeo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
9899
if (type.isAssignableFrom(returnType)) {
99100
return true;
100101
}
101102
}
102103

103-
if (Flux.class.isAssignableFrom(returnType)) {
104+
if (Flux.class.isAssignableFrom(returnType) || Mono.class.isAssignableFrom(returnType)) {
104105
TypeInformation<?> from = TypeInformation.fromReturnTypeOf(repositoryMethod);
105106
TypeInformation<?> componentType = from.getComponentType();
106107
return componentType != null && SearchResult.class.equals(componentType.getType());
@@ -147,10 +148,17 @@ parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vecto
147148
.map(rawResultList -> createWindow(resultProcessor, incrementLimit, parameterAccessor, rawResultList,
148149
preparedQuery.getQueryFragmentsAndParameters()));
149150
}
151+
else if (this.queryMethod.isSearchQuery()) {
152+
rawResult = createSearchResult((Flux<?>) rawResult, returnedType.getReturnedType());
153+
}
150154

151155
return resultProcessor.processResult(rawResult, preparingConverter);
152156
}
153157

158+
private <T> Flux<SearchResult<?>> createSearchResult(Flux<?> rawResult, Class<T> returnedType) {
159+
return rawResult.map(rawValue -> (SearchResult<T>) rawValue);
160+
}
161+
154162
protected abstract <T extends Object> PreparedQuery<T> prepareQuery(Class<T> returnedType,
155163
Collection<PropertyFilter.ProjectedPath> includedProperties, Neo4jParameterAccessor parameterAccessor,
156164
@Nullable Neo4jQueryType queryType, Supplier<BiFunction<TypeSystem, MapAccessor, ?>> mappingFunction,

src/main/java/org/springframework/data/neo4j/repository/query/ReactiveNeo4jQueryMethod.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
*/
1616
package org.springframework.data.neo4j.repository.query;
1717

18+
import java.io.Serializable;
1819
import java.lang.reflect.Method;
20+
import java.util.List;
1921

2022
import org.springframework.dao.InvalidDataAccessApiUsageException;
2123
import org.springframework.data.domain.Page;
2224
import org.springframework.data.domain.Pageable;
25+
import org.springframework.data.domain.SearchResult;
2326
import org.springframework.data.domain.Slice;
2427
import org.springframework.data.neo4j.repository.support.ReactiveCypherdslStatementExecutor;
2528
import org.springframework.data.projection.ProjectionFactory;
@@ -43,6 +46,8 @@
4346
*/
4447
final class ReactiveNeo4jQueryMethod extends Neo4jQueryMethod {
4548

49+
static final List<Class<? extends Serializable>> VECTOR_SEARCH_RESULTS = List.of(SearchResult.class);
50+
4651
@SuppressWarnings("rawtypes")
4752
private static final TypeInformation<Page> PAGE_TYPE = TypeInformation.of(Page.class);
4853

src/main/java/org/springframework/data/neo4j/repository/query/ReactivePartTreeNeo4jQuery.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ private ReactivePartTreeNeo4jQuery(ReactiveNeo4jOperations neo4jOperations, Neo4
5757

5858
static RepositoryQuery create(ReactiveNeo4jOperations neo4jOperations, Neo4jMappingContext mappingContext,
5959
Neo4jQueryMethod queryMethod, ProjectionFactory factory) {
60+
if (queryMethod.hasVectorSearchAnnotation()
61+
&& queryMethod.getVectorSearchAnnotation().get().numberOfNodes() < 1) {
62+
throw new IllegalArgumentException("Number of nodes in a vector search has to be greater than zero.");
63+
}
6064
return new ReactivePartTreeNeo4jQuery(neo4jOperations, mappingContext, queryMethod,
6165
new PartTree(queryMethod.getName(), getDomainType(queryMethod)), factory);
6266
}

src/test/java/org/springframework/data/neo4j/integration/imperative/VectorSearchIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void findAllAsSearchResultsWithVectorIndex(@Autowired VectorSearchRepository rep
103103
}
104104

105105
@Test
106-
void findSingeAsSearchResultWithVectorIndex(@Autowired VectorSearchRepository repository) {
106+
void findSingleAsSearchResultWithVectorIndex(@Autowired VectorSearchRepository repository) {
107107
var result = repository.findBy(Vector.of(new double[] { 0.1d, 0.1d, 0.1d }), Score.of(0.9d));
108108
assertThat(result).isNotNull();
109109
assertThat(result.getContent()).isNotNull();
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/*
2+
* Copyright 2011-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.neo4j.integration.reactive;
17+
18+
import org.junit.jupiter.api.AfterEach;
19+
import org.junit.jupiter.api.BeforeEach;
20+
import org.junit.jupiter.api.Test;
21+
import org.neo4j.driver.Driver;
22+
import org.neo4j.driver.Session;
23+
import reactor.core.publisher.Flux;
24+
import reactor.core.publisher.Mono;
25+
import reactor.test.StepVerifier;
26+
27+
import org.springframework.beans.factory.annotation.Autowired;
28+
import org.springframework.context.annotation.Bean;
29+
import org.springframework.context.annotation.Configuration;
30+
import org.springframework.data.domain.Score;
31+
import org.springframework.data.domain.SearchResult;
32+
import org.springframework.data.domain.Vector;
33+
import org.springframework.data.neo4j.core.ReactiveDatabaseSelectionProvider;
34+
import org.springframework.data.neo4j.core.transaction.Neo4jBookmarkManager;
35+
import org.springframework.data.neo4j.core.transaction.ReactiveNeo4jTransactionManager;
36+
import org.springframework.data.neo4j.integration.shared.common.EntityWithVector;
37+
import org.springframework.data.neo4j.repository.ReactiveNeo4jRepository;
38+
import org.springframework.data.neo4j.repository.config.EnableReactiveNeo4jRepositories;
39+
import org.springframework.data.neo4j.repository.query.VectorSearch;
40+
import org.springframework.data.neo4j.test.BookmarkCapture;
41+
import org.springframework.data.neo4j.test.Neo4jExtension;
42+
import org.springframework.data.neo4j.test.Neo4jIntegrationTest;
43+
import org.springframework.data.neo4j.test.Neo4jReactiveTestConfiguration;
44+
import org.springframework.transaction.ReactiveTransactionManager;
45+
import org.springframework.transaction.annotation.EnableTransactionManagement;
46+
47+
import static org.assertj.core.api.Assertions.assertThat;
48+
49+
/**
50+
* @author Gerrit Meier
51+
*/
52+
@Neo4jIntegrationTest
53+
class ReactiveVectorSearchIT {
54+
55+
protected static Neo4jExtension.Neo4jConnectionSupport neo4jConnectionSupport;
56+
57+
@BeforeEach
58+
void setupData(@Autowired BookmarkCapture bookmarkCapture, @Autowired Driver driver) {
59+
try (Session session = driver.session(bookmarkCapture.createSessionConfig())) {
60+
session.run("MATCH (n) detach delete n");
61+
session.run("""
62+
CREATE VECTOR INDEX dingsIndex IF NOT EXISTS
63+
FOR (m:EntityWithVector)
64+
ON m.myVector
65+
OPTIONS { indexConfig: {
66+
`vector.dimensions`: 3,
67+
`vector.similarity_function`: 'cosine'
68+
}}""").consume();
69+
session.run(
70+
"CREATE (e:EntityWithVector{name:'dings'}) WITH e CALL db.create.setNodeVectorProperty(e, 'myVector', [0.1, 0.1, 0.1])")
71+
.consume();
72+
session.run(
73+
"CREATE (e:EntityWithVector{name:'dings2'}) WITH e CALL db.create.setNodeVectorProperty(e, 'myVector', [0.7, 0.0, 0.3])")
74+
.consume();
75+
session.run("CALL db.awaitIndexes()").consume();
76+
bookmarkCapture.seedWith(session.lastBookmarks());
77+
}
78+
}
79+
80+
@AfterEach
81+
void removeIndex(@Autowired BookmarkCapture bookmarkCapture, @Autowired Driver driver) {
82+
try (Session session = driver.session(bookmarkCapture.createSessionConfig())) {
83+
session.run("DROP INDEX `dingsIndex` IF EXISTS");
84+
}
85+
}
86+
87+
@Test
88+
void findAllWithVectorIndex(@Autowired VectorSearchRepository repository) {
89+
StepVerifier.create(repository.findBy("dings", Vector.of(new double[] { 0.1d, 0.1d, 0.1d })))
90+
.expectNextCount(2)
91+
.verifyComplete();
92+
}
93+
94+
@Test
95+
void findAllAsSearchResultsWithVectorIndex(@Autowired VectorSearchRepository repository) {
96+
StepVerifier.create(repository.findAllBy(Vector.of(new double[] { 0.1d, 0.1d, 0.1d })))
97+
.assertNext(result -> assertThat(result.getContent()).isNotNull())
98+
.assertNext(result -> assertThat(result.getContent()).isNotNull())
99+
.verifyComplete();
100+
}
101+
102+
@Test
103+
void findSingleAsSearchResultWithVectorIndex(@Autowired VectorSearchRepository repository) {
104+
StepVerifier.create(repository.findBy(Vector.of(new double[] { 0.1d, 0.1d, 0.1d }), Score.of(0.9d)))
105+
.assertNext(result -> {
106+
assertThat(result).isNotNull();
107+
assertThat(result.getContent()).isNotNull();
108+
assertThat(result.getScore().getValue()).isGreaterThanOrEqualTo(0.9d);
109+
})
110+
.verifyComplete();
111+
}
112+
113+
@Test
114+
void findByNameWithVectorIndex(@Autowired VectorSearchRepository repository) {
115+
StepVerifier.create(repository.findByName("dings", Vector.of(new double[] { 0.1d, 0.1d, 0.1d })))
116+
.expectNextCount(1)
117+
.verifyComplete();
118+
}
119+
120+
@Test
121+
void findByNameWithVectorIndexAndScore(@Autowired VectorSearchRepository repository) {
122+
StepVerifier
123+
.create(repository.findByName("dings", Vector.of(new double[] { -0.7d, 0.0d, -0.7d }), Score.of(0.01d)))
124+
.expectNextCount(1)
125+
.verifyComplete();
126+
}
127+
128+
@Test
129+
void dontFindByNameWithVectorIndexAndScore(@Autowired VectorSearchRepository repository) {
130+
StepVerifier
131+
.create(repository.findByName("dings", Vector.of(new double[] { -0.7d, 0.0d, -0.7d }), Score.of(0.8d)))
132+
.verifyComplete();
133+
}
134+
135+
interface VectorSearchRepository extends ReactiveNeo4jRepository<EntityWithVector, String> {
136+
137+
@VectorSearch(indexName = "dingsIndex", numberOfNodes = 2)
138+
Flux<EntityWithVector> findBy(String name, Vector searchVector);
139+
140+
@VectorSearch(indexName = "dingsIndex", numberOfNodes = 1)
141+
Flux<EntityWithVector> findByName(String name, Vector searchVector);
142+
143+
@VectorSearch(indexName = "dingsIndex", numberOfNodes = 2)
144+
Flux<EntityWithVector> findByName(String name, Vector searchVector, Score score);
145+
146+
@VectorSearch(indexName = "dingsIndex", numberOfNodes = 2)
147+
Flux<SearchResult<EntityWithVector>> findAllBy(Vector searchVector);
148+
149+
@VectorSearch(indexName = "dingsIndex", numberOfNodes = 2)
150+
Mono<SearchResult<EntityWithVector>> findBy(Vector searchVector, Score score);
151+
152+
}
153+
154+
@Configuration
155+
@EnableTransactionManagement
156+
@EnableReactiveNeo4jRepositories(considerNestedRepositories = true)
157+
static class Config extends Neo4jReactiveTestConfiguration {
158+
159+
@Bean
160+
@Override
161+
public Driver driver() {
162+
return neo4jConnectionSupport.getDriver();
163+
}
164+
165+
@Bean
166+
BookmarkCapture bookmarkCapture() {
167+
return new BookmarkCapture();
168+
}
169+
170+
@Override
171+
public ReactiveTransactionManager reactiveTransactionManager(Driver driver,
172+
ReactiveDatabaseSelectionProvider databaseNameProvider) {
173+
BookmarkCapture bookmarkCapture = bookmarkCapture();
174+
return new ReactiveNeo4jTransactionManager(driver, databaseNameProvider,
175+
Neo4jBookmarkManager.create(bookmarkCapture));
176+
}
177+
178+
@Override
179+
public boolean isCypher5Compatible() {
180+
return neo4jConnectionSupport.isCypher5SyntaxCompatible();
181+
}
182+
183+
}
184+
185+
}

0 commit comments

Comments
 (0)