diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/BatchUpdateIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/BatchUpdateIntegrationTests.java new file mode 100644 index 00000000..bf0fc616 --- /dev/null +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/BatchUpdateIntegrationTests.java @@ -0,0 +1,344 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.jdbc; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; +import com.mongodb.event.CommandFailedEvent; +import com.mongodb.event.CommandListener; +import com.mongodb.event.CommandStartedEvent; +import com.mongodb.hibernate.cfg.MongoConfigurator; +import com.mongodb.hibernate.internal.cfg.MongoConfigurationBuilder; +import com.mongodb.hibernate.service.spi.MongoConfigurationContributor; +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Table; +import java.io.Serial; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.hibernate.cfg.AvailableSettings; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.ServiceRegistry; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SessionFactoryScopeAware; +import org.hibernate.testing.orm.junit.Setting; +import org.junit.jupiter.api.AutoClose; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@SessionFactory(exportSchema = false) +@DomainModel(annotatedClasses = {BatchUpdateIntegrationTests.Movie.class}) +@ServiceRegistry( + settings = @Setting(name = AvailableSettings.STATEMENT_BATCH_SIZE, value = "3"), + services = + @ServiceRegistry.Service( + role = MongoConfigurationContributor.class, + impl = BatchUpdateIntegrationTests.TestingMongoConfigurationContributor.class)) +class BatchUpdateIntegrationTests implements SessionFactoryScopeAware { + + private static class TestingCommandListener implements CommandListener { + private final List successfulCommands = new ArrayList<>(); + private final List failedCommandEvents = new ArrayList<>(0); + + @Override + public void commandStarted(CommandStartedEvent event) { + successfulCommands.add(event.getCommand().clone()); + } + + @Override + public void commandFailed(CommandFailedEvent event) { + failedCommandEvents.add(event); + } + + void clear() { + successfulCommands.clear(); + failedCommandEvents.clear(); + } + + List getSuccessfulCommands() { + return Collections.unmodifiableList(successfulCommands); + } + + List getFailedCommandEvents() { + return Collections.unmodifiableList(failedCommandEvents); + } + } + + private static final TestingCommandListener TESTING_COMMAND_LISTENER = new TestingCommandListener(); + + public static class TestingMongoConfigurationContributor implements MongoConfigurationContributor { + + @Serial + private static final long serialVersionUID = 1L; + + @Override + public void configure(MongoConfigurator configurator) { + configurator.applyToMongoClientSettings(builder -> builder.addCommandListener(TESTING_COMMAND_LISTENER)); + } + } + + @AutoClose + private MongoClient mongoClient; + + private MongoCollection collection; + + private SessionFactoryScope sessionFactoryScope; + + @Override + public void injectSessionFactoryScope(SessionFactoryScope sessionFactoryScope) { + this.sessionFactoryScope = sessionFactoryScope; + } + + @BeforeAll + void beforeAll() { + var config = new MongoConfigurationBuilder( + sessionFactoryScope.getSessionFactory().getProperties()) + .build(); + mongoClient = MongoClients.create(config.mongoClientSettings()); + collection = mongoClient.getDatabase(config.databaseName()).getCollection("movies", BsonDocument.class); + } + + @BeforeEach + void beforeEach() { + collection.drop(); + TESTING_COMMAND_LISTENER.clear(); + } + + @Test + void batchInsertTest() { + var movies = new ArrayList<>(); + for (var i = 1; i <= 8; i++) { + var movie = new Movie(); + movie.id = i; + movie.title = "title_" + i; + movies.add(movie); + } + sessionFactoryScope.inTransaction(session -> { + movies.forEach(session::persist); + }); + + assertThat(TESTING_COMMAND_LISTENER.getFailedCommandEvents()).isEmpty(); + assertThat(TESTING_COMMAND_LISTENER.getSuccessfulCommands()) + .satisfiesExactly( + command1 -> { + assertThat(command1.entrySet()).contains(Map.entry("insert", new BsonString("movies"))); + assertThat(command1.getArray("documents").getValues()) + .containsExactly( + BsonDocument.parse( + """ + {_id: 1, title: "title_1"} + """), + BsonDocument.parse( + """ + {_id: 2, title: "title_2"} + """), + BsonDocument.parse( + """ + {_id: 3, title: "title_3"} + """)); + }, + command2 -> { + assertThat(command2.entrySet()).contains(Map.entry("insert", new BsonString("movies"))); + assertThat(command2.getArray("documents").getValues()) + .containsExactly( + BsonDocument.parse( + """ + {_id: 4, title: "title_4"} + """), + BsonDocument.parse( + """ + {_id: 5, title: "title_5"} + """), + BsonDocument.parse( + """ + {_id: 6, title: "title_6"} + """)); + }, + command3 -> { + assertThat(command3.entrySet()).contains(Map.entry("insert", new BsonString("movies"))); + assertThat(command3.getArray("documents").getValues()) + .containsExactly( + BsonDocument.parse( + """ + {_id: 7, title: "title_7"} + """), + BsonDocument.parse( + """ + {_id: 8, title: "title_8"} + """)); + }, + command4 -> assertThat(command4.getFirstKey()).isEqualTo("commitTransaction")); + } + + @Test + void batchDeleteTest() { + var movies = new ArrayList<>(); + for (var i = 1; i <= 8; i++) { + var movie = new Movie(); + movie.id = i; + movie.title = "title_" + i; + movies.add(movie); + } + sessionFactoryScope.inTransaction(session -> { + movies.forEach(session::persist); + session.flush(); + TESTING_COMMAND_LISTENER.clear(); + movies.forEach(session::remove); + }); + + assertThat(TESTING_COMMAND_LISTENER.getFailedCommandEvents()).isEmpty(); + assertThat(TESTING_COMMAND_LISTENER.getSuccessfulCommands()) + .satisfiesExactly( + command1 -> { + assertThat(command1.entrySet()).contains(Map.entry("delete", new BsonString("movies"))); + assertThat(command1.getArray("deletes").getValues()) + .containsExactly( + BsonDocument.parse( + """ + {q: {_id: {$eq: 1}}, limit: 0} + """), + BsonDocument.parse( + """ + {q: {_id: {$eq: 2}}, limit: 0} + """), + BsonDocument.parse( + """ + {q: {_id: {$eq: 3}}, limit: 0} + """)); + }, + command2 -> { + assertThat(command2.entrySet()).contains(Map.entry("delete", new BsonString("movies"))); + assertThat(command2.getArray("deletes").getValues()) + .containsExactly( + BsonDocument.parse( + """ + {q: {_id: {$eq: 4}}, limit: 0} + """), + BsonDocument.parse( + """ + {q: {_id: {$eq: 5}}, limit: 0} + """), + BsonDocument.parse( + """ + {q: {_id: {$eq: 6}}, limit: 0} + """)); + }, + command3 -> { + assertThat(command3.entrySet()).contains(Map.entry("delete", new BsonString("movies"))); + assertThat(command3.getArray("deletes").getValues()) + .containsExactly( + BsonDocument.parse( + """ + {q: {_id: {$eq: 7}}, limit: 0} + """), + BsonDocument.parse( + """ + {q: {_id: {$eq: 8}}, limit: 0} + """)); + }, + command4 -> assertThat(command4.getFirstKey()).isEqualTo("commitTransaction")); + } + + @Test + void batchUpdateTest() { + var movies = new ArrayList(); + for (var i = 1; i <= 8; i++) { + var movie = new Movie(); + movie.id = i; + movie.title = "title_" + i; + movies.add(movie); + } + sessionFactoryScope.inTransaction(session -> { + movies.forEach(session::persist); + session.flush(); + TESTING_COMMAND_LISTENER.clear(); + movies.forEach(movie -> movie.title = movie.title + "_"); + }); + + assertThat(TESTING_COMMAND_LISTENER.getFailedCommandEvents()).isEmpty(); + assertThat(TESTING_COMMAND_LISTENER.getSuccessfulCommands()) + .satisfiesExactly( + command1 -> { + assertThat(command1.entrySet()).contains(Map.entry("update", new BsonString("movies"))); + assertThat(command1.getArray("updates").getValues()) + .containsExactly( + BsonDocument.parse( + """ + {q: {_id: {$eq: 1}}, u: {$set: {"title": "title_1_"}}, multi: true} + """), + BsonDocument.parse( + """ + {q: {_id: {$eq: 2}}, u: {$set: {"title": "title_2_"}}, multi: true} + """), + BsonDocument.parse( + """ + {q: {_id: {$eq: 3}}, u: {$set: {"title": "title_3_"}}, multi: true} + """)); + }, + command2 -> { + assertThat(command2.entrySet()).contains(Map.entry("update", new BsonString("movies"))); + assertThat(command2.getArray("updates").getValues()) + .containsExactly( + BsonDocument.parse( + """ + {q: {_id: {$eq: 4}}, u: {$set: {"title": "title_4_"}}, multi: true} + """), + BsonDocument.parse( + """ + {q: {_id: {$eq: 5}}, u: {$set: {"title": "title_5_"}}, multi: true} + """), + BsonDocument.parse( + """ + {q: {_id: {$eq: 6}}, u: {$set: {"title": "title_6_"}}, multi: true} + """)); + }, + command3 -> { + assertThat(command3.entrySet()).contains(Map.entry("update", new BsonString("movies"))); + assertThat(command3.getArray("updates").getValues()) + .containsExactly( + BsonDocument.parse( + """ + {q: {_id: {$eq: 7}}, u: {$set: {"title": "title_7_"}}, multi: true} + """), + BsonDocument.parse( + """ + {q: {_id: {$eq: 8}}, u: {$set: {"title": "title_8_"}}, multi: true} + """)); + }, + command4 -> assertThat(command4.getFirstKey()).isEqualTo("commitTransaction")); + } + + @Entity + @Table(name = "movies") + static class Movie { + @Id + @Column(name = "_id") + int id; + + String title; + } +} diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 29dc22b2..dad43318 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -26,11 +26,13 @@ import java.sql.Connection; import java.sql.SQLException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.function.Function; import org.bson.BsonDocument; import org.hibernate.Session; import org.hibernate.SessionFactory; +import org.hibernate.cfg.AvailableSettings; import org.hibernate.cfg.Configuration; import org.junit.jupiter.api.AutoClose; import org.junit.jupiter.api.BeforeAll; @@ -41,6 +43,35 @@ class MongoPreparedStatementIntegrationTests { + private static final String INIT_INSERT_MQL = + """ + { + insert: "books", + documents: [ + { + _id: 1, + title: "War and Peace", + author: "Leo Tolstoy", + outOfStock: false, + tags: [ "classic", "tolstoy" ] + }, + { + _id: 2, + title: "Anna Karenina", + author: "Leo Tolstoy", + outOfStock: false, + tags: [ "classic", "tolstoy" ] + }, + { + _id: 3, + title: "Crime and Punishment", + author: "Fyodor Dostoevsky", + outOfStock: false, + tags: [ "classic", "dostoevsky", "literature" ] + } + ] + }"""; + @AutoClose private static SessionFactory sessionFactory; @@ -68,54 +99,11 @@ void beforeEach() { @Nested class ExecuteUpdateTests { - @BeforeEach - void beforeEach() { - session.doWork(conn -> { - conn.createStatement() - .executeUpdate( - """ - { - delete: "books", - deletes: [ - { q: {}, limit: 0 } - ] - }"""); - }); - } - - private static final String INSERT_MQL = - """ - { - insert: "books", - documents: [ - { - _id: 1, - title: "War and Peace", - author: "Leo Tolstoy", - outOfStock: false, - tags: [ "classic", "tolstoy" ] - }, - { - _id: 2, - title: "Anna Karenina", - author: "Leo Tolstoy", - outOfStock: false, - tags: [ "classic", "tolstoy" ] - }, - { - _id: 3, - title: "Crime and Punishment", - author: "Fyodor Dostoevsky", - outOfStock: false, - tags: [ "classic", "dostoevsky", "literature" ] - } - ] - }"""; - @ParameterizedTest @ValueSource(booleans = {true, false}) void testUpdate(boolean autoCommit) { + clearData(); prepareData(); var expectedDocs = List.of( @@ -177,14 +165,6 @@ void testUpdate(boolean autoCommit) { assertExecuteUpdate(pstmtProvider, autoCommit, 2, expectedDocs); } - private void prepareData() { - session.doWork(connection -> { - connection.setAutoCommit(true); - var statement = connection.createStatement(); - statement.executeUpdate(INSERT_MQL); - }); - } - private void assertExecuteUpdate( Function pstmtProvider, boolean autoCommit, @@ -207,4 +187,432 @@ private void assertExecuteUpdate( }); } } + + @Nested + class BatchTests { + private static final int BATCH_SIZE = 2; + + @AutoClose + private static SessionFactory batchableSessionFactory; + + @AutoClose + private Session batchableSession; + + @BeforeAll + static void beforeAll() { + batchableSessionFactory = new Configuration() + .setProperty(AvailableSettings.STATEMENT_BATCH_SIZE, BATCH_SIZE) + .buildSessionFactory(); + } + + @BeforeEach + void beforeEach() { + batchableSession = batchableSessionFactory.openSession(); + } + + @Nested + class InsertTests { + + private static final String MQL = + """ + { + insert: "books", + documents: [ + { + _id: { $undefined: true }, + title: { $undefined: true } + } + ] + }"""; + + @BeforeEach + void beforeEach() { + clearData(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void test(boolean autoCommit) { + batchableSession.doWork(connection -> { + connection.setAutoCommit(autoCommit); + try (var pstmt = connection.prepareStatement(MQL)) { + try { + pstmt.setInt(1, 1); + pstmt.setString(2, "War and Peace"); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setString(2, "Anna Karenina"); + pstmt.addBatch(); + + pstmt.executeBatch(); + + pstmt.setInt(1, 3); + pstmt.setString(2, "Crime and Punishment"); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setString(2, "Notes from Underground"); + pstmt.addBatch(); + + pstmt.executeBatch(); + + pstmt.setInt(1, 5); + pstmt.setString(2, "Fathers and Sons"); + + pstmt.addBatch(); + + pstmt.executeBatch(); + } finally { + if (!autoCommit) { + connection.commit(); + } + pstmt.clearBatch(); + } + + var expectedDocuments = List.of( + BsonDocument.parse( + """ + { + _id: 1, + title: "War and Peace" + }"""), + BsonDocument.parse( + """ + { + _id: 2, + title: "Anna Karenina" + }"""), + BsonDocument.parse( + """ + { + _id: 3, + title: "Crime and Punishment" + }"""), + BsonDocument.parse( + """ + { + _id: 4, + title: "Notes from Underground" + }"""), + BsonDocument.parse( + """ + { + _id: 5, + title: "Fathers and Sons" + }""")); + + var realDocuments = ((MongoPreparedStatement) pstmt) + .getMongoDatabase() + .getCollection("books", BsonDocument.class) + .find() + .sort(Sorts.ascending("_id")) + .into(new ArrayList<>()); + assertEquals(expectedDocuments, realDocuments); + } + }); + } + } + + @Nested + class UpdateTests { + private static final String UPDATE_ONE_MQL = + """ + { + update: "books", + updates: [ + { + q: { _id: { $eq: { $undefined: true } } }, + u: { $set: { title: { $undefined: true } } }, + multi: false + } + ] + }"""; + + private static final String UPDATE_MANY_MQL = + """ + { + update: "books", + updates: [ + { + q: { author: { $eq: { $undefined: true } } }, + u: { $push: { tags: { $undefined: true } } }, + multi: true + } + ] + }"""; + + @BeforeEach + void beforeEach() { + clearData(); + prepareData(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testUpdateOne(boolean autoCommit) { + batchableSession.doWork(connection -> { + connection.setAutoCommit(autoCommit); + try (var pstmt = connection.prepareStatement(UPDATE_ONE_MQL)) { + try { + pstmt.setInt(1, 1); + pstmt.setString(2, "Insurrection"); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setString(2, "Hadji Murat"); + pstmt.addBatch(); + + pstmt.executeBatch(); + + pstmt.setInt(1, 3); + pstmt.setString(2, "The Brothers Karamazov"); + pstmt.addBatch(); + + pstmt.executeBatch(); + + } finally { + if (!autoCommit) { + connection.commit(); + } + pstmt.clearBatch(); + } + + var expectedDocuments = List.of( + BsonDocument.parse( + """ + { + _id: 1, + title: "Insurrection", + author: "Leo Tolstoy", + outOfStock: false, + tags: [ "classic", "tolstoy" ] + }"""), + BsonDocument.parse( + """ + { + _id: 2, + title: "Hadji Murat", + author: "Leo Tolstoy", + outOfStock: false, + tags: [ "classic", "tolstoy" ] + }"""), + BsonDocument.parse( + """ + { + _id: 3, + title: "The Brothers Karamazov", + author: "Fyodor Dostoevsky", + outOfStock: false, + tags: [ "classic", "dostoevsky", "literature" ] + }""")); + + var realDocuments = ((MongoPreparedStatement) pstmt) + .getMongoDatabase() + .getCollection("books", BsonDocument.class) + .find() + .sort(Sorts.ascending("_id")) + .into(new ArrayList<>()); + assertEquals(expectedDocuments, realDocuments); + } + }); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testUpdateMany(boolean autoCommit) { + batchableSession.doWork(connection -> { + connection.setAutoCommit(autoCommit); + try (var pstmt = connection.prepareStatement(UPDATE_MANY_MQL)) { + try { + pstmt.setString(1, "Leo Tolstoy"); + pstmt.setString(2, "russian"); + pstmt.addBatch(); + + pstmt.setString(1, "Fyodor Dostoevsky"); + pstmt.setString(2, "russian"); + pstmt.addBatch(); + + pstmt.executeBatch(); + + pstmt.setString(1, "Leo Tolstoy"); + pstmt.setString(2, "literature"); + pstmt.addBatch(); + + pstmt.executeBatch(); + + } finally { + if (!autoCommit) { + connection.commit(); + } + pstmt.clearBatch(); + } + + var expectedDocuments = List.of( + BsonDocument.parse( + """ + { + _id: 1, + title: "War and Peace", + author: "Leo Tolstoy", + outOfStock: false, + tags: [ "classic", "tolstoy", "russian", "literature" ] + }"""), + BsonDocument.parse( + """ + { + _id: 2, + title: "Anna Karenina", + author: "Leo Tolstoy", + outOfStock: false, + tags: [ "classic", "tolstoy", "russian", "literature" ] + }"""), + BsonDocument.parse( + """ + { + _id: 3, + title: "Crime and Punishment", + author: "Fyodor Dostoevsky", + outOfStock: false, + tags: [ "classic", "dostoevsky", "literature", "russian" ] + }""")); + + var realDocuments = ((MongoPreparedStatement) pstmt) + .getMongoDatabase() + .getCollection("books", BsonDocument.class) + .find() + .sort(Sorts.ascending("_id")) + .into(new ArrayList<>()); + assertEquals(expectedDocuments, realDocuments); + } + }); + } + } + + @Nested + class DeleteTests { + private static final String DELETE_ONE_MQL = + """ + { + delete: "books", + deletes: [ + { + q: { _id: { $eq: { $undefined: true } } }, + limit: 1 + } + ] + }"""; + + private static final String DELETE_MANY_MQL = + """ + { + delete: "books", + deletes: [ + { + q: { author: { $eq: { $undefined: true } } }, + limit: 0 + } + ] + }"""; + + @BeforeEach + void beforeEach() { + clearData(); + prepareData(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testDeleteOne(boolean autoCommit) { + batchableSession.doWork(connection -> { + connection.setAutoCommit(autoCommit); + try (var pstmt = connection.prepareStatement(DELETE_ONE_MQL)) { + try { + pstmt.setInt(1, 1); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.addBatch(); + + pstmt.executeBatch(); + + pstmt.setInt(1, 3); + pstmt.addBatch(); + + pstmt.executeBatch(); + + } finally { + if (!autoCommit) { + connection.commit(); + } + pstmt.clearBatch(); + } + + var realDocuments = ((MongoPreparedStatement) pstmt) + .getMongoDatabase() + .getCollection("books", BsonDocument.class) + .find() + .sort(Sorts.ascending("_id")) + .into(new ArrayList<>()); + assertEquals(Collections.emptyList(), realDocuments); + } + }); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testDeleteMany(boolean autoCommit) { + batchableSession.doWork(connection -> { + connection.setAutoCommit(autoCommit); + try (var pstmt = connection.prepareStatement(DELETE_MANY_MQL)) { + try { + pstmt.setString(1, "Leo Tolstoy"); + pstmt.addBatch(); + + pstmt.setString(1, "Fyodor Dostoevsky"); + pstmt.addBatch(); + + pstmt.executeBatch(); + + } finally { + if (!autoCommit) { + connection.commit(); + } + pstmt.clearBatch(); + } + + var realDocuments = ((MongoPreparedStatement) pstmt) + .getMongoDatabase() + .getCollection("books", BsonDocument.class) + .find() + .sort(Sorts.ascending("_id")) + .into(new ArrayList<>()); + assertEquals(Collections.emptyList(), realDocuments); + } + }); + } + } + } + + private void prepareData() { + session.doWork(connection -> { + connection.setAutoCommit(true); + var statement = connection.createStatement(); + statement.executeUpdate(INIT_INSERT_MQL); + }); + } + + private void clearData() { + session.doWork(conn -> { + conn.createStatement() + .executeUpdate( + """ + { + delete: "books", + deletes: [ + { q: {}, limit: 0 } + ] + }"""); + }); + } } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 3f0cd011..14f41fa2 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -16,11 +16,19 @@ package com.mongodb.hibernate.jdbc; +import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; +import static com.mongodb.hibernate.internal.MongoAssertions.assertTrue; import static com.mongodb.hibernate.internal.MongoAssertions.fail; import static java.lang.String.format; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.DeleteManyModel; +import com.mongodb.client.model.DeleteOneModel; +import com.mongodb.client.model.InsertOneModel; +import com.mongodb.client.model.UpdateManyModel; +import com.mongodb.client.model.UpdateOneModel; +import com.mongodb.client.model.WriteModel; import com.mongodb.hibernate.internal.FeatureNotSupportedException; import java.io.InputStream; import java.math.BigDecimal; @@ -31,10 +39,12 @@ import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; import java.sql.SQLSyntaxErrorException; +import java.sql.Statement; import java.sql.Time; import java.sql.Timestamp; import java.sql.Types; import java.util.ArrayList; +import java.util.Arrays; import java.util.Calendar; import java.util.List; import java.util.function.Consumer; @@ -49,12 +59,16 @@ import org.bson.BsonNull; import org.bson.BsonString; import org.bson.BsonType; +import org.bson.BsonUndefined; import org.bson.BsonValue; import org.bson.types.Decimal128; final class MongoPreparedStatement extends MongoStatement implements PreparedStatementAdapter { + private static final BsonUndefined PARAMETER_PLACEHOLDER = new BsonUndefined(); + private final BsonDocument command; + private final List commandBatch; private final List> parameterValueSetters; @@ -62,8 +76,9 @@ final class MongoPreparedStatement extends MongoStatement implements PreparedSta MongoDatabase mongoDatabase, ClientSession clientSession, MongoConnection mongoConnection, String mql) throws SQLSyntaxErrorException { super(mongoDatabase, clientSession, mongoConnection); - this.command = MongoStatement.parse(mql); - this.parameterValueSetters = new ArrayList<>(); + command = MongoStatement.parse(mql); + commandBatch = new ArrayList<>(); + parameterValueSetters = new ArrayList<>(); parseParameters(command, parameterValueSetters); } @@ -205,20 +220,14 @@ public void setBinaryStream(int parameterIndex, InputStream x, int length) throw public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - throw new FeatureNotSupportedException("To be implemented during Array / Struct tickets"); + throw new FeatureNotSupportedException("To be implemented in scope of Array / Struct tickets"); } @Override public void setObject(int parameterIndex, Object x) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - throw new FeatureNotSupportedException("To be implemented during Array / Struct tickets"); - } - - @Override - public void addBatch() throws SQLException { - checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-35 https://jira.mongodb.org/browse/HIBERNATE-35"); + throw new FeatureNotSupportedException("To be implemented in scope of Array / Struct tickets"); } @Override @@ -253,7 +262,97 @@ public void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws S public void setNull(int parameterIndex, int sqlType, String typeName) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - throw new FeatureNotSupportedException("To be implemented during Array / Struct tickets"); + throw new FeatureNotSupportedException("To be implemented in scope of Array / Struct tickets"); + } + + @Override + public void addBatch() throws SQLException { + checkClosed(); + commandBatch.add(command.clone()); + parameterValueSetters.forEach(setter -> setter.accept(PARAMETER_PLACEHOLDER)); + } + + @Override + public void clearBatch() throws SQLException { + checkClosed(); + commandBatch.clear(); + } + + @Override + public int[] executeBatch() throws SQLException { + checkClosed(); + startTransactionIfNeeded(); + + if (commandBatch.isEmpty()) { + return new int[0]; + } + + try { + var writeModels = new ArrayList>(commandBatch.size()); + + // Hibernate will group PreparedStatement by both table and mutation type + var commandName = assertNotNull(commandBatch.get(0).getFirstKey()); + var collectionName = + assertNotNull(commandBatch.get(0).getString(commandName).getValue()); + + for (var command : commandBatch) { + + assertTrue(commandName.equals(command.getFirstKey())); + assertTrue(collectionName.equals(command.getString(commandName).getValue())); + + switch (commandName) { + case "insert": + var documents = command.getArray("documents"); + for (var document : documents) { + writeModels.add(new InsertOneModel<>((BsonDocument) document)); + } + break; + case "update": + var updates = command.getArray("updates").getValues(); + for (var update : updates) { + var updateDocument = (BsonDocument) update; + WriteModel updateModel = + !updateDocument.getBoolean("multi").getValue() + ? new UpdateOneModel<>( + updateDocument.getDocument("q"), updateDocument.getDocument("u")) + : new UpdateManyModel<>( + updateDocument.getDocument("q"), updateDocument.getDocument("u")); + writeModels.add(updateModel); + } + break; + case "delete": + var deletes = command.getArray("deletes"); + for (var delete : deletes) { + var deleteDocument = (BsonDocument) delete; + writeModels.add( + deleteDocument.getNumber("limit").intValue() == 1 + ? new DeleteOneModel<>(deleteDocument.getDocument("q")) + : new DeleteManyModel<>(deleteDocument.getDocument("q"))); + } + break; + default: + throw new FeatureNotSupportedException(); + } + } + getMongoDatabase() + .getCollection(collectionName, BsonDocument.class) + .bulkWrite(getClientSession(), writeModels); + + // TODO-HIBERNATE-43 log bulk write result in debug level + + var rowCounts = new int[commandBatch.size()]; + + // MongoDB bulk write API returns row counts grouped by mutation types, not by each command in the batch, + // so returns 'SUCCESS_NO_INFO' to work around + Arrays.fill(rowCounts, Statement.SUCCESS_NO_INFO); + + return rowCounts; + + } catch (RuntimeException e) { + throw new SQLException("Failed to execute batch: " + e.getMessage(), e); + } finally { + clearBatch(); + } } private void setParameter(int parameterIndex, BsonValue parameterValue) { diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index b88b983d..fe068053 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -43,6 +43,14 @@ class MongoStatement implements StatementAdapter { this.clientSession = clientSession; } + MongoDatabase getMongoDatabase() { + return mongoDatabase; + } + + ClientSession getClientSession() { + return clientSession; + } + @Override public ResultSet executeQuery(String mql) throws SQLException { checkClosed(); @@ -152,24 +160,6 @@ public int getFetchSize() throws SQLException { throw new FeatureNotSupportedException("TODO-HIBERNATE-21 https://jira.mongodb.org/browse/HIBERNATE-21"); } - @Override - public void addBatch(String mql) throws SQLException { - checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-35 https://jira.mongodb.org/browse/HIBERNATE-35"); - } - - @Override - public void clearBatch() throws SQLException { - checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-35 https://jira.mongodb.org/browse/HIBERNATE-35"); - } - - @Override - public int[] executeBatch() throws SQLException { - checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-35 https://jira.mongodb.org/browse/HIBERNATE-35"); - } - @Override public Connection getConnection() throws SQLException { checkClosed(); @@ -205,7 +195,7 @@ static BsonDocument parse(String mql) throws SQLSyntaxErrorException { * Starts transaction for the first {@link java.sql.Statement} executing if * {@linkplain MongoConnection#getAutoCommit() auto-commit} is disabled. */ - private void startTransactionIfNeeded() throws SQLException { + void startTransactionIfNeeded() throws SQLException { if (!mongoConnection.getAutoCommit() && !clientSession.hasActiveTransaction()) { clientSession.startTransaction(); } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java index 554f7541..589d6853 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java @@ -157,9 +157,6 @@ private static Stream getMongoStatementMethodInvocationsImpactedByClo Map.entry("getUpdateCount()", MongoStatement::getUpdateCount), Map.entry("setFetchSize(int)", stmt -> stmt.setFetchSize(1)), Map.entry("getFetchSize()", MongoStatement::getFetchSize), - Map.entry("addBatch(String)", stmt -> stmt.addBatch(exampleUpdateMql)), - Map.entry("clearBatch()", MongoStatement::clearBatch), - Map.entry("executeBatch()", MongoStatement::executeBatch), Map.entry("getConnection()", MongoStatement::getConnection), Map.entry("isWrapperFor(Class)", stmt -> stmt.isWrapperFor(Statement.class))) .entrySet()