Skip to content

Commit 691777f

Browse files
refactor to reuse bulk writing usage for any update scenarios
1 parent 6510723 commit 691777f

File tree

4 files changed

+148
-99
lines changed

4 files changed

+148
-99
lines changed

src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,11 @@
1616

1717
package com.mongodb.hibernate.jdbc;
1818

19-
import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull;
20-
import static com.mongodb.hibernate.internal.MongoAssertions.assertTrue;
2119
import static com.mongodb.hibernate.internal.MongoAssertions.fail;
2220
import static java.lang.String.format;
2321

2422
import com.mongodb.client.ClientSession;
2523
import com.mongodb.client.MongoClient;
26-
import com.mongodb.client.model.DeleteManyModel;
27-
import com.mongodb.client.model.DeleteOneModel;
28-
import com.mongodb.client.model.InsertOneModel;
29-
import com.mongodb.client.model.UpdateManyModel;
30-
import com.mongodb.client.model.UpdateOneModel;
31-
import com.mongodb.client.model.WriteModel;
3224
import com.mongodb.hibernate.internal.NotYetImplementedException;
3325
import java.io.InputStream;
3426
import java.math.BigDecimal;
@@ -285,67 +277,12 @@ public void clearBatch() throws SQLException {
285277
@Override
286278
public int[] executeBatch() throws SQLException {
287279
checkClosed();
288-
startTransactionIfNeeded();
289280
try {
290281
if (commandBatch.isEmpty()) {
291282
return new int[0];
292283
}
293284

294-
var writeModels = new ArrayList<WriteModel<BsonDocument>>(commandBatch.size());
295-
296-
// Hibernate will group PreparedStatement by both table and mutation type
297-
var commandName = assertNotNull(commandBatch.get(0).getFirstKey());
298-
var collectionName =
299-
assertNotNull(commandBatch.get(0).getString(commandName).getValue());
300-
301-
for (var command : commandBatch) {
302-
303-
assertTrue(commandName.equals(command.getFirstKey()));
304-
assertTrue(collectionName.equals(command.getString(commandName).getValue()));
305-
306-
List<WriteModel<BsonDocument>> subWriteModels;
307-
308-
switch (commandName) {
309-
case "insert":
310-
var documents = command.getArray("documents");
311-
subWriteModels = new ArrayList<>(documents.size());
312-
for (var document : documents) {
313-
subWriteModels.add(new InsertOneModel<>((BsonDocument) document));
314-
}
315-
break;
316-
case "update":
317-
var updates = command.getArray("updates").getValues();
318-
subWriteModels = new ArrayList<>(updates.size());
319-
for (var update : updates) {
320-
var updateDocument = (BsonDocument) update;
321-
WriteModel<BsonDocument> updateModel =
322-
!updateDocument.getBoolean("multi").getValue()
323-
? new UpdateOneModel<>(
324-
updateDocument.getDocument("q"), updateDocument.getDocument("u"))
325-
: new UpdateManyModel<>(
326-
updateDocument.getDocument("q"), updateDocument.getDocument("u"));
327-
subWriteModels.add(updateModel);
328-
}
329-
break;
330-
case "delete":
331-
var deletes = command.getArray("deleted");
332-
subWriteModels = new ArrayList<>(deletes.size());
333-
for (var delete : deletes) {
334-
var deleteDocument = (BsonDocument) delete;
335-
subWriteModels.add(
336-
deleteDocument.getNumber("limit").intValue() == 1
337-
? new DeleteOneModel<>(deleteDocument.getDocument("q"))
338-
: new DeleteManyModel<>(deleteDocument.getDocument("q")));
339-
}
340-
break;
341-
default:
342-
throw new NotYetImplementedException();
343-
}
344-
writeModels.addAll(subWriteModels);
345-
}
346-
getMongoDatabase()
347-
.getCollection(collectionName, BsonDocument.class)
348-
.bulkWrite(getClientSession(), writeModels);
285+
executeBulkWrite(commandBatch);
349286

350287
var rowCounts = new int[commandBatch.size()];
351288

@@ -356,7 +293,7 @@ public int[] executeBatch() throws SQLException {
356293
return rowCounts;
357294

358295
} catch (RuntimeException e) {
359-
throw new SQLException("Failed to run bulk operation: " + e.getMessage(), e);
296+
throw new SQLException("Failed to execute batch operation: " + e.getMessage(), e);
360297
} finally {
361298
clearBatch();
362299
}

src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,32 @@
1616

1717
package com.mongodb.hibernate.jdbc;
1818

19+
import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull;
20+
import static com.mongodb.hibernate.internal.MongoAssertions.assertTrue;
1921
import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE;
2022
import static com.mongodb.hibernate.jdbc.MongoConnection.DATABASE;
2123
import static java.lang.String.format;
24+
import static java.util.Collections.singletonList;
2225

26+
import com.mongodb.bulk.BulkWriteResult;
2327
import com.mongodb.client.ClientSession;
2428
import com.mongodb.client.MongoClient;
2529
import com.mongodb.client.MongoDatabase;
30+
import com.mongodb.client.model.DeleteManyModel;
31+
import com.mongodb.client.model.DeleteOneModel;
32+
import com.mongodb.client.model.InsertOneModel;
33+
import com.mongodb.client.model.UpdateManyModel;
34+
import com.mongodb.client.model.UpdateOneModel;
35+
import com.mongodb.client.model.WriteModel;
2636
import com.mongodb.hibernate.internal.NotYetImplementedException;
2737
import com.mongodb.hibernate.internal.VisibleForTesting;
2838
import java.sql.Connection;
2939
import java.sql.ResultSet;
3040
import java.sql.SQLException;
3141
import java.sql.SQLSyntaxErrorException;
3242
import java.sql.SQLWarning;
43+
import java.util.ArrayList;
44+
import java.util.List;
3345
import org.bson.BsonDocument;
3446
import org.jspecify.annotations.Nullable;
3547

@@ -71,14 +83,76 @@ public int executeUpdate(String mql) throws SQLException {
7183
}
7284

7385
int executeUpdateCommand(BsonDocument command) throws SQLException {
86+
var bulkWriteResult = executeBulkWrite(singletonList(command));
87+
return switch (command.getFirstKey()) {
88+
case "insert" -> bulkWriteResult.getInsertedCount();
89+
case "update" -> bulkWriteResult.getModifiedCount();
90+
case "delete" -> bulkWriteResult.getDeletedCount();
91+
default -> throw new NotYetImplementedException();
92+
};
93+
}
94+
95+
BulkWriteResult executeBulkWrite(List<? extends BsonDocument> commandBatch) throws SQLException {
7496
startTransactionIfNeeded();
97+
7598
try {
76-
return mongoClient
77-
.getDatabase(DATABASE)
78-
.runCommand(clientSession, command)
79-
.getInteger("n");
80-
} catch (Exception e) {
81-
throw new SQLException("Failed to execute update command", e);
99+
var writeModels = new ArrayList<WriteModel<BsonDocument>>(commandBatch.size());
100+
101+
// Hibernate will group PreparedStatement by both table and mutation type
102+
var commandName = assertNotNull(commandBatch.get(0).getFirstKey());
103+
var collectionName =
104+
assertNotNull(commandBatch.get(0).getString(commandName).getValue());
105+
106+
for (var command : commandBatch) {
107+
108+
assertTrue(commandName.equals(command.getFirstKey()));
109+
assertTrue(collectionName.equals(command.getString(commandName).getValue()));
110+
111+
List<WriteModel<BsonDocument>> subWriteModels;
112+
113+
switch (commandName) {
114+
case "insert":
115+
var documents = command.getArray("documents");
116+
subWriteModels = new ArrayList<>(documents.size());
117+
for (var document : documents) {
118+
subWriteModels.add(new InsertOneModel<>((BsonDocument) document));
119+
}
120+
break;
121+
case "update":
122+
var updates = command.getArray("updates").getValues();
123+
subWriteModels = new ArrayList<>(updates.size());
124+
for (var update : updates) {
125+
var updateDocument = (BsonDocument) update;
126+
WriteModel<BsonDocument> updateModel =
127+
!updateDocument.getBoolean("multi").getValue()
128+
? new UpdateOneModel<>(
129+
updateDocument.getDocument("q"), updateDocument.getDocument("u"))
130+
: new UpdateManyModel<>(
131+
updateDocument.getDocument("q"), updateDocument.getDocument("u"));
132+
subWriteModels.add(updateModel);
133+
}
134+
break;
135+
case "delete":
136+
var deletes = command.getArray("deletes");
137+
subWriteModels = new ArrayList<>(deletes.size());
138+
for (var delete : deletes) {
139+
var deleteDocument = (BsonDocument) delete;
140+
subWriteModels.add(
141+
deleteDocument.getNumber("limit").intValue() == 1
142+
? new DeleteOneModel<>(deleteDocument.getDocument("q"))
143+
: new DeleteManyModel<>(deleteDocument.getDocument("q")));
144+
}
145+
break;
146+
default:
147+
throw new NotYetImplementedException();
148+
}
149+
writeModels.addAll(subWriteModels);
150+
}
151+
return getMongoDatabase()
152+
.getCollection(collectionName, BsonDocument.class)
153+
.bulkWrite(getClientSession(), writeModels);
154+
} catch (RuntimeException e) {
155+
throw new SQLException("Failed to run bulk write: " + e.getMessage(), e);
82156
}
83157
}
84158

src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
2323
import static org.junit.jupiter.api.Assertions.assertSame;
2424
import static org.junit.jupiter.api.Assertions.assertThrows;
25-
import static org.mockito.ArgumentMatchers.any;
2625
import static org.mockito.ArgumentMatchers.anyList;
2726
import static org.mockito.ArgumentMatchers.anyString;
2827
import static org.mockito.ArgumentMatchers.eq;
@@ -31,7 +30,9 @@
3130
import static org.mockito.Mockito.times;
3231
import static org.mockito.Mockito.verify;
3332

33+
import com.mongodb.bulk.BulkWriteInsert;
3434
import com.mongodb.bulk.BulkWriteResult;
35+
import com.mongodb.bulk.BulkWriteUpsert;
3536
import com.mongodb.client.ClientSession;
3637
import com.mongodb.client.MongoClient;
3738
import com.mongodb.client.MongoCollection;
@@ -55,7 +56,6 @@
5556
import java.util.Map;
5657
import java.util.stream.Stream;
5758
import org.bson.BsonDocument;
58-
import org.bson.Document;
5959
import org.junit.jupiter.api.DisplayName;
6060
import org.junit.jupiter.api.Nested;
6161
import org.junit.jupiter.api.Test;
@@ -110,16 +110,52 @@ class ParameterValueSettingTests {
110110
private MongoDatabase mongoDatabase;
111111

112112
@Captor
113-
private ArgumentCaptor<BsonDocument> commandCaptor;
113+
private ArgumentCaptor<List<WriteModel<BsonDocument>>> writeModelsCaptor;
114114

115115
@Test
116116
@DisplayName("Happy path when all parameters are provided values")
117-
void testSuccess() throws SQLException {
117+
void testSuccess(@Mock MongoCollection<BsonDocument> mongoCollection) throws SQLException {
118118
// given
119119
doReturn(mongoDatabase).when(mongoClient).getDatabase(anyString());
120-
doReturn(Document.parse("{ok: 1.0, n: 1}"))
121-
.when(mongoDatabase)
122-
.runCommand(eq(clientSession), any(BsonDocument.class));
120+
doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class));
121+
doReturn(new BulkWriteResult() {
122+
@Override
123+
public boolean wasAcknowledged() {
124+
return false;
125+
}
126+
127+
@Override
128+
public int getInsertedCount() {
129+
return 1;
130+
}
131+
132+
@Override
133+
public int getMatchedCount() {
134+
return 0;
135+
}
136+
137+
@Override
138+
public int getDeletedCount() {
139+
return 0;
140+
}
141+
142+
@Override
143+
public int getModifiedCount() {
144+
return 0;
145+
}
146+
147+
@Override
148+
public List<BulkWriteInsert> getInserts() {
149+
return List.of();
150+
}
151+
152+
@Override
153+
public List<BulkWriteUpsert> getUpserts() {
154+
return List.of();
155+
}
156+
})
157+
.when(mongoCollection)
158+
.bulkWrite(eq(clientSession), anyList());
123159

124160
// when && then
125161
try (var preparedStatement = createMongoPreparedStatement(EXAMPLE_MQL)) {
@@ -132,26 +168,24 @@ void testSuccess() throws SQLException {
132168

133169
preparedStatement.executeUpdate();
134170

135-
verify(mongoDatabase).runCommand(eq(clientSession), commandCaptor.capture());
136-
var command = commandCaptor.getValue();
171+
verify(mongoCollection).bulkWrite(eq(clientSession), writeModelsCaptor.capture());
172+
var writeModels = writeModelsCaptor.getValue();
173+
assertEquals(1, writeModels.size());
174+
var writeModel = writeModels.get(0);
175+
assertInstanceOf(InsertOneModel.class, writeModel);
176+
var insertDoc = ((InsertOneModel<BsonDocument>) writeModel).getDocument();
137177
var expectedDoc = parse(
138178
"""
139-
{
140-
insert: "books",
141-
documents: [
142-
{
143-
title: "War and Peace",
144-
author: "Leo Tolstoy",
145-
publishYear: 1869,
146-
outOfStock: false,
147-
tags: [
148-
"classic"
149-
]
150-
}
179+
{
180+
title: "War and Peace",
181+
author: "Leo Tolstoy",
182+
publishYear: 1869,
183+
outOfStock: false,
184+
tags: [
185+
"classic"
151186
]
152-
}
153-
""");
154-
assertEquals(expectedDoc, command);
187+
} """);
188+
assertEquals(expectedDoc, insertDoc);
155189
}
156190
}
157191

@@ -440,7 +474,7 @@ private static Stream<Arguments> getBulkWriteModelsArguments() {
440474
"""
441475
{
442476
delete: "books",
443-
deleted: [
477+
deletes: [
444478
{ q: { _id: 1 }, limit: 1 },
445479
{ q: {}, limit: 0 }
446480
]

src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@
2020
import static org.junit.jupiter.api.Assertions.assertEquals;
2121
import static org.junit.jupiter.api.Assertions.assertThrows;
2222
import static org.junit.jupiter.api.Assertions.assertTrue;
23-
import static org.mockito.ArgumentMatchers.any;
23+
import static org.mockito.ArgumentMatchers.anyList;
2424
import static org.mockito.ArgumentMatchers.anyString;
25+
import static org.mockito.ArgumentMatchers.eq;
2526
import static org.mockito.ArgumentMatchers.same;
2627
import static org.mockito.Mockito.doReturn;
2728
import static org.mockito.Mockito.doThrow;
2829

2930
import com.mongodb.client.ClientSession;
3031
import com.mongodb.client.MongoClient;
32+
import com.mongodb.client.MongoCollection;
3133
import com.mongodb.client.MongoDatabase;
3234
import java.sql.SQLException;
3335
import java.sql.SQLSyntaxErrorException;
@@ -90,11 +92,13 @@ void testSQLExceptionThrownWhenCalledWithInvalidMql() {
9092

9193
@Test
9294
@DisplayName("SQLException is thrown when database access error occurs")
93-
void testSQLExceptionThrownWhenDBAccessFailed(@Mock MongoDatabase mongoDatabase) {
95+
void testSQLExceptionThrownWhenDBAccessFailed(
96+
@Mock MongoDatabase mongoDatabase, @Mock MongoCollection<BsonDocument> mongoCollection) {
9497
// given
9598
doReturn(mongoDatabase).when(mongoClient).getDatabase(anyString());
99+
doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq((BsonDocument.class)));
96100
var dbAccessException = new RuntimeException();
97-
doThrow(dbAccessException).when(mongoDatabase).runCommand(same(clientSession), any(BsonDocument.class));
101+
doThrow(dbAccessException).when(mongoCollection).bulkWrite(same(clientSession), anyList());
98102
String mql =
99103
"""
100104
{

0 commit comments

Comments
 (0)