diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index c4b37c9f..84148316 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -17,28 +17,39 @@ package com.mongodb.hibernate.jdbc; import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; +import static com.mongodb.hibernate.jdbc.MongoStatementIntegrationTests.doAndTerminateTransaction; +import static com.mongodb.hibernate.jdbc.MongoStatementIntegrationTests.doWithSpecifiedAutoCommit; +import static com.mongodb.hibernate.jdbc.MongoStatementIntegrationTests.insertTestData; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import com.mongodb.client.MongoCollection; import com.mongodb.client.model.Sorts; +import com.mongodb.hibernate.jdbc.MongoStatementIntegrationTests.SqlExecutable; import com.mongodb.hibernate.junit.InjectMongoCollection; import com.mongodb.hibernate.junit.MongoExtension; +import java.math.BigDecimal; import java.sql.Connection; import java.sql.SQLException; import java.util.List; +import java.util.Random; import java.util.function.Function; import org.bson.BsonDocument; import org.hibernate.Session; import org.hibernate.SessionFactory; import org.hibernate.cfg.Configuration; +import org.hibernate.jdbc.Work; import org.junit.jupiter.api.AutoClose; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; @ExtendWith(MongoExtension.class) class MongoPreparedStatementIntegrationTests { @@ -62,23 +73,145 @@ void beforeEach() { session = sessionFactory.openSession(); } - @Nested - class ExecuteUpdateTests { + @Test + void testExecuteQuery() { + + insertTestData( + session, + """ + { + insert: "books", + documents: [ + { _id: 1, publishYear: 1867, title: "War and Peace", author: "Leo Tolstoy", comment: "reference only", vintage: false }, + { _id: 2, publishYear: 1878, title: "Anna Karenina", author: "Leo Tolstoy", vintage: true, comment: null}, + { _id: 3, publishYear: 1866, author: "Fyodor Dostoevsky", title: "Crime and Punishment", vintage: false, comment: null }, + ] + }"""); + + doWorkAwareOfAutoCommit(connection -> { + try (var pstmt = connection.prepareStatement( + """ + { + aggregate: "books", + pipeline: [ + { $match: { author: { $eq: { $undefined: true } } } }, + { $project: { author: 1, _id: 0, vintage: 1, publishYear: 1, comment: 1, title: 1 } } + ] + }""")) { + + pstmt.setString(1, "Leo Tolstoy"); + + try (var rs = pstmt.executeQuery()) { + + assertTrue(rs.next()); + assertAll( + () -> assertEquals("Leo Tolstoy", rs.getString(1)), + () -> assertFalse(rs.getBoolean(2)), + () -> assertEquals(1867, rs.getInt(3)), + () -> assertEquals("reference only", rs.getString(4)), + () -> assertEquals("War and Peace", rs.getString(5))); + assertTrue(rs.next()); + assertAll( + () -> assertEquals("Leo Tolstoy", rs.getString(1)), + () -> assertTrue(rs.getBoolean(2)), + () -> assertEquals(1878, rs.getInt(3)), + () -> assertNull(rs.getString(4)), + () -> assertEquals("Anna Karenina", rs.getString(5))); + assertFalse(rs.next()); + } + } + }); + } + + @Test + void testPreparedStatementAndResultSetRoundTrip() { + + var random = new Random(); + + boolean booleanValue = random.nextBoolean(); + double doubleValue = random.nextDouble(); + int intValue = random.nextInt(); + long longValue = random.nextLong(); + + byte[] bytes = new byte[64]; + random.nextBytes(bytes); + String stringValue = new String(bytes); + + BigDecimal bigDecimalValue = new BigDecimal(random.nextInt()); + + doWorkAwareOfAutoCommit(connection -> { + try (var pstmt = connection.prepareStatement( + """ + { + insert: "books", + documents: [ + { + _id: 1, + booleanField: { $undefined: true }, + doubleField: { $undefined: true }, + intField: { $undefined: true }, + longField: { $undefined: true }, + stringField: { $undefined: true }, + bigDecimalField: { $undefined: true }, + bytesField: { $undefined: true } + } + ] + }""")) { + + pstmt.setBoolean(1, booleanValue); + pstmt.setDouble(2, doubleValue); + pstmt.setInt(3, intValue); + pstmt.setLong(4, longValue); + pstmt.setString(5, stringValue); + pstmt.setBigDecimal(6, bigDecimalValue); + pstmt.setBytes(7, bytes); + + pstmt.executeUpdate(); + } + }); - @BeforeEach - void beforeEach() { - session.doWork(conn -> { - conn.createStatement() - .executeUpdate( - """ + doWorkAwareOfAutoCommit(connection -> { + try (var pstmt = connection.prepareStatement( + """ + { + aggregate: "books", + pipeline: [ + { $match: { _id: { $eq: { $undefined: true } } } }, + { $project: { - delete: "books", - deletes: [ - { q: {}, limit: 0 } - ] - }"""); - }); - } + _id: 0, + booleanField: 1, + doubleField: 1, + intField: 1, + longField: 1, + stringField: 1, + bigDecimalField: 1, + bytesField: 1 + } + } + ] + }""")) { + + pstmt.setInt(1, 1); + try (var rs = pstmt.executeQuery()) { + + assertTrue(rs.next()); + assertAll( + () -> assertEquals(booleanValue, rs.getBoolean(1)), + () -> assertEquals(doubleValue, rs.getDouble(2)), + () -> assertEquals(intValue, rs.getInt(3)), + () -> assertEquals(longValue, rs.getLong(4)), + () -> assertEquals(stringValue, rs.getString(5)), + () -> assertEquals(bigDecimalValue, rs.getBigDecimal(6)), + () -> assertArrayEquals(bytes, rs.getBytes(7))); + assertFalse(rs.next()); + } + } + }); + } + + @Nested + class ExecuteUpdateTests { private static final String INSERT_MQL = """ @@ -109,11 +242,51 @@ void beforeEach() { ] }"""; - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testUpdate(boolean autoCommit) { + @Test + void testInsert() { + Function pstmtProvider = connection -> { + try { + var pstmt = (MongoPreparedStatement) + connection.prepareStatement( + """ + { + insert: "books", + documents: [ + { + _id: 1, + title: {$undefined: true}, + author: {$undefined: true}, + outOfStock: false, + tags: [ "classic", "tolstoy" ] + } + ] + }"""); + pstmt.setString(1, "War and Peace"); + pstmt.setString(2, "Leo Tolstoy"); + return pstmt; + } catch (SQLException e) { + throw new RuntimeException(e); + } + }; + assertExecuteUpdate( + pstmtProvider, + 1, + List.of( + BsonDocument.parse( + """ + { + _id: 1, + title: "War and Peace", + author: "Leo Tolstoy", + outOfStock: false, + tags: [ "classic", "tolstoy" ] + }"""))); + } + + @Test + void testUpdate() { - prepareData(); + insertTestData(session, INSERT_MQL); var expectedDocs = List.of( BsonDocument.parse( @@ -171,36 +344,67 @@ void testUpdate(boolean autoCommit) { throw new RuntimeException(e); } }; - assertExecuteUpdate(pstmtProvider, autoCommit, 2, expectedDocs); + assertExecuteUpdate(pstmtProvider, 2, expectedDocs); } - private void prepareData() { - session.doWork(connection -> { - connection.setAutoCommit(true); - var statement = connection.createStatement(); - statement.executeUpdate(INSERT_MQL); - }); + @Test + void testDelete() { + insertTestData(session, INSERT_MQL); + + Function pstmtProvider = connection -> { + try { + var pstmt = (MongoPreparedStatement) + connection.prepareStatement( + """ + { + delete: "books", + deletes: [ + { + q: { author: { $undefined: true } }, + limit: 0 + } + ] + }"""); + pstmt.setString(1, "Leo Tolstoy"); + return pstmt; + } catch (SQLException e) { + throw new RuntimeException(e); + } + }; + assertExecuteUpdate( + pstmtProvider, + 2, + List.of( + BsonDocument.parse( + """ + { + _id: 3, + title: "Crime and Punishment", + author: "Fyodor Dostoevsky", + outOfStock: false, + tags: [ "classic", "dostoevsky", "literature" ] + }"""))); } private void assertExecuteUpdate( Function pstmtProvider, - boolean autoCommit, int expectedUpdatedRowCount, List expectedDocuments) { - session.doWork(connection -> { - connection.setAutoCommit(autoCommit); + doWorkAwareOfAutoCommit(connection -> { try (var pstmt = pstmtProvider.apply(connection)) { - try { - assertEquals(expectedUpdatedRowCount, pstmt.executeUpdate()); - } finally { - if (!autoCommit) { - connection.commit(); - } - } - assertThat(mongoCollection.find().sort(Sorts.ascending(ID_FIELD_NAME))) - .containsExactlyElementsOf(expectedDocuments); + assertEquals(expectedUpdatedRowCount, pstmt.executeUpdate()); } }); + assertThat(mongoCollection.find().sort(Sorts.ascending(ID_FIELD_NAME))) + .containsExactlyElementsOf(expectedDocuments); } } + + private void doWorkAwareOfAutoCommit(Work work) { + session.doWork(connection -> doAwareOfAutoCommit(connection, () -> work.execute(connection))); + } + + void doAwareOfAutoCommit(Connection connection, SqlExecutable work) throws SQLException { + doWithSpecifiedAutoCommit(false, connection, () -> doAndTerminateTransaction(connection, work)); + } } diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementWithAutoCommitIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementWithAutoCommitIntegrationTests.java new file mode 100644 index 00000000..a636b015 --- /dev/null +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementWithAutoCommitIntegrationTests.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.jdbc; + +import static com.mongodb.hibernate.jdbc.MongoStatementIntegrationTests.doWithSpecifiedAutoCommit; + +import com.mongodb.hibernate.jdbc.MongoStatementIntegrationTests.SqlExecutable; +import java.sql.Connection; +import java.sql.SQLException; + +class MongoPreparedStatementWithAutoCommitIntegrationTests extends MongoPreparedStatementIntegrationTests { + @Override + void doAwareOfAutoCommit(Connection connection, SqlExecutable work) throws SQLException { + doWithSpecifiedAutoCommit(true, connection, work); + } +} diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoStatementIntegrationTests.java index 8b5e5931..fc220a61 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoStatementIntegrationTests.java @@ -18,24 +18,29 @@ import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import com.mongodb.client.MongoCollection; import com.mongodb.client.model.Sorts; import com.mongodb.hibernate.junit.InjectMongoCollection; import com.mongodb.hibernate.junit.MongoExtension; +import java.sql.Connection; +import java.sql.SQLException; import java.util.List; import org.bson.BsonDocument; import org.hibernate.Session; import org.hibernate.SessionFactory; import org.hibernate.cfg.Configuration; +import org.hibernate.jdbc.Work; import org.junit.jupiter.api.AutoClose; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; @ExtendWith(MongoExtension.class) class MongoStatementIntegrationTests { @@ -59,24 +64,53 @@ void beforeEach() { session = sessionFactory.openSession(); } + @Test + void testExecuteQuery() { + + insertTestData( + session, + """ + { + insert: "books", + documents: [ + { _id: 1, publishYear: 1867, title: "War and Peace", author: "Leo Tolstoy" }, + { _id: 2, publishYear: 1878, author: "Leo Tolstoy", title: "Anna Karenina" }, + { _id: 3, publishYear: 1866, title: "Crime and Punishment", author: "Fyodor Dostoevsky" } + ] + }"""); + + doWorkAwareOfAutoCommit(connection -> { + try (var stmt = connection.createStatement()) { + try (var rs = stmt.executeQuery( + """ + { + aggregate: "books", + pipeline: [ + { $match: { author: { $eq: "Leo Tolstoy" } } }, + { $project: { author: 1, _id: 0, publishYear: 1, title: 1 } } + ] + }""")) { + assertTrue(rs.next()); + assertAll( + () -> assertEquals("Leo Tolstoy", rs.getString(1)), + () -> assertEquals(1867, rs.getInt(2)), + () -> assertEquals("War and Peace", rs.getString(3))); + + assertTrue(rs.next()); + assertAll( + () -> assertEquals("Leo Tolstoy", rs.getString(1)), + () -> assertEquals(1878, rs.getInt(2)), + () -> assertEquals("Anna Karenina", rs.getString(3))); + + assertFalse(rs.next()); + } + } + }); + } + @Nested class ExecuteUpdateTests { - @BeforeEach - void beforeEach() { - session.doWork(conn -> { - conn.createStatement() - .executeUpdate( - """ - { - delete: "books", - deletes: [ - { q: {}, limit: 0 } - ] - }"""); - }); - } - private static final String INSERT_MQL = """ { @@ -103,9 +137,8 @@ void beforeEach() { ] }"""; - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testInsert(boolean autoCommit) { + @Test + void testInsert() { var expectedDocs = List.of( BsonDocument.parse( """ @@ -131,14 +164,13 @@ void testInsert(boolean autoCommit) { author: "Fyodor Dostoevsky", outOfStock: false }""")); - assertExecuteUpdate(INSERT_MQL, autoCommit, 3, expectedDocs); + assertExecuteUpdate(INSERT_MQL, 3, expectedDocs); } - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testUpdate(boolean autoCommit) { + @Test + void testUpdate() { - prepareData(); + insertTestData(session, INSERT_MQL); var updateMql = """ @@ -179,14 +211,13 @@ void testUpdate(boolean autoCommit) { author: "Fyodor Dostoevsky", outOfStock: false }""")); - assertExecuteUpdate(updateMql, autoCommit, 2, expectedDocs); + assertExecuteUpdate(updateMql, 2, expectedDocs); } - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testDelete(boolean autoCommit) { + @Test + void testDelete() { - prepareData(); + insertTestData(session, INSERT_MQL); var deleteMql = """ @@ -216,33 +247,71 @@ void testDelete(boolean autoCommit) { author: "Fyodor Dostoevsky", outOfStock: false }""")); - assertExecuteUpdate(deleteMql, autoCommit, 1, expectedDocs); + assertExecuteUpdate(deleteMql, 1, expectedDocs); } - private void prepareData() { - session.doWork(connection -> { - connection.setAutoCommit(true); - var statement = connection.createStatement(); - statement.executeUpdate(INSERT_MQL); + private void assertExecuteUpdate( + String mql, int expectedRowCount, List expectedDocuments) { + doWorkAwareOfAutoCommit(connection -> { + try (var stmt = (MongoStatement) connection.createStatement()) { + assertEquals(expectedRowCount, stmt.executeUpdate(mql)); + } }); + assertThat(mongoCollection.find().sort(Sorts.ascending(ID_FIELD_NAME))) + .containsExactlyElementsOf(expectedDocuments); } + } - private void assertExecuteUpdate( - String mql, boolean autoCommit, int expectedRowCount, List expectedDocuments) { - session.doWork(connection -> { - connection.setAutoCommit(autoCommit); - try (var stmt = (MongoStatement) connection.createStatement()) { - try { - assertEquals(expectedRowCount, stmt.executeUpdate(mql)); - } finally { - if (!autoCommit) { - connection.commit(); - } + static void insertTestData(Session session, String insertMql) { + session.doWork(connection -> doWithSpecifiedAutoCommit( + false, + connection, + () -> doAndTerminateTransaction(connection, () -> { + try (var statement = connection.createStatement()) { + statement.executeUpdate(insertMql); } - assertThat(mongoCollection.find().sort(Sorts.ascending(ID_FIELD_NAME))) - .containsExactlyElementsOf(expectedDocuments); + }))); + } + + private void doWorkAwareOfAutoCommit(Work work) { + session.doWork(connection -> doAwareOfAutoCommit(connection, () -> work.execute(connection))); + } + + void doAwareOfAutoCommit(Connection connection, SqlExecutable work) throws SQLException { + doWithSpecifiedAutoCommit(false, connection, () -> doAndTerminateTransaction(connection, work)); + } + + static void doWithSpecifiedAutoCommit(boolean autoCommit, Connection connection, SqlExecutable work) + throws SQLException { + var originalAutoCommit = connection.getAutoCommit(); + connection.setAutoCommit(autoCommit); + try { + work.execute(); + } finally { + connection.setAutoCommit(originalAutoCommit); + } + } + + static void doAndTerminateTransaction(Connection connectionNoAutoCommit, SqlExecutable work) throws SQLException { + Throwable primaryException = null; + try { + work.execute(); + connectionNoAutoCommit.commit(); + } catch (Throwable e) { + primaryException = e; + throw e; + } finally { + if (primaryException != null) { + try { + connectionNoAutoCommit.rollback(); + } catch (Throwable suppressedException) { + primaryException.addSuppressed(suppressedException); } - }); + } } } + + interface SqlExecutable { + void execute() throws SQLException; + } } diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoStatementWithAutoCommitIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoStatementWithAutoCommitIntegrationTests.java new file mode 100644 index 00000000..10bae994 --- /dev/null +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoStatementWithAutoCommitIntegrationTests.java @@ -0,0 +1,27 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.jdbc; + +import java.sql.Connection; +import java.sql.SQLException; + +class MongoStatementWithAutoCommitIntegrationTests extends MongoStatementIntegrationTests { + @Override + void doAwareOfAutoCommit(Connection connection, SqlExecutable work) throws SQLException { + doWithSpecifiedAutoCommit(true, connection, work); + } +} diff --git a/src/main/java/com/mongodb/hibernate/internal/extension/MongoAdditionalMappingContributor.java b/src/main/java/com/mongodb/hibernate/internal/extension/MongoAdditionalMappingContributor.java index 15e39cc7..668a394a 100644 --- a/src/main/java/com/mongodb/hibernate/internal/extension/MongoAdditionalMappingContributor.java +++ b/src/main/java/com/mongodb/hibernate/internal/extension/MongoAdditionalMappingContributor.java @@ -22,6 +22,7 @@ import static java.lang.String.format; import com.mongodb.hibernate.internal.FeatureNotSupportedException; +import org.hibernate.annotations.DynamicInsert; import org.hibernate.boot.ResourceStreamLocator; import org.hibernate.boot.spi.AdditionalMappingContributions; import org.hibernate.boot.spi.AdditionalMappingContributor; @@ -44,7 +45,13 @@ public void contribute( InFlightMetadataCollector metadata, ResourceStreamLocator resourceStreamLocator, MetadataBuildingContext buildingContext) { - metadata.getEntityBindings().forEach(MongoAdditionalMappingContributor::setIdentifierColumnName); + metadata.getEntityBindings().forEach(persistentClass -> { + if (persistentClass.useDynamicInsert()) { + throw new FeatureNotSupportedException( + format("%s is not supported", DynamicInsert.class.getSimpleName())); + } + setIdentifierColumnName(persistentClass); + }); } private static void setIdentifierColumnName(PersistentClass persistentClass) { diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoConnection.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoConnection.java index 0cd77191..6fb5a4d4 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoConnection.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoConnection.java @@ -28,7 +28,9 @@ import java.sql.Array; import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; import java.sql.SQLWarning; import java.sql.Statement; import java.sql.Struct; @@ -52,9 +54,6 @@ final class MongoConnection implements ConnectionAdapter { autoCommit = true; } - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // transaction - @Override public void setAutoCommit(boolean autoCommit) throws SQLException { checkClosed(); @@ -107,9 +106,6 @@ public void rollback() throws SQLException { } } - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // close() and isClosed() - @Override public void close() throws SQLException { if (!closed) { @@ -127,9 +123,6 @@ public boolean isClosed() { return closed; } - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // Statement and PreparedStatement - @Override public Statement createStatement() throws SQLException { checkClosed(); @@ -139,19 +132,24 @@ public Statement createStatement() throws SQLException { @Override public PreparedStatement prepareStatement(String mql) throws SQLException { checkClosed(); - return new MongoPreparedStatement(mongoDatabase, clientSession, this, mql); + return prepareStatement(mql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); } @Override public PreparedStatement prepareStatement(String mql, int resultSetType, int resultSetConcurrency) throws SQLException { checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-21 https://jira.mongodb.org/browse/HIBERNATE-21"); + if (resultSetType != ResultSet.TYPE_FORWARD_ONLY) { + throw new SQLFeatureNotSupportedException( + "Unsupported result set type (only TYPE_FORWARD_ONLY is supported): " + resultSetType); + } + if (resultSetConcurrency != ResultSet.CONCUR_READ_ONLY) { + throw new SQLFeatureNotSupportedException( + "Unsupported result set concurrency (only CONCUR_READ_ONLY is supported): " + resultSetConcurrency); + } + return new MongoPreparedStatement(mongoDatabase, clientSession, this, mql); } - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // SQL99 data types - @Override public Array createArrayOf(String typeName, Object[] elements) throws SQLException { checkClosed(); @@ -164,9 +162,6 @@ public Struct createStruct(String typeName, Object[] attributes) throws SQLExcep throw new FeatureNotSupportedException(); } - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // Database meta data - @Override public DatabaseMetaData getMetaData() throws SQLException { checkClosed(); @@ -203,21 +198,12 @@ public DatabaseMetaData getMetaData() throws SQLException { return null; } - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // dummy implementations - - /** - * Only used in {@link org.hibernate.engine.jdbc.spi.SqlExceptionHelper}. - * - *

Currently no need arises to record warning in this connection class. - */ @Override public @Nullable SQLWarning getWarnings() throws SQLException { checkClosed(); return null; } - /** Only used in {@link org.hibernate.engine.jdbc.spi.SqlExceptionHelper}. */ @Override public void clearWarnings() throws SQLException { checkClosed(); diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 3f0cd011..98125c2d 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -22,11 +22,11 @@ import com.mongodb.client.ClientSession; import com.mongodb.client.MongoDatabase; import com.mongodb.hibernate.internal.FeatureNotSupportedException; -import java.io.InputStream; import java.math.BigDecimal; import java.sql.Array; import java.sql.Date; import java.sql.JDBCType; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; @@ -56,7 +56,7 @@ final class MongoPreparedStatement extends MongoStatement implements PreparedSta private final BsonDocument command; - private final List> parameterValueSetters; + private final List parameterValueSetters; MongoPreparedStatement( MongoDatabase mongoDatabase, ClientSession clientSession, MongoConnection mongoConnection, String mql) @@ -70,15 +70,27 @@ final class MongoPreparedStatement extends MongoStatement implements PreparedSta @Override public ResultSet executeQuery() throws SQLException { checkClosed(); - throw new FeatureNotSupportedException(); + closeLastOpenResultSet(); + checkAllParametersSet(); + return executeQueryCommand(command); } @Override public int executeUpdate() throws SQLException { checkClosed(); + closeLastOpenResultSet(); + checkAllParametersSet(); return executeUpdateCommand(command); } + private void checkAllParametersSet() throws SQLException { + for (var i = 0; i < parameterValueSetters.size(); i++) { + if (!parameterValueSetters.get(i).isUsed()) { + throw new SQLException(format("Parameter with index [%d] is not set", i + 1)); + } + } + } + @Override public void setNull(int parameterIndex, int sqlType) throws SQLException { checkClosed(); @@ -110,20 +122,6 @@ public void setBoolean(int parameterIndex, boolean x) throws SQLException { setParameter(parameterIndex, BsonBoolean.valueOf(x)); } - @Override - public void setByte(int parameterIndex, byte x) throws SQLException { - checkClosed(); - checkParameterIndex(parameterIndex); - setInt(parameterIndex, x); - } - - @Override - public void setShort(int parameterIndex, short x) throws SQLException { - checkClosed(); - checkParameterIndex(parameterIndex); - setInt(parameterIndex, x); - } - @Override public void setInt(int parameterIndex, int x) throws SQLException { checkClosed(); @@ -138,13 +136,6 @@ public void setLong(int parameterIndex, long x) throws SQLException { setParameter(parameterIndex, new BsonInt64(x)); } - @Override - public void setFloat(int parameterIndex, float x) throws SQLException { - checkClosed(); - checkParameterIndex(parameterIndex); - setDouble(parameterIndex, x); - } - @Override public void setDouble(int parameterIndex, double x) throws SQLException { checkClosed(); @@ -194,25 +185,18 @@ public void setTimestamp(int parameterIndex, Timestamp x) throws SQLException { throw new FeatureNotSupportedException("TODO-HIBERNATE-42 https://jira.mongodb.org/browse/HIBERNATE-42"); } - @Override - public void setBinaryStream(int parameterIndex, InputStream x, int length) throws SQLException { - checkClosed(); - checkParameterIndex(parameterIndex); - throw new FeatureNotSupportedException(); - } - @Override public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - throw new FeatureNotSupportedException("To be implemented during Array / Struct tickets"); + throw new FeatureNotSupportedException("To be implemented in scope of Array / Struct tickets"); } @Override public void setObject(int parameterIndex, Object x) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - throw new FeatureNotSupportedException("To be implemented during Array / Struct tickets"); + throw new FeatureNotSupportedException("To be implemented in scope of Array / Struct tickets"); } @Override @@ -253,7 +237,39 @@ public void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws S public void setNull(int parameterIndex, int sqlType, String typeName) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - throw new FeatureNotSupportedException("To be implemented during Array / Struct tickets"); + throw new FeatureNotSupportedException("To be implemented in scope of Array / Struct tickets"); + } + + @Override + public void setQueryTimeout(int seconds) throws SQLException { + checkClosed(); + throw new FeatureNotSupportedException("TODO-HIBERNATE-55 https://jira.mongodb.org/browse/HIBERNATE-55"); + } + + @Override + public void setFetchSize(int rows) throws SQLException { + checkClosed(); + throw new FeatureNotSupportedException("TODO-HIBERNATE-54 https://jira.mongodb.org/browse/HIBERNATE-54"); + } + + @Override + public ResultSet executeQuery(String mql) throws SQLException { + throw new SQLException(format("Must not be called on %s", PreparedStatement.class.getSimpleName())); + } + + @Override + public int executeUpdate(String mql) throws SQLException { + throw new SQLException(format("Must not be called on %s", PreparedStatement.class.getSimpleName())); + } + + @Override + public boolean execute(String mql) throws SQLException { + throw new SQLException(format("Must not be called on %s", PreparedStatement.class.getSimpleName())); + } + + @Override + public void addBatch(String mql) throws SQLException { + throw new SQLException(format("Must not be called on %s", PreparedStatement.class.getSimpleName())); } private void setParameter(int parameterIndex, BsonValue parameterValue) { @@ -261,29 +277,29 @@ private void setParameter(int parameterIndex, BsonValue parameterValue) { parameterValueSetter.accept(parameterValue); } - private static void parseParameters(BsonDocument command, List> parameterValueSetters) { + private static void parseParameters(BsonDocument command, List parameterValueSetters) { for (var entry : command.entrySet()) { if (isParameterMarker(entry.getValue())) { - parameterValueSetters.add(entry::setValue); + parameterValueSetters.add(new ParameterValueSetter(entry::setValue)); } else if (entry.getValue().getBsonType().isContainer()) { parseParameters(entry.getValue(), parameterValueSetters); } } } - private static void parseParameters(BsonArray array, List> parameterValueSetters) { + private static void parseParameters(BsonArray array, List parameterValueSetters) { for (var i = 0; i < array.size(); i++) { var value = array.get(i); if (isParameterMarker(value)) { var idx = i; - parameterValueSetters.add(v -> array.set(idx, v)); + parameterValueSetters.add(new ParameterValueSetter(v -> array.set(idx, v))); } else if (value.getBsonType().isContainer()) { parseParameters(value, parameterValueSetters); } } } - private static void parseParameters(BsonValue value, List> parameterValueSetters) { + private static void parseParameters(BsonValue value, List parameterValueSetters) { if (value.isDocument()) { parseParameters(value.asDocument(), parameterValueSetters); } else if (value.isArray()) { @@ -304,8 +320,27 @@ private void checkParameterIndex(int parameterIndex) throws SQLException { } if (parameterIndex < 1 || parameterIndex > parameterValueSetters.size()) { throw new SQLException(format( - "Parameter index invalid: %d; should be within [1, %d]", + "Invalid parameter index [%d]; cannot be under 1 or over the current number of parameters [%d]", parameterIndex, parameterValueSetters.size())); } } + + private static final class ParameterValueSetter implements Consumer { + private final Consumer setter; + private boolean used; + + ParameterValueSetter(Consumer setter) { + this.setter = setter; + } + + @Override + public void accept(BsonValue bsonValue) { + used = true; + setter.accept(bsonValue); + } + + boolean isUsed() { + return used; + } + } } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoResultSet.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoResultSet.java new file mode 100644 index 00000000..44767c28 --- /dev/null +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoResultSet.java @@ -0,0 +1,267 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.jdbc; + +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import static com.mongodb.hibernate.internal.MongoAssertions.assertFalse; +import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; +import static java.lang.String.format; + +import com.mongodb.client.MongoCursor; +import com.mongodb.hibernate.internal.FeatureNotSupportedException; +import java.math.BigDecimal; +import java.sql.Array; +import java.sql.Date; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import org.bson.BsonDocument; +import org.bson.BsonValue; +import org.jspecify.annotations.Nullable; + +final class MongoResultSet implements ResultSetAdapter { + + private final MongoCursor mongoCursor; + + private final List fieldNames; + + private @Nullable BsonDocument currentDocument; + + private boolean lastReadColumnValueWasNull; + + private boolean closed; + + MongoResultSet(MongoCursor mongoCursor, List fieldNames) { + assertFalse(fieldNames.isEmpty()); + this.mongoCursor = mongoCursor; + this.fieldNames = fieldNames; + } + + @Override + public boolean next() throws SQLException { + checkClosed(); + if (mongoCursor.hasNext()) { + currentDocument = mongoCursor.next(); + return true; + } else { + return false; + } + } + + @Override + public void close() throws SQLException { + if (!closed) { + closed = true; + try { + mongoCursor.close(); + } catch (RuntimeException e) { + throw new SQLException( + format("Failed to close %s", mongoCursor.getClass().getSimpleName()), e); + } + } + } + + @Override + public boolean isClosed() { + return closed; + } + + @Override + public boolean wasNull() throws SQLException { + checkClosed(); + return lastReadColumnValueWasNull; + } + + @Override + public @Nullable String getString(int columnIndex) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + return getValue(columnIndex, bsonValue -> bsonValue.asString().getValue()); + } + + @Override + public boolean getBoolean(int columnIndex) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + return getValue(columnIndex, bsonValue -> bsonValue.asBoolean().getValue(), false); + } + + @Override + public int getInt(int columnIndex) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + return getValue(columnIndex, bsonValue -> bsonValue.asInt32().intValue(), 0); + } + + @Override + public long getLong(int columnIndex) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + return getValue(columnIndex, bsonValue -> bsonValue.asInt64().longValue(), 0L); + } + + @Override + public double getDouble(int columnIndex) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + return getValue(columnIndex, bsonValue -> bsonValue.asDouble().getValue(), 0) + .doubleValue(); + } + + @Override + public byte @Nullable [] getBytes(int columnIndex) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + return getValue(columnIndex, bsonValue -> bsonValue.asBinary().getData()); + } + + @Override + public @Nullable Date getDate(int columnIndex) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + throw new FeatureNotSupportedException("TODO-HIBERNATE-42 https://jira.mongodb.org/browse/HIBERNATE-42"); + } + + @Override + public @Nullable Time getTime(int columnIndex) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + throw new FeatureNotSupportedException("TODO-HIBERNATE-42 https://jira.mongodb.org/browse/HIBERNATE-42"); + } + + @Override + public @Nullable Time getTime(int columnIndex, Calendar cal) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + throw new FeatureNotSupportedException("TODO-HIBERNATE-42 https://jira.mongodb.org/browse/HIBERNATE-42"); + } + + @Override + public @Nullable Timestamp getTimestamp(int columnIndex) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + throw new FeatureNotSupportedException("TODO-HIBERNATE-42 https://jira.mongodb.org/browse/HIBERNATE-42"); + } + + @Override + public @Nullable Timestamp getTimestamp(int columnIndex, Calendar cal) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + throw new FeatureNotSupportedException("TODO-HIBERNATE-42 https://jira.mongodb.org/browse/HIBERNATE-42"); + } + + @Override + public @Nullable BigDecimal getBigDecimal(int columnIndex) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + return getValue( + columnIndex, + bsonValue -> bsonValue.asDecimal128().decimal128Value().bigDecimalValue()); + } + + @Override + public @Nullable Array getArray(int columnIndex) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + throw new FeatureNotSupportedException(); + } + + @Override + public @Nullable T getObject(int columnIndex, Class type) throws SQLException { + checkClosed(); + checkColumnIndex(columnIndex); + throw new FeatureNotSupportedException("To be implemented in scope of Array / Struct tickets"); + } + + @Override + public ResultSetMetaData getMetaData() throws SQLException { + checkClosed(); + return new MongoResultSetMetadata(); + } + + @Override + public int findColumn(String columnLabel) throws SQLException { + checkClosed(); + throw new FeatureNotSupportedException("To be implemented in scope of native query tickets"); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + checkClosed(); + return false; + } + + private String getKey(int columnIndex) { + return fieldNames.get(columnIndex - 1); + } + + private void checkClosed() throws SQLException { + if (closed) { + throw new SQLException(format("%s has been closed", getClass().getSimpleName())); + } + } + + private T getValue(int columnIndex, Function toJavaConverter, T defaultValue) + throws SQLException { + return Objects.requireNonNullElse(getValue(columnIndex, toJavaConverter), defaultValue); + } + + private @Nullable T getValue(int columnIndex, Function toJavaConverter) throws SQLException { + try { + var key = getKey(columnIndex); + var bsonValue = assertNotNull(currentDocument).get(key); + if (bsonValue == null) { + throw new RuntimeException(format("The BSON document field with the name [%s] is missing", key)); + } + T value = bsonValue.isNull() ? null : toJavaConverter.apply(bsonValue); + lastReadColumnValueWasNull = value == null; + return value; + } catch (RuntimeException e) { + throw new SQLException(format("Failed to get value from column [index: %d]", columnIndex), e); + } + } + + private void checkColumnIndex(int columnIndex) throws SQLException { + if (columnIndex < 1 || columnIndex > fieldNames.size()) { + throw new SQLException(format( + "Invalid column index [%d]; cannot be under 1 or over the current number of fields [%d]", + columnIndex, fieldNames.size())); + } + } + + private static final class MongoResultSetMetadata implements ResultSetMetaDataAdapter {} +} diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index b88b983d..26d07381 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -16,17 +16,25 @@ package com.mongodb.hibernate.jdbc; +import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; +import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE; import static java.lang.String.format; +import static java.util.stream.Collectors.toCollection; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoDatabase; import com.mongodb.hibernate.internal.FeatureNotSupportedException; +import com.mongodb.hibernate.internal.VisibleForTesting; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLSyntaxErrorException; import java.sql.SQLWarning; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import org.bson.BsonDocument; +import org.bson.BsonValue; import org.jspecify.annotations.Nullable; class MongoStatement implements StatementAdapter { @@ -35,6 +43,7 @@ class MongoStatement implements StatementAdapter { private final MongoConnection mongoConnection; private final ClientSession clientSession; + private @Nullable ResultSet resultSet; private boolean closed; MongoStatement(MongoDatabase mongoDatabase, ClientSession clientSession, MongoConnection mongoConnection) { @@ -46,54 +55,80 @@ class MongoStatement implements StatementAdapter { @Override public ResultSet executeQuery(String mql) throws SQLException { checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-21 https://jira.mongodb.org/browse/HIBERNATE-21"); + closeLastOpenResultSet(); + var command = parse(mql); + return executeQueryCommand(command); } - @Override - public int executeUpdate(String mql) throws SQLException { - checkClosed(); - var command = parse(mql); - return executeUpdateCommand(command); + void closeLastOpenResultSet() throws SQLException { + if (resultSet != null && !resultSet.isClosed()) { + resultSet.close(); + } } - int executeUpdateCommand(BsonDocument command) throws SQLException { - startTransactionIfNeeded(); + ResultSet executeQueryCommand(BsonDocument command) throws SQLException { try { - return mongoDatabase.runCommand(clientSession, command).getInteger("n"); - } catch (Exception e) { - throw new SQLException("Failed to execute update command", e); + startTransactionIfNeeded(); + + var collectionName = command.getString("aggregate").getValue(); + var collection = mongoDatabase.getCollection(collectionName, BsonDocument.class); + + var pipeline = command.getArray("pipeline").stream() + .map(BsonValue::asDocument) + .toList(); + var fieldNames = getFieldNamesFromProjectStage( + pipeline.get(pipeline.size() - 1).getDocument("$project")); + + return resultSet = new MongoResultSet( + collection.aggregate(clientSession, pipeline).cursor(), fieldNames); + } catch (RuntimeException e) { + throw new SQLException("Failed to execute query", e); } } - @Override - public void close() { - if (!closed) { - closed = true; + @VisibleForTesting(otherwise = PRIVATE) + static List getFieldNamesFromProjectStage(BsonDocument projectStage) { + var fieldNames = projectStage.entrySet().stream() + .filter(field -> trueOrOne(field.getValue())) + .map(Map.Entry::getKey) + .collect(toCollection(ArrayList::new)); + if (!projectStage.containsKey(ID_FIELD_NAME)) { + // MongoDB includes this field unless it is explicitly excluded + fieldNames.add(ID_FIELD_NAME); } + return fieldNames; } - @Override - public int getMaxRows() throws SQLException { - checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-21 https://jira.mongodb.org/browse/HIBERNATE-21"); + private static boolean trueOrOne(BsonValue value) { + return (value.isBoolean() && value.asBoolean().getValue()) + || (value.isNumber() && value.asNumber().intValue() == 1); } @Override - public void setMaxRows(int max) throws SQLException { + public int executeUpdate(String mql) throws SQLException { checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-21 https://jira.mongodb.org/browse/HIBERNATE-21"); + closeLastOpenResultSet(); + var command = parse(mql); + return executeUpdateCommand(command); } - @Override - public int getQueryTimeout() throws SQLException { - checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-21 https://jira.mongodb.org/browse/HIBERNATE-21"); + int executeUpdateCommand(BsonDocument command) throws SQLException { + try { + startTransactionIfNeeded(); + return mongoDatabase.runCommand(clientSession, command).getInteger("n"); + } catch (RuntimeException e) { + throw new SQLException("Failed to execute update command", e); + } } @Override - public void setQueryTimeout(int seconds) throws SQLException { - checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-21 https://jira.mongodb.org/browse/HIBERNATE-21"); + public void close() throws SQLException { + if (!closed) { + closed = true; + if (resultSet != null) { + resultSet.close(); + } + } } @Override @@ -108,17 +143,15 @@ public void cancel() throws SQLException { return null; } - /** Only used in {@link org.hibernate.engine.jdbc.spi.SqlExceptionHelper}. */ @Override public void clearWarnings() throws SQLException { checkClosed(); } - // ----------------------- Multiple Results -------------------------- - @Override public boolean execute(String mql) throws SQLException { checkClosed(); + closeLastOpenResultSet(); throw new FeatureNotSupportedException("To be implemented in scope of index and unique constraint creation"); } @@ -140,18 +173,6 @@ public int getUpdateCount() throws SQLException { throw new FeatureNotSupportedException("To be implemented in scope of index and unique constraint creation"); } - @Override - public void setFetchSize(int rows) throws SQLException { - checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-21 https://jira.mongodb.org/browse/HIBERNATE-21"); - } - - @Override - public int getFetchSize() throws SQLException { - checkClosed(); - throw new FeatureNotSupportedException("TODO-HIBERNATE-21 https://jira.mongodb.org/browse/HIBERNATE-21"); - } - @Override public void addBatch(String mql) throws SQLException { checkClosed(); @@ -167,6 +188,7 @@ public void clearBatch() throws SQLException { @Override public int[] executeBatch() throws SQLException { checkClosed(); + closeLastOpenResultSet(); throw new FeatureNotSupportedException("TODO-HIBERNATE-35 https://jira.mongodb.org/browse/HIBERNATE-35"); } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/ResultSetAdapter.java b/src/main/java/com/mongodb/hibernate/jdbc/ResultSetAdapter.java new file mode 100644 index 00000000..dcd7b912 --- /dev/null +++ b/src/main/java/com/mongodb/hibernate/jdbc/ResultSetAdapter.java @@ -0,0 +1,1002 @@ +/* + * Copyright 2024-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.jdbc; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.Ref; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Statement; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.Map; +import org.jspecify.annotations.Nullable; + +interface ResultSetAdapter extends ResultSet { + @Override + default boolean next() throws SQLException { + throw new SQLFeatureNotSupportedException("next not implemented"); + } + + @Override + default void close() throws SQLException { + throw new SQLFeatureNotSupportedException("close not implemented"); + } + + @Override + default boolean wasNull() throws SQLException { + throw new SQLFeatureNotSupportedException("wasNull not implemented"); + } + + @Override + default @Nullable String getString(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getString not implemented"); + } + + @Override + default boolean getBoolean(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getBoolean not implemented"); + } + + @Override + default byte getByte(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getByte not implemented"); + } + + @Override + default short getShort(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getShort not implemented"); + } + + @Override + default int getInt(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getInt not implemented"); + } + + @Override + default long getLong(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getLong not implemented"); + } + + @Override + default float getFloat(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getFloat not implemented"); + } + + @Override + default double getDouble(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getDouble not implemented"); + } + + @Override + @SuppressWarnings("deprecation") + default @Nullable BigDecimal getBigDecimal(int columnIndex, int scale) throws SQLException { + throw new SQLFeatureNotSupportedException("getBigDecimal not implemented"); + } + + @Override + default byte @Nullable [] getBytes(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getBytes not implemented"); + } + + @Override + default @Nullable Date getDate(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getDate not implemented"); + } + + @Override + default @Nullable Time getTime(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getTime not implemented"); + } + + @Override + default @Nullable Timestamp getTimestamp(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getTimestamp not implemented"); + } + + @Override + default @Nullable InputStream getAsciiStream(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getAsciiStream not implemented"); + } + + @Override + @SuppressWarnings("deprecation") + default @Nullable InputStream getUnicodeStream(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getUnicodeStream not implemented"); + } + + @Override + default @Nullable InputStream getBinaryStream(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getBinaryStream not implemented"); + } + + @Override + default @Nullable String getString(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getString not implemented"); + } + + @Override + default boolean getBoolean(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getBoolean not implemented"); + } + + @Override + default byte getByte(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getByte not implemented"); + } + + @Override + default short getShort(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getShort not implemented"); + } + + @Override + default int getInt(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getInt not implemented"); + } + + @Override + default long getLong(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getLong not implemented"); + } + + @Override + default float getFloat(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getFloat not implemented"); + } + + @Override + default double getDouble(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getDouble not implemented"); + } + + @Override + @SuppressWarnings("deprecation") + default BigDecimal getBigDecimal(String columnLabel, int scale) throws SQLException { + throw new SQLFeatureNotSupportedException("getBigDecimal not implemented"); + } + + @Override + default byte[] getBytes(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getBytes not implemented"); + } + + @Override + default Date getDate(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getDate not implemented"); + } + + @Override + default Time getTime(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getTime not implemented"); + } + + @Override + default Timestamp getTimestamp(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getTimestamp not implemented"); + } + + @Override + default InputStream getAsciiStream(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getAsciiStream not implemented"); + } + + @Override + @SuppressWarnings("deprecation") + default InputStream getUnicodeStream(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getUnicodeStream not implemented"); + } + + @Override + default InputStream getBinaryStream(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getBinaryStream not implemented"); + } + + @Override + default SQLWarning getWarnings() throws SQLException { + throw new SQLFeatureNotSupportedException("getWarnings not implemented"); + } + + @Override + default void clearWarnings() throws SQLException { + throw new SQLFeatureNotSupportedException("clearWarnings not implemented"); + } + + @Override + default String getCursorName() throws SQLException { + throw new SQLFeatureNotSupportedException("getCursorName not implemented"); + } + + @Override + default ResultSetMetaData getMetaData() throws SQLException { + throw new SQLFeatureNotSupportedException("getMetaData not implemented"); + } + + @Override + default Object getObject(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getObject not implemented"); + } + + @Override + default Object getObject(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getObject not implemented"); + } + + @Override + default int findColumn(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("findColumn not implemented"); + } + + @Override + default Reader getCharacterStream(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getCharacterStream not implemented"); + } + + @Override + default Reader getCharacterStream(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getCharacterStream not implemented"); + } + + @Override + default @Nullable BigDecimal getBigDecimal(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getBigDecimal not implemented"); + } + + @Override + default BigDecimal getBigDecimal(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getBigDecimal not implemented"); + } + + @Override + default boolean isBeforeFirst() throws SQLException { + throw new SQLFeatureNotSupportedException("isBeforeFirst not implemented"); + } + + @Override + default boolean isAfterLast() throws SQLException { + throw new SQLFeatureNotSupportedException("isAfterLast not implemented"); + } + + @Override + default boolean isFirst() throws SQLException { + throw new SQLFeatureNotSupportedException("isFirst not implemented"); + } + + @Override + default boolean isLast() throws SQLException { + throw new SQLFeatureNotSupportedException("isLast not implemented"); + } + + @Override + default void beforeFirst() throws SQLException { + throw new SQLFeatureNotSupportedException("beforeFirst not implemented"); + } + + @Override + default void afterLast() throws SQLException { + throw new SQLFeatureNotSupportedException("afterLast not implemented"); + } + + @Override + default boolean first() throws SQLException { + throw new SQLFeatureNotSupportedException("first not implemented"); + } + + @Override + default boolean last() throws SQLException { + throw new SQLFeatureNotSupportedException("last not implemented"); + } + + @Override + default int getRow() throws SQLException { + throw new SQLFeatureNotSupportedException("getRow not implemented"); + } + + @Override + default boolean absolute(int row) throws SQLException { + throw new SQLFeatureNotSupportedException("absolute not implemented"); + } + + @Override + default boolean relative(int rows) throws SQLException { + throw new SQLFeatureNotSupportedException("relative not implemented"); + } + + @Override + default boolean previous() throws SQLException { + throw new SQLFeatureNotSupportedException("previous not implemented"); + } + + @Override + default void setFetchDirection(int direction) throws SQLException { + throw new SQLFeatureNotSupportedException("setFetchDirection not implemented"); + } + + @Override + default int getFetchDirection() throws SQLException { + throw new SQLFeatureNotSupportedException("getFetchDirection not implemented"); + } + + @Override + default void setFetchSize(int rows) throws SQLException { + throw new SQLFeatureNotSupportedException("setFetchSize not implemented"); + } + + @Override + default int getFetchSize() throws SQLException { + throw new SQLFeatureNotSupportedException("getFetchSize not implemented"); + } + + @Override + default int getType() throws SQLException { + throw new SQLFeatureNotSupportedException("getType not implemented"); + } + + @Override + default int getConcurrency() throws SQLException { + throw new SQLFeatureNotSupportedException("getConcurrency not implemented"); + } + + @Override + default boolean rowUpdated() throws SQLException { + throw new SQLFeatureNotSupportedException("rowUpdated not implemented"); + } + + @Override + default boolean rowInserted() throws SQLException { + throw new SQLFeatureNotSupportedException("rowInserted not implemented"); + } + + @Override + default boolean rowDeleted() throws SQLException { + throw new SQLFeatureNotSupportedException("rowDeleted not implemented"); + } + + @Override + default void updateNull(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNull not implemented"); + } + + @Override + default void updateBoolean(int columnIndex, boolean x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBoolean not implemented"); + } + + @Override + default void updateByte(int columnIndex, byte x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateByte not implemented"); + } + + @Override + default void updateShort(int columnIndex, short x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateShort not implemented"); + } + + @Override + default void updateInt(int columnIndex, int x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateInt not implemented"); + } + + @Override + default void updateLong(int columnIndex, long x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateLong not implemented"); + } + + @Override + default void updateFloat(int columnIndex, float x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateFloat not implemented"); + } + + @Override + default void updateDouble(int columnIndex, double x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateDouble not implemented"); + } + + @Override + default void updateBigDecimal(int columnIndex, BigDecimal x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBigDecimal not implemented"); + } + + @Override + default void updateString(int columnIndex, String x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateString not implemented"); + } + + @Override + default void updateBytes(int columnIndex, byte[] x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBytes not implemented"); + } + + @Override + default void updateDate(int columnIndex, Date x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateDate not implemented"); + } + + @Override + default void updateTime(int columnIndex, Time x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateTime not implemented"); + } + + @Override + default void updateTimestamp(int columnIndex, Timestamp x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateTimestamp not implemented"); + } + + @Override + default void updateAsciiStream(int columnIndex, InputStream x, int length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateAsciiStream not implemented"); + } + + @Override + default void updateBinaryStream(int columnIndex, InputStream x, int length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBinaryStream not implemented"); + } + + @Override + default void updateCharacterStream(int columnIndex, Reader x, int length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateCharacterStream not implemented"); + } + + @Override + default void updateObject(int columnIndex, Object x, int scaleOrLength) throws SQLException { + throw new SQLFeatureNotSupportedException("updateObject not implemented"); + } + + @Override + default void updateObject(int columnIndex, Object x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateObject not implemented"); + } + + @Override + default void updateNull(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNull not implemented"); + } + + @Override + default void updateBoolean(String columnLabel, boolean x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBoolean not implemented"); + } + + @Override + default void updateByte(String columnLabel, byte x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateByte not implemented"); + } + + @Override + default void updateShort(String columnLabel, short x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateShort not implemented"); + } + + @Override + default void updateInt(String columnLabel, int x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateInt not implemented"); + } + + @Override + default void updateLong(String columnLabel, long x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateLong not implemented"); + } + + @Override + default void updateFloat(String columnLabel, float x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateFloat not implemented"); + } + + @Override + default void updateDouble(String columnLabel, double x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateDouble not implemented"); + } + + @Override + default void updateBigDecimal(String columnLabel, BigDecimal x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBigDecimal not implemented"); + } + + @Override + default void updateString(String columnLabel, String x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateString not implemented"); + } + + @Override + default void updateBytes(String columnLabel, byte[] x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBytes not implemented"); + } + + @Override + default void updateDate(String columnLabel, Date x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateDate not implemented"); + } + + @Override + default void updateTime(String columnLabel, Time x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateTime not implemented"); + } + + @Override + default void updateTimestamp(String columnLabel, Timestamp x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateTimestamp not implemented"); + } + + @Override + default void updateAsciiStream(String columnLabel, InputStream x, int length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateAsciiStream not implemented"); + } + + @Override + default void updateBinaryStream(String columnLabel, InputStream x, int length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBinaryStream not implemented"); + } + + @Override + default void updateCharacterStream(String columnLabel, Reader reader, int length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateCharacterStream not implemented"); + } + + @Override + default void updateObject(String columnLabel, Object x, int scaleOrLength) throws SQLException { + throw new SQLFeatureNotSupportedException("updateObject not implemented"); + } + + @Override + default void updateObject(String columnLabel, Object x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateObject not implemented"); + } + + @Override + default void insertRow() throws SQLException { + throw new SQLFeatureNotSupportedException("insertRow not implemented"); + } + + @Override + default void updateRow() throws SQLException { + throw new SQLFeatureNotSupportedException("updateRow not implemented"); + } + + @Override + default void deleteRow() throws SQLException { + throw new SQLFeatureNotSupportedException("deleteRow not implemented"); + } + + @Override + default void refreshRow() throws SQLException { + throw new SQLFeatureNotSupportedException("refreshRow not implemented"); + } + + @Override + default void cancelRowUpdates() throws SQLException { + throw new SQLFeatureNotSupportedException("cancelRowUpdates not implemented"); + } + + @Override + default void moveToInsertRow() throws SQLException { + throw new SQLFeatureNotSupportedException("moveToInsertRow not implemented"); + } + + @Override + default void moveToCurrentRow() throws SQLException { + throw new SQLFeatureNotSupportedException("moveToCurrentRow not implemented"); + } + + @Override + default Statement getStatement() throws SQLException { + throw new SQLFeatureNotSupportedException("getStatement not implemented"); + } + + @Override + default Object getObject(int columnIndex, Map> map) throws SQLException { + throw new SQLFeatureNotSupportedException("getObject not implemented"); + } + + @Override + default Ref getRef(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getRef not implemented"); + } + + @Override + default Blob getBlob(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getBlob not implemented"); + } + + @Override + default Clob getClob(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getClob not implemented"); + } + + @Override + default @Nullable Array getArray(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getArray not implemented"); + } + + @Override + default Object getObject(String columnLabel, Map> map) throws SQLException { + throw new SQLFeatureNotSupportedException("getObject not implemented"); + } + + @Override + default Ref getRef(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getRef not implemented"); + } + + @Override + default Blob getBlob(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getBlob not implemented"); + } + + @Override + default Clob getClob(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getClob not implemented"); + } + + @Override + default Array getArray(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getArray not implemented"); + } + + @Override + default Date getDate(int columnIndex, Calendar cal) throws SQLException { + throw new SQLFeatureNotSupportedException("getDate not implemented"); + } + + @Override + default Date getDate(String columnLabel, Calendar cal) throws SQLException { + throw new SQLFeatureNotSupportedException("getDate not implemented"); + } + + @Override + default @Nullable Time getTime(int columnIndex, Calendar cal) throws SQLException { + throw new SQLFeatureNotSupportedException("getTime not implemented"); + } + + @Override + default @Nullable Time getTime(String columnLabel, Calendar cal) throws SQLException { + throw new SQLFeatureNotSupportedException("getTime not implemented"); + } + + @Override + default @Nullable Timestamp getTimestamp(int columnIndex, Calendar cal) throws SQLException { + throw new SQLFeatureNotSupportedException("getTimestamp not implemented"); + } + + @Override + default @Nullable Timestamp getTimestamp(String columnLabel, Calendar cal) throws SQLException { + throw new SQLFeatureNotSupportedException("getTimestamp not implemented"); + } + + @Override + default URL getURL(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getURL not implemented"); + } + + @Override + default URL getURL(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getURL not implemented"); + } + + @Override + default void updateRef(int columnIndex, Ref x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateRef not implemented"); + } + + @Override + default void updateRef(String columnLabel, Ref x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateRef not implemented"); + } + + @Override + default void updateBlob(int columnIndex, Blob x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBlob not implemented"); + } + + @Override + default void updateBlob(String columnLabel, Blob x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBlob not implemented"); + } + + @Override + default void updateClob(int columnIndex, Clob x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateClob not implemented"); + } + + @Override + default void updateClob(String columnLabel, Clob x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateClob not implemented"); + } + + @Override + default void updateArray(int columnIndex, Array x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateArray not implemented"); + } + + @Override + default void updateArray(String columnLabel, Array x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateArray not implemented"); + } + + @Override + default RowId getRowId(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getRowId not implemented"); + } + + @Override + default RowId getRowId(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getRowId not implemented"); + } + + @Override + default void updateRowId(int columnIndex, RowId x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateRowId not implemented"); + } + + @Override + default void updateRowId(String columnLabel, RowId x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateRowId not implemented"); + } + + @Override + default int getHoldability() throws SQLException { + throw new SQLFeatureNotSupportedException("getHoldability not implemented"); + } + + @Override + default boolean isClosed() throws SQLException { + throw new SQLFeatureNotSupportedException("isClosed not implemented"); + } + + @Override + default void updateNString(int columnIndex, String nString) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNString not implemented"); + } + + @Override + default void updateNString(String columnLabel, String nString) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNString not implemented"); + } + + @Override + default void updateNClob(int columnIndex, NClob nClob) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNClob not implemented"); + } + + @Override + default void updateNClob(String columnLabel, NClob nClob) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNClob not implemented"); + } + + @Override + default NClob getNClob(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getNClob not implemented"); + } + + @Override + default NClob getNClob(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getNClob not implemented"); + } + + @Override + default SQLXML getSQLXML(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getSQLXML not implemented"); + } + + @Override + default SQLXML getSQLXML(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getSQLXML not implemented"); + } + + @Override + default void updateSQLXML(int columnIndex, SQLXML xmlObject) throws SQLException { + throw new SQLFeatureNotSupportedException("updateSQLXML not implemented"); + } + + @Override + default void updateSQLXML(String columnLabel, SQLXML xmlObject) throws SQLException { + throw new SQLFeatureNotSupportedException("updateSQLXML not implemented"); + } + + @Override + default String getNString(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getNString not implemented"); + } + + @Override + default String getNString(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getNString not implemented"); + } + + @Override + default Reader getNCharacterStream(int columnIndex) throws SQLException { + throw new SQLFeatureNotSupportedException("getNCharacterStream not implemented"); + } + + @Override + default Reader getNCharacterStream(String columnLabel) throws SQLException { + throw new SQLFeatureNotSupportedException("getNCharacterStream not implemented"); + } + + @Override + default void updateNCharacterStream(int columnIndex, Reader x, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNCharacterStream not implemented"); + } + + @Override + default void updateNCharacterStream(String columnLabel, Reader reader, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNCharacterStream not implemented"); + } + + @Override + default void updateAsciiStream(int columnIndex, InputStream x, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateAsciiStream not implemented"); + } + + @Override + default void updateBinaryStream(int columnIndex, InputStream x, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBinaryStream not implemented"); + } + + @Override + default void updateCharacterStream(int columnIndex, Reader x, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateCharacterStream not implemented"); + } + + @Override + default void updateAsciiStream(String columnLabel, InputStream x, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateAsciiStream not implemented"); + } + + @Override + default void updateBinaryStream(String columnLabel, InputStream x, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBinaryStream not implemented"); + } + + @Override + default void updateCharacterStream(String columnLabel, Reader reader, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateCharacterStream not implemented"); + } + + @Override + default void updateBlob(int columnIndex, InputStream inputStream, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBlob not implemented"); + } + + @Override + default void updateBlob(String columnLabel, InputStream inputStream, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBlob not implemented"); + } + + @Override + default void updateClob(int columnIndex, Reader reader, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateClob not implemented"); + } + + @Override + default void updateClob(String columnLabel, Reader reader, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateClob not implemented"); + } + + @Override + default void updateNClob(int columnIndex, Reader reader, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNClob not implemented"); + } + + @Override + default void updateNClob(String columnLabel, Reader reader, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNClob not implemented"); + } + + @Override + default void updateNCharacterStream(int columnIndex, Reader x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNCharacterStream not implemented"); + } + + @Override + default void updateNCharacterStream(String columnLabel, Reader reader) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNCharacterStream not implemented"); + } + + @Override + default void updateAsciiStream(int columnIndex, InputStream x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateAsciiStream not implemented"); + } + + @Override + default void updateBinaryStream(int columnIndex, InputStream x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBinaryStream not implemented"); + } + + @Override + default void updateCharacterStream(int columnIndex, Reader x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateCharacterStream not implemented"); + } + + @Override + default void updateAsciiStream(String columnLabel, InputStream x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateAsciiStream not implemented"); + } + + @Override + default void updateBinaryStream(String columnLabel, InputStream x) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBinaryStream not implemented"); + } + + @Override + default void updateCharacterStream(String columnLabel, Reader reader) throws SQLException { + throw new SQLFeatureNotSupportedException("updateCharacterStream not implemented"); + } + + @Override + default void updateBlob(int columnIndex, InputStream inputStream) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBlob not implemented"); + } + + @Override + default void updateBlob(String columnLabel, InputStream inputStream) throws SQLException { + throw new SQLFeatureNotSupportedException("updateBlob not implemented"); + } + + @Override + default void updateClob(int columnIndex, Reader reader) throws SQLException { + throw new SQLFeatureNotSupportedException("updateClob not implemented"); + } + + @Override + default void updateClob(String columnLabel, Reader reader) throws SQLException { + throw new SQLFeatureNotSupportedException("updateClob not implemented"); + } + + @Override + default void updateNClob(int columnIndex, Reader reader) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNClob not implemented"); + } + + @Override + default void updateNClob(String columnLabel, Reader reader) throws SQLException { + throw new SQLFeatureNotSupportedException("updateNClob not implemented"); + } + + @Override + default @Nullable T getObject(int columnIndex, Class type) throws SQLException { + throw new SQLFeatureNotSupportedException("getObject not implemented"); + } + + @Override + default T getObject(String columnLabel, Class type) throws SQLException { + throw new SQLFeatureNotSupportedException("getObject not implemented"); + } + + @Override + default T unwrap(Class iface) throws SQLException { + throw new SQLFeatureNotSupportedException("unwrap not implemented"); + } + + @Override + default boolean isWrapperFor(Class iface) throws SQLException { + throw new SQLFeatureNotSupportedException("isWrapperFor not implemented"); + } +} diff --git a/src/main/java/com/mongodb/hibernate/jdbc/ResultSetMetaDataAdapter.java b/src/main/java/com/mongodb/hibernate/jdbc/ResultSetMetaDataAdapter.java new file mode 100644 index 00000000..5f434069 --- /dev/null +++ b/src/main/java/com/mongodb/hibernate/jdbc/ResultSetMetaDataAdapter.java @@ -0,0 +1,138 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.jdbc; + +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; + +interface ResultSetMetaDataAdapter extends ResultSetMetaData { + @Override + default int getColumnCount() throws SQLException { + throw new SQLFeatureNotSupportedException("getColumnCount not implemented"); + } + + @Override + default boolean isAutoIncrement(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("isAutoIncrement not implemented"); + } + + @Override + default boolean isCaseSensitive(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("isCaseSensitive not implemented"); + } + + @Override + default boolean isSearchable(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("isSearchable not implemented"); + } + + @Override + default boolean isCurrency(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("isCurrency not implemented"); + } + + @Override + default int isNullable(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("isNullable not implemented"); + } + + @Override + default boolean isSigned(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("isSigned not implemented"); + } + + @Override + default int getColumnDisplaySize(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("getColumnDisplaySize not implemented"); + } + + @Override + default String getColumnLabel(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("getColumnLabel not implemented"); + } + + @Override + default String getColumnName(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("getColumnName not implemented"); + } + + @Override + default String getSchemaName(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("getSchemaName not implemented"); + } + + @Override + default int getPrecision(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("getPrecision not implemented"); + } + + @Override + default int getScale(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("getScale not implemented"); + } + + @Override + default String getTableName(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("getTableName not implemented"); + } + + @Override + default String getCatalogName(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("getCatalogName not implemented"); + } + + @Override + default int getColumnType(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("getColumnType not implemented"); + } + + @Override + default String getColumnTypeName(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("getColumnTypeName not implemented"); + } + + @Override + default boolean isReadOnly(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("isReadOnly not implemented"); + } + + @Override + default boolean isWritable(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("isWritable not implemented"); + } + + @Override + default boolean isDefinitelyWritable(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("isDefinitelyWritable not implemented"); + } + + @Override + default String getColumnClassName(int column) throws SQLException { + throw new SQLFeatureNotSupportedException("getColumnClassName not implemented"); + } + + @Override + default T unwrap(Class iface) throws SQLException { + throw new SQLFeatureNotSupportedException("unwrap not implemented"); + } + + @Override + default boolean isWrapperFor(Class iface) throws SQLException { + throw new SQLFeatureNotSupportedException("isWrapperFor not implemented"); + } +} diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoConnectionTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoConnectionTests.java index c45c3a6f..bed04c3e 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoConnectionTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoConnectionTests.java @@ -18,7 +18,11 @@ import static com.mongodb.hibernate.jdbc.MongoDatabaseMetaData.MONGO_DATABASE_PRODUCT_NAME; import static com.mongodb.hibernate.jdbc.MongoDatabaseMetaData.MONGO_JDBC_DRIVER_NAME; -import static org.assertj.core.api.Assertions.assertThat; +import static java.sql.ResultSet.CONCUR_READ_ONLY; +import static java.sql.ResultSet.CONCUR_UPDATABLE; +import static java.sql.ResultSet.TYPE_FORWARD_ONLY; +import static java.sql.ResultSet.TYPE_SCROLL_INSENSITIVE; +import static java.sql.ResultSet.TYPE_SCROLL_SENSITIVE; import static org.hibernate.cfg.AvailableSettings.JAKARTA_JDBC_URL; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -40,19 +44,16 @@ import com.mongodb.client.MongoDatabase; import com.mongodb.hibernate.internal.cfg.MongoConfigurationBuilder; import java.sql.Connection; -import java.sql.ResultSet; import java.sql.SQLException; import java.util.Map; -import java.util.stream.Stream; import org.bson.Document; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.function.Executable; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; @@ -102,72 +103,57 @@ void testClosedWhenSessionClosingThrowsException() { } } - @Nested - class ClosedTests { - - interface ConnectionMethodInvocation { - void runOn(MongoConnection conn) throws SQLException; - } - - @ParameterizedTest(name = "SQLException is thrown when \"{0}\" is called on a closed MongoConnection") - @MethodSource("getMongoConnectionMethodInvocationsImpactedByClosing") - void testCheckClosed(String label, ConnectionMethodInvocation methodInvocation) throws SQLException { - - mongoConnection.close(); - - var sqlException = assertThrows(SQLException.class, () -> methodInvocation.runOn(mongoConnection)); - assertThat(sqlException.getMessage()).matches("MongoConnection has been closed"); - } + @Test + void testCheckClosed() throws SQLException { + mongoConnection.close(); + checkMethodsWithOpenPrecondition(); + } - private static Stream getMongoConnectionMethodInvocationsImpactedByClosing() { - var exampleQueryMql = - """ - { - find: "restaurants", - filter: { rating: { $gte: 9 }, cuisine: "italian" }, - projection: { name: 1, rating: 1, address: 1 }, - sort: { name: 1 }, - limit: 5 - }"""; - var exampleUpdateMql = - """ + private void checkMethodsWithOpenPrecondition() { + var exampleQueryMql = + """ + { + find: "restaurants", + filter: { rating: { $gte: 9 }, cuisine: "italian" }, + projection: { name: 1, rating: 1, address: 1 }, + sort: { name: 1 }, + limit: 5 + }"""; + var exampleUpdateMql = + """ + { + update: "members", + updates: [ { - update: "members", - updates: [ - { - q: {}, - u: { $inc: { points: 1 } }, - multi: true - } - ] - }"""; - return Map.ofEntries( - Map.entry("setAutoCommit(boolean)", conn -> conn.setAutoCommit(false)), - Map.entry("getAutoCommit()", MongoConnection::getAutoCommit), - Map.entry("commit()", MongoConnection::commit), - Map.entry("rollback()", MongoConnection::rollback), - Map.entry("createStatement()", MongoConnection::createStatement), - Map.entry("prepareStatement(String)", conn -> conn.prepareStatement(exampleUpdateMql)), - Map.entry( - "prepareStatement(String,int,int)", - conn -> conn.prepareStatement( - exampleQueryMql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)), - Map.entry( - "createArrayOf(String,Object[])", - conn -> conn.createArrayOf("myArrayType", new String[] {"value1", "value2"})), - Map.entry( - "createStruct(String,Object[])", - conn -> conn.createStruct("myStructType", new Object[] {1, "Toronto"})), - Map.entry("getMetaData()", MongoConnection::getMetaData), - Map.entry("getCatalog()", MongoConnection::getCatalog), - Map.entry("getSchema()", MongoConnection::getSchema), - Map.entry("getWarnings()", MongoConnection::getWarnings), - Map.entry("clearWarnings()", MongoConnection::clearWarnings), - Map.entry("isWrapperFor()", conn -> conn.isWrapperFor(Connection.class))) - .entrySet() - .stream() - .map(entry -> Arguments.of(entry.getKey(), entry.getValue())); - } + q: {}, + u: { $inc: { points: 1 } }, + multi: true + } + ] + }"""; + assertAll( + () -> assertThrowsClosedException(() -> mongoConnection.setAutoCommit(false)), + () -> assertThrowsClosedException(mongoConnection::getAutoCommit), + () -> assertThrowsClosedException(mongoConnection::rollback), + () -> assertThrowsClosedException(mongoConnection::createStatement), + () -> assertThrowsClosedException(() -> mongoConnection.prepareStatement(exampleUpdateMql)), + () -> assertThrowsClosedException( + () -> mongoConnection.prepareStatement(exampleQueryMql, TYPE_FORWARD_ONLY, CONCUR_READ_ONLY)), + () -> assertThrowsClosedException( + () -> mongoConnection.createArrayOf("myArrayType", new String[] {"value1", "value2"})), + () -> assertThrowsClosedException( + () -> mongoConnection.createStruct("myStructType", new Object[] {1, "Toronto"})), + () -> assertThrowsClosedException(mongoConnection::getMetaData), + () -> assertThrowsClosedException(mongoConnection::getCatalog), + () -> assertThrowsClosedException(mongoConnection::getSchema), + () -> assertThrowsClosedException(mongoConnection::getWarnings), + () -> assertThrowsClosedException(mongoConnection::clearWarnings), + () -> assertThrowsClosedException(() -> mongoConnection.isWrapperFor(Connection.class))); + } + + private static void assertThrowsClosedException(Executable executable) { + var exception = assertThrows(SQLException.class, executable); + assertEquals("MongoConnection has been closed", exception.getMessage()); } @Nested @@ -356,4 +342,36 @@ void testSQLExceptionThrownWhenMetaDataFetchingFailed() { assertThrows(SQLException.class, () -> mongoConnection.getMetaData()); } } + + @Nested + class ResultSetSupportTests { + + private static final String EXAMPLE_MQL = "{}"; + + @ParameterizedTest(name = "ResultSet type: {0}") + @ValueSource(ints = {TYPE_FORWARD_ONLY, TYPE_SCROLL_SENSITIVE, TYPE_SCROLL_INSENSITIVE}) + void testType(int resultSetType) { + Executable executable = () -> mongoConnection + .prepareStatement(EXAMPLE_MQL, resultSetType, CONCUR_READ_ONLY) + .close(); + if (resultSetType == TYPE_FORWARD_ONLY) { + assertDoesNotThrow(executable); + } else { + assertThrows(SQLException.class, executable); + } + } + + @ParameterizedTest(name = "ResultSet concurrency: {0}") + @ValueSource(ints = {CONCUR_READ_ONLY, CONCUR_UPDATABLE}) + void testConcurrency(int resultSetConcurrency) { + Executable executable = () -> mongoConnection + .prepareStatement(EXAMPLE_MQL, TYPE_FORWARD_ONLY, resultSetConcurrency) + .close(); + if (resultSetConcurrency == CONCUR_READ_ONLY) { + assertDoesNotThrow(executable); + } else { + assertThrows(SQLException.class, executable); + } + } + } } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index b1964d5c..34c268ea 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -16,40 +16,44 @@ package com.mongodb.hibernate.jdbc; -import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.verify; +import com.mongodb.client.AggregateIterable; import com.mongodb.client.ClientSession; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoCursor; import com.mongodb.client.MongoDatabase; -import java.io.ByteArrayInputStream; import java.math.BigDecimal; import java.sql.Array; import java.sql.Date; +import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLSyntaxErrorException; import java.sql.Time; import java.sql.Timestamp; import java.sql.Types; import java.util.Calendar; -import java.util.Map; -import java.util.stream.Stream; +import java.util.function.Consumer; import org.bson.BsonDocument; import org.bson.Document; -import org.junit.jupiter.api.AutoClose; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.api.function.Executable; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; @@ -136,180 +140,142 @@ void testSuccess() throws SQLException { assertEquals(expectedDoc, command); } } - - @Test - @DisplayName("SQLException is thrown when parameter index is invalid") - void testParameterIndexInvalid() throws SQLSyntaxErrorException { - try (var preparedStatement = createMongoPreparedStatement(EXAMPLE_MQL)) { - var sqlException = - assertThrows(SQLException.class, () -> preparedStatement.setString(0, "War and Peace")); - assertEquals( - format("Parameter index invalid: %d; should be within [1, %d]", 0, 5), - sqlException.getMessage()); - } - } } - @Nested - class CloseTests { + @Test + void testParameterIndexUnderflow() throws SQLSyntaxErrorException { + var mongoPreparedStatement = createMongoPreparedStatement(EXAMPLE_MQL); + checkSetterMethods(mongoPreparedStatement, 0, MongoPreparedStatementTests::assertThrowsOutOfRangeException); + } - @ParameterizedTest(name = "SQLException is thrown when \"{0}\" is called on a closed MongoPreparedStatement") - @MethodSource("getMongoPreparedStatementMethodInvocationsImpactedByClosing") - void testCheckClosed(String label, PreparedStatementMethodInvocation methodInvocation) - throws SQLSyntaxErrorException { + @Test + void testParameterIndexOverflow() throws SQLSyntaxErrorException { + var mongoPreparedStatement = createMongoPreparedStatement(EXAMPLE_MQL); + checkSetterMethods(mongoPreparedStatement, 6, MongoPreparedStatementTests::assertThrowsOutOfRangeException); + } - var mql = - """ - { - insert: "books", - documents: [ - { - title: "War and Peace", - author: "Leo Tolstoy", - outOfStock: false, - values: [ - { $undefined: true } - ] - } - ] - } - """; + @Nested + class ExecuteMethodClosesLastOpenResultSetTests { - var preparedStatement = createMongoPreparedStatement(mql); - preparedStatement.close(); + @Mock + MongoCollection mongoCollection; - var sqlException = assertThrows(SQLException.class, () -> methodInvocation.runOn(preparedStatement)); - assertThat(sqlException.getMessage()).matches("MongoPreparedStatement has been closed"); - } + @Mock + AggregateIterable aggregateIterable; - private static Stream getMongoPreparedStatementMethodInvocationsImpactedByClosing() { - var now = System.currentTimeMillis(); - var calendar = Calendar.getInstance(); - return Map.ofEntries( - Map.entry("executeQuery()", MongoPreparedStatement::executeQuery), - Map.entry("executeUpdate()", MongoPreparedStatement::executeUpdate), - Map.entry("setNull(int,int)", pstmt -> pstmt.setNull(1, Types.INTEGER)), - Map.entry("setBoolean(int,boolean)", pstmt -> pstmt.setBoolean(1, true)), - Map.entry("setByte(int,byte)", pstmt -> pstmt.setByte(1, (byte) 10)), - Map.entry("setShort(int,short)", pstmt -> pstmt.setShort(1, (short) 10)), - Map.entry("setInt(int,int)", pstmt -> pstmt.setInt(1, 1)), - Map.entry("setLong(int,long)", pstmt -> pstmt.setLong(1, 1L)), - Map.entry("setFloat(int,float)", pstmt -> pstmt.setFloat(1, 1.0F)), - Map.entry("setDouble(int,double)", pstmt -> pstmt.setDouble(1, 1.0)), - Map.entry( - "setBigDecimal(int,BigDecimal)", - pstmt -> pstmt.setBigDecimal(1, new BigDecimal(1))), - Map.entry("setString(int,String)", pstmt -> pstmt.setString(1, "")), - Map.entry("setBytes(int,byte[])", pstmt -> pstmt.setBytes(1, "".getBytes())), - Map.entry("setDate(int,Date)", pstmt -> pstmt.setDate(1, new Date(now))), - Map.entry("setTime(int,Time)", pstmt -> pstmt.setTime(1, new Time(now))), - Map.entry( - "setTimestamp(int,Timestamp)", pstmt -> pstmt.setTimestamp(1, new Timestamp(now))), - Map.entry( - "setBinaryStream(int,InputStream,int)", - pstmt -> pstmt.setBinaryStream(1, new ByteArrayInputStream("".getBytes()), 0)), - Map.entry( - "setObject(int,Object,int)", - pstmt -> pstmt.setObject(1, Mockito.mock(Array.class), Types.OTHER)), - Map.entry("addBatch()", MongoPreparedStatement::addBatch), - Map.entry("setArray(int,Array)", pstmt -> pstmt.setArray(1, Mockito.mock(Array.class))), - Map.entry("setDate(int,Date,Calendar)", pstmt -> pstmt.setDate(1, new Date(now), calendar)), - Map.entry("setTime(int,Time,Calendar)", pstmt -> pstmt.setTime(1, new Time(now), calendar)), - Map.entry( - "setTimestamp(int,Timestamp,Calendar)", - pstmt -> pstmt.setTimestamp(1, new Timestamp(now), calendar)), - Map.entry("setNull(int,Object,String)", pstmt -> pstmt.setNull(1, Types.STRUCT, "BOOK"))) - .entrySet() - .stream() - .map(entry -> Arguments.of(entry.getKey(), entry.getValue())); - } - } + @Mock + MongoCursor mongoCursor; - @Nested - class ParameterIndexCheckingTests { + private ResultSet lastOpenResultSet; - @AutoClose - private MongoPreparedStatement preparedStatement; + private MongoPreparedStatement mongoPreparedStatement; @BeforeEach - void beforeEach() throws SQLSyntaxErrorException { - preparedStatement = createMongoPreparedStatement(EXAMPLE_MQL); + void beforeEach() throws SQLException { + String exampleQueryMql = + """ + { + aggregate: "books", + pipeline: [ + { $match: { _id: { $eq: 1 } } }, + { $project: { _id: 0, title: 1, publishYear: 1 } } + ] + }"""; + mongoPreparedStatement = createMongoPreparedStatement(exampleQueryMql); + doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class)); + doReturn(aggregateIterable).when(mongoCollection).aggregate(same(clientSession), anyList()); + doReturn(mongoCursor).when(aggregateIterable).cursor(); + + lastOpenResultSet = mongoPreparedStatement.executeQuery(); + assertFalse(lastOpenResultSet.isClosed()); } - @ParameterizedTest(name = "SQLException is thrown when \"{0}\" is called with parameter index being too low") - @MethodSource("getMongoPreparedStatementMethodInvocationsWithParameterIndexUnderflow") - void testParameterIndexUnderflow(String label, PreparedStatementMethodInvocation methodInvocation) { - var sqlException = assertThrows(SQLException.class, () -> methodInvocation.runOn(preparedStatement)); - assertThat(sqlException.getMessage()).startsWith("Parameter index invalid"); + @Test + void testExecuteQuery() throws SQLException { + mongoPreparedStatement.executeQuery(); + assertTrue(lastOpenResultSet.isClosed()); } - @ParameterizedTest(name = "SQLException is thrown when \"{0}\" is called with parameter index being too high") - @MethodSource("getMongoPreparedStatementMethodInvocationsWithParameterIndexOverflow") - void testParameterIndexOverflow(String label, PreparedStatementMethodInvocation methodInvocation) { - var sqlException = assertThrows(SQLException.class, () -> methodInvocation.runOn(preparedStatement)); - assertThat(sqlException.getMessage()).startsWith("Parameter index invalid"); + @Test + void testExecuteUpdate() throws SQLException { + doReturn(Document.parse("{n: 10}")) + .when(mongoDatabase) + .runCommand(eq(clientSession), any(BsonDocument.class)); + mongoPreparedStatement.executeUpdate(); + assertTrue(lastOpenResultSet.isClosed()); } + } - private static Stream getMongoPreparedStatementMethodInvocationsWithParameterIndexUnderflow() { - return doGetMongoPreparedStatementMethodInvocationsWithParameterIndex(0); - } + @Test + void testCheckClosed() throws SQLException { + var mql = + """ + { + insert: "books", + documents: [ + { + title: "War and Peace", + author: "Leo Tolstoy", + outOfStock: false, + values: [ + { $undefined: true } + ] + } + ] + } + """; + var mongoPreparedStatement = createMongoPreparedStatement(mql); + mongoPreparedStatement.close(); + checkMethodsWithOpenPrecondition( + mongoPreparedStatement, MongoPreparedStatementTests::assertThrowsClosedException); + } - private static Stream getMongoPreparedStatementMethodInvocationsWithParameterIndexOverflow() { - return doGetMongoPreparedStatementMethodInvocationsWithParameterIndex(6); - } + private static void checkSetterMethods( + MongoPreparedStatement mongoPreparedStatement, int parameterIndex, Consumer asserter) { + var now = System.currentTimeMillis(); + var calendar = Calendar.getInstance(); + assertAll( + () -> asserter.accept(() -> mongoPreparedStatement.setNull(parameterIndex, Types.INTEGER)), + () -> asserter.accept(() -> mongoPreparedStatement.setBoolean(parameterIndex, true)), + () -> asserter.accept(() -> mongoPreparedStatement.setInt(parameterIndex, 1)), + () -> asserter.accept(() -> mongoPreparedStatement.setLong(parameterIndex, 1L)), + () -> asserter.accept(() -> mongoPreparedStatement.setDouble(parameterIndex, 1.0)), + () -> asserter.accept(() -> mongoPreparedStatement.setBigDecimal(parameterIndex, new BigDecimal(1))), + () -> asserter.accept(() -> mongoPreparedStatement.setString(parameterIndex, "")), + () -> asserter.accept(() -> mongoPreparedStatement.setBytes(parameterIndex, "".getBytes())), + () -> asserter.accept(() -> mongoPreparedStatement.setDate(parameterIndex, new Date(now))), + () -> asserter.accept(() -> mongoPreparedStatement.setTime(parameterIndex, new Time(now))), + () -> asserter.accept(() -> mongoPreparedStatement.setTimestamp(parameterIndex, new Timestamp(now))), + () -> asserter.accept( + () -> mongoPreparedStatement.setObject(parameterIndex, Mockito.mock(Array.class), Types.OTHER)), + () -> asserter.accept( + () -> mongoPreparedStatement.setObject(parameterIndex, Mockito.mock(Array.class))), + () -> asserter.accept(() -> mongoPreparedStatement.setArray(parameterIndex, Mockito.mock(Array.class))), + () -> asserter.accept(() -> mongoPreparedStatement.setDate(parameterIndex, new Date(now), calendar)), + () -> asserter.accept(() -> mongoPreparedStatement.setTime(parameterIndex, new Time(now), calendar)), + () -> asserter.accept( + () -> mongoPreparedStatement.setTimestamp(parameterIndex, new Timestamp(now), calendar)), + () -> asserter.accept(() -> mongoPreparedStatement.setNull(parameterIndex, Types.STRUCT, "BOOK"))); + } - private static Stream doGetMongoPreparedStatementMethodInvocationsWithParameterIndex( - int parameterIndex) { - var now = System.currentTimeMillis(); - var calendar = Calendar.getInstance(); - return Map.ofEntries( - Map.entry("setNull(int,int)", pstmt -> pstmt.setNull(parameterIndex, Types.INTEGER)), - Map.entry("setBoolean(int,boolean)", pstmt -> pstmt.setBoolean(parameterIndex, true)), - Map.entry("setByte(int,byte)", pstmt -> pstmt.setByte(parameterIndex, (byte) 10)), - Map.entry("setShort(int,short)", pstmt -> pstmt.setShort(parameterIndex, (short) 10)), - Map.entry("setInt(int,int)", pstmt -> pstmt.setInt(parameterIndex, 1)), - Map.entry("setLong(int,long)", pstmt -> pstmt.setLong(parameterIndex, 1L)), - Map.entry("setFloat(int,float)", pstmt -> pstmt.setFloat(parameterIndex, 1.0F)), - Map.entry("setDouble(int,double)", pstmt -> pstmt.setDouble(parameterIndex, 1.0)), - Map.entry( - "setBigDecimal(int,BigDecimal)", - pstmt -> pstmt.setBigDecimal(parameterIndex, new BigDecimal(1))), - Map.entry("setString(int,String)", pstmt -> pstmt.setString(parameterIndex, "")), - Map.entry("setBytes(int,byte[])", pstmt -> pstmt.setBytes(parameterIndex, "".getBytes())), - Map.entry("setDate(int,Date)", pstmt -> pstmt.setDate(parameterIndex, new Date(now))), - Map.entry("setTime(int,Time)", pstmt -> pstmt.setTime(parameterIndex, new Time(now))), - Map.entry( - "setTimestamp(int,Timestamp)", - pstmt -> pstmt.setTimestamp(parameterIndex, new Timestamp(now))), - Map.entry( - "setBinaryStream(int,InputStream,int)", - pstmt -> pstmt.setBinaryStream( - parameterIndex, new ByteArrayInputStream("".getBytes()), 0)), - Map.entry( - "setObject(int,Object,int)", - pstmt -> pstmt.setObject(parameterIndex, Mockito.mock(Array.class), Types.OTHER)), - Map.entry( - "setArray(int,Array)", - pstmt -> pstmt.setArray(parameterIndex, Mockito.mock(Array.class))), - Map.entry( - "setDate(int,Date,Calendar)", - pstmt -> pstmt.setDate(parameterIndex, new Date(now), calendar)), - Map.entry( - "setTime(int,Time,Calendar)", - pstmt -> pstmt.setTime(parameterIndex, new Time(now), calendar)), - Map.entry( - "setTimestamp(int,Timestamp,Calendar)", - pstmt -> pstmt.setTimestamp(parameterIndex, new Timestamp(now), calendar)), - Map.entry( - "setNull(int,Object,String)", - pstmt -> pstmt.setNull(parameterIndex, Types.STRUCT, "BOOK"))) - .entrySet() - .stream() - .map(entry -> Arguments.of(entry.getKey(), entry.getValue())); - } + private static void checkMethodsWithOpenPrecondition( + MongoPreparedStatement mongoPreparedStatement, Consumer asserter) { + checkSetterMethods(mongoPreparedStatement, 1, asserter); + assertAll( + () -> asserter.accept(mongoPreparedStatement::executeQuery), + () -> asserter.accept(mongoPreparedStatement::executeUpdate), + () -> asserter.accept(mongoPreparedStatement::addBatch), + () -> asserter.accept(() -> mongoPreparedStatement.setQueryTimeout(20_000)), + () -> asserter.accept(() -> mongoPreparedStatement.setFetchSize(10))); + } + + private static void assertThrowsOutOfRangeException(Executable executable) { + var e = assertThrows(SQLException.class, executable); + assertThat(e.getMessage()).startsWith("Invalid parameter index"); } - interface PreparedStatementMethodInvocation { - void runOn(MongoPreparedStatement pstmt) throws SQLException; + private static void assertThrowsClosedException(Executable executable) { + var exception = assertThrows(SQLException.class, executable); + assertThat(exception.getMessage()).isEqualTo("MongoPreparedStatement has been closed"); } } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoResultSetTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoResultSetTests.java new file mode 100644 index 00000000..a23369cd --- /dev/null +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoResultSetTests.java @@ -0,0 +1,245 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.jdbc; + +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.doReturn; + +import com.mongodb.client.MongoCursor; +import java.math.BigDecimal; +import java.sql.SQLException; +import java.util.Calendar; +import java.util.List; +import java.util.UUID; +import java.util.function.Consumer; +import org.bson.BsonBinary; +import org.bson.BsonBoolean; +import org.bson.BsonDecimal128; +import org.bson.BsonDocument; +import org.bson.BsonDouble; +import org.bson.BsonInt32; +import org.bson.BsonInt64; +import org.bson.BsonNull; +import org.bson.BsonString; +import org.bson.BsonValue; +import org.bson.types.Decimal128; +import org.junit.jupiter.api.AutoClose; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class MongoResultSetTests { + + private static final List FIELDS = List.of("id", "title", "publishYear"); + + @Mock + private MongoCursor mongoCursor; + + @AutoClose + private MongoResultSet mongoResultSet; + + @BeforeEach + void beforeEach() { + mongoResultSet = new MongoResultSet(mongoCursor, FIELDS); + } + + @Test + void testColumnIndexUnderflow() { + checkGetterMethods(0, MongoResultSetTests::assertThrowsOutOfRangeException); + } + + @Test + void testColumnIndexOverflow() { + checkGetterMethods(FIELDS.size() + 1, MongoResultSetTests::assertThrowsOutOfRangeException); + } + + @Test + void testCheckClosed() throws SQLException { + mongoResultSet.close(); + checkMethodsWithOpenPrecondition(MongoResultSetTests::assertThrowsClosedException); + } + + @Nested + class GettersTests { + + private void createResultSetWith(BsonValue value) throws SQLException { + var bsonDocument = new BsonDocument().append("field", value); + + doReturn(true).when(mongoCursor).hasNext(); + doReturn(bsonDocument).when(mongoCursor).next(); + mongoResultSet = new MongoResultSet(mongoCursor, singletonList("field")); + assertTrue(mongoResultSet.next()); + } + + @Test + void testGettersForNull() throws SQLException { + createResultSetWith(new BsonNull()); + assertAll( + () -> assertNull(mongoResultSet.getString(1)), + () -> assertFalse(mongoResultSet.getBoolean(1)), + () -> assertEquals(0, mongoResultSet.getInt(1)), + () -> assertEquals(0L, mongoResultSet.getLong(1)), + () -> assertEquals(0D, mongoResultSet.getDouble(1)), + () -> assertNull(mongoResultSet.getBigDecimal(1))); + } + + @Test + void testGettersForBoolean() throws SQLException { + createResultSetWith(new BsonBoolean(true)); + assertAll( + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getString(1)), + () -> assertTrue(mongoResultSet.getBoolean(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getInt(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getLong(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getDouble(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBigDecimal(1))); + } + + @Test + void testGettersForDouble() throws SQLException { + createResultSetWith(new BsonDouble(3.1415)); + assertAll( + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getString(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBoolean(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getInt(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getLong(1)), + () -> assertEquals(3.1415, mongoResultSet.getDouble(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBigDecimal(1))); + } + + @Test + void testGettersForInt() throws SQLException { + createResultSetWith(new BsonInt32(120)); + assertAll( + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getString(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBoolean(1)), + () -> assertEquals(120, mongoResultSet.getInt(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getLong(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getDouble(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBigDecimal(1))); + } + + @Test + void testGettersForLong() throws SQLException { + createResultSetWith(new BsonInt64(12345678)); + assertAll( + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getString(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBoolean(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getInt(1)), + () -> assertEquals(12345678L, mongoResultSet.getLong(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getDouble(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBigDecimal(1))); + } + + @Test + void testGettersForString() throws SQLException { + createResultSetWith(new BsonString("Hello World")); + assertAll( + () -> assertEquals("Hello World", mongoResultSet.getString(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBoolean(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getInt(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getLong(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getDouble(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBigDecimal(1))); + } + + @Test + void testGettersForBigDecimal() throws SQLException { + var bigDecimalValue = new BigDecimal("10692467440017.111"); + var value = new BsonDecimal128(new Decimal128(bigDecimalValue)); + createResultSetWith(value); + assertAll( + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getString(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBoolean(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getInt(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getLong(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getDouble(1)), + () -> assertEquals(bigDecimalValue, mongoResultSet.getBigDecimal(1))); + } + + @Test + void testGettersForBinary() throws SQLException { + var bytes = UUID.randomUUID().toString().getBytes(); + var value = new BsonBinary(bytes); + createResultSetWith(value); + assertAll( + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getString(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBoolean(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getInt(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getLong(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getDouble(1)), + () -> assertThrowsTypeMismatchException(() -> mongoResultSet.getBigDecimal(1)), + () -> assertEquals(bytes, mongoResultSet.getBytes(1))); + } + } + + private void checkMethodsWithOpenPrecondition(Consumer asserter) { + checkGetterMethods(1, asserter); + assertAll( + () -> asserter.accept(() -> mongoResultSet.next()), + () -> asserter.accept(() -> mongoResultSet.wasNull()), + () -> asserter.accept(() -> mongoResultSet.getMetaData()), + () -> asserter.accept(() -> mongoResultSet.findColumn("id")), + () -> asserter.accept(() -> mongoResultSet.getMetaData()), + () -> asserter.accept(() -> mongoResultSet.isWrapperFor(MongoResultSet.class))); + } + + private void checkGetterMethods(int columnIndex, Consumer asserter) { + assertAll( + () -> asserter.accept(() -> mongoResultSet.getString(columnIndex)), + () -> asserter.accept(() -> mongoResultSet.getBoolean(columnIndex)), + () -> asserter.accept(() -> mongoResultSet.getInt(columnIndex)), + () -> asserter.accept(() -> mongoResultSet.getLong(columnIndex)), + () -> asserter.accept(() -> mongoResultSet.getDouble(columnIndex)), + () -> asserter.accept(() -> mongoResultSet.getBytes(columnIndex)), + () -> asserter.accept(() -> mongoResultSet.getDate(columnIndex)), + () -> asserter.accept(() -> mongoResultSet.getTime(columnIndex)), + () -> asserter.accept(() -> mongoResultSet.getTime(columnIndex, Calendar.getInstance())), + () -> asserter.accept(() -> mongoResultSet.getTimestamp(columnIndex)), + () -> asserter.accept(() -> mongoResultSet.getTimestamp(columnIndex, Calendar.getInstance())), + () -> asserter.accept(() -> mongoResultSet.getBigDecimal(columnIndex)), + () -> asserter.accept(() -> mongoResultSet.getArray(columnIndex)), + () -> asserter.accept(() -> mongoResultSet.getObject(columnIndex, UUID.class))); + } + + private static void assertThrowsOutOfRangeException(Executable executable) { + var e = assertThrows(SQLException.class, executable); + assertThat(e.getMessage()).startsWith("Invalid column index"); + } + + private static void assertThrowsClosedException(Executable executable) { + var exception = assertThrows(SQLException.class, executable); + assertThat(exception.getMessage()).isEqualTo("MongoResultSet has been closed"); + } + + private static void assertThrowsTypeMismatchException(Executable executable) { + var exception = assertThrows(SQLException.class, executable); + assertThat(exception.getMessage()).startsWith("Failed to get value from column"); + } +} diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java index 554f7541..7e48a389 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java @@ -17,30 +17,38 @@ package com.mongodb.hibernate.jdbc; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; +import com.mongodb.client.AggregateIterable; import com.mongodb.client.ClientSession; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoCursor; import com.mongodb.client.MongoDatabase; +import com.mongodb.hibernate.internal.FeatureNotSupportedException; +import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLSyntaxErrorException; -import java.sql.Statement; -import java.util.Map; -import java.util.stream.Stream; +import java.util.List; +import java.util.function.BiConsumer; import org.bson.BsonDocument; -import org.junit.jupiter.api.DisplayName; +import org.bson.Document; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.InjectMocks; +import org.junit.jupiter.api.function.Executable; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; @@ -56,12 +64,15 @@ class MongoStatementTests { @Mock private MongoConnection mongoConnection; - @InjectMocks private MongoStatement mongoStatement; + @BeforeEach + void beforeEach() { + mongoStatement = new MongoStatement(mongoDatabase, clientSession, mongoConnection); + } + @Test - @DisplayName("No-op when 'close()' is called on a closed MongoStatement") - void testNoopWhenCloseStatementClosed() { + void testNoopWhenCloseStatementClosed() throws SQLException { mongoStatement.close(); assertTrue(mongoStatement.isClosed()); @@ -69,11 +80,124 @@ void testNoopWhenCloseStatementClosed() { assertDoesNotThrow(() -> mongoStatement.close()); } + @Test + void testResultSetClosedWhenStatementClosed( + @Mock MongoCollection mongoCollection, + @Mock AggregateIterable aggregateIterable, + @Mock MongoCursor mongoCursor) + throws SQLException { + + doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class)); + doReturn(aggregateIterable).when(mongoCollection).aggregate(same(clientSession), anyList()); + doReturn(mongoCursor).when(aggregateIterable).cursor(); + + var query = + """ + { + aggregate: "books", + pipeline: [ + { $match: { _id: { $eq: 1 } } }, + { $project: { _id: 0, title: 1, publishYear: 1 } } + ] + }"""; + + var resultSet = mongoStatement.executeQuery(query); + mongoStatement.close(); + + assertTrue(resultSet.isClosed()); + } + + @Nested + class ExecuteMethodClosesLastOpenResultSetTests { + + private final String exampleQueryMql = + """ + { + aggregate: "books", + pipeline: [ + { $match: { _id: { $eq: 1 } } }, + { $project: { _id: 0, title: 1, publishYear: 1 } } + ] + }"""; + private final String exampleUpdateMql = + """ + { + update: "members", + updates: [ + { + q: {}, + u: { $inc: { points: 1 } }, + multi: true + } + ] + }"""; + + @Mock + MongoCollection mongoCollection; + + @Mock + AggregateIterable aggregateIterable; + + @Mock + MongoCursor mongoCursor; + + private ResultSet lastOpenResultSet; + + @BeforeEach + void beforeEach() throws SQLException { + doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class)); + doReturn(aggregateIterable).when(mongoCollection).aggregate(same(clientSession), anyList()); + doReturn(mongoCursor).when(aggregateIterable).cursor(); + + lastOpenResultSet = mongoStatement.executeQuery(exampleQueryMql); + assertFalse(lastOpenResultSet.isClosed()); + } + + @Test + void testExecuteQuery() throws SQLException { + mongoStatement.executeQuery(exampleQueryMql); + assertTrue(lastOpenResultSet.isClosed()); + } + + @Test + void testExecuteUpdate() throws SQLException { + doReturn(Document.parse("{n: 10}")) + .when(mongoDatabase) + .runCommand(eq(clientSession), any(BsonDocument.class)); + mongoStatement.executeUpdate(exampleUpdateMql); + assertTrue(lastOpenResultSet.isClosed()); + } + + @Test + void testExecute() throws SQLException { + assertThrows(FeatureNotSupportedException.class, () -> mongoStatement.execute(exampleUpdateMql)); + assertTrue(lastOpenResultSet.isClosed()); + } + + @Test + void testExecuteBatch() throws SQLException { + assertThrows(FeatureNotSupportedException.class, () -> mongoStatement.executeBatch()); + assertTrue(lastOpenResultSet.isClosed()); + } + } + + @Test + void testGetProjectStageFieldNames() { + BiConsumer> asserter = (projectStage, expectedFieldNames) -> assertEquals( + expectedFieldNames, MongoStatement.getFieldNamesFromProjectStage(BsonDocument.parse(projectStage))); + assertAll( + () -> asserter.accept("{title: 1, publishYear: 1}", List.of("title", "publishYear", "_id")), + () -> asserter.accept("{title: 1, publishYear: 0}", List.of("title", "_id")), + () -> asserter.accept("{title: 1, publishYear: false}", List.of("title", "_id")), + () -> asserter.accept("{title: 1, _id: 0}", List.of("title")), + () -> asserter.accept("{title: 1, _id: false}", List.of("title")), + () -> asserter.accept("{_id: 1, title: 1}", List.of("_id", "title"))); + } + @Nested class ExecuteUpdateTests { @Test - @DisplayName("SQLException is thrown when 'mql' is invalid") void testSQLExceptionThrownWhenCalledWithInvalidMql() { String invalidMql = @@ -85,7 +209,6 @@ void testSQLExceptionThrownWhenCalledWithInvalidMql() { } @Test - @DisplayName("SQLException is thrown when database access error occurs") void testSQLExceptionThrownWhenDBAccessFailed() { var dbAccessException = new RuntimeException(); @@ -103,68 +226,53 @@ void testSQLExceptionThrownWhenDBAccessFailed() { } } - @Nested - class ClosedTests { - - interface StatementMethodInvocation { - void runOn(MongoStatement stmt) throws SQLException; - } - - @ParameterizedTest(name = "SQLException is thrown when \"{0}\" is called on a closed MongoStatement") - @MethodSource("getMongoStatementMethodInvocationsImpactedByClosing") - void testCheckClosed(String label, StatementMethodInvocation methodInvocation) { - mongoStatement.close(); - - var sqlException = assertThrows(SQLException.class, () -> methodInvocation.runOn(mongoStatement)); - assertThat(sqlException.getMessage()).matches("MongoStatement has been closed"); - } + @Test + void testCheckClosed() throws SQLException { + mongoStatement.close(); + checkMethodsWithOpenPrecondition(); + } - private static Stream getMongoStatementMethodInvocationsImpactedByClosing() { - var exampleQueryMql = - """ + private void checkMethodsWithOpenPrecondition() { + var exampleQueryMql = + """ + { + find: "restaurants", + filter: { rating: { $gte: 9 }, cuisine: "italian" }, + projection: { name: 1, rating: 1, address: 1 }, + sort: { name: 1 }, + limit: 5 + }"""; + var exampleUpdateMql = + """ + { + update: "members", + updates: [ { - find: "restaurants", - filter: { rating: { $gte: 9 }, cuisine: "italian" }, - projection: { name: 1, rating: 1, address: 1 }, - sort: { name: 1 }, - limit: 5 - }"""; - var exampleUpdateMql = - """ - { - update: "members", - updates: [ - { - q: {}, - u: { $inc: { points: 1 } }, - multi: true - } - ] - }"""; - return Map.ofEntries( - Map.entry("executeQuery(String)", stmt -> stmt.executeQuery(exampleQueryMql)), - Map.entry("executeUpdate(String)", stmt -> stmt.executeUpdate(exampleUpdateMql)), - Map.entry("getMaxRows()", MongoStatement::getMaxRows), - Map.entry("setMaxRows(int)", stmt -> stmt.setMaxRows(10)), - Map.entry("getQueryTimeout()", MongoStatement::getQueryTimeout), - Map.entry("setQueryTimeout(int)", stmt -> stmt.setQueryTimeout(1)), - Map.entry("cancel()", MongoStatement::cancel), - Map.entry("getWarnings()", MongoStatement::getWarnings), - Map.entry("clearWarnings()", MongoStatement::clearWarnings), - Map.entry("execute(String)", stmt -> stmt.execute(exampleQueryMql)), - Map.entry("getResultSet()", MongoStatement::getResultSet), - Map.entry("getMoreResultSet()", MongoStatement::getMoreResults), - Map.entry("getUpdateCount()", MongoStatement::getUpdateCount), - Map.entry("setFetchSize(int)", stmt -> stmt.setFetchSize(1)), - Map.entry("getFetchSize()", MongoStatement::getFetchSize), - Map.entry("addBatch(String)", stmt -> stmt.addBatch(exampleUpdateMql)), - Map.entry("clearBatch()", MongoStatement::clearBatch), - Map.entry("executeBatch()", MongoStatement::executeBatch), - Map.entry("getConnection()", MongoStatement::getConnection), - Map.entry("isWrapperFor(Class)", stmt -> stmt.isWrapperFor(Statement.class))) - .entrySet() - .stream() - .map(entry -> Arguments.of(entry.getKey(), entry.getValue())); - } + q: {}, + u: { $inc: { points: 1 } }, + multi: true + } + ] + }"""; + assertAll( + () -> assertThrowsClosedException(() -> mongoStatement.executeQuery(exampleQueryMql)), + () -> assertThrowsClosedException(() -> mongoStatement.executeUpdate(exampleUpdateMql)), + () -> assertThrowsClosedException(mongoStatement::cancel), + () -> assertThrowsClosedException(mongoStatement::getWarnings), + () -> assertThrowsClosedException(mongoStatement::clearWarnings), + () -> assertThrowsClosedException(() -> mongoStatement.execute(exampleUpdateMql)), + () -> assertThrowsClosedException(mongoStatement::getResultSet), + () -> assertThrowsClosedException(mongoStatement::getMoreResults), + () -> assertThrowsClosedException(mongoStatement::getUpdateCount), + () -> assertThrowsClosedException(() -> mongoStatement.addBatch(exampleUpdateMql)), + () -> assertThrowsClosedException(mongoStatement::clearBatch), + () -> assertThrowsClosedException(mongoStatement::executeBatch), + () -> assertThrowsClosedException(mongoStatement::getConnection), + () -> assertThrowsClosedException(() -> mongoStatement.isWrapperFor(MongoStatement.class))); + } + + private static void assertThrowsClosedException(Executable executable) { + var exception = assertThrows(SQLException.class, executable); + assertThat(exception.getMessage()).isEqualTo("MongoStatement has been closed"); } }