From 68ee54d06b2dfad682848a0bf23135bd889fa415 Mon Sep 17 00:00:00 2001 From: htt <641571835@qq.com> Date: Fri, 18 Feb 2022 08:18:07 +0800 Subject: [PATCH 01/17] upgrade kotlin version to 1.6.10 upgrade kotlin-coroutine version to 1.6.0 upgrade r2dbc version to 0.9.1.RELEASE delete cachedRow delete cacheRowMetadata update test sql file update sqlDialect update database add ktorm-r2dbc dsl add table support add entity support add entitySequence support --- build.gradle | 7 +- ktorm-r2dbc-core/ktorm-r2dbc-core.gradle | 2 + .../org/ktorm/r2dbc/database/CachedRow.kt | 345 ---- .../ktorm/r2dbc/database/CachedRowMetadata.kt | 86 - .../database/CoroutinesTransactionManager.kt | 18 +- .../org/ktorm/r2dbc/database/Database.kt | 95 +- .../org/ktorm/r2dbc/database/Keywords.kt | 268 +++ .../ktorm/r2dbc/database/R2JdbcExtensions.kt | 37 + .../org/ktorm/r2dbc/database/SqlDialect.kt | 115 +- .../kotlin/org/ktorm/r2dbc/dsl/Aggregation.kt | 11 +- .../main/kotlin/org/ktorm/r2dbc/dsl/Dml.kt | 387 +++++ .../kotlin/org/ktorm/r2dbc/dsl/Operators.kt | 546 ++++++ .../main/kotlin/org/ktorm/r2dbc/dsl/Query.kt | 788 +++++++++ .../kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt | 62 + .../kotlin/org/ktorm/r2dbc/dsl/QuerySource.kt | 145 ++ .../kotlin/org/ktorm/r2dbc/entity/Entity.kt | 292 ++++ .../ktorm/r2dbc/entity/EntityExtensions.kt | 197 +++ .../r2dbc/entity/EntityImplementation.kt | 289 ++++ .../org/ktorm/r2dbc/entity/EntitySequence.kt | 1503 +++++++++++++++++ .../ktorm/r2dbc/expression/SqlFormatter.kt | 139 +- .../org/ktorm/r2dbc/schema/BaseTable.kt | 70 +- .../kotlin/org/ktorm/r2dbc/schema/Column.kt | 1 + .../r2dbc/schema/ColumnBindingHandler.kt | 126 ++ .../org/ktorm/r2dbc/schema/EntityDml.kt | 30 + .../kotlin/org/ktorm/r2dbc/schema/SqlType.kt | 36 +- .../kotlin/org/ktorm/r2dbc/schema/SqlTypes.kt | 262 +++ .../kotlin/org/ktorm/r2dbc/schema/Table.kt | 174 ++ .../org/ktorm/r2dbc/schema/TypeReference.kt | 2 +- .../src/test/kotlin/org/ktorm/BaseTest.kt | 134 ++ .../kotlin/org/ktorm/database/DatabaseTest.kt | 191 +++ .../src/test/resources/drop-data.sql | 6 + .../src/test/resources/init-data.sql | 60 + 32 files changed, 5865 insertions(+), 559 deletions(-) delete mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CachedRow.kt delete mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CachedRowMetadata.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Keywords.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/R2JdbcExtensions.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Dml.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Operators.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QuerySource.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityExtensions.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/ColumnBindingHandler.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/EntityDml.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlTypes.kt create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt create mode 100644 ktorm-r2dbc-core/src/test/kotlin/org/ktorm/BaseTest.kt create mode 100644 ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt create mode 100644 ktorm-r2dbc-core/src/test/resources/drop-data.sql create mode 100644 ktorm-r2dbc-core/src/test/resources/init-data.sql diff --git a/build.gradle b/build.gradle index 7e0a195..0efbe9d 100644 --- a/build.gradle +++ b/build.gradle @@ -1,9 +1,9 @@ buildscript { ext { - kotlinVersion = "1.4.21" - coroutinesVersion = "1.4.2" - r2dbcVersion = "0.8.3.RELEASE" + kotlinVersion = "1.6.10" + coroutinesVersion = "1.6.0" + r2dbcVersion = "0.9.1.RELEASE" detektVersion = "1.12.0-RC1" } repositories { @@ -14,6 +14,7 @@ buildscript { classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:${kotlinVersion}" classpath "com.jfrog.bintray.gradle:gradle-bintray-plugin:1.8.4" classpath "io.gitlab.arturbosch.detekt:detekt-gradle-plugin:${detektVersion}" + classpath "org.jetbrains.kotlinx:kotlinx-coroutines-reactive:${coroutinesVersion}" } } diff --git a/ktorm-r2dbc-core/ktorm-r2dbc-core.gradle b/ktorm-r2dbc-core/ktorm-r2dbc-core.gradle index 7119b05..d4d6c80 100644 --- a/ktorm-r2dbc-core/ktorm-r2dbc-core.gradle +++ b/ktorm-r2dbc-core/ktorm-r2dbc-core.gradle @@ -3,6 +3,8 @@ dependencies { compileOnly "org.slf4j:slf4j-api:1.7.25" compileOnly "commons-logging:commons-logging:1.2" compileOnly "com.google.android:android:1.5_r4" + + testImplementation 'io.r2dbc:r2dbc-h2:0.9.1.RELEASE' } configurations { diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CachedRow.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CachedRow.kt deleted file mode 100644 index ccc2f43..0000000 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CachedRow.kt +++ /dev/null @@ -1,345 +0,0 @@ -package org.ktorm.r2dbc.database - -import io.r2dbc.spi.Blob -import io.r2dbc.spi.Clob -import io.r2dbc.spi.Row -import io.r2dbc.spi.RowMetadata -import kotlinx.coroutines.runBlocking -import java.math.BigDecimal -import java.math.BigInteger -import java.nio.ByteBuffer -import java.time.* - -/** - * Created by vince on Feb 10, 2021. - */ -public open class CachedRow(row: Row, metadata: RowMetadata): Row { - private val _values = readValues(row, metadata) - private val _metadata = readMetadata(row, metadata) - - public val metadata: RowMetadata get() = _metadata - - private fun readValues(row: Row, metadata: RowMetadata): Map { - if (row is CachedRow) { - return row._values - } else { - return metadata.columnMetadatas.reversed().associate { column -> - val value = row.get(column.name) - column.name.toUpperCase() to when (value) { - is Clob -> Clob.from(CachedPublisher(value.stream())) - is Blob -> Blob.from(CachedPublisher(value.stream())) - else -> value - } - } - } - } - - private fun readMetadata(row: Row, metadata: RowMetadata): CachedRowMetadata { - return when { - row is CachedRow -> row._metadata - metadata is CachedRowMetadata -> metadata - else -> CachedRowMetadata(metadata) - } - } - - override fun get(index: Int, type: Class): T? { - val column = metadata.getColumnMetadata(index) - return get(column.name, type) - } - - override fun get(name: String, type: Class): T? { - val result = when (type.kotlin) { - String::class -> getString(name) - Clob::class -> getClob(name) - Boolean::class -> getBoolean(name) - Byte::class -> getByte(name) - Short::class -> getShort(name) - Int::class -> getInt(name) - Long::class -> getLong(name) - Float::class -> getFloat(name) - Double::class -> getDouble(name) - BigDecimal::class -> getBigDecimal(name) - BigInteger::class -> getBigInteger(name) - ByteBuffer::class -> getByteBuffer(name) - ByteArray::class -> getBytes(name) - Blob::class -> getBlob(name) - LocalDate::class -> getDate(name) - LocalTime::class -> getTime(name) - LocalDateTime::class -> getDateTime(name) - ZonedDateTime::class -> getZonedDateTime(name) - OffsetDateTime::class -> getOffsetDateTime(name) - Instant::class -> getInstant(name) - else -> getColumnValue(name) - } - - return type.cast(result) - } - - private fun getColumnValue(name: String): Any? { - return _values[name.toUpperCase()] - } - - private fun getString(name: String): String? { - return when (val value = getColumnValue(name)) { - is String -> value - is Clob -> runBlocking { value.toText() } // Won't block if data was pre-fetched by CachedPublisher. - else -> value?.toString() - } - } - - private fun getClob(name: String): Clob? { - return when (val value = getColumnValue(name)) { - null -> null - is Clob -> value - is String -> Clob.from(IterableAsPublisher(value)) - else -> throw IllegalArgumentException("Cannot convert ${value.javaClass.name} value to Clob.") - } - } - - private fun getBoolean(name: String): Boolean? { - return when (val value = getColumnValue(name)) { - null -> null - is Boolean -> value - is Number -> value.toDouble().toBits() != 0.0.toBits() - else -> value.toString().toDouble().toBits() != 0.0.toBits() - } - } - - private fun getByte(name: String): Byte? { - return when (val value = getColumnValue(name)) { - is Byte -> value - is Number -> value.toByte() - is Boolean -> if (value) 1 else 0 - else -> value?.toString()?.toByte() - } - } - - private fun getShort(name: String): Short? { - return when (val value = getColumnValue(name)) { - is Short -> value - is Number -> value.toShort() - is Boolean -> if (value) 1 else 0 - else -> value?.toString()?.toShort() - } - } - - private fun getInt(name: String): Int? { - return when (val value = getColumnValue(name)) { - is Int -> value - is Number -> value.toInt() - is Boolean -> if (value) 1 else 0 - else -> value?.toString()?.toInt() - } - } - - private fun getLong(name: String): Long? { - return when (val value = getColumnValue(name)) { - is Long -> value - is Number -> value.toLong() - is Boolean -> if (value) 1 else 0 - else -> value?.toString()?.toLong() - } - } - - private fun getFloat(name: String): Float? { - return when (val value = getColumnValue(name)) { - is Float -> value - is Number -> value.toFloat() - is Boolean -> if (value) 1.0F else 0.0F - else -> value?.toString()?.toFloat() - } - } - - private fun getDouble(name: String): Double? { - return when (val value = getColumnValue(name)) { - is Double -> value - is Number -> value.toDouble() - is Boolean -> if (value) 1.0 else 0.0 - else -> value?.toString()?.toDouble() - } - } - - private fun getBigDecimal(name: String): BigDecimal? { - return when (val value = getColumnValue(name)) { - is BigDecimal -> value - is Boolean -> if (value) BigDecimal.ONE else BigDecimal.ZERO - else -> value?.toString()?.toBigDecimal() - } - } - - private fun getBigInteger(name: String): BigInteger? { - return when (val value = getColumnValue(name)) { - is BigInteger -> value - is Boolean -> if (value) BigInteger.ONE else BigInteger.ZERO - else -> value?.toString()?.toBigInteger() - } - } - - private fun getByteBuffer(name: String): ByteBuffer? { - return when (val value = getColumnValue(name)) { - null -> null - is ByteBuffer -> value - is ByteArray -> ByteBuffer.wrap(value) - is Blob -> runBlocking { value.toByteBuffer() } // Won't block if data was pre-fetched by CachedPublisher. - else -> throw IllegalArgumentException("Cannot convert ${value.javaClass.name} value to ByteBuffer.") - } - } - - private fun getBytes(name: String): ByteArray? { - return when (val value = getColumnValue(name)) { - null -> null - is ByteArray -> value - is ByteBuffer -> value.toBytes() - is Blob -> runBlocking { value.toBytes() } // Won't block if data was pre-fetched by CachedPublisher. - else -> throw IllegalArgumentException("Cannot convert ${value.javaClass.name} value to byte[].") - } - } - - private fun getBlob(name: String): Blob? { - return when (val value = getColumnValue(name)) { - null -> null - is Blob -> value - is ByteBuffer -> Blob.from(IterableAsPublisher(value)) - is ByteArray -> Blob.from(IterableAsPublisher(ByteBuffer.wrap(value))) - else -> throw IllegalArgumentException("Cannot convert ${value.javaClass.name} value to Blob.") - } - } - - private fun getDate(name: String): LocalDate? { - return when (val value = getColumnValue(name)) { - null -> null - is LocalDate -> value - is LocalDateTime -> value.toLocalDate() - is ZonedDateTime -> value.toLocalDate() - is OffsetDateTime -> value.toLocalDate() - is Instant -> value.atZone(ZoneId.systemDefault()).toLocalDate() - is Number -> Instant.ofEpochMilli(value.toLong()).atZone(ZoneId.systemDefault()).toLocalDate() - is String -> { - val number = value.toLongOrNull() - if (number != null) { - Instant.ofEpochMilli(number).atZone(ZoneId.systemDefault()).toLocalDate() - } else { - LocalDate.parse(value) - } - } - else -> { - throw IllegalArgumentException("Cannot convert ${value.javaClass.name} value to LocalDate.") - } - } - } - - private fun getTime(name: String): LocalTime? { - return when (val value = getColumnValue(name)) { - null -> null - is LocalTime -> value - is LocalDateTime -> value.toLocalTime() - is ZonedDateTime -> value.toLocalTime() - is OffsetDateTime -> value.toLocalTime() - is Instant -> value.atZone(ZoneId.systemDefault()).toLocalTime() - is Number -> Instant.ofEpochMilli(value.toLong()).atZone(ZoneId.systemDefault()).toLocalTime() - is String -> { - val number = value.toLongOrNull() - if (number != null) { - Instant.ofEpochMilli(number).atZone(ZoneId.systemDefault()).toLocalTime() - } else { - LocalTime.parse(value) - } - } - else -> { - throw IllegalArgumentException("Cannot convert ${value.javaClass.name} value to LocalTime.") - } - } - } - - private fun getDateTime(name: String): LocalDateTime? { - return when (val value = getColumnValue(name)) { - null -> null - is LocalDateTime -> value - is LocalDate -> value.atStartOfDay() - is ZonedDateTime -> value.toLocalDateTime() - is OffsetDateTime -> value.toLocalDateTime() - is Instant -> value.atZone(ZoneId.systemDefault()).toLocalDateTime() - is Number -> Instant.ofEpochMilli(value.toLong()).atZone(ZoneId.systemDefault()).toLocalDateTime() - is String -> { - val number = value.toLongOrNull() - if (number != null) { - Instant.ofEpochMilli(number).atZone(ZoneId.systemDefault()).toLocalDateTime() - } else { - LocalDateTime.parse(value) - } - } - else -> { - throw IllegalArgumentException("Cannot convert ${value.javaClass.name} value to LocalDateTime.") - } - } - } - - private fun getZonedDateTime(name: String): ZonedDateTime? { - return when (val value = getColumnValue(name)) { - null -> null - is ZonedDateTime -> value - is LocalDate -> value.atStartOfDay(ZoneId.systemDefault()) - is LocalDateTime -> value.atZone(ZoneId.systemDefault()) - is OffsetDateTime -> value.toZonedDateTime() - is Instant -> value.atZone(ZoneId.systemDefault()) - is Number -> Instant.ofEpochMilli(value.toLong()).atZone(ZoneId.systemDefault()) - is String -> { - val number = value.toLongOrNull() - if (number != null) { - Instant.ofEpochMilli(number).atZone(ZoneId.systemDefault()) - } else { - ZonedDateTime.parse(value) - } - } - else -> { - throw IllegalArgumentException("Cannot convert ${value.javaClass.name} value to LocalDateTime.") - } - } - } - - private fun getOffsetDateTime(name: String): OffsetDateTime? { - return when (val value = getColumnValue(name)) { - null -> null - is OffsetDateTime -> value - is LocalDate -> value.atStartOfDay(ZoneId.systemDefault()).toOffsetDateTime() - is LocalDateTime -> value.atZone(ZoneId.systemDefault()).toOffsetDateTime() - is ZonedDateTime -> value.toOffsetDateTime() - is Instant -> value.atZone(ZoneId.systemDefault()).toOffsetDateTime() - is Number -> Instant.ofEpochMilli(value.toLong()).atZone(ZoneId.systemDefault()).toOffsetDateTime() - is String -> { - val number = value.toLongOrNull() - if (number != null) { - Instant.ofEpochMilli(number).atZone(ZoneId.systemDefault()).toOffsetDateTime() - } else { - OffsetDateTime.parse(value) - } - } - else -> { - throw IllegalArgumentException("Cannot convert ${value.javaClass.name} value to LocalDateTime.") - } - } - } - - private fun getInstant(name: String): Instant? { - return when (val value = getColumnValue(name)) { - null -> null - is Instant -> value - is LocalDate -> value.atStartOfDay(ZoneId.systemDefault()).toInstant() - is LocalDateTime -> value.atZone(ZoneId.systemDefault()).toInstant() - is ZonedDateTime -> value.toInstant() - is OffsetDateTime -> value.toInstant() - is Number -> Instant.ofEpochMilli(value.toLong()) - is String -> { - val number = value.toLongOrNull() - if (number != null) { - Instant.ofEpochMilli(number) - } else { - Instant.parse(value) - } - } - else -> { - throw IllegalArgumentException("Cannot convert ${value.javaClass.name} value to LocalDateTime.") - } - } - } -} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CachedRowMetadata.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CachedRowMetadata.kt deleted file mode 100644 index 5b9f9d8..0000000 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CachedRowMetadata.kt +++ /dev/null @@ -1,86 +0,0 @@ -package org.ktorm.r2dbc.database - -import io.r2dbc.spi.ColumnMetadata -import io.r2dbc.spi.Nullability -import io.r2dbc.spi.RowMetadata -import java.util.* -import kotlin.collections.AbstractCollection - -/** - * Created by vince on Feb 10, 2021. - */ -internal class CachedRowMetadata(metadata: RowMetadata) : RowMetadata { - private val columns = metadata.columnMetadatas.map { CachedColumnMetadata(it) } - - override fun getColumnMetadata(index: Int): ColumnMetadata { - return columns[index] - } - - override fun getColumnMetadata(name: String): ColumnMetadata { - return columns.first { it.name.equals(name, ignoreCase = true) } - } - - override fun getColumnMetadatas(): Iterable { - return Collections.unmodifiableList(columns) - } - - override fun getColumnNames(): Collection { - return object : AbstractCollection() { - override val size: Int = columns.size - - override fun iterator(): Iterator { - return TransformedIterator(columns.iterator()) { it.name } - } - - override fun contains(element: String): Boolean { - return columns.any { it.name.equals(element, ignoreCase = true) } - } - } - } - - private class TransformedIterator( - private val sourceIterator: Iterator, - private val transform: (T) -> R - ) : Iterator { - override fun hasNext(): Boolean { - return sourceIterator.hasNext() - } - - override fun next(): R { - return transform(sourceIterator.next()) - } - } - - private class CachedColumnMetadata(metadata: ColumnMetadata) : ColumnMetadata { - private val _javaType = metadata.javaType - private val _name = metadata.name - private val _nativeTypeMetadata = metadata.nativeTypeMetadata - private val _nullability = metadata.nullability - private val _precision = metadata.precision - private val _scale = metadata.scale - - override fun getJavaType(): Class<*>? { - return _javaType - } - - override fun getName(): String { - return _name - } - - override fun getNativeTypeMetadata(): Any? { - return _nativeTypeMetadata - } - - override fun getNullability(): Nullability { - return _nullability - } - - override fun getPrecision(): Int? { - return _precision - } - - override fun getScale(): Int? { - return _scale - } - } -} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt index dd328b2..915e10f 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt @@ -3,8 +3,8 @@ package org.ktorm.r2dbc.database import io.r2dbc.spi.Connection import io.r2dbc.spi.ConnectionFactory import io.r2dbc.spi.IsolationLevel +import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactive.awaitSingle -import kotlinx.coroutines.reactive.awaitSingleOrNull /** * Created by vince on Jan 30, 2021. @@ -41,13 +41,13 @@ public class CoroutinesTransactionManager( suspend fun begin() { try { if (desiredIsolation != null && desiredIsolation != originIsolation) { - connection.setTransactionIsolationLevel(desiredIsolation).awaitSingleOrNull() + connection.setTransactionIsolationLevel(desiredIsolation).awaitFirstOrNull() } if (originAutoCommit) { - connection.setAutoCommit(false).awaitSingleOrNull() + connection.setAutoCommit(false).awaitFirstOrNull() } - connection.beginTransaction().awaitSingleOrNull() + connection.beginTransaction().awaitFirstOrNull() } catch (e: Throwable) { close() throw e @@ -55,25 +55,25 @@ public class CoroutinesTransactionManager( } override suspend fun commit() { - connection.commitTransaction().awaitSingleOrNull() + connection.commitTransaction().awaitFirstOrNull() } override suspend fun rollback() { - connection.rollbackTransaction().awaitSingleOrNull() + connection.rollbackTransaction().awaitFirstOrNull() } override suspend fun close() { try { if (desiredIsolation != null && desiredIsolation != originIsolation) { - connection.setTransactionIsolationLevel(originIsolation).awaitSingleOrNull() + connection.setTransactionIsolationLevel(originIsolation).awaitFirstOrNull() } if (originAutoCommit) { - connection.setAutoCommit(true).awaitSingleOrNull() + connection.setAutoCommit(true).awaitFirstOrNull() } } catch (_: Throwable) { } finally { try { - connection.close().awaitSingleOrNull() + connection.close().awaitFirstOrNull() } catch (_: Throwable) { } finally { currentTransaction.remove() diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt index f787a33..cfebf11 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt @@ -1,8 +1,10 @@ package org.ktorm.r2dbc.database import io.r2dbc.spi.* +import kotlinx.coroutines.reactive.awaitFirst +import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactive.awaitSingle -import kotlinx.coroutines.reactive.awaitSingleOrNull +import kotlinx.coroutines.runBlocking import org.ktorm.r2dbc.expression.ArgumentExpression import org.ktorm.r2dbc.expression.SqlExpression import org.ktorm.r2dbc.logging.Logger @@ -20,9 +22,48 @@ public class Database( public val transactionManager: TransactionManager = CoroutinesTransactionManager(connectionFactory), public val dialect: SqlDialect = detectDialectImplementation(), public val logger: Logger = detectLoggerImplementation(), - public val exceptionTranslator: ((R2dbcException) -> Throwable)? = null + public val exceptionTranslator: ((R2dbcException) -> Throwable)? = null, + public val alwaysQuoteIdentifiers: Boolean = false, + public val generateSqlInUpperCase: Boolean? = null ) { + /** + * The name of the connected database product, eg. MySQL, H2. + */ + public val productName: String + + /** + * The version of the connected database product. + */ + public val productVersion: String + + /** + * A set of all of this database's SQL keywords (including SQL:2003 keywords), all in uppercase. + */ + public val keywords: Set + + + + + init { + fun kotlin.Result.orEmpty() = getOrNull().orEmpty() + fun kotlin.Result.orFalse() = getOrDefault(false) + + runBlocking { + useConnection { conn -> + val metadata = conn.metadata + productName = metadata.runCatching { databaseProductName }.orEmpty() + productVersion = metadata.runCatching { databaseVersion }.orEmpty() + keywords = ANSI_SQL_2003_KEYWORDS + dialect.sqlKeywords + } + + if (logger.isInfoEnabled()) { + val msg = "Connected to productName: %s, productVersion: %s, logger: %s, dialect: %s" + logger.info(msg.format(productName, productVersion, logger, dialect)) + } + } + } + @OptIn(ExperimentalContracts::class) public suspend inline fun useConnection(func: (Connection) -> T): T { contract { @@ -36,7 +77,7 @@ public class Database( try { return func(connection) } finally { - if (transaction == null) connection.close().awaitSingleOrNull() + if (transaction == null) connection.close().awaitFirstOrNull() } } catch (e: R2dbcException) { throw exceptionTranslator?.invoke(e) ?: e @@ -109,9 +150,9 @@ public class Database( } } - public suspend fun executeQuery(expression: SqlExpression): List { + public suspend fun executeQuery(expression: SqlExpression): List { executeExpression(expression) { result -> - val rows = result.map { row, metadata -> CachedRow(row, metadata) }.toList() + val rows = result.map { row, _ -> row }.toList() if (logger.isDebugEnabled()) { logger.debug("Results: ${rows.size}") @@ -132,6 +173,50 @@ public class Database( return effects } } + /** + * Batch execute the given SQL expressions and return the effected row counts for each expression. + * + * Note that this function is implemented based on [Statement.add] and [Statement.execute], + * and any item in a batch operation must have the same structure, otherwise an exception will be thrown. + * + * @since 2.7 + * @param expressions the SQL expressions to be executed. + * @return the effected row counts for each sub-operation. + */ + public suspend fun executeBatch(expressions: List): IntArray { + val (sql, _) = formatExpression(expressions[0]) + + if (logger.isDebugEnabled()) { + logger.debug("SQL: $sql") + } + + useConnection { conn -> + val statement = conn.createStatement(sql) + for (expr in expressions) { + val (subSql, args) = formatExpression(expr) + + if (subSql != sql) { + throw IllegalArgumentException( + "Every item in a batch operation must generate the same SQL: \n\n$subSql" + ) + } + if (logger.isDebugEnabled()) { + logger.debug("Parameters: " + args.map { "${it.value}(${it.sqlType.javaType.simpleName})" }) + } + + statement.bindParameters(args) + statement.add() + } + + val results = statement.execute().toList() + + /* if (logaddBatchger.isDebugEnabled()) { + logger.debug("Effects: ${results?.contentToString()}") + }*/ + + return results.map { result -> result.rowsUpdated.awaitFirst() }.toIntArray() + } + } public companion object { diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Keywords.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Keywords.kt new file mode 100644 index 0000000..7a4dc00 --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Keywords.kt @@ -0,0 +1,268 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.database + +/** + * Keywords in SQL:2003 standard, all in uppercase. + */ +internal val ANSI_SQL_2003_KEYWORDS = setOf( + "ADD", + "ALL", + "ALLOCATE", + "ALTER", + "AND", + "ANY", + "ARE", + "ARRAY", + "AS", + "ASENSITIVE", + "ASYMMETRIC", + "AT", + "ATOMIC", + "AUTHORIZATION", + "BEGIN", + "BETWEEN", + "BIGINT", + "BINARY", + "BLOB", + "BINARY", + "BOTH", + "BY", + "CALL", + "CALLED", + "CASCADED", + "CASE", + "CAST", + "CHAR", + "CHARACTER", + "CHECK", + "CLOB", + "CLOB", + "CLOSE", + "COLLATE", + "COLUMN", + "COMMIT", + "CONDITION", + "CONNECT", + "CONSTRAINT", + "CONTINUE", + "CORRESPONDING", + "CREATE", + "CROSS", + "CUBE", + "CURRENT", + "CURRENT_DATE", + "CURRENT_PATH", + "CURRENT_ROLE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "CURSOR", + "CYCLE", + "DATE", + "DAY", + "DEALLOCATE", + "DEC", + "DECIMAL", + "DECLARE", + "DEFAULT", + "DELETE", + "DEREF", + "DESCRIBE", + "DETERMINISTIC", + "DISCONNECT", + "DISTINCT", + "DO", + "DOUBLE", + "DROP", + "DYNAMIC", + "EACH", + "ELEMENT", + "ELSE", + "ELSIF", + "END", + "ESCAPE", + "EXCEPT", + "EXEC", + "EXECUTE", + "EXISTS", + "EXIT", + "EXTERNAL", + "FALSE", + "FETCH", + "FILTER", + "FLOAT", + "FOR", + "FOREIGN", + "FREE", + "FROM", + "FULL", + "FUNCTION", + "GET", + "GLOBAL", + "GRANT", + "GROUP", + "GROUPING", + "HANDLER", + "HAVING", + "HOLD", + "HOUR", + "IDENTITY", + "IF", + "IMMEDIATE", + "IN", + "INDICATOR", + "INNER", + "INOUT", + "INPUT", + "INSENSITIVE", + "INSERT", + "INT", + "INTEGER", + "INTERSECT", + "INTERVAL", + "INTO", + "IS", + "ITERATE", + "JOIN", + "LANGUAGE", + "LARGE", + "LATERAL", + "LEADING", + "LEAVE", + "LEFT", + "LIKE", + "LOCAL", + "LOCALTIME", + "LOCALTIMESTAMP", + "LOOP", + "MATCH", + "MEMBER", + "MERGE", + "METHOD", + "MINUTE", + "MODIFIES", + "MODULE", + "MONTH", + "MULTISET", + "NATIONAL", + "NATURAL", + "NCHAR", + "NCLOB", + "NEW", + "NO", + "NONE", + "NOT", + "NULL", + "NUMERIC", + "OF", + "OLD", + "ON", + "ONLY", + "OPEN", + "OR", + "ORDER", + "OUT", + "OUTER", + "OUTPUT", + "OVER", + "OVERLAPS", + "PARAMETER", + "PARTITION", + "PRECISION", + "PREPARE", + "PRIMARY", + "PROCEDURE", + "RANGE", + "READS", + "REAL", + "RECURSIVE", + "REF", + "REFERENCES", + "REFERENCING", + "RELEASE", + "REPEAT", + "RESIGNAL", + "RESULT", + "RETURN", + "RETURNS", + "REVOKE", + "RIGHT", + "ROLLBACK", + "ROLLUP", + "ROW", + "ROWS", + "SAVEPOINT", + "SCROLL", + "SEARCH", + "SECOND", + "SELECT", + "SENSITIVE", + "SESSION_USE", + "SET", + "SIGNAL", + "SIMILAR", + "SMALLINT", + "SOME", + "SPECIFIC", + "SPECIFICTYPE", + "SQL", + "SQLEXCEPTION", + "SQLSTATE", + "SQLWARNING", + "START", + "STATIC", + "SUBMULTISET", + "SYMMETRIC", + "SYSTEM", + "SYSTEM_USER", + "TABLE", + "TABLESAMPLE", + "THEN", + "TIME", + "TIMESTAMP", + "TIMEZONE_HOUR", + "TIMEZONE_MINUTE", + "TO", + "TRAILING", + "TRANSLATION", + "TREAT", + "TRIGGER", + "TRUE", + "UNDO", + "UNION", + "UNIQUE", + "UNKNOWN", + "UNNEST", + "UNTIL", + "UPDATE", + "USER", + "USING", + "VALUE", + "VALUES", + "VARCHAR", + "VARYING", + "WHEN", + "WHENEVER", + "WHERE", + "WHILE", + "WINDOW", + "WITH", + "WITHIN", + "WITHOUT", + "YEAR" +) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/R2JdbcExtensions.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/R2JdbcExtensions.kt new file mode 100644 index 0000000..0f62ba7 --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/R2JdbcExtensions.kt @@ -0,0 +1,37 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.database + +import io.r2dbc.spi.Statement +import org.ktorm.r2dbc.expression.ArgumentExpression +import org.ktorm.r2dbc.schema.SqlType + + +/** + * Set the arguments for this [Statement]. + * + * @since 2.7 + * @param args the arguments to bind into the statement. + */ +public fun Statement.bindParameters(args: List>) { + for ((i, expr) in args.withIndex()) { + @Suppress("UNCHECKED_CAST") + val sqlType = expr.sqlType as SqlType + sqlType.bindParameter(this, i + 1, expr.value) + } +} + diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/SqlDialect.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/SqlDialect.kt index 284bdfe..2c51d6f 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/SqlDialect.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/SqlDialect.kt @@ -24,6 +24,100 @@ public interface SqlDialect { } } } + + /** + * What is the string used to quote SQL identifiers? This returns a space if identifier quoting + * isn't supported. A JDBC Compliant driver will always use a double quote character. + */ + public val identifierQuoteString: String + + /** + * All the "extra" characters that can be used in unquoted identifier names (those beyond a-z, A-Z, 0-9 and _). + */ + public val extraNameCharacters: String + + /** + * Whether this database treats mixed case unquoted SQL identifiers as case sensitive and as a result + * stores them in mixed case. + * + * @since 3.1.0 + */ + public val supportsMixedCaseIdentifiers: Boolean + + /** + * Whether this database treats mixed case unquoted SQL identifiers as case insensitive and + * stores them in mixed case. + * + * @since 3.1.0 + */ + public val storesMixedCaseIdentifiers: Boolean + + /** + * Whether this database treats mixed case unquoted SQL identifiers as case insensitive and + * stores them in upper case. + * + * @since 3.1.0 + */ + public val storesUpperCaseIdentifiers: Boolean + + /** + * Whether this database treats mixed case unquoted SQL identifiers as case insensitive and + * stores them in lower case. + * + * @since 3.1.0 + */ + public val storesLowerCaseIdentifiers: Boolean + + /** + * Whether this database treats mixed case quoted SQL identifiers as case sensitive and as a result + * stores them in mixed case. + * + * @since 3.1.0 + */ + public val supportsMixedCaseQuotedIdentifiers: Boolean + + /** + * Whether this database treats mixed case quoted SQL identifiers as case insensitive and + * stores them in mixed case. + * + * @since 3.1.0 + */ + public val storesMixedCaseQuotedIdentifiers: Boolean + + /** + * Whether this database treats mixed case quoted SQL identifiers as case insensitive and + * stores them in upper case. + * + * @since 3.1.0 + */ + public val storesUpperCaseQuotedIdentifiers: Boolean + + /** + * Whether this database treats mixed case quoted SQL identifiers as case insensitive and + * stores them in lower case. + * + * @since 3.1.0 + */ + public val storesLowerCaseQuotedIdentifiers: Boolean + + /** + * Retrieves a comma-separated list of all of this database's SQL keywords + * that are NOT also SQL:2003 keywords. + * + * @return the list of this database's keywords that are not also + * SQL:2003 keywords + * @since 3.1.0 + */ + public val sqlKeywords: Set + + /** + * The maximum number of characters this database allows for a column name. Zero means that there is no limit + * or the limit is not known. + * + * @since 3.1.0 + */ + public val maxColumnNameLength: Int + } /** @@ -48,9 +142,24 @@ public class DialectFeatureNotSupportedException( public fun detectDialectImplementation(): SqlDialect { val dialects = ServiceLoader.load(SqlDialect::class.java).toList() return when (dialects.size) { - 0 -> object : SqlDialect { } + 0 -> object : SqlDialect { + override val identifierQuoteString: String = "" + override val extraNameCharacters: String = "" + override val supportsMixedCaseIdentifiers: Boolean = false + override val storesMixedCaseIdentifiers: Boolean = false + override val storesUpperCaseIdentifiers: Boolean = false + override val storesLowerCaseIdentifiers: Boolean = false + override val supportsMixedCaseQuotedIdentifiers: Boolean = false + override val storesMixedCaseQuotedIdentifiers: Boolean = false + override val storesUpperCaseQuotedIdentifiers: Boolean = false + override val storesLowerCaseQuotedIdentifiers: Boolean = false + override val sqlKeywords: Set = emptySet() + override val maxColumnNameLength: Int = 0 + } 1 -> dialects[0] - else -> error("More than one dialect implementations found in the classpath, " + - "please choose one manually, they are: $dialects") + else -> error( + "More than one dialect implementations found in the classpath, " + + "please choose one manually, they are: $dialects" + ) } } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Aggregation.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Aggregation.kt index f7b8a0a..30f06c6 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Aggregation.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Aggregation.kt @@ -18,7 +18,10 @@ package org.ktorm.r2dbc.dsl import org.ktorm.r2dbc.expression.AggregateExpression import org.ktorm.r2dbc.expression.AggregateType -import org.ktorm.r2dbc.schema.* +import org.ktorm.r2dbc.schema.ColumnDeclaring +import org.ktorm.r2dbc.schema.IntSqlType +import org.ktorm.r2dbc.schema.LongSqlType +import org.ktorm.r2dbc.schema.SimpleSqlType /** * The min function, translated to `min(column)` in SQL. @@ -79,13 +82,13 @@ public fun sumDistinct(column: ColumnDeclaring): AggregateExpres /** * The count function, translated to `count(column)` in SQL. */ -public fun count(column: ColumnDeclaring<*>? = null): AggregateExpression { - return AggregateExpression(AggregateType.COUNT, column?.asExpression(), false, SimpleSqlType(Int::class)) +public fun count(column: ColumnDeclaring<*>? = null): AggregateExpression { + return AggregateExpression(AggregateType.COUNT, column?.asExpression(), false, LongSqlType) } /** * The count function with distinct, translated to `count(distinct column)` in SQL. */ public fun countDistinct(column: ColumnDeclaring<*>? = null): AggregateExpression { - return AggregateExpression(AggregateType.COUNT, column?.asExpression(), true, SimpleSqlType(Int::class)) + return AggregateExpression(AggregateType.COUNT, column?.asExpression(), true, IntSqlType) } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Dml.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Dml.kt new file mode 100644 index 0000000..986fdc6 --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Dml.kt @@ -0,0 +1,387 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.dsl + +import io.r2dbc.spi.Statement +import org.ktorm.r2dbc.database.Database +import org.ktorm.r2dbc.expression.* +import org.ktorm.r2dbc.schema.BaseTable +import org.ktorm.r2dbc.schema.Column +import org.ktorm.r2dbc.schema.ColumnDeclaring +import org.ktorm.r2dbc.schema.defaultValue +import java.lang.reflect.InvocationHandler +import java.lang.reflect.Proxy + +/** + * Construct an update expression in the given closure, then execute it and return the effected row count. + * + * Usage: + * + * ```kotlin + * database.update(Employees) { + * set(it.job, "engineer") + * set(it.managerId, null) + * set(it.salary, 100) + * where { + * it.id eq 2 + * } + * } + * ``` + * + * @since 2.7 + * @param table the table to be updated. + * @param block the DSL block, an extension function of [UpdateStatementBuilder], used to construct the expression. + * @return the effected row count. + */ +public suspend fun > Database.update(table: T, block: UpdateStatementBuilder.(T) -> Unit): Int { + val builder = UpdateStatementBuilder().apply { block(table) } + + val expression = AliasRemover.visit( + UpdateExpression(table.asExpression(), builder.assignments, builder.where?.asExpression()) + ) + + return executeUpdate(expression) +} + +/** + * Construct update expressions in the given closure, then batch execute them and return the effected + * row counts for each expression. + * + * Note that this function is implemented based on [Statement.addBatch] and [Statement.executeBatch], + * and any item in a batch operation must have the same structure, otherwise an exception will be thrown. + * + * Usage: + * + * ```kotlin + * database.batchUpdate(Departments) { + * for (i in 1..2) { + * item { + * set(it.location, "Hong Kong") + * where { + * it.id eq i + * } + * } + * } + * } + * ``` + * + * @since 2.7 + * @param table the table to be updated. + * @param block the DSL block, extension function of [BatchUpdateStatementBuilder], used to construct the expressions. + * @return the effected row counts for each sub-operation. + */ +public suspend fun > Database.batchUpdate( + table: T, + block: BatchUpdateStatementBuilder.() -> Unit +): IntArray { + val builder = BatchUpdateStatementBuilder(table).apply(block) + val expressions = builder.expressions.map { AliasRemover.visit(it) } + + if (expressions.isEmpty()) { + return IntArray(0) + } else { + return executeBatch(expressions) + } +} + +/** + * Construct an insert expression in the given closure, then execute it and return the effected row count. + * + * Usage: + * + * ```kotlin + * database.insert(Employees) { + * set(it.name, "jerry") + * set(it.job, "trainee") + * set(it.managerId, 1) + * set(it.hireDate, LocalDate.now()) + * set(it.salary, 50) + * set(it.departmentId, 1) + * } + * ``` + * + * @since 2.7 + * @param table the table to be inserted. + * @param block the DSL block, an extension function of [AssignmentsBuilder], used to construct the expression. + * @return the effected row count. + */ +public suspend fun > Database.insert(table: T, block: AssignmentsBuilder.(T) -> Unit): Int { + val builder = AssignmentsBuilder().apply { block(table) } + val expression = AliasRemover.visit(InsertExpression(table.asExpression(), builder.assignments)) + return executeUpdate(expression) +} + +/** + * Construct an insert expression in the given closure, then execute it and return the auto-generated key. + * + * This function assumes that at least one auto-generated key will be returned, and that the first key in + * the result set will be the primary key for the row. + * + * Usage: + * + * ```kotlin + * val id = database.insertAndGenerateKey(Employees) { + * set(it.name, "jerry") + * set(it.job, "trainee") + * set(it.managerId, 1) + * set(it.hireDate, LocalDate.now()) + * set(it.salary, 50) + * set(it.departmentId, 1) + * } + * ``` + * + * @since 2.7 + * @param table the table to be inserted. + * @param block the DSL block, an extension function of [AssignmentsBuilder], used to construct the expression. + * @return the first auto-generated key. + */ +/* +TODO +public fun > Database.insertAndGenerateKey(table: T, block: AssignmentsBuilder.(T) -> Unit): Any { + val builder = AssignmentsBuilder().apply { block(table) } + val expression = AliasRemover.visit(InsertExpression(table.asExpression(), builder.assignments)) + val (_, rowSet) = executeUpdateAndRetrieveKeys(expression) + + if (rowSet.next()) { + val pk = table.singlePrimaryKey { "Key retrieval is not supported for compound primary keys." } + val generatedKey = pk.sqlType.getResult(rowSet, 1) ?: error("Generated key is null.") + + if (logger.isDebugEnabled()) { + logger.debug("Generated Key: $generatedKey") + } + + return generatedKey + } else { + error("No generated key returns by database.") + } +}*/ + +/** + * Construct insert expressions in the given closure, then batch execute them and return the effected + * row counts for each expression. + * + * Note that this function is implemented based on [Statement.add] and [Statement.executeBatch], + * and any item in a batch operation must have the same structure, otherwise an exception will be thrown. + * + * Usage: + * + * ```kotlin + * database.batchInsert(Employees) { + * item { + * set(it.name, "jerry") + * set(it.job, "trainee") + * set(it.managerId, 1) + * set(it.hireDate, LocalDate.now()) + * set(it.salary, 50) + * set(it.departmentId, 1) + * } + * item { + * set(it.name, "linda") + * set(it.job, "assistant") + * set(it.managerId, 3) + * set(it.hireDate, LocalDate.now()) + * set(it.salary, 100) + * set(it.departmentId, 2) + * } + * } + * ``` + * + * @since 2.7 + * @param table the table to be inserted. + * @param block the DSL block, extension function of [BatchInsertStatementBuilder], used to construct the expressions. + * @return the effected row counts for each sub-operation. + */ +public suspend fun > Database.batchInsert( + table: T, + block: BatchInsertStatementBuilder.() -> Unit +): IntArray { + val builder = BatchInsertStatementBuilder(table).apply(block) + val expressions = builder.expressions.map { AliasRemover.visit(it) } + + if (expressions.isEmpty()) { + return IntArray(0) + } else { + return executeBatch(expressions) + } +} + +/** + * Insert the current [Query]'s results into the given table, useful when transfer data from a table to another table. + */ +public suspend fun Query.insertTo(table: BaseTable<*>, vararg columns: Column<*>): Int { + val expression = InsertFromQueryExpression( + table = table.asExpression(), + columns = columns.map { it.asExpression() }, + query = this.expression + ) + + return database.executeUpdate(expression) +} + +/** + * Delete the records in the [table] that matches the given [predicate]. + * + * @since 2.7 + */ +public suspend fun > Database.delete(table: T, predicate: (T) -> ColumnDeclaring): Int { + val expression = AliasRemover.visit(DeleteExpression(table.asExpression(), predicate(table).asExpression())) + return executeUpdate(expression) +} + +/** + * Delete all the records in the table. + * + * @since 2.7 + */ +public suspend fun Database.deleteAll(table: BaseTable<*>): Int { + val expression = AliasRemover.visit(DeleteExpression(table.asExpression(), where = null)) + return executeUpdate(expression) +} + +/** + * Marker annotation for Ktorm DSL builder classes. + */ +@DslMarker +public annotation class KtormDsl + +/** + * Base class of DSL builders, provide basic functions used to build assignments for insert or update DSL. + */ +@KtormDsl +public open class AssignmentsBuilder { + @Suppress("VariableNaming") + protected val _assignments: ArrayList> = ArrayList() + + /** + * A getter that returns the readonly view of the built assignments list. + */ + internal val assignments: List> get() = _assignments + + /** + * Assign the specific column's value to another column or an expression's result. + * + * @since 3.1.0 + */ + public fun set(column: Column, expr: ColumnDeclaring) { + _assignments += ColumnAssignmentExpression(column.asExpression(), expr.asExpression()) + } + + /** + * Assign the specific column to a value. + * + * @since 3.1.0 + */ + public fun set(column: Column, value: C?) { + _assignments += ColumnAssignmentExpression(column.asExpression(), column.wrapArgument(value)) + } + + private fun Column.checkAssignableFrom(value: Any?) { + if (value == null) return + + val handler = InvocationHandler { _, method, _ -> + // Do nothing... + @Suppress("ForbiddenVoid") + if (method.returnType == Void.TYPE || !method.returnType.isPrimitive) { + null + } else { + method.returnType.defaultValue + } + } + + + val proxy = Proxy.newProxyInstance(javaClass.classLoader, arrayOf(Statement::class.java), handler) + + try { + sqlType.bindParameter(proxy as Statement, 1, value) + } catch (e: ClassCastException) { + throw IllegalArgumentException("Argument type doesn't match the column's type, column: $this", e) + } + } +} + +/** + * DSL builder for update statements. + */ +@KtormDsl +public class UpdateStatementBuilder : AssignmentsBuilder() { + internal var where: ColumnDeclaring? = null + + /** + * Specify the where clause for this update statement. + */ + public fun where(block: () -> ColumnDeclaring) { + this.where = block() + } +} + +/** + * DSL builder for batch update statements. + */ +@KtormDsl +public class BatchUpdateStatementBuilder>(internal val table: T) { + internal val expressions = ArrayList() + + /** + * Add an update statement to the current batch operation. + */ + public fun item(block: UpdateStatementBuilder.(T) -> Unit) { + val builder = UpdateStatementBuilder() + builder.block(table) + + expressions += UpdateExpression(table.asExpression(), builder.assignments, builder.where?.asExpression()) + } +} + +/** + * DSL builder for batch insert statements. + */ +@KtormDsl +public class BatchInsertStatementBuilder>(internal val table: T) { + internal val expressions = ArrayList() + + /** + * Add an insert statement to the current batch operation. + */ + public fun item(block: AssignmentsBuilder.(T) -> Unit) { + val builder = AssignmentsBuilder() + builder.block(table) + + expressions += InsertExpression(table.asExpression(), builder.assignments) + } +} + +/** + * [SqlExpressionVisitor] implementation used to removed table aliases, used by Ktorm internal. + */ +internal object AliasRemover : SqlExpressionVisitor() { + + override fun visitTable(expr: TableExpression): TableExpression { + if (expr.tableAlias == null) { + return expr + } else { + return expr.copy(tableAlias = null) + } + } + + override fun visitColumn(expr: ColumnExpression): ColumnExpression { + if (expr.table == null) { + return expr + } else { + return expr.copy(table = null) + } + } +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Operators.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Operators.kt new file mode 100644 index 0000000..04f77c9 --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Operators.kt @@ -0,0 +1,546 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.dsl + +import org.ktorm.r2dbc.expression.* +import org.ktorm.r2dbc.schema.* + +// ---- Unary operators... ---- + +/** + * Check if the current column or expression is null, translated to `is null` in SQL. + */ +public fun ColumnDeclaring<*>.isNull(): UnaryExpression { + return UnaryExpression(UnaryExpressionType.IS_NULL, asExpression(), BooleanSqlType) +} + +/** + * Check if the current column or expression is not null, translated to `is not null` in SQL. + */ +public fun ColumnDeclaring<*>.isNotNull(): UnaryExpression { + return UnaryExpression(UnaryExpressionType.IS_NOT_NULL, asExpression(), BooleanSqlType) +} + +/** + * Unary minus operator, translated to `-` in SQL. + */ +public operator fun ColumnDeclaring.unaryMinus(): UnaryExpression { + return UnaryExpression(UnaryExpressionType.UNARY_MINUS, asExpression(), sqlType) +} + +/** + * Unary plus operator, translated to `+` in SQL. + */ +public operator fun ColumnDeclaring.unaryPlus(): UnaryExpression { + return UnaryExpression(UnaryExpressionType.UNARY_PLUS, asExpression(), sqlType) +} + +/** + * Negative operator, translated to the `not` keyword in SQL. + */ +public operator fun ColumnDeclaring.not(): UnaryExpression { + return UnaryExpression(UnaryExpressionType.NOT, asExpression(), BooleanSqlType) +} + +// ---- Plus(+) ---- + +/** + * Plus operator, translated to `+` in SQL. + */ +public infix operator fun ColumnDeclaring.plus(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.PLUS, asExpression(), expr.asExpression(), sqlType) +} + +/** + * Plus operator, translated to `+` in SQL. + */ +public infix operator fun ColumnDeclaring.plus(value: T): BinaryExpression { + return this + wrapArgument(value) +} + +/** + * Plus operator, translated to `+` in SQL. + */ +public infix operator fun T.plus(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) + expr +} + +// ------- Minus(-) ----------- + +/** + * Minus operator, translated to `-` in SQL. + */ +public infix operator fun ColumnDeclaring.minus(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.MINUS, asExpression(), expr.asExpression(), sqlType) +} + +/** + * Minus operator, translated to `-` in SQL. + */ +public infix operator fun ColumnDeclaring.minus(value: T): BinaryExpression { + return this - wrapArgument(value) +} + +/** + * Minus operator, translated to `-` in SQL. + */ +public infix operator fun T.minus(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) - expr +} + +// -------- Times(*) ----------- + +/** + * Multiply operator, translated to `*` in SQL. + */ +public infix operator fun ColumnDeclaring.times(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.TIMES, asExpression(), expr.asExpression(), sqlType) +} + +/** + * Multiply operator, translated to `*` in SQL. + */ +public infix operator fun ColumnDeclaring.times(value: T): BinaryExpression { + return this * wrapArgument(value) +} + +/** + * Multiply operator, translated to `*` in SQL. + */ +public infix operator fun T.times(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) * expr +} + +// -------- Div(/) ---------- + +/** + * Divide operator, translated to `/` in SQL. + */ +public infix operator fun ColumnDeclaring.div(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.DIV, asExpression(), expr.asExpression(), sqlType) +} + +/** + * Divide operator, translated to `/` in SQL. + */ +public infix operator fun ColumnDeclaring.div(value: T): BinaryExpression { + return this / wrapArgument(value) +} + +/** + * Divide operator, translated to `/` in SQL. + */ +public infix operator fun T.div(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) / expr +} + +// -------- Rem(%) ---------- + +/** + * Mod operator, translated to `%` in SQL. + */ +public infix operator fun ColumnDeclaring.rem(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.REM, asExpression(), expr.asExpression(), sqlType) +} + +/** + * Mod operator, translated to `%` in SQL. + */ +public infix operator fun ColumnDeclaring.rem(value: T): BinaryExpression { + return this % wrapArgument(value) +} + +/** + * Mod operator, translated to `%` in SQL. + */ +public infix operator fun T.rem(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) % expr +} + +// -------- Like ---------- + +/** + * Like operator, translated to the `like` keyword in SQL. + */ +public infix fun ColumnDeclaring<*>.like(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.LIKE, asExpression(), expr.asExpression(), BooleanSqlType) +} + +/** + * Like operator, translated to the `like` keyword in SQL. + */ +public infix fun ColumnDeclaring<*>.like(value: String): BinaryExpression { + return this like ArgumentExpression(value, VarcharSqlType) +} + +/** + * Not like operator, translated to the `not like` keyword in SQL. + */ +public infix fun ColumnDeclaring<*>.notLike(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.NOT_LIKE, asExpression(), expr.asExpression(), BooleanSqlType) +} + +/** + * Not like operator, translated to the `not like` keyword in SQL. + */ +public infix fun ColumnDeclaring<*>.notLike(value: String): BinaryExpression { + return this notLike ArgumentExpression(value, VarcharSqlType) +} + +// --------- And ------------ + +/** + * And operator, translated to the `and` keyword in SQL. + */ +public infix fun ColumnDeclaring.and(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.AND, asExpression(), expr.asExpression(), BooleanSqlType) +} + +/** + * And operator, translated to the `and` keyword in SQL. + */ +public infix fun ColumnDeclaring.and(value: Boolean): BinaryExpression { + return this and wrapArgument(value) +} + +/** + * And operator, translated to the `and` keyword in SQL. + */ +public infix fun Boolean.and(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) and expr +} + +// --------- Or ---------- + +/** + * Or operator, translated to the `or` keyword in SQL. + */ +public infix fun ColumnDeclaring.or(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.OR, asExpression(), expr.asExpression(), BooleanSqlType) +} + +/** + * Or operator, translated to the `or` keyword in SQL. + */ +public infix fun ColumnDeclaring.or(value: Boolean): BinaryExpression { + return this or wrapArgument(value) +} + +/** + * Or operator, translated to the `or` keyword in SQL. + */ +public infix fun Boolean.or(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) or expr +} + +// -------- Xor --------- + +/** + * Xor operator, translated to the `xor` keyword in SQL. + */ +public infix fun ColumnDeclaring.xor(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.XOR, asExpression(), expr.asExpression(), BooleanSqlType) +} + +/** + * Xor operator, translated to the `xor` keyword in SQL. + */ +public infix fun ColumnDeclaring.xor(value: Boolean): BinaryExpression { + return this xor wrapArgument(value) +} + +/** + * Xor operator, translated to the `xor` keyword in SQL. + */ +public infix fun Boolean.xor(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) xor expr +} + +// ------- Less -------- + +/** + * Less operator, translated to `<` in SQL. + */ +public infix fun > ColumnDeclaring.less(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.LESS_THAN, asExpression(), expr.asExpression(), BooleanSqlType) +} + +/** + * Less operator, translated to `<` in SQL. + */ +public infix fun > ColumnDeclaring.less(value: T): BinaryExpression { + return this less wrapArgument(value) +} + +/** + * Less operator, translated to `<` in SQL. + */ +public infix fun > T.less(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) less expr +} + +// ------- LessEq --------- + +/** + * Less-eq operator, translated to `<=` in SQL. + */ +public infix fun > ColumnDeclaring.lessEq(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression( + type = BinaryExpressionType.LESS_THAN_OR_EQUAL, + left = asExpression(), + right = expr.asExpression(), + sqlType = BooleanSqlType + ) +} + +/** + * Less-eq operator, translated to `<=` in SQL. + */ +public infix fun > ColumnDeclaring.lessEq(value: T): BinaryExpression { + return this lessEq wrapArgument(value) +} + +/** + * Less-eq operator, translated to `<=` in SQL. + */ +public infix fun > T.lessEq(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) lessEq expr +} + +// ------- Greater --------- + +/** + * Greater operator, translated to `>` in SQL. + */ +public infix fun > ColumnDeclaring.greater(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.GREATER_THAN, asExpression(), expr.asExpression(), BooleanSqlType) +} + +/** + * Greater operator, translated to `>` in SQL. + */ +public infix fun > ColumnDeclaring.greater(value: T): BinaryExpression { + return this greater wrapArgument(value) +} + +/** + * Greater operator, translated to `>` in SQL. + */ +public infix fun > T.greater(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) greater expr +} + +// -------- GreaterEq --------- + +/** + * Greater-eq operator, translated to `>=` in SQL. + */ +public infix fun > ColumnDeclaring.greaterEq(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression( + type = BinaryExpressionType.GREATER_THAN_OR_EQUAL, + left = asExpression(), + right = expr.asExpression(), + sqlType = BooleanSqlType + ) +} + +/** + * Greater-eq operator, translated to `>=` in SQL. + */ +public infix fun > ColumnDeclaring.greaterEq(value: T): BinaryExpression { + return this greaterEq wrapArgument(value) +} + +/** + * Greater-eq operator, translated to `>=` in SQL. + */ +public infix fun > T.greaterEq(expr: ColumnDeclaring): BinaryExpression { + return expr.wrapArgument(this) greaterEq expr +} + +// -------- Eq --------- + +/** + * Equal operator, translated to `=` in SQL. + */ +public infix fun ColumnDeclaring.eq(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.EQUAL, asExpression(), expr.asExpression(), BooleanSqlType) +} + +/** + * Equal operator, translated to `=` in SQL. + */ +public infix fun ColumnDeclaring.eq(value: T): BinaryExpression { + return this eq wrapArgument(value) +} + +// infix fun T.eq(expr: ColumnDeclaring): BinaryExpression { +// return expr.wrapArgument(this) eq expr +// } + +// ------- NotEq ------- + +/** + * Not-equal operator, translated to `<>` in SQL. + */ +public infix fun ColumnDeclaring.notEq(expr: ColumnDeclaring): BinaryExpression { + return BinaryExpression(BinaryExpressionType.NOT_EQUAL, asExpression(), expr.asExpression(), BooleanSqlType) +} + +/** + * Not-equal operator, translated to `<>` in SQL. + */ +public infix fun ColumnDeclaring.notEq(value: T): BinaryExpression { + return this notEq wrapArgument(value) +} + +// infix fun T.notEq(expr: ColumnDeclaring): BinaryExpression { +// return expr.wrapArgument(this) notEq expr +// } + +// ---- Between ---- + +/** + * Between operator, translated to `between .. and ..` in SQL. + */ +public infix fun > ColumnDeclaring.between(range: ClosedRange): BetweenExpression { + return BetweenExpression(asExpression(), wrapArgument(range.start), wrapArgument(range.endInclusive)) +} + +/** + * Not-between operator, translated to `not between .. and ..` in SQL. + */ +public infix fun > ColumnDeclaring.notBetween(range: ClosedRange): BetweenExpression { + return BetweenExpression( + expression = asExpression(), + lower = wrapArgument(range.start), + upper = wrapArgument(range.endInclusive), + notBetween = true + ) +} + +// ----- InList ------ + +/** + * In-list operator, translated to the `in` keyword in SQL. + */ +public fun ColumnDeclaring.inList(vararg list: T): InListExpression { + return InListExpression(left = asExpression(), values = list.map { wrapArgument(it) }) +} + +/** + * In-list operator, translated to the `in` keyword in SQL. + */ +public infix fun ColumnDeclaring.inList(list: Collection): InListExpression { + return InListExpression(left = asExpression(), values = list.map { wrapArgument(it) }) +} + +/** + * In-list operator, translated to the `in` keyword in SQL. + */ +public infix fun ColumnDeclaring.inList(query: Query): InListExpression { + return InListExpression(left = asExpression(), query = query.expression) +} + +/** + * Not-in-list operator, translated to the `not in` keyword in SQL. + */ +public fun ColumnDeclaring.notInList(vararg list: T): InListExpression { + return InListExpression(left = asExpression(), values = list.map { wrapArgument(it) }, notInList = true) +} + +/** + * Not-in-list operator, translated to the `not in` keyword in SQL. + */ +public infix fun ColumnDeclaring.notInList(list: Collection): InListExpression { + return InListExpression(left = asExpression(), values = list.map { wrapArgument(it) }, notInList = true) +} + +/** + * Not-in-list operator, translated to the `not in` keyword in SQL. + */ +public infix fun ColumnDeclaring.notInList(query: Query): InListExpression { + return InListExpression(left = asExpression(), query = query.expression, notInList = true) +} + +// ---- Exists ------ + +/** + * Check if the given query has at least one result, translated to the `exists` keyword in SQL. + */ +public fun exists(query: Query): ExistsExpression { + return ExistsExpression(query.expression) +} + +/** + * Check if the given query doesn't have any results, translated to the `not exists` keyword in SQL. + */ +public fun notExists(query: Query): ExistsExpression { + return ExistsExpression(query.expression, notExists = true) +} + +// ---- Type casting... ---- + +/** + * Cast the current column or expression's type to [Double]. + */ +public fun ColumnDeclaring.toDouble(): CastingExpression { + return this.cast(DoubleSqlType) +} + +/** + * Cast the current column or expression's type to [Float]. + */ +public fun ColumnDeclaring.toFloat(): CastingExpression { + return this.cast(FloatSqlType) +} + +/** + * Cast the current column or expression's type to [Int]. + */ +public fun ColumnDeclaring.toInt(): CastingExpression { + return this.cast(IntSqlType) +} + +/** + * Cast the current column or expression's type to [Short]. + */ +public fun ColumnDeclaring.toShort(): CastingExpression { + return this.cast(ShortSqlType) +} + +/** + * Cast the current column or expression's type to [Long]. + */ +public fun ColumnDeclaring.toLong(): CastingExpression { + return this.cast(LongSqlType) +} + +/** + * Cast the current column or expression's type to [Int]. + */ +@JvmName("booleanToInt") +public fun ColumnDeclaring.toInt(): CastingExpression { + return this.cast(IntSqlType) +} + +/** + * Cast the current column or expression to the given [SqlType]. + */ +public fun ColumnDeclaring<*>.cast(sqlType: SqlType): CastingExpression { + return CastingExpression(asExpression(), sqlType) +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt new file mode 100644 index 0000000..d1d414a --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt @@ -0,0 +1,788 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.dsl + +import org.ktorm.r2dbc.database.Database +import org.ktorm.r2dbc.expression.* +import org.ktorm.r2dbc.schema.BooleanSqlType +import org.ktorm.r2dbc.schema.Column +import org.ktorm.r2dbc.schema.ColumnDeclaring +import java.sql.ResultSet + +/** + * [Query] is an abstraction of query operations and the core class of Ktorm's query DSL. + * + * The constructor of this class accepts two parameters: [database] is the database instance that this query + * is running on; [expression] is the abstract representation of the executing SQL statement. Usually, we don't + * use the constructor to create [Query] objects but use the `database.from(..).select(..)` syntax instead. + * + * [Query] provides a built-in [iterator], so we can iterate the results by a for-each loop: + * + * ```kotlin + * for (row in database.from(Employees).select()) { + * println(row[Employees.name]) + * } + * ``` + * + * Moreover, there are many extension functions that can help us easily process the query results, such as + * [Query.map], [Query.flatMap], [Query.associate], [Query.fold], etc. With the help of these functions, we can + * obtain rows from a query just like it's a common Kotlin collection. + * + * Query objects are immutable. Query DSL functions are provided as its extension functions normally. We can + * chaining call these functions to modify them and create new query objects. Here is a simple example: + * + * ```kotlin + * val query = database + * .from(Employees) + * .select(Employees.salary) + * .where { (Employees.departmentId eq 1) and (Employees.name like "%vince%") } + * ``` + * + * Easy to know that the query obtains the salary of an employee named vince in department 1. The generated + * SQL is obviously: + * + * ```sql + * select t_employee.salary as t_employee_salary + * from t_employee + * where (t_employee.department_id = ?) and (t_employee.name like ?) + * ``` + * + * More usages can be found in the documentations of those DSL functions. + * + * @property database the [Database] instance that this query is running on. + * @property expression the underlying SQL expression of this query object. + */ +public class Query(public val database: Database, public val expression: QueryExpression) { + + /** + * The executable SQL string of this query. + * + * Useful when we want to ensure if the generated SQL is expected while debugging. + */ + public val sql: String by lazy(LazyThreadSafetyMode.NONE) { + database.formatExpression(expression, beautifySql = true).first + } + + public suspend fun doQuery(expression: QueryExpression = this.expression): List { + return database.executeQuery(expression).map { QueryRow(this@Query, it) } + } + + /** + * The [ResultSet] object of this query, lazy initialized after first access, obtained from the database by + * executing the generated SQL. + * + * Note that the return type of this property is not a normal [ResultSet], but a [QueryRow] instead. That's + * a special implementation provided by Ktorm, different from normal result sets, it is available offline and + * overrides the indexed access operator. More details can be found in the documentation of [QueryRow]. + */ + /* public val rowSet: QueryRow by lazy(LazyThreadSafetyMode.NONE) { + QueryRow(this, database.executeQuery(expression)) + }*/ + + /** + * The total record count of this query ignoring the pagination params. + * + * If the query doesn't limits the results via [Query.limit] function, return the size of the result set. Or if + * it does, return the total record count of the query ignoring the offset and limit parameters. This property + * is provided to support pagination, we can calculate the page count through dividing it by our page size. + */ + public suspend fun totalRecords(): Int { + return if (expression.offset == null && expression.limit == null) { + doQuery().size + } else { + val countExpr = expression.toCountExpression() + val count = doQuery(countExpr) + .map { it.get(0, Int::class.java) } + .firstOrNull() + val (sql, _) = database.formatExpression(countExpr, beautifySql = true) + count ?: throw IllegalStateException("No result return for sql: $sql") + } + } + + + /** + * Return a copy of this [Query] with the [expression] modified. + */ + public fun withExpression(expression: QueryExpression): Query { + return Query(database, expression) + } + + /** + * Return an iterator over the rows of this query. + * + * Note that this function is simply implemented as `rowSet.iterator()`, so every element returned by the iterator + * exactly shares the same instance as the [rowSet] property. + * + * @see rowSet + * @see ResultSet.iterator + */ + public suspend operator fun iterator(): Iterator { + return doQuery().iterator() + } +} + +/** + * Create a query object, selecting the specific columns or expressions from this [QuerySource]. + * + * Note that the specific columns can be empty, that means `select *` in SQL. + * + * @since 2.7 + */ +public fun QuerySource.select(columns: Collection>): Query { + val declarations = columns.map { it.asDeclaringExpression() } + return Query(database, SelectExpression(columns = declarations, from = expression)) +} + +/** + * Create a query object, selecting the specific columns or expressions from this [QuerySource]. + * + * Note that the specific columns can be empty, that means `select *` in SQL. + * + * @since 2.7 + */ +public fun QuerySource.select(vararg columns: ColumnDeclaring<*>): Query { + return select(columns.asList()) +} + +/** + * Create a query object, selecting the specific columns or expressions from this [QuerySource] distinctly. + * + * Note that the specific columns can be empty, that means `select distinct *` in SQL. + * + * @since 2.7 + */ +public fun QuerySource.selectDistinct(columns: Collection>): Query { + val declarations = columns.map { it.asDeclaringExpression() } + return Query(database, SelectExpression(columns = declarations, from = expression, isDistinct = true)) +} + +/** + * Create a query object, selecting the specific columns or expressions from this [QuerySource] distinctly. + * + * Note that the specific columns can be empty, that means `select distinct *` in SQL. + * + * @since 2.7 + */ +public fun QuerySource.selectDistinct(vararg columns: ColumnDeclaring<*>): Query { + return selectDistinct(columns.asList()) +} + +/** + * Wrap this expression as a [ColumnDeclaringExpression]. + */ +internal fun ColumnDeclaring.asDeclaringExpression(): ColumnDeclaringExpression { + return when (this) { + is ColumnDeclaringExpression -> this + is Column -> this.aliased(label) + else -> this.aliased(null) + } +} + +/** + * Specify the `where` clause of this query using the given condition expression. + */ +public fun Query.where(condition: ColumnDeclaring): Query { + return this.withExpression( + when (expression) { + is SelectExpression -> expression.copy(where = condition.asExpression()) + is UnionExpression -> throw IllegalStateException("Where clause is not supported in a union expression.") + } + ) +} + +/** + * Specify the `where` clause of this query using the expression returned by the given callback function. + */ +public inline fun Query.where(condition: () -> ColumnDeclaring): Query { + return where(condition()) +} + +/** + * Create a mutable list, then add filter conditions to the list in the given callback function, finally combine + * them with the [and] operator and set the combined condition as the `where` clause of this query. + * + * Note that if we don't add any conditions to the list, the `where` clause would not be set. + */ +public inline fun Query.whereWithConditions(block: (MutableList>) -> Unit): Query { + val conditions = ArrayList>().apply(block) + + if (conditions.isEmpty()) { + return this + } else { + return this.where { conditions.reduce { a, b -> a and b } } + } +} + +/** + * Create a mutable list, then add filter conditions to the list in the given callback function, finally combine + * them with the [or] operator and set the combined condition as the `where` clause of this query. + * + * Note that if we don't add any conditions to the list, the `where` clause would not be set. + */ +public inline fun Query.whereWithOrConditions(block: (MutableList>) -> Unit): Query { + val conditions = ArrayList>().apply(block) + + if (conditions.isEmpty()) { + return this + } else { + return this.where { conditions.reduce { a, b -> a or b } } + } +} + +/** + * Combine this iterable of boolean expressions with the [and] operator. + * + * If the iterable is empty, the param [ifEmpty] will be returned. + */ +public fun Iterable>.combineConditions(ifEmpty: Boolean = true): ColumnDeclaring { + return this.reduceOrNull { a, b -> a and b } ?: ArgumentExpression(ifEmpty, BooleanSqlType) +} + +/** + * Specify the `group by` clause of this query using the given columns or expressions. + */ +public fun Query.groupBy(columns: Collection>): Query { + return this.withExpression( + when (expression) { + is SelectExpression -> expression.copy(groupBy = columns.map { it.asExpression() }) + is UnionExpression -> throw IllegalStateException("Group by clause is not supported in a union expression.") + } + ) +} + +/** + * Specify the `group by` clause of this query using the given columns or expressions. + */ +public fun Query.groupBy(vararg columns: ColumnDeclaring<*>): Query { + return groupBy(columns.asList()) +} + +/** + * Specify the `having` clause of this query using the given condition expression. + */ +public fun Query.having(condition: ColumnDeclaring): Query { + return this.withExpression( + when (expression) { + is SelectExpression -> expression.copy(having = condition.asExpression()) + is UnionExpression -> throw IllegalStateException("Having clause is not supported in a union expression.") + } + ) +} + +/** + * Specify the `having` clause of this query using the expression returned by the given callback function. + */ +public inline fun Query.having(condition: () -> ColumnDeclaring): Query { + return having(condition()) +} + +/** + * Specify the `order by` clause of this query using the given order-by expressions. + */ +public fun Query.orderBy(orders: Collection): Query { + return this.withExpression( + when (expression) { + is SelectExpression -> expression.copy(orderBy = orders.toList()) + is UnionExpression -> { + val replacer = OrderByReplacer(expression) + expression.copy(orderBy = orders.map { replacer.visit(it) as OrderByExpression }) + } + } + ) +} + +/** + * Specify the `order by` clause of this query using the given order-by expressions. + */ +public fun Query.orderBy(vararg orders: OrderByExpression): Query { + return orderBy(orders.asList()) +} + +private class OrderByReplacer(query: UnionExpression) : SqlExpressionVisitor() { + val declaringColumns = query.findDeclaringColumns() + + override fun visitOrderBy(expr: OrderByExpression): OrderByExpression { + val declaring = declaringColumns.find { it.declaredName != null && it.expression == expr.expression } + + if (declaring == null) { + throw IllegalArgumentException("Could not find the ordering column in the union expression, column: $expr") + } else { + return OrderByExpression( + expression = ColumnExpression( + table = null, + name = declaring.declaredName!!, + sqlType = declaring.expression.sqlType + ), + orderType = expr.orderType + ) + } + } +} + +internal tailrec fun QueryExpression.findDeclaringColumns(): List> { + return when (this) { + is SelectExpression -> columns + is UnionExpression -> left.findDeclaringColumns() + } +} + +/** + * Order this column or expression in ascending order. + */ +public fun ColumnDeclaring<*>.asc(): OrderByExpression { + return OrderByExpression(asExpression(), OrderType.ASCENDING) +} + +/** + * Order this column or expression in descending order, corresponding to the `desc` keyword in SQL. + */ +public fun ColumnDeclaring<*>.desc(): OrderByExpression { + return OrderByExpression(asExpression(), OrderType.DESCENDING) +} + +/** + * Specify the pagination offset parameter of this query. + * + * This function requires a dialect enabled, different SQLs will be generated with different dialects. + * + * Note that if the number isn't positive then it will be ignored. + */ +public fun Query.offset(n: Int): Query { + return limit(offset = n, limit = null) +} + +/** + * Specify the pagination limit parameter of this query. + * + * This function requires a dialect enabled, different SQLs will be generated with different dialects. + * + * Note that if the number isn't positive then it will be ignored. + */ +public fun Query.limit(n: Int): Query { + return limit(offset = null, limit = n) +} + +/** + * Specify the pagination parameters of this query. + * + * This function requires a dialect enabled, different SQLs will be generated with different dialects. For example, + * `limit ?, ?` by MySQL, `limit m offset n` by PostgreSQL. + * + * Note that if the numbers aren't positive, they will be ignored. + */ +public fun Query.limit(offset: Int?, limit: Int?): Query { + return this.withExpression( + when (expression) { + is SelectExpression -> expression.copy( + offset = offset?.takeIf { it > 0 } ?: expression.offset, + limit = limit?.takeIf { it > 0 } ?: expression.limit + ) + is UnionExpression -> expression.copy( + offset = offset?.takeIf { it > 0 } ?: expression.offset, + limit = limit?.takeIf { it > 0 } ?: expression.limit + ) + } + ) +} + +/** + * Union this query with the given one, corresponding to the `union` keyword in SQL. + */ +public fun Query.union(right: Query): Query { + return this.withExpression(UnionExpression(left = expression, right = right.expression, isUnionAll = false)) +} + +/** + * Union this query with the given one, corresponding to the `union all` keyword in SQL. + */ +public fun Query.unionAll(right: Query): Query { + return this.withExpression(UnionExpression(left = expression, right = right.expression, isUnionAll = true)) +} + +/** + * Wrap this query as [Iterable]. + * + * @since 3.0.0 + */ +public suspend fun Query.asIterable(): Iterable { + val iterator = iterator() + return Iterable { iterator } +} + +/** + * Perform the given [action] on each row of the query. + * + * @since 3.0.0 + */ +public suspend inline fun Query.forEach(action: (row: QueryRow) -> Unit) { + for (row in this) action(row) +} + +/** + * Perform the given [action] on each row of the query, providing sequential index with the row. + * + * The [action] function takes the index of a row and the row itself and performs the desired action on the row. + * + * @since 3.0.0 + */ +public suspend inline fun Query.forEachIndexed(action: (index: Int, row: QueryRow) -> Unit) { + var index = 0 + for (row in this) action(index++, row) +} + +/** + * Return a lazy [Iterable] that wraps each row of the query into an [IndexedValue] containing the index of + * that row and the row itself. + * + * @since 3.0.0 + */ + +public suspend fun Query.withIndex(): Iterable> { + val iterator = IndexingIterator(iterator()) + return Iterable { iterator } +} + +/** + * Iterator transforming original [iterator] into iterator of [IndexedValue], counting index from zero. + */ + +@Suppress("IteratorNotThrowingNoSuchElementException") +internal class IndexingIterator(private val iterator: Iterator) : Iterator> { + private var index = 0 + + override fun hasNext(): Boolean { + return iterator.hasNext() + } + + override fun next(): IndexedValue { + return IndexedValue(index++, iterator.next()) + } +} + +/** + * Return a list containing the results of applying the given [transform] function to each row of the query. + * + * @since 3.0.0 + */ + +public suspend inline fun Query.map(transform: (row: QueryRow) -> R): List { + return mapTo(ArrayList(), transform) +} + +/** + * Apply the given [transform] function to each row of the query and append the results to the given [destination]. + * + * @since 3.0.0 + */ + +public suspend inline fun > Query.mapTo( + destination: C, + transform: (row: QueryRow) -> R +): C { + for (row in this) destination += transform(row) + return destination +} + +/** + * Return a list containing only the non-null results of applying the given [transform] function to each row of + * the query. + * + * @since 3.0.0 + */ + +public suspend inline fun Query.mapNotNull(transform: (row: QueryRow) -> R?): List { + return mapNotNullTo(ArrayList(), transform) +} + +/** + * Apply the given [transform] function to each row of the query and append only the non-null results to + * the given [destination]. + * + * @since 3.0.0 + */ + +public suspend inline fun > Query.mapNotNullTo( + destination: C, + transform: (row: QueryRow) -> R? +): C { + forEach { row -> transform(row)?.let { destination += it } } + return destination +} + +/** + * Return a list containing the results of applying the given [transform] function to each row and its index. + * + * The [transform] function takes the index of a row and the row itself and returns the result of the transform + * applied to the row. + * + * @since 3.0.0 + */ + +public suspend inline fun Query.mapIndexed(transform: (index: Int, row: QueryRow) -> R): List { + return mapIndexedTo(ArrayList(), transform) +} + +/** + * Apply the given [transform] function the each row and its index and append the results to the given [destination]. + * + * The [transform] function takes the index of a row and the row itself and returns the result of the transform + * applied to the row. + * + * @since 3.0.0 + */ + +public suspend inline fun > Query.mapIndexedTo( + destination: C, + transform: (index: Int, row: QueryRow) -> R +): C { + var index = 0 + return mapTo(destination) { row -> transform(index++, row) } +} + +/** + * Return a list containing only the non-null results of applying the given [transform] function to each row + * and its index. + * + * The [transform] function takes the index of a row and the row itself and returns the result of the transform + * applied to the row. + * + * @since 3.0.0 + */ + +public suspend inline fun Query.mapIndexedNotNull(transform: (index: Int, row: QueryRow) -> R?): List { + return mapIndexedNotNullTo(ArrayList(), transform) +} + +/** + * Apply the given [transform] function the each row and its index and append only the non-null results to + * the given [destination]. + * + * The [transform] function takes the index of a row and the row itself and returns the result of the transform + * applied to the row. + * + * @since 3.0.0 + */ + +public suspend inline fun > Query.mapIndexedNotNullTo( + destination: C, + transform: (index: Int, row: QueryRow) -> R? +): C { + forEachIndexed { index, row -> transform(index, row)?.let { destination += it } } + return destination +} + +/** + * Return a single list of all elements yielded from results of [transform] function being invoked on each row + * of the query. + * + * @since 3.0.0 + */ + +public suspend inline fun Query.flatMap(transform: (row: QueryRow) -> Iterable): List { + return flatMapTo(ArrayList(), transform) +} + +/** + * Append all elements yielded from results of [transform] function being invoked on each row of the query, + * to the given [destination]. + * + * @since 3.0.0 + */ + +public suspend inline fun > Query.flatMapTo( + destination: C, + transform: (row: QueryRow) -> Iterable +): C { + for (row in this) destination += transform(row) + return destination +} + +/** + * Return a single list of all elements yielded from results of [transform] function being invoked on each row + * and its index in the query. + * + * @since 3.1.0 + */ + +public suspend inline fun Query.flatMapIndexed(transform: (index: Int, row: QueryRow) -> Iterable): List { + return flatMapIndexedTo(ArrayList(), transform) +} + +/** + * Append all elements yielded from results of [transform] function being invoked on each row and its index + * in the query, to the given [destination]. + * + * @since 3.1.0 + */ + +public suspend inline fun > Query.flatMapIndexedTo( + destination: C, + transform: (index: Int, row: QueryRow) -> Iterable +): C { + var index = 0 + return flatMapTo(destination) { transform(index++, it) } +} + +/** + * Return a [Map] containing key-value pairs provided by [transform] function applied to rows of the query. + * + * If any of two pairs would have the same key the last one gets added to the map. + * + * The returned map preserves the entry iteration order of the original query results. + * + * @since 3.0.0 + */ + +public suspend inline fun Query.associate(transform: (row: QueryRow) -> Pair): Map { + return associateTo(LinkedHashMap(), transform) +} + +/** + * Return a [Map] containing the values provided by [valueTransform] and indexed by [keySelector] functions applied to + * rows of the query. + * + * If any two rows would have the same key returned by [keySelector] the last one gets added to the map. + * + * The returned map preserves the entry iteration order of the original query results. + * + * @since 3.0.0 + */ + +public suspend inline fun Query.associateBy( + keySelector: (row: QueryRow) -> K, + valueTransform: (row: QueryRow) -> V +): Map { + return associateByTo(LinkedHashMap(), keySelector, valueTransform) +} + +/** + * Populate and return the [destination] mutable map with key-value pairs provided by [transform] function applied to + * each row of the query. + * + * If any of two pairs would have the same key the last one gets added to the map. + * + * @since 3.0.0 + */ + +public suspend inline fun > Query.associateTo( + destination: M, + transform: (row: QueryRow) -> Pair +): M { + for (row in this) destination += transform(row) + return destination +} + +/** + * Populate and return the [destination] mutable map with key-value pairs, where key is provided by the [keySelector] + * function and value is provided by the [valueTransform] function applied to rows of the query. + * + * If any two rows would have the same key returned by [keySelector] the last one gets added to the map. + * + * @since 3.0.0 + */ + +public suspend inline fun > Query.associateByTo( + destination: M, + keySelector: (row: QueryRow) -> K, + valueTransform: (row: QueryRow) -> V +): M { + for (row in this) destination.put(keySelector(row), valueTransform(row)) + return destination +} + +/** + * Accumulate value starting with [initial] value and applying [operation] to current accumulator value and each row. + * + * @since 3.0.0 + */ + +public suspend inline fun Query.fold(initial: R, operation: (acc: R, row: QueryRow) -> R): R { + var accumulator = initial + for (row in this) accumulator = operation(accumulator, row) + return accumulator +} + +/** + * Accumulate value starting with [initial] value and applying [operation] to current accumulator value and each row + * with its index in the original query results. + * + * The [operation] function takes the index of a row, current accumulator value and the row itself, + * and calculates the next accumulator value. + * + * @since 3.0.0 + */ + +public suspend inline fun Query.foldIndexed(initial: R, operation: (index: Int, acc: R, row: QueryRow) -> R): R { + var index = 0 + var accumulator = initial + for (row in this) accumulator = operation(index++, accumulator, row) + return accumulator +} + +/** + * Append the string from all rows separated using [separator] and using the given [prefix] and [postfix] if supplied. + * + * If the query result could be huge, you can specify a non-negative value of [limit], in which case only the first + * [limit] rows will be appended, followed by the [truncated] string (which defaults to "..."). + * + * @since 3.0.0 + */ + +public suspend fun Query.joinTo( + buffer: A, + separator: CharSequence = ", ", + prefix: CharSequence = "", + postfix: CharSequence = "", + limit: Int = -1, + truncated: CharSequence = "...", + transform: (row: QueryRow) -> CharSequence +): A { + buffer.append(prefix) + var count = 0 + for (row in this) { + if (++count > 1) buffer.append(separator) + if (limit < 0 || count <= limit) { + buffer.append(transform(row)) + } else { + buffer.append(truncated) + break + } + } + buffer.append(postfix) + return buffer +} + +/** + * Create a string from all rows separated using [separator] and using the given [prefix] and [postfix] if supplied. + * + * If the query result could be huge, you can specify a non-negative value of [limit], in which case only the first + * [limit] rows will be appended, followed by the [truncated] string (which defaults to "..."). + * + * @since 3.0.0 + */ + +public suspend fun Query.joinToString( + separator: CharSequence = ", ", + prefix: CharSequence = "", + postfix: CharSequence = "", + limit: Int = -1, + truncated: CharSequence = "...", + transform: (row: QueryRow) -> CharSequence +): String { + return joinTo(StringBuilder(), separator, prefix, postfix, limit, truncated, transform).toString() +} + diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt new file mode 100644 index 0000000..ee70b1f --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt @@ -0,0 +1,62 @@ +package org.ktorm.r2dbc.dsl + +import io.r2dbc.spi.Row +import org.ktorm.r2dbc.expression.ColumnDeclaringExpression +import org.ktorm.r2dbc.schema.Column + +public class QueryRow internal constructor(public val query: Query, private val row: Row) : Row by row { + + public operator fun get(column: ColumnDeclaringExpression, columnClass: Class): C? { + if (column.declaredName.isNullOrBlank()) { + throw IllegalArgumentException("Label of the specified column cannot be null or blank.") + } + val metadata = row.metadata + for (index in metadata.columnMetadatas.indices) { + if (metadata.getColumnMetadata(index).name eq column.declaredName) { + return row.get(index, columnClass) + } + } + return null + } + + /** + * Obtain the value of the specific [Column] instance. + * + * Note that if the column doesn't exist in the result set, this function will return null rather than + * throwing an exception. + */ + public operator fun get(column: Column): C? { + val metadata = row.metadata + if (query.expression.findDeclaringColumns().isNotEmpty()) { + // Try to find the column by label. + for (index in metadata.columnMetadatas.indices) { + if (metadata.getColumnMetadata(index).name eq column.label) { + return column.sqlType.getResult(row,metadata,index) + } + } + // Return null if the column doesn't exist in the result set. + return null + } else { + // Try to find the column by name and its table name (happens when we are using `select *`). + val indices = metadata.columnMetadatas.indices.filter { index -> + /*val tableName = metadata.getTableName(index) + val tableNameMatched = tableName.isBlank() || tableName eq table.alias || tableName eq table.tableName + val columnName = metaData.getColumnName(index)*/ + metadata.columnMetadatas[index].name eq column.name/* && tableNameMatched*/ + } + + return when (indices.size) { + 0 -> null // Return null if the column doesn't exist in the result set. + 1 -> return column.sqlType.getResult(row,metadata,indices.first()) + else -> throw IllegalArgumentException(warningConfusedColumnName(column.name)) + } + } + } + + + private infix fun String?.eq(other: String?) = this.equals(other, ignoreCase = true) + + private fun warningConfusedColumnName(name: String): String { + return "Confused column name, there are more than one column named '$name' in query: \n\n${query.sql}\n" + } +} \ No newline at end of file diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QuerySource.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QuerySource.kt new file mode 100644 index 0000000..61eb7b7 --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QuerySource.kt @@ -0,0 +1,145 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.dsl + +import org.ktorm.r2dbc.database.Database +import org.ktorm.r2dbc.expression.* +import org.ktorm.r2dbc.schema.BaseTable +import org.ktorm.r2dbc.schema.BooleanSqlType +import org.ktorm.r2dbc.schema.ColumnDeclaring +import org.ktorm.r2dbc.schema.ReferenceBinding + + +/** + * Represents a query source, used in the `from` clause of a query. + * + * @since 2.7 + * @property database the [Database] instance that the query is running on. + * @property sourceTable the origin source table. + * @property expression the underlying SQL expression. + */ +public data class QuerySource( + val database: Database, + val sourceTable: BaseTable<*>, + val expression: QuerySourceExpression +) + +/** + * Wrap the specific table as a [QuerySource]. + * + * @since 2.7 + */ +public fun Database.from(table: BaseTable<*>): QuerySource { + return QuerySource(this, table, table.asExpression()) +} + +/** + * Join the right table and return a new [QuerySource], translated to `cross join` in SQL. + */ +public fun QuerySource.crossJoin(right: BaseTable<*>, on: ColumnDeclaring? = null): QuerySource { + return this.copy( + expression = JoinExpression( + type = JoinType.CROSS_JOIN, + left = expression, + right = right.asExpression(), + condition = on?.asExpression() + ) + ) +} + +/** + * Join the right table and return a new [QuerySource], translated to `inner join` in SQL. + */ +public fun QuerySource.innerJoin(right: BaseTable<*>, on: ColumnDeclaring? = null): QuerySource { + return this.copy( + expression = JoinExpression( + type = JoinType.INNER_JOIN, + left = expression, + right = right.asExpression(), + condition = on?.asExpression() + ) + ) +} + +/** + * Join the right table and return a new [QuerySource], translated to `left join` in SQL. + */ +public fun QuerySource.leftJoin(right: BaseTable<*>, on: ColumnDeclaring? = null): QuerySource { + return this.copy( + expression = JoinExpression( + type = JoinType.LEFT_JOIN, + left = expression, + right = right.asExpression(), + condition = on?.asExpression() + ) + ) +} + +/** + * Join the right table and return a new [QuerySource], translated to `right join` in SQL. + */ +public fun QuerySource.rightJoin(right: BaseTable<*>, on: ColumnDeclaring? = null): QuerySource { + return this.copy( + expression = JoinExpression( + type = JoinType.RIGHT_JOIN, + left = expression, + right = right.asExpression(), + condition = on?.asExpression() + ) + ) +} + +/** + * Return a new-created [Query] object, left joining all the reference tables, and selecting all columns of them. + */ +public fun QuerySource.joinReferencesAndSelect(): Query { + val joinedTables = ArrayList>() + + return sourceTable + .joinReferences(this, joinedTables) + .select(joinedTables.flatMap { it.columns }) +} + +private fun BaseTable<*>.joinReferences( + querySource: QuerySource, + joinedTables: MutableList> +): QuerySource { + + var curr = querySource + + joinedTables += this + + for (column in columns) { + for (binding in column.allBindings) { + if (binding is ReferenceBinding) { + val refTable = binding.referenceTable + val pk = refTable.singlePrimaryKey { + "Cannot reference the table '$refTable' as there is compound primary keys." + } + + curr = curr.leftJoin(refTable, on = column eq pk) + curr = refTable.joinReferences(curr, joinedTables) + } + } + } + + return curr +} + +private infix fun ColumnDeclaring<*>.eq(column: ColumnDeclaring<*>): BinaryExpression { + return BinaryExpression(BinaryExpressionType.EQUAL, asExpression(), column.asExpression(), BooleanSqlType) +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt new file mode 100644 index 0000000..bfc251d --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt @@ -0,0 +1,292 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.entity + +import org.ktorm.r2dbc.database.Database +import org.ktorm.r2dbc.schema.Table +import org.ktorm.schema.TypeReference +import java.io.ObjectInputStream +import java.io.ObjectOutputStream +import java.io.Serializable +import java.lang.reflect.Proxy +import java.sql.SQLException +import kotlin.reflect.KClass +import kotlin.reflect.full.isSubclassOf +import kotlin.reflect.jvm.jvmErasure + +/** + * The super interface of all entity classes in Ktorm. This interface injects many useful functions into entities. + * + * Ktorm requires us to define entity classes as interfaces extending from this interface. A simple example is + * given as follows: + * + * ```kotlin + * interface Department : Entity { + * val id: Int + * var name: String + * var location: String + * } + * ``` + * + * ### Creating Entity Objects + * + * As everyone knows, interfaces cannot be instantiated, so Ktorm provides a [Entity.create] function for us to + * create entity objects. This function will generate implementations for entity interfaces via JDK dynamic proxy + * and create their instances. + * + * In case you don't like creating objects by [Entity.create], Ktorm also provides an abstract factory class + * [Entity.Factory]. This class overloads the `invoke` operator of Kotlin, so we just need to add a companion + * object to our entity class extending from [Entity.Factory], then entity objects can be created just like there + * is a constructor: `val department = Department()`. + * + * ### Getting and Setting Properties + * + * Entity objects in Ktorm are proxies, that's why Ktorm can intercept all the invocations on entities and listen + * the status changes of them. Behind those entity objects, there is a value table that holds all the values of the + * properties for each entity. Any operation of getting or setting a property is actually operating the underlying + * value table. However, what if the value doesn't exist while we are getting a property? Ktorm defines a set of + * rules for this situation: + * + * - If the value doesn’t exist and the property’s type is nullable (eg. `var name: String?`), then we’ll return null. + * - If the value doesn’t exist and the property’s type is not nullable (eg. `var name: String`), then we can not + * return null anymore, because the null value here can cause an unexpected null pointer exception, we’ll return the + * type’s default value instead. + * + * The default values of different types are well-defined: + * + * - For [Boolean] type, the default value is `false`. + * - For [Char] type, the default value is `\u0000`. + * - For number types (such as [Int], [Long], [Double], etc), the default value is zero. + * - For [String] type, the default value is an empty string. + * - For entity types, the default value is a new-created entity object which is empty. + * - For enum types, the default value is the first value of the enum, whose ordinal is 0. + * - For array types, the default value is a new-created empty array. + * - For collection types (such as [Set], [List], [Map], etc), the default value is a new created mutable collection + * of the concrete type. + * - For any other types, the default value is an instance created by its no-args constructor. If the constructor + * doesn’t exist, an exception is thrown. + * + * Moreover, there is a cache mechanism for default values, that ensures a property always returns the same default + * value instance even if it’s called twice or more. This can avoid some counterintuitive bugs. + * + * ### Non-abstract members + * + * If we are using domain driven design, then entities are not only data containers that hold property values, there + * are also some behaviors, so we need to add some business functions to our entities. Fortunately, Kotlin allows us + * to define non-abstract functions in interfaces, that’s why we don’t lose anything even if Ktorm’s entity classes + * are all interfaces. Here is an example: + * + * ```kotlin + * interface Foo : Entity { + * companion object : Entity.Factory() + * val name: String + * fun printName() { + * println(name) + * } + * } + * ``` + * + * Then if we call `Foo().printName()`, the value of the property `name` will be printed. + * + * Besides of non-abstract functions, Kotlin also allows us to define properties with custom getters or setters in + * interfaces. For example, in the following code, if we call `Foo().upperName`, then the value of the `name` property + * will be returned in upper case: + * + * ```kotlin + * interface Foo : Entity { + * companion object : Entity.Factory() + * val name: String + * val upperName get() = name.toUpperCase() + * } + * ``` + * + * More details can be found in our website: https://www.ktorm.org/en/entities-and-column-binding.html#More-About-Entities + * + * ### Serialization + * + * The [Entity] interface extends from [Serializable], so all entity objects are serializable by default. We can save + * them to our disks, or transfer them between systems through networks. + * + * Note that Ktorm only saves entities’ property values when serialization, any other data that used to track entity + * status are lost (marked as transient). So we can not obtain an entity object from one system, then flush its changes + * into the database in another system. + * + * Java uses [ObjectOutputStream] to serialize objects, and uses [ObjectInputStream] to deserialize them, you can + * refer to their documentation for more details. + * + * Besides of JDK serialization, the ktorm-jackson module also supports serializing entities in JSON format. This + * module provides an extension for Jackson, the famous JSON framework in Java word. It supports serializing entity + * objects into JSON format and parsing JSONs as entity objects. More details can be found in its documentation. + */ +public interface Entity> : Serializable { + + /** + * Return this entity's [KClass] instance, which must be an interface. + */ + public val entityClass: KClass + + /** + * Return the immutable view of this entity's all properties. + */ + public val properties: Map + + /** + * Update the property changes of this entity into the database and return the affected record number. + * + * Using this function, we need to note that: + * + * 1. This function requires a primary key specified in the table object via [Table.primaryKey], + * otherwise Ktorm doesn’t know how to identify entity objects and will throw an exception. + * + * 2. The entity object calling this function must be ATTACHED to the database first. In Ktorm’s implementation, + * every entity object holds a reference `fromDatabase`. For entity objects obtained by sequence APIs, their + * `fromDatabase` references point to the database they are obtained from. For entity objects created by + * [Entity.create] or [Entity.Factory], their `fromDatabase` references are `null` initially, so we can not call + * [flushChanges] on them. But once we use them with [add] or [update] function, `fromDatabase` will be modified + * to the current database, so we will be able to call [flushChanges] on them afterwards. + * + * @see add + * @see update + */ + @Throws(SQLException::class) + public fun flushChanges(): Int + + /** + * Clear the tracked property changes of this entity. + * + * After calling this function, the [flushChanges] doesn't do anything anymore because the property changes + * are discarded. + */ + public fun discardChanges() + + /** + * Delete this entity in the database and return the affected record number. + * + * Similar to [flushChanges], we need to note that: + * + * 1. The function requires a primary key specified in the table object via [Table.primaryKey], + * otherwise, Ktorm doesn’t know how to identify entity objects. + * + * 2. The entity object calling this function must be ATTACHED to the database first. + * + * @see add + * @see update + * @see flushChanges + */ + @Throws(SQLException::class) + public fun delete(): Int + + /** + * Obtain a property's value by its name. + * + * Note that this function doesn't follows the rules of default values discussed in the class level documentation. + * If the value doesn't exist, we will return `null` simply. + */ + public operator fun get(name: String): Any? + + /** + * Modify a property's value by its name. + */ + public operator fun set(name: String, value: Any?) + + /** + * Return a deep copy of this entity, which has the same property values and tracked statuses. + */ + public fun copy(): E + + /** + * Indicate whether some other object is "equal to" this entity. + * Two entities are equal only if they have the same [entityClass] and [properties]. + * + * @since 3.4.0 + */ + public override fun equals(other: Any?): Boolean + + /** + * Return a hash code value for this entity. + * + * @since 3.4.0 + */ + public override fun hashCode(): Int + + /** + * Return a string representation of this table. + * The format is like `Employee{id=1, name=Eric, job=contributor, hireDate=2021-05-05, salary=50}`. + */ + public override fun toString(): String + + /** + * Companion object provides functions to create entity instances. + */ + public companion object { + + /** + * Create an entity object. This functions is used by Ktorm internal. + */ + internal fun create( + entityClass: KClass<*>, + parent: EntityImplementation? = null, + fromDatabase: Database? = parent?.fromDatabase, + fromTable: Table<*>? = parent?.fromTable + ): Entity<*> { + if (!entityClass.isSubclassOf(Entity::class)) { + throw IllegalArgumentException("An entity class must be subclass of Entity.") + } + if (!entityClass.java.isInterface) { + throw IllegalArgumentException("An entity class must be defined as an interface.") + } + + val handler = EntityImplementation(entityClass, fromDatabase, fromTable, parent) + return Proxy.newProxyInstance(entityClass.java.classLoader, arrayOf(entityClass.java), handler) as Entity<*> + } + + /** + * Create an entity object by JDK dynamic proxy. + */ + public fun create(entityClass: KClass<*>): Entity<*> { + return create(entityClass, null, null, null) + } + + /** + * Create an entity object by JDK dynamic proxy. + */ + public inline fun > create(): E { + return create(E::class) as E + } + } + + /** + * Abstract factory used to create entity objects, typically declared as companion objects of entity classes. + */ + public abstract class Factory> : TypeReference() { + + /** + * Overload the `invoke` operator, creating an entity object just like there is a constructor. + */ + @Suppress("UNCHECKED_CAST") + public operator fun invoke(): E { + return create(referencedKotlinType.jvmErasure) as E + } + + /** + * Overload the `invoke` operator, creating an entity object and call the [init] function. + */ + public inline operator fun invoke(init: E.() -> Unit): E { + return invoke().apply(init) + } + } +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityExtensions.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityExtensions.kt new file mode 100644 index 0000000..cb3ceea --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityExtensions.kt @@ -0,0 +1,197 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.entity + +import org.ktorm.r2dbc.schema.ColumnBinding +import org.ktorm.r2dbc.schema.NestedBinding +import org.ktorm.r2dbc.schema.ReferenceBinding +import org.ktorm.r2dbc.schema.Table +import java.lang.reflect.Proxy +import java.util.* +import kotlin.reflect.jvm.jvmErasure + +internal fun EntityImplementation.hasPrimaryKeyValue(fromTable: Table<*>): Boolean { + val pk = fromTable.singlePrimaryKey { "Table '$fromTable' has compound primary keys." } + if (pk.binding == null) { + error("Primary column $pk has no bindings to any entity field.") + } else { + return hasColumnValue(pk.binding) + } +} + +internal fun EntityImplementation.hasColumnValue(binding: ColumnBinding): Boolean { + when (binding) { + is ReferenceBinding -> { + if (!this.hasProperty(binding.onProperty)) { + return false + } + + val child = this.getProperty(binding.onProperty) as Entity<*>? + if (child == null) { + // null is also a legal column value. + return true + } + + return child.implementation.hasPrimaryKeyValue(binding.referenceTable as Table<*>) + } + is NestedBinding -> { + var curr: EntityImplementation = this + + for ((i, prop) in binding.properties.withIndex()) { + if (i != binding.properties.lastIndex) { + if (!curr.hasProperty(prop)) { + return false + } + + val child = curr.getProperty(prop) as Entity<*>? + if (child == null) { + // null is also a legal column value. + return true + } + + curr = child.implementation + } + } + + return curr.hasProperty(binding.properties.last()) + } + } +} + +internal fun EntityImplementation.getPrimaryKeyValue(fromTable: Table<*>): Any? { + val pk = fromTable.singlePrimaryKey { "Table '$fromTable' has compound primary keys." } + if (pk.binding == null) { + error("Primary column $pk has no bindings to any entity field.") + } else { + return getColumnValue(pk.binding) + } +} + +internal fun EntityImplementation.getColumnValue(binding: ColumnBinding): Any? { + when (binding) { + is ReferenceBinding -> { + val child = this.getProperty(binding.onProperty) as Entity<*>? + return child?.implementation?.getPrimaryKeyValue(binding.referenceTable as Table<*>) + } + is NestedBinding -> { + var curr: EntityImplementation? = this + for ((i, prop) in binding.properties.withIndex()) { + if (i != binding.properties.lastIndex) { + val child = curr?.getProperty(prop) as Entity<*>? + curr = child?.implementation + } + } + return curr?.getProperty(binding.properties.last()) + } + } +} + +internal fun EntityImplementation.setPrimaryKeyValue( + fromTable: Table<*>, + value: Any?, + forceSet: Boolean = false, + useExtraBindings: Boolean = false +) { + val pk = fromTable.singlePrimaryKey { "Table '$fromTable' has compound primary keys." } + if (pk.binding == null) { + error("Primary column $pk has no bindings to any entity field.") + } else { + setColumnValue(pk.binding, value, forceSet) + } + + if (useExtraBindings) { + for (extraBinding in pk.extraBindings) { + setColumnValue(extraBinding, value, forceSet) + } + } +} + +internal fun EntityImplementation.setColumnValue(binding: ColumnBinding, value: Any?, forceSet: Boolean = false) { + when (binding) { + is ReferenceBinding -> { + var child = this.getProperty(binding.onProperty) as Entity<*>? + if (child == null) { + child = Entity.create( + entityClass = binding.onProperty.returnType.jvmErasure, + fromDatabase = this.fromDatabase, + fromTable = binding.referenceTable as Table<*> + ) + this.setProperty(binding.onProperty, child, forceSet) + } + + val refTable = binding.referenceTable as Table<*> + child.implementation.setPrimaryKeyValue(refTable, value, forceSet, useExtraBindings = true) + } + is NestedBinding -> { + var curr: EntityImplementation = this + for ((i, prop) in binding.properties.withIndex()) { + if (i != binding.properties.lastIndex) { + var child = curr.getProperty(prop) as Entity<*>? + if (child == null) { + child = Entity.create(prop.returnType.jvmErasure, parent = curr) + curr.setProperty(prop, child, forceSet) + } + + curr = child.implementation + } + } + + curr.setProperty(binding.properties.last(), value, forceSet) + } + } +} + +internal fun EntityImplementation.isPrimaryKey(name: String): Boolean { + for (pk in this.fromTable?.primaryKeys.orEmpty()) { + when (pk.binding) { + is ReferenceBinding -> { + if (parent == null && pk.binding.onProperty.name == name) { + return true + } + } + is NestedBinding -> { + val namesPath = LinkedList>() + namesPath.addFirst(setOf(name)) + + var curr: EntityImplementation = this + while (true) { + val parent = curr.parent ?: break + val children = parent.values.filterValues { it == curr } + + if (children.isEmpty()) { + break + } else { + namesPath.addFirst(children.keys) + curr = parent + } + } + + if (namesPath.withIndex().all { (i, names) -> pk.binding.properties[i].name in names }) { + return true + } + } + null -> return false + } + } + + return false +} + +internal val Entity<*>.implementation: EntityImplementation + get() { + return Proxy.getInvocationHandler(this) as EntityImplementation +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt new file mode 100644 index 0000000..d66d918 --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt @@ -0,0 +1,289 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.entity + +import org.ktorm.r2dbc.database.Database +import org.ktorm.r2dbc.schema.Table +import org.ktorm.r2dbc.schema.defaultValue +import org.ktorm.r2dbc.schema.kotlinProperty +import java.io.* +import java.lang.reflect.InvocationHandler +import java.lang.reflect.InvocationTargetException +import java.lang.reflect.Method +import java.util.* +import kotlin.collections.LinkedHashMap +import kotlin.collections.LinkedHashSet +import kotlin.reflect.KClass +import kotlin.reflect.KProperty1 +import kotlin.reflect.jvm.javaGetter +import kotlin.reflect.jvm.jvmErasure +import kotlin.reflect.jvm.jvmName +import kotlin.reflect.jvm.kotlinFunction + +internal class EntityImplementation( + var entityClass: KClass<*>, + @Transient var fromDatabase: Database?, + @Transient var fromTable: Table<*>?, + @Transient var parent: EntityImplementation? +) : InvocationHandler, Serializable { + + var values = LinkedHashMap() + @Transient var changedProperties = LinkedHashSet() + + companion object { + private const val serialVersionUID = 1L + private val defaultImplsCache: MutableMap = Collections.synchronizedMap(WeakHashMap()) + } + + override fun invoke(proxy: Any, method: Method, args: Array?): Any? { + return when (method.declaringClass.kotlin) { + Any::class -> { + when (method.name) { + "equals" -> this == args!![0] + "hashCode" -> this.hashCode() + "toString" -> this.toString() + else -> throw IllegalStateException("Unrecognized method: $method") + } + } + Entity::class -> { + when (method.name) { + "getEntityClass" -> this.entityClass + "getProperties" -> Collections.unmodifiableMap(this.values) + /* "flushChanges" -> this.doFlushChanges() + "discardChanges" -> this.doDiscardChanges() + "delete" -> this.doDelete()*/ + "get" -> this.values[args!![0] as String] + "set" -> this.doSetProperty(args!![0] as String, args[1]) + "copy" -> this.copy() + else -> throw IllegalStateException("Unrecognized method: $method") + } + } + else -> { + handleMethodCall(proxy, method, args) + } + } + } + + private fun handleMethodCall(proxy: Any, method: Method, args: Array?): Any? { + val ktProp = method.kotlinProperty + if (ktProp != null) { + val (prop, isGetter) = ktProp + if (prop.isAbstract) { + if (isGetter) { + val result = this.getProperty(prop, unboxInlineValues = true) + if (result != null || prop.returnType.isMarkedNullable) { + return result + } else { + return prop.defaultValue.also { cacheDefaultValue(prop, it) } + } + } else { + this.setProperty(prop, args!![0]) + return null + } + } else { + return callDefaultImpl(proxy, method, args) + } + } else { + val func = method.kotlinFunction + if (func != null && !func.isAbstract) { + return callDefaultImpl(proxy, method, args) + } else { + throw IllegalStateException("Unrecognized method: $method") + } + } + } + + private val KProperty1<*, *>.defaultValue: Any get() { + try { + return javaGetter!!.returnType.defaultValue + } catch (e: Throwable) { + val msg = "" + + "The value of non-null property [$this] doesn't exist, " + + "an error occurred while trying to create a default one. " + + "Please ensure its value exists, or you can mark the return type nullable [${this.returnType}?]" + throw IllegalStateException(msg, e) + } + } + + private fun cacheDefaultValue(prop: KProperty1<*, *>, value: Any) { + val type = prop.javaGetter!!.returnType + + // Skip for primitive types, enums and string, because their default values always share the same instance. + if (type == Boolean::class.javaPrimitiveType) return + if (type == Char::class.javaPrimitiveType) return + if (type == Byte::class.javaPrimitiveType) return + if (type == Short::class.javaPrimitiveType) return + if (type == Int::class.javaPrimitiveType) return + if (type == Long::class.javaPrimitiveType) return + if (type == String::class.java) return + if (type.isEnum) return + + setProperty(prop, value) + } + + @Suppress("SwallowedException") + private fun callDefaultImpl(proxy: Any, method: Method, args: Array?): Any? { + val impl = defaultImplsCache.computeIfAbsent(method) { + val cls = Class.forName(method.declaringClass.name + "\$DefaultImpls") + cls.getMethod(method.name, method.declaringClass, *method.parameterTypes) + } + + try { + if (args == null) { + return impl.invoke(null, proxy) + } else { + return impl.invoke(null, proxy, *args) + } + } catch (e: InvocationTargetException) { + throw e.targetException + } + } + + fun hasProperty(prop: KProperty1<*, *>): Boolean { + return prop.name in values + } + + @OptIn(ExperimentalUnsignedTypes::class) + fun getProperty(prop: KProperty1<*, *>, unboxInlineValues: Boolean = false): Any? { + if (!unboxInlineValues) { + return values[prop.name] + } + + val returnType = prop.javaGetter!!.returnType + val value = values[prop.name] + + // Unbox inline class values if necessary. + // In principle, we need to check for all inline classes, but kotlin-reflect is still unable to determine + // whether a class is inline, so as a workaround, we have to enumerate some common-used types here. + return when { + value is UByte && returnType == Byte::class.javaPrimitiveType -> value.toByte() + value is UShort && returnType == Short::class.javaPrimitiveType -> value.toShort() + value is UInt && returnType == Int::class.javaPrimitiveType -> value.toInt() + value is ULong && returnType == Long::class.javaPrimitiveType -> value.toLong() + value is UByteArray && returnType == ByteArray::class.java -> value.toByteArray() + value is UShortArray && returnType == ShortArray::class.java -> value.toShortArray() + value is UIntArray && returnType == IntArray::class.java -> value.toIntArray() + value is ULongArray && returnType == LongArray::class.java -> value.toLongArray() + else -> value + } + } + + @OptIn(ExperimentalUnsignedTypes::class) + fun setProperty(prop: KProperty1<*, *>, value: Any?, forceSet: Boolean = false) { + val propType = prop.returnType.jvmErasure + + // For inline classes, always box the underlying values as wrapper types. + // In principle, we need to check for all inline classes, but kotlin-reflect is still unable to determine + // whether a class is inline, so as a workaround, we have to enumerate some common-used types here. + val boxedValue = when { + propType == UByte::class && value is Byte -> value.toUByte() + propType == UShort::class && value is Short -> value.toUShort() + propType == UInt::class && value is Int -> value.toUInt() + propType == ULong::class && value is Long -> value.toULong() + propType == UByteArray::class && value is ByteArray -> value.toUByteArray() + propType == UShortArray::class && value is ShortArray -> value.toUShortArray() + propType == UIntArray::class && value is IntArray -> value.toUIntArray() + propType == ULongArray::class && value is LongArray -> value.toULongArray() + else -> value + } + + doSetProperty(prop.name, boxedValue, forceSet) + } + + private fun doSetProperty(name: String, value: Any?, forceSet: Boolean = false) { + if (!forceSet && isPrimaryKey(name) && name in values) { + val msg = "Cannot modify the primary key `$name` because it's already set to ${values[name]}" + throw UnsupportedOperationException(msg) + } + + values[name] = value + changedProperties.add(name) + } + + private fun copy(): Entity<*> { + val entity = Entity.create(entityClass, parent, fromDatabase, fromTable) + entity.implementation.changedProperties.addAll(changedProperties) + + for ((name, value) in values) { + if (value is Entity<*>) { + val valueCopy = value.copy() + + // Keep the parent relationship. + if (valueCopy.implementation.parent == this) { + valueCopy.implementation.parent = entity.implementation + } + + entity.implementation.values[name] = valueCopy + } else { + entity.implementation.values[name] = value?.let { deserialize(serialize(it)) } + } + } + + return entity + } + + private fun serialize(obj: Any): ByteArray { + ByteArrayOutputStream().use { buffer -> + ObjectOutputStream(buffer).use { output -> + output.writeObject(obj) + output.flush() + return buffer.toByteArray() + } + } + } + + private fun deserialize(bytes: ByteArray): Any { + ByteArrayInputStream(bytes).use { buffer -> + ObjectInputStream(buffer).use { input -> + return input.readObject() + } + } + } + + private fun writeObject(output: ObjectOutputStream) { + output.writeUTF(entityClass.jvmName) + output.writeObject(values) + } + + @Suppress("UNCHECKED_CAST") + private fun readObject(input: ObjectInputStream) { + entityClass = Class.forName(input.readUTF()).kotlin + values = input.readObject() as LinkedHashMap + changedProperties = LinkedHashSet() + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + + return when (other) { + is EntityImplementation -> entityClass == other.entityClass && values == other.values + is Entity<*> -> entityClass == other.implementation.entityClass && values == other.implementation.values + else -> false + } + } + + override fun hashCode(): Int { + var result = 1 + result = 31 * result + entityClass.hashCode() + result = 31 * result + values.hashCode() + return result + } + + override fun toString(): String { + return entityClass.simpleName + values + } +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt new file mode 100644 index 0000000..317dd03 --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt @@ -0,0 +1,1503 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.entity + +import org.ktorm.r2dbc.database.Database +import org.ktorm.r2dbc.database.DialectFeatureNotSupportedException +import org.ktorm.r2dbc.dsl.* +import org.ktorm.r2dbc.expression.* +import org.ktorm.r2dbc.schema.BaseTable +import org.ktorm.r2dbc.schema.Column +import org.ktorm.r2dbc.schema.ColumnDeclaring +import java.util.* +import kotlin.experimental.ExperimentalTypeInference +import kotlin.math.min + +/** + * Represents a sequence of entity objects. As the name implies, the style and use pattern of Ktorm's entity sequence + * APIs are highly similar to [kotlin.sequences.Sequence] and the extension functions in Kotlin standard lib, as it + * provides many extension functions with the same names, such as [filter], [map], [reduce], etc. + * + * To create an [EntitySequence], we can use the extension function [sequenceOf]: + * + * ```kotlin + * val sequence = database.sequenceOf(Employees) + * ``` + * + * Now we got a default sequence, which can obtain all employees from the table. Please know that Ktorm doesn't execute + * the query right now. The sequence provides an iterator of type `Iterator`, only when we iterate the + * sequence using the iterator, the query is executed. The following code prints all employees using a for-each loop: + * + * ```kotlin + * for (employee in sequence) { + * println(employee) + * } + * ``` + * + * This class wraps a [Query] object, and it’s iterator exactly wraps the query’s iterator. While an entity sequence is + * iterated, its internal query is executed, and the [entityExtractor] function is applied to create an entity object + * for each row. As for other properties in sequences (such as [sql], [rowSet], [totalRecords], etc), all of them + * delegates the callings to their internal query objects, and their usages are totally the same as the corresponding + * properties in [Query] class. + * + * Most of the entity sequence APIs are provided as extension functions, which can be divided into two groups: + * + * - **Intermediate operations:** these functions don’t execute the internal queries but return new-created sequence + * objects applying some modifications. For example, the [filter] function creates a new sequence object with the filter + * condition given by its parameter. The return types of intermediate operations are usually [EntitySequence], so we + * can chaining call other sequence functions continuously. + * + * - **Terminal operations:** the return types of these functions are usually a collection or a computed result, as + * they execute the queries right now, obtain their results and perform some calculations on them. Eg. [toList], + * [reduce], etc. + * + * For the list of sequence operations available, see the extension functions below. + */ +public class EntitySequence>( + + /** + * The [Database] instance that the internal query is running on. + */ + public val database: Database, + + /** + * The source table from which elements are obtained. + */ + public val sourceTable: T, + + /** + * The SQL expression to be executed by this sequence when obtaining elements. + */ + public val expression: SelectExpression, + + /** + * The function used to extract entity objects for each result row. + */ + public val entityExtractor: (row: QueryRow) -> E +) { + /** + * The internal query of this sequence to be executed, created by [database] and [expression]. + */ + public val query: Query = Query(database, expression) + + /** + * The executable SQL string of the internal query. + * + * This property is delegated to [Query.sql], more details can be found in its documentation. + */ + public val sql: String get() = query.sql + + /** + * The [ResultSet] object of the internal query, lazy initialized after first access, obtained from the database by + * executing the generated SQL. + * + * This property is delegated to [Query.rowSet], more details can be found in its documentation. + */ + public suspend fun getRowSet(): List = query.doQuery() + + /** + * The total records count of this query ignoring the pagination params. + * + * This property is delegated to [Query.totalRecords], more details can be found in its documentation. + */ + public suspend fun totalRecords(): Int = query.totalRecords() + + /** + * Return a copy of this [EntitySequence] with the [expression] modified. + */ + public fun withExpression(expression: SelectExpression): EntitySequence { + return EntitySequence(database, sourceTable, expression, entityExtractor) + } + + /** + * Create a [kotlin.sequences.Sequence] instance that wraps this original entity sequence returning all the + * elements when being iterated. + */ + public suspend fun asKotlinSequence(): Sequence { + val iterator = iterator() + return Sequence { iterator } + } + + /** + * Return an iterator over the elements of this sequence. + */ + @Suppress("IteratorNotThrowingNoSuchElementException") + public suspend operator fun iterator(): Iterator { + val iterator = query.iterator() + return object : Iterator { + override fun hasNext(): Boolean { + return iterator.hasNext() + } + + override fun next(): E { + return entityExtractor(iterator.next()) + } + } + } +} + +/** + * Create an [EntitySequence] from the specific table. + * + * @since 2.7 + */ +public fun > Database.sequenceOf( + table: T, + withReferences: Boolean = true +): EntitySequence { + val query = if (withReferences) from(table).joinReferencesAndSelect() else from(table).select(table.columns) + val entityExtractor = { row: QueryRow -> table.createEntity(row, withReferences) } + return EntitySequence(this, table, query.expression as SelectExpression, entityExtractor) +} + +/** + * Append all elements to the given [destination] collection. + * + * The operation is terminal. + */ +public suspend fun > EntitySequence.toCollection(destination: C): C { + for (element in this) destination += element + return destination +} + +/** + * Return a [List] containing all the elements of this sequence. + * + * The operation is terminal. + */ +public suspend fun EntitySequence.toList(): List { + return toCollection(ArrayList()) +} + +/** + * Return a [MutableList] containing all the elements of this sequence. + * + * The operation is terminal. + */ +public suspend fun EntitySequence.toMutableList(): MutableList { + return toCollection(ArrayList()) +} + +/** + * Return a [Set] containing all the elements of this sequence. + * + * The returned set preserves the element iteration order of the original sequence. + * + * The operation is terminal. + */ +public suspend fun EntitySequence.toSet(): Set { + return toCollection(LinkedHashSet()) +} + +/** + * Return a [MutableSet] containing all the elements of this sequence. + * + * The returned set preserves the element iteration order of the original sequence. + * + * The operation is terminal. + */ +public suspend fun EntitySequence.toMutableSet(): MutableSet { + return toCollection(LinkedHashSet()) +} + +/** + * Return a [HashSet] containing all the elements of this sequence. + * + * The operation is terminal. + */ +public suspend fun EntitySequence.toHashSet(): HashSet { + return toCollection(HashSet()) +} + +/** + * Return a [SortedSet] containing all the elements of this sequence. + * + * The operation is terminal. + */ +public suspend fun EntitySequence.toSortedSet(): SortedSet where E : Any, E : Comparable { + return toCollection(TreeSet()) +} + +/** + * Return a [SortedSet] containing all the elements of this sequence. + * + * Elements in the set returned are sorted according to the given [comparator]. + * + * The operation is terminal. + */ +public suspend fun EntitySequence.toSortedSet( + comparator: Comparator +): SortedSet where E : Any, E : Comparable { + return toCollection(TreeSet(comparator)) +} + +/** + * Return a sequence customizing the selected columns of the internal query. + * + * The operation is intermediate. + */ +public inline fun > EntitySequence.filterColumns( + selector: (T) -> List> +): EntitySequence { + val columns = selector(sourceTable) + if (columns.isEmpty()) { + return this + } else { + return this.withExpression(expression.copy(columns = columns.map { it.aliased(it.label) })) + } +} + +/** + * Return a sequence containing only elements matching the given [predicate]. + * + * The operation is intermediate. + */ +public inline fun > EntitySequence.filter( + predicate: (T) -> ColumnDeclaring +): EntitySequence { + if (expression.where == null) { + return this.withExpression(expression.copy(where = predicate(sourceTable).asExpression())) + } else { + return this.withExpression(expression.copy(where = expression.where and predicate(sourceTable))) + } +} + +/** + * Return a sequence containing only elements not matching the given [predicate]. + * + * The operation is intermediate. + */ +public inline fun > EntitySequence.filterNot( + predicate: (T) -> ColumnDeclaring +): EntitySequence { + return filter { !predicate(it) } +} + +/** + * Append all elements matching the given [predicate] to the given [destination]. + * + * The operation is terminal. + */ +public suspend inline fun , C : MutableCollection> EntitySequence.filterTo( + destination: C, + predicate: (T) -> ColumnDeclaring +): C { + return filter(predicate).toCollection(destination) +} + +/** + * Append all elements not matching the given [predicate] to the given [destination]. + * + * The operation is terminal. + */ +public suspend inline fun , C : MutableCollection> EntitySequence.filterNotTo( + destination: C, + predicate: (T) -> ColumnDeclaring +): C { + return filterNot(predicate).toCollection(destination) +} + +/** + * Return a [List] containing the results of applying the given [transform] function + * to each element in the original sequence. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.map(transform: (E) -> R): List { + return mapTo(ArrayList(), transform) +} + +/** + * Apply the given [transform] function to each element of the original sequence + * and append the results to the given [destination]. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.mapTo( + destination: C, + transform: (E) -> R +): C { + for (element in this) destination += transform(element) + return destination +} + +/** + * Return a [List] containing only the non-null results of applying the given [transform] function + * to each element in the original sequence. + * + * The operation is terminal. + * + * @since 3.0.0 + */ +public suspend inline fun EntitySequence.mapNotNull(transform: (E) -> R?): List { + return mapNotNullTo(ArrayList(), transform) +} + +/** + * Apply the given [transform] function to each element in the original sequence + * and append only the non-null results to the given [destination]. + * + * The operation is terminal. + * + * @since 3.0.0 + */ +public suspend inline fun > EntitySequence.mapNotNullTo( + destination: C, + transform: (E) -> R? +): C { + forEach { element -> transform(element)?.let { destination += it } } + return destination +} + +/** + * Return a [List] containing the results of applying the given [transform] function + * to each element and its index in the original sequence. + * + * The [transform] function takes the index of an element and the element itself and + * returns the result of the transform applied to the element. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.mapIndexed(transform: (index: Int, E) -> R): List { + return mapIndexedTo(ArrayList(), transform) +} + +/** + * Apply the given [transform] function to each element and its index in the original sequence + * and append the results to the given [destination]. + * + * The [transform] function takes the index of an element and the element itself and + * returns the result of the transform applied to the element. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.mapIndexedTo( + destination: C, + transform: (index: Int, E) -> R +): C { + var index = 0 + return mapTo(destination) { transform(index++, it) } +} + +/** + * Return a [List] containing only the non-null results of applying the given [transform] function + * to each element and its index in the original sequence. + * + * The [transform] function takes the index of an element and the element itself and + * returns the result of the transform applied to the element. + * + * The operation is terminal. + * + * @since 3.0.0 + */ +public suspend inline fun EntitySequence.mapIndexedNotNull(transform: (index: Int, E) -> R?): List { + return mapIndexedNotNullTo(ArrayList(), transform) +} + +/** + * Apply the given [transform] function to each element and its index in the original sequence + * and append only the non-null results to the given [destination]. + * + * The [transform] function takes the index of an element and the element itself and + * returns the result of the transform applied to the element. + * + * The operation is terminal. + * + * @since 3.0.0 + */ +public suspend inline fun > EntitySequence.mapIndexedNotNullTo( + destination: C, + transform: (index: Int, E) -> R? +): C { + forEachIndexed { index, element -> transform(index, element)?.let { destination += it } } + return destination +} + +/** + * Return a single list of all elements yielded from results of [transform] function being invoked + * on each element of original sequence. + * + * The operation is terminal. + * + * @since 3.0.0 + */ +public suspend inline fun EntitySequence.flatMap(transform: (E) -> Iterable): List { + return flatMapTo(ArrayList(), transform) +} + +/** + * Append all elements yielded from results of [transform] function being invoked on each element + * of original sequence, to the given [destination]. + * + * The operation is terminal. + * + * @since 3.0.0 + */ +public suspend inline fun > EntitySequence.flatMapTo( + destination: C, + transform: (E) -> Iterable +): C { + for (element in this) destination += transform(element) + return destination +} + +/** + * Return a single list of all elements yielded from results of [transform] function being invoked + * on each element and its index in the original sequence. + * + * The operation is terminal. + * + * @since 3.1.0 + */ +public suspend inline fun EntitySequence.flatMapIndexed(transform: (index: Int, E) -> Iterable): List { + return flatMapIndexedTo(ArrayList(), transform) +} + +/** + * Append all elements yielded from results of [transform] function being invoked on each element + * and its index in the original sequence, to the given [destination]. + * + * The operation is terminal. + * + * @since 3.1.0 + */ +public suspend inline fun > EntitySequence.flatMapIndexedTo( + destination: C, + transform: (index: Int, E) -> Iterable +): C { + var index = 0 + return flatMapTo(destination) { transform(index++, it) } +} + +/** + * Customize the selected columns of the internal query by the given [columnSelector] function, and return a [List] + * containing the query results. + * + * This function is similar to [EntitySequence.map], but the [columnSelector] closure accepts the current table + * object [T] as the parameter, so what we get in the closure by `it` is the table object instead of an entity + * element. Besides, the function’s return type is `ColumnDeclaring`, and we should return a column or expression + * to customize the `select` clause of the generated SQL. + * + * Ktorm also supports selecting two or more columns, we just need to wrap our selected columns by [tupleOf] + * in the closure, then the function’s return type becomes `List>`. + * + * The operation is terminal. + * + * @param isDistinct specify if the query is distinct, the generated SQL becomes `select distinct` if it's set to true. + * @param columnSelector a function in which we should return a column or expression to be selected. + * @return a list of the query results. + */ +@OptIn(ExperimentalTypeInference::class) +@OverloadResolutionByLambdaReturnType +public suspend inline fun , reified C : Any> EntitySequence.mapColumns( + isDistinct: Boolean = false, + columnSelector: (T) -> ColumnDeclaring +): List { + return mapColumnsTo(ArrayList(), isDistinct, columnSelector) +} + +/** + * Customize the selected columns of the internal query by the given [columnSelector] function, and append the query + * results to the given [destination]. + * + * This function is similar to [EntitySequence.mapTo], but the [columnSelector] closure accepts the current table + * object [T] as the parameter, so what we get in the closure by `it` is the table object instead of an entity + * element. Besides, the function’s return type is `ColumnDeclaring`, and we should return a column or expression + * to customize the `select` clause of the generated SQL. + * + * Ktorm also supports selecting two or more columns, we just need to wrap our selected columns by [tupleOf] + * in the closure, then the function’s return type becomes `List>`. + * + * The operation is terminal. + * + * @param destination a [MutableCollection] used to store the results. + * @param isDistinct specify if the query is distinct, the generated SQL becomes `select distinct` if it's set to true. + * @param columnSelector a function in which we should return a column or expression to be selected. + * @return the [destination] collection of the query results. + */ +@OptIn(ExperimentalTypeInference::class) +@OverloadResolutionByLambdaReturnType +public suspend inline fun , reified C, R> EntitySequence.mapColumnsTo( + destination: R, + isDistinct: Boolean = false, + columnSelector: (T) -> ColumnDeclaring +): R where C : Any, R : MutableCollection { + val column = columnSelector(sourceTable) + + val expr = expression.copy( + columns = listOf(column.aliased(null)), + isDistinct = isDistinct + ) + + return Query(database, expr).mapTo(destination) { row -> row[0, C::class.java] } +} + + +/** + * Customize the selected columns of the internal query by the given [columnSelector] function, and return a [List] + * containing the non-null results. + * + * This function is similar to [EntitySequence.mapColumns], but null results are filtered, more details can be found + * in its documentation. + * + * The operation is terminal. + * + * @param isDistinct specify if the query is distinct, the generated SQL becomes `select distinct` if it's set to true. + * @param columnSelector a function in which we should return a column or expression to be selected. + */ +public suspend inline fun , reified C : Any> EntitySequence.mapColumnsNotNull( + isDistinct: Boolean = false, + columnSelector: (T) -> ColumnDeclaring +): List { + return mapColumnsNotNullTo(ArrayList(), isDistinct, columnSelector) +} + +/** + * Customize the selected columns of the internal query by the given [columnSelector] function, and append non-null + * results to the given [destination]. + * + * This function is similar to [EntitySequence.mapColumnsTo], but null results are filtered, more details can be found + * in its documentation. + * + * The operation is terminal. + * + * @param destination a [MutableCollection] used to store the results. + * @param isDistinct specify if the query is distinct, the generated SQL becomes `select distinct` if it's set to true. + * @param columnSelector a function in which we should return a column or expression to be selected. + */ +public suspend inline fun , reified C, R> EntitySequence.mapColumnsNotNullTo( + destination: R, + isDistinct: Boolean = false, + columnSelector: (T) -> ColumnDeclaring +): R where C : Any, R : MutableCollection { + val column = columnSelector(sourceTable) + + val expr = expression.copy( + columns = listOf(column.aliased(null)), + isDistinct = isDistinct + ) + + return Query(database, expr).mapNotNullTo(destination) { row -> row.get(0, C::class.java) } +} + +/** + * Return a sequence customizing the `order by` clause of the internal query. + * + * The operation is intermediate. + */ +@Deprecated( + message = "This function is deprecated, use sortedBy({ it.col1.asc() }, { it.col2.desc() }) instead.", + replaceWith = ReplaceWith("sortedBy") +) +public inline fun > EntitySequence.sorted( + selector: (T) -> List +): EntitySequence { + return this.withExpression(expression.copy(orderBy = selector(sourceTable))) +} + +/** + * Return a sequence sorting elements by multiple columns, in ascending or descending order. For example, + * `sortedBy({ it.col1.asc() }, { it.col2.desc() })`. + * + * The operation is intermediate. + */ +@OptIn(ExperimentalTypeInference::class) +@OverloadResolutionByLambdaReturnType +public fun > EntitySequence.sortedBy( + vararg selectors: (T) -> OrderByExpression +): EntitySequence { + return this.withExpression(expression.copy(orderBy = selectors.map { it(sourceTable) })) +} + +/** + * Return a sequence sorting elements by a column, in ascending or descending order. For example, + * `sortedBy { it.col.asc() }` + * + * The operation is intermediate. + */ +@OptIn(ExperimentalTypeInference::class) +@OverloadResolutionByLambdaReturnType +public inline fun > EntitySequence.sortedBy( + selector: (T) -> OrderByExpression +): EntitySequence { + return this.withExpression(expression.copy(orderBy = listOf(selector(sourceTable)))) +} + +/** + * Return a sequence sorting elements by the specific column in ascending order. + * + * The operation is intermediate. + */ +@JvmName("sortedByAscending") +@OptIn(ExperimentalTypeInference::class) +@OverloadResolutionByLambdaReturnType +public inline fun > EntitySequence.sortedBy( + selector: (T) -> ColumnDeclaring<*> +): EntitySequence { + return this.withExpression(expression.copy(orderBy = listOf(selector(sourceTable).asc()))) +} + +/** + * Return a sequence sorting elements by the specific column in descending order. + * + * The operation is intermediate. + */ +public inline fun > EntitySequence.sortedByDescending( + selector: (T) -> ColumnDeclaring<*> +): EntitySequence { + return this.withExpression(expression.copy(orderBy = listOf(selector(sourceTable).desc()))) +} + +/** + * Returns a sequence containing all elements except first [n] elements. + * + * Note that this function is implemented based on the pagination feature of the specific databases. It's known that + * there is a uniform standard for SQL language, but the SQL standard doesn’t say how to implement paging queries, + * different databases provide different implementations on that. So we have to enable a dialect if we need to use this + * function, otherwise an exception will be thrown. + * + * The operation is intermediate. + */ +public fun > EntitySequence.drop(n: Int): EntitySequence { + if (n == 0) { + return this + } else { + val offset = expression.offset ?: 0 + return this.withExpression(expression.copy(offset = offset + n)) + } +} + +/** + * Returns a sequence containing first [n] elements. + * + * Note that this function is implemented based on the pagination feature of the specific databases. It's known that + * there is a uniform standard for SQL language, but the SQL standard doesn’t say how to implement paging queries, + * different databases provide different implementations on that. So we have to enable a dialect if we need to use this + * function, otherwise an exception will be thrown. + * + * The operation is intermediate. + */ +public fun > EntitySequence.take(n: Int): EntitySequence { + val limit = expression.limit ?: Int.MAX_VALUE + return this.withExpression(expression.copy(limit = min(limit, n))) +} + +/** + * Perform an aggregation given by [aggregationSelector] for all elements in the sequence, + * and return the aggregate result. + * + * Ktorm also supports aggregating two or more columns, we just need to wrap our aggregate expressions by + * [tupleOf] in the closure, then the function’s return type becomes `TupleN`. + * + * The operation is terminal. + * + * @param aggregationSelector a function that accepts the source table and returns the aggregate expression. + * @return the aggregate result. + */ +@OptIn(ExperimentalTypeInference::class) +@OverloadResolutionByLambdaReturnType +public suspend inline fun , reified C : Any> EntitySequence.aggregateColumns( + aggregationSelector: (T) -> ColumnDeclaring +): C? { + val aggregation = aggregationSelector(sourceTable) + + val expr = expression.copy( + columns = listOf(aggregation.aliased(null)) + ) + + val rowSet = Query(database, expr).doQuery() + if (rowSet.size == 1) { + val row = rowSet.first() + return aggregation.sqlType.getResult(row, row.metadata,0) + } else { + val (sql, _) = database.formatExpression(expr, beautifySql = true) + throw IllegalStateException("Expected 1 row but ${rowSet.size} returned from sql: \n\n$sql") + } +} + +/** + * Return the number of elements in this sequence. + * + * The operation is terminal. + */ +public suspend fun > EntitySequence.count(): Int { + val count = aggregateColumns { org.ktorm.r2dbc.dsl.count() }?.toInt() + return count ?: error("Count expression returns null, which should never happens.") +} + +/** + * Return the number of elements matching the given [predicate]. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.count( + predicate: (T) -> ColumnDeclaring +): Int { + return filter(predicate).count() +} + +/** + * Return `true` if the sequence has no elements. + * + * The operation is terminal. + * + * @since 2.7 + */ +public suspend fun > EntitySequence.isEmpty(): Boolean { + return count() == 0 +} + +/** + * Return `true` if the sequence has at lease one element. + * + * The operation is terminal. + * + * @since 2.7 + */ +public suspend fun > EntitySequence.isNotEmpty(): Boolean { + return count() > 0 +} + +/** + * Return `true` if the sequence has no elements. + * + * The operation is terminal. + */ +public suspend fun > EntitySequence.none(): Boolean { + return count() == 0 +} + +/** + * Return `true` if no elements match the given [predicate]. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.none( + predicate: (T) -> ColumnDeclaring +): Boolean { + return count(predicate) == 0 +} + +/** + * Return `true` if the sequence has at lease one element. + * + * The operation is terminal. + */ +public suspend fun > EntitySequence.any(): Boolean { + return count() > 0 +} + +/** + * Return `true` if at least one element matches the given [predicate]. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.any( + predicate: (T) -> ColumnDeclaring +): Boolean { + return count(predicate) > 0 +} + +/** + * Return `true` if all elements match the given [predicate]. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.all( + predicate: (T) -> ColumnDeclaring +): Boolean { + return none { !predicate(it) } +} + +/** + * Return the sum of the column given by [selector] in this sequence. + * + * The operation is terminal. + */ +public suspend inline fun ,reified C : Number> EntitySequence.sumBy( + selector: (T) -> ColumnDeclaring +): C? { + return aggregateColumns { sum(selector(it)) } +} + +/** + * Return the max value of the column given by [selector] in this sequence. + * + * The operation is terminal. + */ +public suspend inline fun ,reified C : Comparable> EntitySequence.maxBy( + selector: (T) -> ColumnDeclaring +): C? { + return aggregateColumns { max(selector(it)) } +} + +/** + * Return the min value of the column given by [selector] in this sequence. + * + * The operation is terminal. + */ +public suspend inline fun , reified C : Comparable> EntitySequence.minBy( + selector: (T) -> ColumnDeclaring +): C? { + return aggregateColumns { min(selector(it)) } +} + +/** + * Return the average value of the column given by [selector] in this sequence. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.averageBy( + selector: (T) -> ColumnDeclaring +): Double? { + return aggregateColumns { avg(selector(it)) } +} + +/** + * Return a [Map] containing key-value pairs provided by [transform] function applied to elements of the given sequence. + * + * If any of two pairs would have the same key the last one gets added to the map. + * + * The returned map preserves the entry iteration order of the original sequence. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.associate(transform: (E) -> Pair): Map { + return associateTo(LinkedHashMap(), transform) +} + +/** + * Return a [Map] containing the elements from the given sequence indexed by the key returned from [keySelector] + * function applied to each element. + * + * If any two elements would have the same key returned by [keySelector] the last one gets added to the map. + * + * The returned map preserves the entry iteration order of the original sequence. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.associateBy(keySelector: (E) -> K): Map { + return associateByTo(LinkedHashMap(), keySelector) +} + +/** + * Return a [Map] containing the values provided by [valueTransform] and indexed by [keySelector] functions + * applied to elements of the given sequence. + * + * If any two elements would have the same key returned by [keySelector] the last one gets added to the map. + * + * The returned map preserves the entry iteration order of the original sequence. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.associateBy( + keySelector: (E) -> K, + valueTransform: (E) -> V +): Map { + return associateByTo(LinkedHashMap(), keySelector, valueTransform) +} + +/** + * Return a [Map] where keys are elements from the given sequence and values are produced by the [valueSelector] + * function applied to each element. + * + * If any two elements are equal, the last one gets added to the map. + * + * The returned map preserves the entry iteration order of the original sequence. + * + * The operation is terminal. + */ +public suspend inline fun , V> EntitySequence.associateWith(valueSelector: (K) -> V): Map { + return associateWithTo(LinkedHashMap(), valueSelector) +} + +/** + * Populate and return the [destination] mutable map with key-value pairs provided by [transform] function applied + * to each element of the given sequence. + * + * If any of two pairs would have the same key the last one gets added to the map. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.associateTo( + destination: M, + transform: (E) -> Pair +): M { + for (element in this) destination += transform(element) + return destination +} + +/** + * Populate and return the [destination] mutable map with key-value pairs, where key is provided by the [keySelector] + * function applied to each element of the given sequence and value is the element itself. + * + * If any two elements would have the same key returned by [keySelector] the last one gets added to the map. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.associateByTo( + destination: M, + keySelector: (E) -> K +): M { + for (element in this) destination.put(keySelector(element), element) + return destination +} + +/** + * Populate and return the [destination] mutable map with key-value pairs, where key is provided by the [keySelector] + * function and and value is provided by the [valueTransform] function applied to elements of the given sequence. + * + * If any two elements would have the same key returned by [keySelector] the last one gets added to the map. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.associateByTo( + destination: M, + keySelector: (E) -> K, + valueTransform: (E) -> V +): M { + for (element in this) destination.put(keySelector(element), valueTransform(element)) + return destination +} + +/** + * Populate and return the [destination] mutable map with key-value pairs for each element of the given sequence, + * where key is the element itself and value is provided by the [valueSelector] function applied to that key. + * + * If any two elements are equal, the last one overwrites the former value in the map. + * + * The operation is terminal. + */ +public suspend inline fun , V, M : MutableMap> EntitySequence.associateWithTo( + destination: M, + valueSelector: (K) -> V +): M { + for (element in this) destination.put(element, valueSelector(element)) + return destination +} + +/** + * Return an element at the given [index] or `null` if the [index] is out of bounds of this sequence. + * + * Especially, if a dialect is enabled, this function will use the pagination feature to obtain the very record only. + * Assuming we are using MySQL and calling this function with an index 10, a SQL containing `limit 10, 1` will be + * generated. But if there are no dialects enabled, then all records in the sequence will be obtained to ensure the + * function just works. + * + * The operation is terminal. + */ +public suspend fun > EntitySequence.elementAtOrNull(index: Int): E? { + try { + @Suppress("UnconditionalJumpStatementInLoop") + for (element in this.drop(index).take(1)) return element + return null + } catch (e: DialectFeatureNotSupportedException) { + if (database.logger.isTraceEnabled()) { + database.logger.trace("Pagination is not supported, retrieving all records instead: ", e) + } + + var count = 0 + for (element in this) { + if (index == count++) return element + } + + return null + } +} + +/** + * Return an element at the given [index] or the result of calling the [defaultValue] function if the [index] is out + * of bounds of this sequence. + * + * Especially, if a dialect is enabled, this function will use the pagination feature to obtain the very record only. + * Assuming we are using MySQL and calling this function with an index 10, a SQL containing `limit 10, 1` will be + * generated. But if there are no dialects enabled, then all records in the sequence will be obtained to ensure the + * function just works. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.elementAtOrElse( + index: Int, + defaultValue: (Int) -> E +): E { + return elementAtOrNull(index) ?: defaultValue(index) +} + +/** + * Return an element at the given [index] or throws an [IndexOutOfBoundsException] if the [index] is out of bounds + * of this sequence. + * + * Especially, if a dialect is enabled, this function will use the pagination feature to obtain the very record only. + * Assuming we are using MySQL and calling this function with an index 10, a SQL containing `limit 10, 1` will be + * generated. But if there are no dialects enabled, then all records in the sequence will be obtained to ensure the + * function just works. + * + * The operation is terminal. + */ +public suspend fun > EntitySequence.elementAt(index: Int): E { + val result = elementAtOrNull(index) + return result ?: throw IndexOutOfBoundsException("Sequence doesn't contain element at index $index.") +} + +/** + * Return the first element, or `null` if the sequence is empty. + * + * Especially, if a dialect is enabled, this function will use the pagination feature to obtain the very record only. + * Assuming we are using MySQL, a SQL containing `limit 0, 1` will be generated. But if there are no dialects enabled, + * then all records in the sequence will be obtained to ensure the function just works. + * + * The operation is terminal. + */ +public suspend fun > EntitySequence.firstOrNull(): E? { + return elementAtOrNull(0) +} + +/** + * Return the first element matching the given [predicate], or `null` if element was not found. + * + * Especially, if a dialect is enabled, this function will use the pagination feature to obtain the very record only. + * Assuming we are using MySQL, a SQL containing `limit 0, 1` will be generated. But if there are no dialects enabled, + * then all records in the sequence matching the given [predicate] will be obtained to ensure the function just works. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.firstOrNull( + predicate: (T) -> ColumnDeclaring +): E? { + return filter(predicate).elementAtOrNull(0) +} + +/** + * Return the first element, or throws [NoSuchElementException] if the sequence is empty. + * + * Especially, if a dialect is enabled, this function will use the pagination feature to obtain the very record only. + * Assuming we are using MySQL, a SQL containing `limit 0, 1` will be generated. But if there are no dialects enabled, + * then all records in the sequence will be obtained to ensure the function just works. + * + * The operation is terminal. + */ +public suspend fun > EntitySequence.first(): E { + return firstOrNull() ?: throw NoSuchElementException("Sequence is empty.") +} + +/** + * Return the first element matching the given [predicate], or throws [NoSuchElementException] if element was not found. + * + * Especially, if a dialect is enabled, this function will use the pagination feature to obtain the very record only. + * Assuming we are using MySQL, a SQL containing `limit 0, 1` will be generated. But if there are no dialects enabled, + * then all records in the sequence matching the given [predicate] will be obtained to ensure the function just works. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.first( + predicate: (T) -> ColumnDeclaring +): E { + val result = firstOrNull(predicate) + return result ?: throw NoSuchElementException("Sequence contains no elements matching the predicate") +} + +/** + * Return the last element, or `null` if the sequence is empty. + * + * The operation is terminal. + */ +public suspend fun EntitySequence.lastOrNull(): E? { + var last: E? = null + for (element in this) last = element + return last +} + +/** + * Return the last element matching the given [predicate], or `null` if no such element was found. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.lastOrNull( + predicate: (T) -> ColumnDeclaring +): E? { + return filter(predicate).lastOrNull() +} + +/** + * Return the last element, or throws [NoSuchElementException] if the sequence is empty. + * + * The operation is terminal. + */ +public suspend fun EntitySequence.last(): E { + return lastOrNull() ?: throw NoSuchElementException("Sequence is empty.") +} + +/** + * Return the last element matching the given [predicate], or throws [NoSuchElementException] if no such element found. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.last( + predicate: (T) -> ColumnDeclaring +): E { + val result = lastOrNull(predicate) + return result ?: throw NoSuchElementException("Sequence contains no elements matching the predicate") +} + +/** + * Return the first element matching the given [predicate], or `null` if no such element was found. + * + * Especially, if a dialect is enabled, this function will use the pagination feature to obtain the very record only. + * Assuming we are using MySQL, a SQL containing `limit 0, 1` will be generated. But if there are no dialects enabled, + * then all records in the sequence matching the given [predicate] will be obtained to ensure the function just works. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.find( + predicate: (T) -> ColumnDeclaring +): E? { + return firstOrNull(predicate) +} + +/** + * Return the last element matching the given [predicate], or `null` if no such element was found. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.findLast( + predicate: (T) -> ColumnDeclaring +): E? { + return lastOrNull(predicate) +} + +/** + * Return single element, or `null` if the sequence is empty or has more than one element. + * + * The operation is terminal. + */ +public suspend fun > EntitySequence.singleOrNull(): E? { + val iterator = iterator() + if (!iterator.hasNext()) return null + + val single = iterator.next() + return if (iterator.hasNext()) null else single +} + +/** + * Return the single element matching the given [predicate], or `null` if element was not found or more than one + * element was found. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.singleOrNull( + predicate: (T) -> ColumnDeclaring +): E? { + return filter(predicate).singleOrNull() +} + +/** + * Return the single element, or throws an exception if the sequence is empty or has more than one element. + * + * The operation is terminal. + */ +public suspend fun > EntitySequence.single(): E { + val iterator = iterator() + if (!iterator.hasNext()) throw NoSuchElementException("Sequence is empty.") + + val single = iterator.next() + if (iterator.hasNext()) throw IllegalArgumentException("Sequence has more than one element.") + return single +} + +/** + * Return the single element matching the given [predicate], or throws exception if there is no or more than one + * matching element. + * + * The operation is terminal. + */ +public suspend inline fun > EntitySequence.single( + predicate: (T) -> ColumnDeclaring +): E { + return filter(predicate).single() +} + +/** + * Accumulate value starting with [initial] value and applying [operation] from left to right to current accumulator + * value and each element. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.fold(initial: R, operation: (acc: R, E) -> R): R { + var accumulator = initial + for (element in this) accumulator = operation(accumulator, element) + return accumulator +} + +/** + * Accumulate value starting with [initial] value and applying [operation] from left to right to current accumulator + * value and each element with its index in the original sequence. + * + * The [operation] function takes the index of an element, current accumulator value and the element itself, and + * calculates the next accumulator value. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.foldIndexed( + initial: R, + operation: (index: Int, acc: R, E) -> R +): R { + var index = 0 + var accumulator = initial + for (element in this) accumulator = operation(index++, accumulator, element) + return accumulator +} + +/** + * Accumulate value starting with the first element and applying [operation] from left to right to current accumulator + * value and each element. + * + * Throws an exception if this sequence is empty. If the sequence can be empty in an expected way, please use + * [reduceOrNull] instead. It returns `null` when its receiver is empty. + * + * The [operation] function takes the current accumulator value and an element, and calculates the next + * accumulator value. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.reduce(operation: (acc: E, E) -> E): E { + return reduceOrNull(operation) ?: throw UnsupportedOperationException("Empty sequence can't be reduced.") +} + +/** + * Accumulate value starting with the first element and applying [operation] from left to right to current accumulator + * value and each element with its index in the original sequence. + * + * Throws an exception if this sequence is empty. If the sequence can be empty in an expected way, please use + * [reduceIndexedOrNull] instead. It returns `null` when its receiver is empty. + * + * The [operation] function takes the index of an element, current accumulator value and the element itself and + * calculates the next accumulator value. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.reduceIndexed(operation: (index: Int, acc: E, E) -> E): E { + return reduceIndexedOrNull(operation) ?: throw UnsupportedOperationException("Empty sequence can't be reduced.") +} + +/** + * Accumulate value starting with the first element and applying [operation] from left to right to current accumulator + * value and each element. + * + * Returns `null` if the sequence is empty. + * + * The [operation] function takes the current accumulator value and an element, and calculates the next + * accumulator value. + * + * The operation is terminal. + * + * @since 3.1.0 + */ +public suspend inline fun EntitySequence.reduceOrNull(operation: (acc: E, E) -> E): E? { + val iterator = iterator() + if (!iterator.hasNext()) return null + + var accumulator = iterator.next() + while (iterator.hasNext()) { + accumulator = operation(accumulator, iterator.next()) + } + + return accumulator +} + +/** + * Accumulate value starting with the first element and applying [operation] from left to right to current accumulator + * value and each element with its index in the original sequence. + * + * Returns `null` if the sequence is empty. + * + * The [operation] function takes the index of an element, current accumulator value and the element itself and + * calculates the next accumulator value. + * + * The operation is terminal. + * + * @since 3.1.0 + */ +public suspend inline fun EntitySequence.reduceIndexedOrNull(operation: (index: Int, acc: E, E) -> E): E? { + var index = 1 + return reduceOrNull { acc, e -> operation(index++, acc, e) } +} + +/** + * Perform the given [action] on each element. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.forEach(action: (E) -> Unit) { + for (element in this) action(element) +} + +/** + * Perform the given [action] on each element, providing sequential index with the element. + * + * The [action] function takes the index of an element and the element itself and perform on the element. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.forEachIndexed(action: (index: Int, E) -> Unit) { + var index = 0 + for (element in this) action(index++, element) +} + +/** + * Return a lazy [Sequence] that wraps each element of the original sequence into an [IndexedValue] containing + * the index of that element and the element itself. + * + * @since 3.0.0 + */ +public suspend fun EntitySequence.withIndex(): Sequence> { + val iterator = iterator() + return Sequence { IndexingIterator(iterator) } +} + +/** + * Group elements of the original sequence by the key returned by the given [keySelector] function applied to each + * element and return a map where each group key is associated with a list of corresponding elements. + * + * The returned map preserves the entry iteration order of the keys produced from the original sequence. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.groupBy(keySelector: (E) -> K): Map> { + return groupByTo(LinkedHashMap(), keySelector) +} + +/** + * Group values returned by the [valueTransform] function applied to each element of the original sequence by the key + * returned by the given [keySelector] function applied to the element and returns a map where each group key is + * associated with a list of corresponding values. + * + * The returned map preserves the entry iteration order of the keys produced from the original sequence. + * + * The operation is terminal. + */ +public suspend inline fun EntitySequence.groupBy( + keySelector: (E) -> K, + valueTransform: (E) -> V +): Map> { + return groupByTo(LinkedHashMap(), keySelector, valueTransform) +} + +/** + * Group elements of the original sequence by the key returned by the given [keySelector] function applied to each + * element and put to the [destination] map each group key associated with a list of corresponding elements. + * + * The operation is terminal. + */ +public suspend inline fun >> EntitySequence.groupByTo( + destination: M, + keySelector: (E) -> K +): M { + for (element in this) { + val key = keySelector(element) + val list = destination.getOrPut(key) { ArrayList() } + list += element + } + + return destination +} + +/** + * Group values returned by the [valueTransform] function applied to each element of the original sequence by the key + * returned by the given [keySelector] function applied to the element and put to the [destination] map each group key + * associated with a list of corresponding values. + * + * The operation is terminal. + */ +public suspend inline fun >> EntitySequence.groupByTo( + destination: M, + keySelector: (E) -> K, + valueTransform: (E) -> V +): M { + for (element in this) { + val key = keySelector(element) + val list = destination.getOrPut(key) { ArrayList() } + list += valueTransform(element) + } + + return destination +} + +/** + * Create an [EntityGrouping] from the sequence to be used later with one of group-and-fold operations. + * + * The [keySelector] can be applied to each record to get its key, or used as the `group by` clause of generated SQLs. + * + * The operation is intermediate. + */ +/* +TODO grouping +public suspend fun , K : Any> EntitySequence.groupingBy( + keySelector: (T) -> ColumnDeclaring +): EntityGrouping { + return EntityGrouping(this, keySelector) +}*/ + +/** + * Append the string from all the elements separated using [separator] and using the given [prefix] and [postfix]. + * + * If the collection could be huge, you can specify a non-negative value of [limit], in which case only the first + * [limit] elements will be appended, followed by the [truncated] string (which defaults to "..."). + * + * The operation is terminal. + */ +public suspend fun EntitySequence.joinTo( + buffer: A, + separator: CharSequence = ", ", + prefix: CharSequence = "", + postfix: CharSequence = "", + limit: Int = -1, + truncated: CharSequence = "...", + transform: ((E) -> CharSequence)? = null +): A { + buffer.append(prefix) + var count = 0 + for (element in this) { + if (++count > 1) buffer.append(separator) + if (limit < 0 || count <= limit) { + if (transform != null) buffer.append(transform(element)) else buffer.append(element.toString()) + } else { + buffer.append(truncated) + break + } + } + buffer.append(postfix) + return buffer +} + +/** + * Create a string from all the elements separated using [separator] and using the given [prefix] and [postfix]. + * + * If the collection could be huge, you can specify a non-negative value of [limit], in which case only the first + * [limit] elements will be appended, followed by the [truncated] string (which defaults to "..."). + * + * The operation is terminal. + */ +public suspend fun EntitySequence.joinToString( + separator: CharSequence = ", ", + prefix: CharSequence = "", + postfix: CharSequence = "", + limit: Int = -1, + truncated: CharSequence = "...", + transform: ((E) -> CharSequence)? = null +): String { + return joinTo(StringBuilder(), separator, prefix, postfix, limit, truncated, transform).toString() +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlFormatter.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlFormatter.kt index 3b53356..2c767df 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlFormatter.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlFormatter.kt @@ -18,7 +18,6 @@ package org.ktorm.r2dbc.expression import org.ktorm.r2dbc.database.Database import org.ktorm.r2dbc.database.DialectFeatureNotSupportedException -import java.util.* /** * Subclass of [SqlExpressionVisitor], visiting SQL expression trees using visitor pattern. After the visit completes, @@ -75,78 +74,79 @@ public abstract class SqlFormatter( } protected fun writeKeyword(keyword: String) { - _builder.append(keyword) -// when (database.generateSqlInUpperCase) { -// true -> { -// _builder.append(keyword.toUpperCase()) -// } -// false -> { -// _builder.append(keyword.toLowerCase()) -// } -// null -> { -// if (database.supportsMixedCaseIdentifiers || !database.storesLowerCaseIdentifiers) { -// _builder.append(keyword.toUpperCase()) -// } else { -// _builder.append(keyword.toLowerCase()) -// } -// } -// } + when (database.generateSqlInUpperCase) { + true -> { + _builder.append(keyword.uppercase()) + } + false -> { + _builder.append(keyword.lowercase()) + } + null -> { + if (database.dialect.supportsMixedCaseIdentifiers || !database.dialect.storesLowerCaseIdentifiers) { + _builder.append(keyword.uppercase()) + } else { + _builder.append(keyword.lowercase()) + } + } + } } protected open fun checkColumnName(name: String) { } protected open fun shouldQuote(identifier: String): Boolean { -// if (database.alwaysQuoteIdentifiers) { -// return true -// } + if (database.alwaysQuoteIdentifiers) { + return true + } if (!identifier.isIdentifier) { return true } -// if (identifier.toUpperCase() in database.keywords) { -// return true -// } -// if (identifier.isMixedCase -// && !database.supportsMixedCaseIdentifiers && database.supportsMixedCaseQuotedIdentifiers) { -// return true -// } + if (identifier.uppercase() in database.keywords) { + return true + } + if (identifier.isMixedCase + && !database.dialect.supportsMixedCaseIdentifiers + && database.dialect.supportsMixedCaseQuotedIdentifiers + ) { + return true + } return false } protected val String.quoted: String get() { - return this -// if (shouldQuote(this)) { -// if (database.supportsMixedCaseQuotedIdentifiers) { -// return "${database.identifierQuoteString}${this}${database.identifierQuoteString}" -// } else { -// if (database.storesUpperCaseQuotedIdentifiers) { -// return "${database.identifierQuoteString}${this.toUpperCase()}${database.identifierQuoteString}" -// } -// if (database.storesLowerCaseQuotedIdentifiers) { -// return "${database.identifierQuoteString}${this.toLowerCase()}${database.identifierQuoteString}" -// } -// if (database.storesMixedCaseQuotedIdentifiers) { -// return "${database.identifierQuoteString}${this}${database.identifierQuoteString}" -// } -// // Should never happen, but it's still needed as some database drivers are not implemented correctly. -// return "${database.identifierQuoteString}${this}${database.identifierQuoteString}" -// } -// } else { -// if (database.supportsMixedCaseIdentifiers) { -// return this -// } else { -// if (database.storesUpperCaseIdentifiers) { -// return this.toUpperCase() -// } -// if (database.storesLowerCaseIdentifiers) { -// return this.toLowerCase() -// } -// if (database.storesMixedCaseIdentifiers) { -// return this -// } -// // Should never happen, but it's still needed as some database drivers are not implemented correctly. -// return this -// } -// } + val dialect = database.dialect + if (shouldQuote(this)) { + if (dialect.supportsMixedCaseQuotedIdentifiers) { + return "${dialect.identifierQuoteString}${this}${dialect.identifierQuoteString}" + } else { + if (dialect.storesUpperCaseQuotedIdentifiers) { + return "${dialect.identifierQuoteString}${this.uppercase()}${dialect.identifierQuoteString}" + } + if (dialect.storesLowerCaseQuotedIdentifiers) { + return "${dialect.identifierQuoteString}${this.lowercase()}${dialect.identifierQuoteString}" + } + if (dialect.storesMixedCaseQuotedIdentifiers) { + return "${dialect.identifierQuoteString}${this}${dialect.identifierQuoteString}" + } + // Should never happen, but it's still needed as some dialect drivers are not implemented correctly. + return "${dialect.identifierQuoteString}${this}${dialect.identifierQuoteString}" + } + } else { + if (dialect.supportsMixedCaseIdentifiers) { + return this + } else { + if (dialect.storesUpperCaseIdentifiers) { + return this.uppercase() + } + if (dialect.storesLowerCaseIdentifiers) { + return this.lowercase() + } + if (dialect.storesMixedCaseIdentifiers) { + return this + } + // Should never happen, but it's still needed as some dialect drivers are not implemented correctly. + return this + } + } } protected val String.isMixedCase: Boolean get() { @@ -171,19 +171,19 @@ public abstract class SqlFormatter( if (this == '_') { return true } -// if (this in database.extraNameCharacters) { -// return true -// } + if (this in database.dialect.extraNameCharacters) { + return true + } return false } protected val SqlExpression.removeBrackets: Boolean get() { return isLeafNode - || this is ColumnExpression<*> - || this is FunctionExpression<*> - || this is AggregateExpression<*> - || this is ExistsExpression - || this is ColumnDeclaringExpression<*> + || this is ColumnExpression<*> + || this is FunctionExpression<*> + || this is AggregateExpression<*> + || this is ExistsExpression + || this is ColumnDeclaringExpression<*> } override fun visit(expr: SqlExpression): SqlExpression { @@ -391,6 +391,7 @@ public abstract class SqlFormatter( if (expr.offset != null || expr.limit != null) { writePagination(expr) } + @Suppress("DEPRECATION") if (expr.forUpdate) { writeKeyword("for update ") } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt index b13220f..a5906bd 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt @@ -16,8 +16,9 @@ package org.ktorm.r2dbc.schema -import io.r2dbc.spi.Row +import org.ktorm.r2dbc.dsl.QueryRow import org.ktorm.r2dbc.expression.TableExpression +import org.ktorm.schema.* import java.util.* import kotlin.reflect.KClass import kotlin.reflect.jvm.jvmErasure @@ -174,6 +175,32 @@ public abstract class BaseTable( } } + /** + * Transform the registered column's [SqlType] to another. The transformed [SqlType] has the same `typeCode` and + * `typeName` as the underlying one, and performs the specific transformations on column values. + * + * This enables a user-friendly syntax to extend more data types. For example, the following code defines a column + * of type `Column`, based on the existing column definition function [int]: + * + * ```kotlin + * val role = int("role").transform({ UserRole.fromCode(it) }, { it.code }) + * ``` + * + * Note: Since [Column] is immutable, this function will create a new [Column] instance and replace the origin + * registered one. + * + * @param fromUnderlyingValue a function that transforms a value of underlying type to the user's type. + * @param toUnderlyingValue a function that transforms a value of user's type the to the underlying type. + * @return the new [Column] instance with its type changed to [R]. + * @see SqlType.transform + */ + public inline fun Column.transform( + noinline fromUnderlyingValue: (C) -> R, + noinline toUnderlyingValue: (R) -> C, + ): Column { + return transform(fromUnderlyingValue, toUnderlyingValue, R::class.java) + } + /** * Transform the registered column's [SqlType] to another. The transformed [SqlType] has the same `typeCode` and * `typeName` as the underlying one, and performs the specific transformations on column values. @@ -195,12 +222,13 @@ public abstract class BaseTable( */ public fun Column.transform( fromUnderlyingValue: (C) -> R, - toUnderlyingValue: (R) -> C + toUnderlyingValue: (R) -> C, + javaType: Class ): Column { checkRegistered() checkTransformable() - val result = Column(table, name, sqlType = sqlType.transform(fromUnderlyingValue, toUnderlyingValue)) + val result = Column(table, name, sqlType = sqlType.transform(fromUnderlyingValue, toUnderlyingValue, javaType)) _columns[name] = result return result } @@ -231,17 +259,19 @@ public abstract class BaseTable( private fun Column.checkConflictBinding(binding: ColumnBinding) { for (column in _columns.values) { val hasConflict = when (binding) { - is NestedBinding -> column.allBindings - .filterIsInstance() - .filter { it.properties == binding.properties } - .any() - is ReferenceBinding -> column.allBindings - .filterIsInstance() - .filter { it.referenceTable.tableName == binding.referenceTable.tableName } - .filter { it.referenceTable.catalog == binding.referenceTable.catalog } - .filter { it.referenceTable.schema == binding.referenceTable.schema } - .filter { it.onProperty == binding.onProperty } - .any() + is NestedBinding -> + column.allBindings + .filterIsInstance() + .filter { it.properties == binding.properties } + .any() + is ReferenceBinding -> + column.allBindings + .filterIsInstance() + .filter { it.referenceTable.tableName == binding.referenceTable.tableName } + .filter { it.referenceTable.catalog == binding.referenceTable.catalog } + .filter { it.referenceTable.schema == binding.referenceTable.schema } + .filter { it.onProperty == binding.onProperty } + .any() } if (hasConflict) { @@ -345,13 +375,13 @@ public abstract class BaseTable( * it is equivalent to `c.bindTo { it.department.id }` in this case, that avoids unnecessary object creations * and some exceptions raised by conflict column names. */ - public fun createEntity(row: Row, withReferences: Boolean = true): E { + public fun createEntity(row: QueryRow, withReferences: Boolean = true): E { val entity = doCreateEntity(row, withReferences) -// val logger = row.query.database.logger -// if (logger.isTraceEnabled()) { -// logger.trace("Entity: $entity") -// } + val logger = row.query.database.logger + if (logger.isTraceEnabled()) { + logger.trace("Entity: $entity") + } return entity } @@ -362,7 +392,7 @@ public abstract class BaseTable( * This function is called by [createEntity]. Subclasses should override it and implement the actual logic of * retrieving an entity object from the query results. */ - protected abstract fun doCreateEntity(row: Row, withReferences: Boolean): E + protected abstract fun doCreateEntity(row: QueryRow, withReferences: Boolean): E /** * Convert this table to a [TableExpression]. diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Column.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Column.kt index 38b7f38..3bf19eb 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Column.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Column.kt @@ -73,6 +73,7 @@ public interface ColumnDeclaring { * Wrap the given [argument] as an [ArgumentExpression] using the [sqlType]. */ public fun wrapArgument(argument: T?): ArgumentExpression + } /** diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/ColumnBindingHandler.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/ColumnBindingHandler.kt new file mode 100644 index 0000000..5a75a3b --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/ColumnBindingHandler.kt @@ -0,0 +1,126 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.schema + +import org.ktorm.r2dbc.entity.Entity +import java.lang.reflect.InvocationHandler +import java.lang.reflect.Method +import java.lang.reflect.Proxy +import java.util.* +import kotlin.reflect.KClass +import kotlin.reflect.KMutableProperty +import kotlin.reflect.KProperty1 +import kotlin.reflect.full.createInstance +import kotlin.reflect.full.declaredMemberProperties +import kotlin.reflect.full.isSubclassOf +import kotlin.reflect.jvm.javaGetter +import kotlin.reflect.jvm.javaSetter + +@PublishedApi +internal class ColumnBindingHandler(val properties: MutableList>) : InvocationHandler { + + override fun invoke(proxy: Any, method: Method, args: Array?): Any? { + when (method.declaringClass.kotlin) { + Any::class, Entity::class -> { + error("Unsupported method: $method") + } + else -> { + val (prop, isGetter) = method.kotlinProperty ?: error("Unsupported method: $method") + if (!prop.isAbstract) { + error("Cannot bind a column to a non-abstract property: $prop") + } + if (!isGetter) { + error("Cannot modify a property while we are binding a column to it, property: $prop") + } + + properties += prop + + val returnType = method.returnType + return when { + returnType.kotlin.isSubclassOf(Entity::class) -> createProxy(returnType.kotlin, properties) + returnType.isPrimitive -> returnType.defaultValue + else -> null + } + } + } + } + + private fun error(msg: String): Nothing { + throw UnsupportedOperationException(msg) + } + + companion object { + + fun createProxy(entityClass: KClass<*>, properties: MutableList>): Entity<*> { + val handler = ColumnBindingHandler(properties) + return Proxy.newProxyInstance(entityClass.java.classLoader, arrayOf(entityClass.java), handler) as Entity<*> + } + } +} + +internal val Method.kotlinProperty: Pair, Boolean>? get() { + for (prop in declaringClass.kotlin.declaredMemberProperties) { + if (prop.javaGetter == this) { + return Pair(prop, true) + } + if (prop is KMutableProperty<*> && prop.javaSetter == this) { + return Pair(prop, false) + } + } + return null +} + +@OptIn(ExperimentalUnsignedTypes::class) +internal val Class<*>.defaultValue: Any get() { + val value = when { + this == Boolean::class.javaPrimitiveType -> false + this == Char::class.javaPrimitiveType -> 0.toChar() + this == Byte::class.javaPrimitiveType -> 0.toByte() + this == Short::class.javaPrimitiveType -> 0.toShort() + this == Int::class.javaPrimitiveType -> 0 + this == Long::class.javaPrimitiveType -> 0L + this == Float::class.javaPrimitiveType -> 0.0F + this == Double::class.javaPrimitiveType -> 0.0 + this == String::class.java -> "" + this == UByte::class.java -> 0.toUByte() + this == UShort::class.java -> 0.toUShort() + this == UInt::class.java -> 0U + this == ULong::class.java -> 0UL + this == UByteArray::class.java -> ubyteArrayOf() + this == UShortArray::class.java -> ushortArrayOf() + this == UIntArray::class.java -> uintArrayOf() + this == ULongArray::class.java -> ulongArrayOf() + this == Set::class.java -> LinkedHashSet() + this == List::class.java -> ArrayList() + this == Collection::class.java -> ArrayList() + this == Map::class.java -> LinkedHashMap() + this == Queue::class.java || this == Deque::class.java -> LinkedList() + this == SortedSet::class.java || this == NavigableSet::class.java -> TreeSet() + this == SortedMap::class.java || this == NavigableMap::class.java -> TreeMap() + this.isEnum -> this.enumConstants[0] + this.isArray -> java.lang.reflect.Array.newInstance(this.componentType, 0) + this.kotlin.isSubclassOf(Entity::class) -> Entity.create(this.kotlin) + else -> this.kotlin.createInstance() + } + + if (this.kotlin.isInstance(value)) { + return value + } else { + // never happens... + throw AssertionError("$value must be instance of $this") + } +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/EntityDml.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/EntityDml.kt new file mode 100644 index 0000000..4a2fa88 --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/EntityDml.kt @@ -0,0 +1,30 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.schema + +import org.ktorm.r2dbc.entity.Entity +import org.ktorm.r2dbc.entity.implementation + +internal fun Entity<*>.clearChangesRecursively() { + implementation.changedProperties.clear() + + for ((_, value) in properties) { + if (value is Entity<*>) { + value.clearChangesRecursively() + } + } +} \ No newline at end of file diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt index af245d7..a2e4a50 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt @@ -7,24 +7,31 @@ import kotlin.reflect.KClass public interface SqlType { + public val javaType: Class + public fun bindParameter(statement: Statement, index: Int, value: T?) public fun bindParameter(statement: Statement, name: String, value: T?) - + public fun getResult(row: Row, metadata: RowMetadata, index: Int): T? - + public fun getResult(row: Row, metadata: RowMetadata, name: String): T? - public fun transform(fromUnderlyingValue: (T) -> R, toUnderlyingValue: (R) -> T): SqlType { - return TransformedSqlType(this, fromUnderlyingValue, toUnderlyingValue) - } } -public class SimpleSqlType(public val kotlinType: KClass) : SqlType { +public fun SqlType.transform( + fromUnderlyingValue: (T) -> R, + toUnderlyingValue: (R) -> T, + javaType: Class +): SqlType { + return TransformedSqlType(this, fromUnderlyingValue, toUnderlyingValue, javaType) +} + +public open class SimpleSqlType(public val kotlinType: KClass) : SqlType { override fun bindParameter(statement: Statement, index: Int, value: T?) { if (value == null) { - statement.bindNull(index, kotlinType.java) + statement.bindNull(index, kotlinType.javaObjectType) } else { statement.bind(index, value) } @@ -32,29 +39,30 @@ public class SimpleSqlType(public val kotlinType: KClass) : SqlType< override fun bindParameter(statement: Statement, name: String, value: T?) { if (value == null) { - statement.bindNull(name, kotlinType.java) + statement.bindNull(name, kotlinType.javaObjectType) } else { statement.bind(name, value) } } override fun getResult(row: Row, metadata: RowMetadata, index: Int): T? { - return row.get(index, kotlinType.java) + return row.get(index, kotlinType.javaObjectType) } override fun getResult(row: Row, metadata: RowMetadata, name: String): T? { - return row.get(name, kotlinType.java) + return row.get(name, kotlinType.javaObjectType) } -// public companion object { -// public inline operator fun invoke(): SimpleSqlType = SimpleSqlType(T::class) -// } + override val javaType: Class + get() = kotlinType.javaObjectType + } public class TransformedSqlType( public val underlyingType: SqlType, public val fromUnderlyingValue: (T) -> R, - public val toUnderlyingValue: (R) -> T + public val toUnderlyingValue: (R) -> T, + public override val javaType: Class ) : SqlType { override fun bindParameter(statement: Statement, index: Int, value: R?) { diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlTypes.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlTypes.kt new file mode 100644 index 0000000..9659eb4 --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlTypes.kt @@ -0,0 +1,262 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.schema + +import java.math.BigDecimal +import java.sql.Timestamp +import java.time.* +import java.util.* + +/** + * Define a column typed of [BooleanSqlType]. + */ +public fun BaseTable<*>.boolean(name: String): Column { + return registerColumn(name, BooleanSqlType) +} + +/** + * [SqlType] implementation represents `boolean` SQL type. + */ +public object BooleanSqlType : SimpleSqlType(Boolean::class) + +/** + * Define a column typed of [IntSqlType]. + */ +public fun BaseTable<*>.int(name: String): Column { + return registerColumn(name, IntSqlType) +} + +/** + * [SqlType] implementation represents `int` SQL type. + */ +public object IntSqlType : SimpleSqlType(Int::class) + +/** + * Define a column typed of [ShortSqlType]. + * + * @since 3.1.0 + */ +public fun BaseTable<*>.short(name: String): Column { + return registerColumn(name, ShortSqlType) +} + +/** + * [SqlType] implementation represents `smallint` SQL type. + * + * @since 3.1.0 + */ +public object ShortSqlType : SimpleSqlType(Short::class) + +/** + * Define a column typed of [LongSqlType]. + */ +public fun BaseTable<*>.long(name: String): Column { + return registerColumn(name, LongSqlType) +} + +/** + * [SqlType] implementation represents `long` SQL type. + */ +public object LongSqlType : SimpleSqlType(Long::class) +/** + * Define a column typed of [FloatSqlType]. + */ +public fun BaseTable<*>.float(name: String): Column { + return registerColumn(name, FloatSqlType) +} + +/** + * [SqlType] implementation represents `float` SQL type. + */ +public object FloatSqlType : SimpleSqlType(Float::class) + +/** + * Define a column typed of [DoubleSqlType]. + */ +public fun BaseTable<*>.double(name: String): Column { + return registerColumn(name, DoubleSqlType) +} + +/** + * [SqlType] implementation represents `double` SQL type. + */ +public object DoubleSqlType : SimpleSqlType(Double::class) + +/** + * Define a column typed of [DecimalSqlType]. + */ +public fun BaseTable<*>.decimal(name: String): Column { + return registerColumn(name, DecimalSqlType) +} + +/** + * [SqlType] implementation represents `decimal` SQL type. + */ +public object DecimalSqlType : SimpleSqlType(BigDecimal::class) + +/** + * Define a column typed of [VarcharSqlType]. + */ +public fun BaseTable<*>.varchar(name: String): Column { + return registerColumn(name, VarcharSqlType) +} + +/** + * [SqlType] implementation represents `varchar` SQL type. + */ +public object VarcharSqlType : SimpleSqlType(String::class) + +/** + * Define a column typed of [TextSqlType]. + */ +public fun BaseTable<*>.text(name: String): Column { + return registerColumn(name, TextSqlType) +} + +/** + * [SqlType] implementation represents `text` SQL type. + */ +public object TextSqlType : SimpleSqlType(String::class) + +/** + * Define a column typed of [BlobSqlType]. + */ +public fun BaseTable<*>.blob(name: String): Column { + return registerColumn(name, BlobSqlType) +} + +/** + * [SqlType] implementation represents `blob` SQL type. + */ +public object BlobSqlType : SimpleSqlType(ByteArray::class) +/** + * Define a column typed of [BytesSqlType]. + */ +public fun BaseTable<*>.bytes(name: String): Column { + return registerColumn(name, BytesSqlType) +} + +/** + * [SqlType] implementation represents `bytes` SQL type. + */ +public object BytesSqlType : SimpleSqlType(ByteArray::class) +/** + * Define a column typed of [TimestampSqlType]. + */ +public fun BaseTable<*>.jdbcTimestamp(name: String): Column { + return registerColumn(name, TimestampSqlType) +} + +/** + * [SqlType] implementation represents `timestamp` SQL type. + */ +public object TimestampSqlType : SimpleSqlType(Timestamp::class) + +/** + * Define a column typed of [InstantSqlType]. + */ +public fun BaseTable<*>.timestamp(name: String): Column { + return registerColumn(name, InstantSqlType) +} + +/** + * [SqlType] implementation represents `timestamp` SQL type. + */ +public object InstantSqlType : SimpleSqlType(Instant::class) +/** + * Define a column typed of [LocalDateTimeSqlType]. + */ +public fun BaseTable<*>.datetime(name: String): Column { + return registerColumn(name, LocalDateTimeSqlType) +} + +/** + * [SqlType] implementation represents `datetime` SQL type. + */ +public object LocalDateTimeSqlType : SimpleSqlType(LocalDateTime::class) + +/** + * Define a column typed of [LocalDateSqlType]. + */ +public fun BaseTable<*>.date(name: String): Column { + return registerColumn(name, LocalDateSqlType) +} + +/** + * [SqlType] implementation represents `date` SQL type. + */ +public object LocalDateSqlType : SimpleSqlType(LocalDate::class) + +/** + * Define a column typed of [LocalTimeSqlType]. + */ +public fun BaseTable<*>.time(name: String): Column { + return registerColumn(name, LocalTimeSqlType) +} + +/** + * [SqlType] implementation represents `time` SQL type. + */ +public object LocalTimeSqlType : SimpleSqlType(LocalTime::class) +/** + * Define a column typed of [MonthDaySqlType], instances of [MonthDay] are saved as strings in format `MM-dd`. + */ +public fun BaseTable<*>.monthDay(name: String): Column { + return registerColumn(name, MonthDaySqlType) +} + +/** + * [SqlType] implementation used to save [MonthDay] instances, formating them to strings with pattern `MM-dd`. + */ +public object MonthDaySqlType : SimpleSqlType(MonthDay::class) + +/** + * Define a column typed of [YearMonthSqlType], instances of [YearMonth] are saved as strings in format `yyyy-MM`. + */ +public fun BaseTable<*>.yearMonth(name: String): Column { + return registerColumn(name, YearMonthSqlType) +} + +/** + * [SqlType] implementation used to save [YearMonth] instances, formating them to strings with pattern `yyyy-MM`. + */ +@Suppress("MagicNumber") +public object YearMonthSqlType : SimpleSqlType(YearMonth::class) + +/** + * Define a column typed of [YearSqlType], instances of [Year] are saved as integers. + */ +public fun BaseTable<*>.year(name: String): Column { + return registerColumn(name, YearSqlType) +} + +/** + * [SqlType] implementation used to save [Year] instances as integers. + */ +public object YearSqlType : SimpleSqlType(Year::class) + +/** + * Define a column typed of [UuidSqlType]. + */ +public fun BaseTable<*>.uuid(name: String): Column { + return registerColumn(name, UuidSqlType) +} + +/** + * [SqlType] implementation represents `uuid` SQL type. + */ +public object UuidSqlType : SimpleSqlType(UUID::class) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt new file mode 100644 index 0000000..79d0ba6 --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt @@ -0,0 +1,174 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.schema + +import org.ktorm.r2dbc.dsl.QueryRow +import org.ktorm.r2dbc.entity.Entity +import org.ktorm.r2dbc.entity.EntityImplementation +import org.ktorm.r2dbc.entity.implementation +import org.ktorm.r2dbc.entity.setColumnValue +import kotlin.reflect.KClass +import kotlin.reflect.KProperty1 +import kotlin.reflect.jvm.jvmErasure + +/** + * Base class of Ktorm's table objects. This class extends from [BaseTable], additionally providing a binding mechanism + * with [Entity] interfaces based on functions such as [bindTo], [references]. + * + * [Table] implements the [doCreateEntity] function from the parent class. The function automatically creates an + * entity object using the binding configuration specified by [bindTo] and [references], reading columns’ values from + * the result set and filling them into corresponding entity properties. + * + * To use this class, we need to define our entities as interfaces extending from [Entity]. Here is an example. More + * documents can be found at https://www.ktorm.org/en/entities-and-column-binding.html + * + * ```kotlin + * interface Department : Entity { + * val id: Int + * var name: String + * var location: String + * } + * object Departments : Table("t_department") { + * val id = int("id").primaryKey().bindTo { it.id } + * val name = varchar("name").bindTo { it.name } + * val location = varchar("location").bindTo { it.location } + * } + * ``` + */ +@Suppress("UNCHECKED_CAST") +public open class Table>( + tableName: String, + alias: String? = null, + catalog: String? = null, + schema: String? = null, + entityClass: KClass? = null +) : BaseTable(tableName, alias, catalog, schema, entityClass) { + + /** + * Bind the column to nested properties, eg. `employee.manager.department.id`. + * + * Note: Since [Column] is immutable, this function will create a new [Column] instance and replace the origin + * registered one. + * + * @param selector a lambda in which we should return the property we want to bind. + * For example: `val name = varchar("name").bindTo { it.name }`. + * + * @return the new [Column] instance. + */ + public inline fun Column.bindTo(selector: (E) -> C?): Column { + val properties = detectBindingProperties(selector) + return doBindInternal(NestedBinding(properties)) + } + + /** + * Bind the column to a reference table, equivalent to a foreign key in relational databases. + * Entity sequence APIs would automatically left join all references (recursively) by default. + * + * Note: Since [Column] is immutable, this function will create a new [Column] instance and replace the origin + * registered one. + * + * @param referenceTable the reference table, will be copied by calling its [aliased] function with + * an alias like `_refN`. + * + * @param selector a lambda in which we should return the property used to hold the referenced entities. + * For example: `val departmentId = int("department_id").references(Departments) { it.department }`. + * + * @return the new [Column] instance. + * + * @see org.ktorm.entity.sequenceOf + * @see createEntity + */ + public inline fun > Column.references( + referenceTable: Table, + selector: (E) -> R? + ): Column { + val properties = detectBindingProperties(selector) + + if (properties.size > 1) { + throw IllegalArgumentException("Reference binding doesn't support nested properties.") + } else { + return doBindInternal(ReferenceBinding(referenceTable, properties[0])) + } + } + + @PublishedApi + internal inline fun detectBindingProperties(selector: (E) -> Any?): List> { + val entityClass = this.entityClass ?: error("No entity class configured for table: '$this'") + val properties = ArrayList>() + + val proxy = ColumnBindingHandler.createProxy(entityClass, properties) + selector(proxy as E) + + if (properties.isEmpty()) { + throw IllegalArgumentException("No binding properties found.") + } else { + return properties + } + } + + override fun aliased(alias: String): Table { + val result = Table(tableName, alias, catalog, schema, entityClass) + result.copyDefinitionsFrom(this) + return result + } + + final override fun doCreateEntity(row: QueryRow, withReferences: Boolean): E { + val entityClass = this.entityClass ?: error("No entity class configured for table: '$this'") + val entity = Entity.create(entityClass, fromDatabase = row.query.database, fromTable = this) as E + + for (column in columns) { + row.retrieveColumn(column, intoEntity = entity, withReferences = withReferences) + } + + return entity.apply { clearChangesRecursively() } + } + + private fun EntityImplementation.setColumnValue(column: Column<*>, value: Any?, forceSet: Boolean = false) { + for (binding in column.allBindings) { + this.setColumnValue(binding, value, forceSet) + } + } + + private fun QueryRow.retrieveColumn(column: Column<*>, intoEntity: E, withReferences: Boolean) { + val columnValue = this[column] + + for (binding in column.allBindings) { + when (binding) { + is ReferenceBinding -> { + val refTable = binding.referenceTable as Table<*> + val pk = refTable.singlePrimaryKey { + "Cannot reference the table '$refTable' as there is compound primary keys." + } + + if (withReferences) { + val child = refTable.doCreateEntity(this, withReferences = true) + child.implementation.setColumnValue(pk, columnValue, forceSet = true) + intoEntity[binding.onProperty.name] = child + } else { + val entityClass = binding.onProperty.returnType.jvmErasure + val child = Entity.create(entityClass, fromDatabase = query.database, fromTable = refTable) + child.implementation.setColumnValue(pk, columnValue) + intoEntity[binding.onProperty.name] = child + } + } + is NestedBinding -> { + intoEntity.implementation.setColumnValue(binding, columnValue) + } + } + } + } +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/TypeReference.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/TypeReference.kt index 25a75c9..8cdf291 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/TypeReference.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/TypeReference.kt @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.ktorm.r2dbc.schema +package org.ktorm.schema import java.lang.reflect.ParameterizedType import java.lang.reflect.Type diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/BaseTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/BaseTest.kt new file mode 100644 index 0000000..424faf6 --- /dev/null +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/BaseTest.kt @@ -0,0 +1,134 @@ +package org.ktorm + +import kotlinx.coroutines.reactive.awaitFirstOrNull +import kotlinx.coroutines.runBlocking +import org.junit.After +import org.junit.Before +import org.ktorm.r2dbc.database.Database +import org.ktorm.r2dbc.entity.Entity +import org.ktorm.r2dbc.entity.sequenceOf +import org.ktorm.r2dbc.logging.ConsoleLogger +import org.ktorm.r2dbc.logging.LogLevel +import org.ktorm.r2dbc.schema.* +import java.io.Serializable +import java.time.LocalDate + +/** + * Created by vince on Dec 07, 2018. + */ +open class BaseTest { + lateinit var database: Database + + @Before + open fun init() { + runBlocking { + database = Database.connect( + url = "r2dbc:h2:mem:///testdb?DB_CLOSE_DELAY=-1", + logger = ConsoleLogger(threshold = LogLevel.TRACE), + ) + + execSqlScript("init-data.sql") + } + } + + @After + open fun destroy() { + runBlocking { + execSqlScript("drop-data.sql") + } + } + + protected suspend fun execSqlScript(filename: String) { + database.useConnection { conn -> + javaClass.classLoader + ?.getResourceAsStream(filename) + ?.bufferedReader() + ?.use { reader -> + for (sql in reader.readText().split(";")) { + if (sql.any { it.isLetterOrDigit() }) { + val statement = conn.createStatement(sql) + statement.execute().awaitFirstOrNull() + } + } + } + } + } + + data class LocationWrapper(val underlying: String = "") : Serializable + + interface Department : Entity { + companion object : Entity.Factory() + + val id: Int + var name: String + var location: LocationWrapper + var mixedCase: String? + } + + interface Employee : Entity { + companion object : Entity.Factory() + + var id: Int + var name: String + var job: String + var manager: Employee? + var hireDate: LocalDate + var salary: Long + var department: Department + + val upperName get() = name.toUpperCase() + fun upperName() = name.toUpperCase() + } + + interface Customer : Entity { + companion object : Entity.Factory() + + var id: Int + var name: String + var email: String + var phoneNumber: String + } + + open class Departments(alias: String?) : Table("t_department", alias) { + companion object : Departments(null) + + override fun aliased(alias: String) = Departments(alias) + + val id = int("id").primaryKey().bindTo { it.id } + val name = varchar("name").bindTo { it.name } + val location = varchar("location").transform({ LocationWrapper(it) }, { it.underlying }).bindTo { it.location } + val mixedCase = varchar("mixedCase").bindTo { it.mixedCase } + } + + open class Employees(alias: String?) : Table("t_employee", alias) { + companion object : Employees(null) + + override fun aliased(alias: String) = Employees(alias) + + val id = int("id").primaryKey().bindTo { it.id } + val name = varchar("name").bindTo { it.name } + val job = varchar("job").bindTo { it.job } + val managerId = int("manager_id").bindTo { it.manager?.id } + val hireDate = date("hire_date").bindTo { it.hireDate } + val salary = long("salary").bindTo { it.salary } + val departmentId = int("department_id").references(Departments) { it.department } + val department = departmentId.referenceTable as Departments + } + + open class Customers(alias: String?) : Table("t_customer", alias, schema = "company") { + companion object : Customers(null) + + override fun aliased(alias: String) = Customers(alias) + + val id = int("id").primaryKey().bindTo { it.id } + val name = varchar("name").bindTo { it.name } + val email = varchar("email").bindTo { it.email } + val phoneNumber = varchar("phone_number").bindTo { it.phoneNumber } + } + + val Database.departments get() = this.sequenceOf(Departments) + + val Database.employees get() = this.sequenceOf(Employees) + + val Database.customers get() = this.sequenceOf(Customers) +} \ No newline at end of file diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt new file mode 100644 index 0000000..7ec2ca9 --- /dev/null +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt @@ -0,0 +1,191 @@ +package org.ktorm.database + +import kotlinx.coroutines.reactive.awaitFirstOrNull +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.ktorm.BaseTest +import org.ktorm.r2dbc.database.toList +import org.ktorm.r2dbc.dsl.insert +import org.ktorm.r2dbc.entity.count +import org.ktorm.r2dbc.entity.forEach +import org.ktorm.r2dbc.entity.sequenceOf + +/** + * Created by vince on Dec 02, 2018. + */ +@ExperimentalUnsignedTypes +class DatabaseTest : BaseTest() { + + + /*@Test + fun testKeywordWrapping(): Unit = runBlocking { + val configs = object : Table("t_config") { + val key = varchar("key").primaryKey() + val value = varchar("value") + } + + database.useConnection { + val sql = """CREATE TABLE T_CONFIG(KEY VARCHAR(128) PRIMARY KEY, VALUE VARCHAR(128))""" + it.createStatement(sql).execute().awaitFirstOrNull() + } + + database.insert(configs) { + set(it.key, "test") + set(it.value, "test value") + } + + assert(database.sequenceOf(configs).count { it.key eq "test" } == 1) + + database.delete(configs) { it.key eq "test" } + }*/ + + @Test + fun testTransaction() = runBlocking { + class DummyException : Exception() + + try { + database.useTransaction { + database.insert(Departments) { + set(it.name, "administration") + set(it.location, LocationWrapper("Hong Kong")) + } + + assert(database.departments.count() == 3) + + throw DummyException() + } + + } catch (e: DummyException) { + assert(database.departments.count() == 2) + } + } + + @Test + fun testRawSql() = runBlocking { + val names = database.useConnection { conn -> + val sql = """ + SELECT "NAME" FROM "T_EMPLOYEE" + WHERE "DEPARTMENT_ID" = ? + ORDER BY "ID" + """ + + val statement = conn.createStatement(sql) + statement.bind(0, 1) + statement.execute().awaitFirstOrNull()?.map { row, _ -> + row.get(0) + }?.toList() ?: emptyList() + } + + assert(names.size == 2) + assert(names[0] == "VINCE") + assert(names[1] == "MARRY") + } + + @Test + fun tableTest() = runBlocking { + val employees = database.sequenceOf(Employees) + employees.forEach { + println(it) + } + assert(true) + } + + /*fun BaseTable<*>.ulong(name: String): Column { + return registerColumn(name, object : SqlType(Types.BIGINT, "bigint unsigned") { + override fun doSetParameter(ps: PreparedStatement, index: Int, parameter: ULong) { + ps.setLong(index, parameter.toLong()) + } + + override fun doGetResult(rs: ResultSet, index: Int): ULong? { + return rs.getLong(index).toULong() + } + }) + } + + interface TestUnsigned : Entity { + companion object : Entity.Factory() + var id: ULong + } + + @Test + fun testUnsigned() { + val t = object : Table("T_TEST_UNSIGNED") { + val id = ulong("ID").primaryKey().bindTo { it.id } + } + + database.useConnection { conn -> + conn.createStatement().use { statement -> + val sql = """CREATE TABLE T_TEST_UNSIGNED(ID BIGINT UNSIGNED NOT NULL PRIMARY KEY)""" + statement.executeUpdate(sql) + } + } + + val unsigned = TestUnsigned { id = 5UL } + assert(unsigned.id == 5UL) + database.sequenceOf(t).add(unsigned) + + val ids = database.sequenceOf(t).toList().map { it.id } + println(ids) + assert(ids == listOf(5UL)) + + database.insert(t) { + set(it.id, 6UL) + } + + val ids2 = database.from(t).select(t.id).map { row -> row[t.id] } + println(ids2) + assert(ids2 == listOf(5UL, 6UL)) + + assert(TestUnsigned().id == 0UL) + } + + interface TestUnsignedNullable : Entity { + companion object : Entity.Factory() + var id: ULong? + } + + @Test + fun testUnsignedNullable() { + val t = object : Table("T_TEST_UNSIGNED_NULLABLE") { + val id = ulong("ID").primaryKey().bindTo { it.id } + } + + database.useConnection { conn -> + conn.createStatement().use { statement -> + val sql = """CREATE TABLE T_TEST_UNSIGNED_NULLABLE(ID BIGINT UNSIGNED NOT NULL PRIMARY KEY)""" + statement.executeUpdate(sql) + } + } + + val unsigned = TestUnsignedNullable { id = 5UL } + assert(unsigned.id == 5UL) + database.sequenceOf(t).add(unsigned) + + val ids = database.sequenceOf(t).toList().map { it.id } + println(ids) + assert(ids == listOf(5UL)) + + assert(TestUnsignedNullable().id == null) + } + + @Test + fun testDefaultValueReferenceEquality() { + assert(Boolean::class.javaPrimitiveType!!.defaultValue === Boolean::class.javaPrimitiveType!!.defaultValue) + assert(Char::class.javaPrimitiveType!!.defaultValue === Char::class.javaPrimitiveType!!.defaultValue) + assert(Byte::class.javaPrimitiveType!!.defaultValue === Byte::class.javaPrimitiveType!!.defaultValue) + assert(Short::class.javaPrimitiveType!!.defaultValue === Short::class.javaPrimitiveType!!.defaultValue) + assert(Int::class.javaPrimitiveType!!.defaultValue === Int::class.javaPrimitiveType!!.defaultValue) + assert(Long::class.javaPrimitiveType!!.defaultValue === Long::class.javaPrimitiveType!!.defaultValue) + assert(Float::class.javaPrimitiveType!!.defaultValue !== Float::class.javaPrimitiveType!!.defaultValue) + assert(Double::class.javaPrimitiveType!!.defaultValue !== Double::class.javaPrimitiveType!!.defaultValue) + assert(String::class.java.defaultValue === String::class.java.defaultValue) + assert(UByte::class.java.defaultValue !== UByte::class.java.defaultValue) + assert(UShort::class.java.defaultValue !== UShort::class.java.defaultValue) + assert(UInt::class.java.defaultValue !== UInt::class.java.defaultValue) + assert(ULong::class.java.defaultValue !== ULong::class.java.defaultValue) + assert(UByteArray::class.java.defaultValue !== UByteArray::class.java.defaultValue) + assert(UShortArray::class.java.defaultValue !== UShortArray::class.java.defaultValue) + assert(UIntArray::class.java.defaultValue !== UIntArray::class.java.defaultValue) + assert(ULongArray::class.java.defaultValue !== ULongArray::class.java.defaultValue) + }*/ +} \ No newline at end of file diff --git a/ktorm-r2dbc-core/src/test/resources/drop-data.sql b/ktorm-r2dbc-core/src/test/resources/drop-data.sql new file mode 100644 index 0000000..dcfef45 --- /dev/null +++ b/ktorm-r2dbc-core/src/test/resources/drop-data.sql @@ -0,0 +1,6 @@ + +DROP TABLE IF EXISTS "T_DEPARTMENT"; +DROP TABLE IF EXISTS "T_EMPLOYEE"; +DROP TABLE IF EXISTS "T_EMPLOYEE0"; +DROP TABLE IF EXISTS "COMPANY"."T_CUSTOMER"; +DROP SCHEMA IF EXISTS "COMPANY"; diff --git a/ktorm-r2dbc-core/src/test/resources/init-data.sql b/ktorm-r2dbc-core/src/test/resources/init-data.sql new file mode 100644 index 0000000..9d1cc8f --- /dev/null +++ b/ktorm-r2dbc-core/src/test/resources/init-data.sql @@ -0,0 +1,60 @@ + +CREATE TABLE "T_DEPARTMENT"( + "ID" INT NOT NULL PRIMARY KEY AUTO_INCREMENT, + "NAME" VARCHAR(128) NOT NULL, + "LOCATION" VARCHAR(128) NOT NULL, + "MIXEDCASE" VARCHAR(128) +); + +CREATE TABLE "T_EMPLOYEE"( + "ID" INT NOT NULL PRIMARY KEY AUTO_INCREMENT, + "NAME" VARCHAR(128) NOT NULL, + "JOB" VARCHAR(128) NOT NULL, + "MANAGER_ID" INT NULL, + "HIRE_DATE" DATE NOT NULL, + "SALARY" BIGINT NOT NULL, + "DEPARTMENT_ID" INT NOT NULL +); + +CREATE SCHEMA "COMPANY"; +CREATE TABLE "COMPANY"."T_CUSTOMER" ( + "ID" INT NOT NULL PRIMARY KEY AUTO_INCREMENT, + "NAME" VARCHAR(128) NOT NULL, + "EMAIL" VARCHAR(128) NOT NULL, + "PHONE_NUMBER" VARCHAR(128) NOT NULL +); + +INSERT INTO "T_DEPARTMENT"("NAME", "LOCATION") VALUES ('TECH', 'GUANGZHOU'); +INSERT INTO "T_DEPARTMENT"("NAME", "LOCATION") VALUES ('FINANCE', 'BEIJING'); + +INSERT INTO "T_EMPLOYEE"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") + VALUES ('VINCE', 'ENGINEER', NULL, '2018-01-01', 100, 1); +INSERT INTO "T_EMPLOYEE"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") + VALUES ('MARRY', 'TRAINEE', 1, '2019-01-01', 50, 1); + +INSERT INTO "T_EMPLOYEE"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") + VALUES ('TOM', 'DIRECTOR', NULL, '2018-01-01', 200, 2); +INSERT INTO "T_EMPLOYEE"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") + VALUES ('PENNY', 'ASSISTANT', 3, '2019-01-01', 100, 2); + + + +CREATE TABLE "T_EMPLOYEE0"( + "ID" INT NOT NULL PRIMARY KEY AUTO_INCREMENT, + "NAME" VARCHAR(128) NOT NULL, + "JOB" VARCHAR(128) NOT NULL, + "MANAGER_ID" INT NULL, + "HIRE_DATE" DATE NOT NULL, + "SALARY" BIGINT NOT NULL, + "DEPARTMENT_ID" INT NOT NULL +); + +INSERT INTO "T_EMPLOYEE0"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") + VALUES ('VINCE', 'ENGINEER', NULL, '2018-01-01', 100, 1); +INSERT INTO "T_EMPLOYEE0"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") + VALUES ('MARRY', 'TRAINEE', 1, '2019-01-01', 50, 1); + +INSERT INTO "T_EMPLOYEE0"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") + VALUES ('TOM', 'DIRECTOR', NULL, '2018-01-01', 200, 2); +INSERT INTO "T_EMPLOYEE0"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") + VALUES ('PENNY', 'ASSISTANT', 3, '2019-01-01', 100, 2); \ No newline at end of file From 37eb21e0860584a2e98ce03a9a8fba82d651dcb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E9=94=90=E9=9B=84?= <641571835@qq.com> Date: Fri, 18 Feb 2022 09:42:07 +0800 Subject: [PATCH 02/17] update database metadata --- .../org/ktorm/r2dbc/database/Database.kt | 96 ++++++++++++++++++- .../ktorm/r2dbc/expression/SqlFormatter.kt | 39 ++++---- 2 files changed, 112 insertions(+), 23 deletions(-) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt index cfebf11..64ea007 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt @@ -31,7 +31,6 @@ public class Database( * The name of the connected database product, eg. MySQL, H2. */ public val productName: String - /** * The version of the connected database product. */ @@ -42,12 +41,92 @@ public class Database( */ public val keywords: Set + /** + * The string used to quote SQL identifiers, returns an empty string if identifier quoting is not supported. + */ + public val identifierQuoteString: String + + /** + * All the "extra" characters that can be used in unquoted identifier names (those beyond a-z, A-Z, 0-9 and _). + */ + public val extraNameCharacters: String + + /** + * Whether this database treats mixed case unquoted SQL identifiers as case sensitive and as a result + * stores them in mixed case. + * + * @since 3.1.0 + */ + public val supportsMixedCaseIdentifiers: Boolean + + /** + * Whether this database treats mixed case unquoted SQL identifiers as case insensitive and + * stores them in mixed case. + * + * @since 3.1.0 + */ + public val storesMixedCaseIdentifiers: Boolean + + /** + * Whether this database treats mixed case unquoted SQL identifiers as case insensitive and + * stores them in upper case. + * + * @since 3.1.0 + */ + public val storesUpperCaseIdentifiers: Boolean + + /** + * Whether this database treats mixed case unquoted SQL identifiers as case insensitive and + * stores them in lower case. + * + * @since 3.1.0 + */ + public val storesLowerCaseIdentifiers: Boolean + + /** + * Whether this database treats mixed case quoted SQL identifiers as case sensitive and as a result + * stores them in mixed case. + * + * @since 3.1.0 + */ + public val supportsMixedCaseQuotedIdentifiers: Boolean + + /** + * Whether this database treats mixed case quoted SQL identifiers as case insensitive and + * stores them in mixed case. + * + * @since 3.1.0 + */ + public val storesMixedCaseQuotedIdentifiers: Boolean + + /** + * Whether this database treats mixed case quoted SQL identifiers as case insensitive and + * stores them in upper case. + * + * @since 3.1.0 + */ + public val storesUpperCaseQuotedIdentifiers: Boolean + + /** + * Whether this database treats mixed case quoted SQL identifiers as case insensitive and + * stores them in lower case. + * + * @since 3.1.0 + */ + public val storesLowerCaseQuotedIdentifiers: Boolean + + /** + * The maximum number of characters this database allows for a column name. Zero means that there is no limit + * or the limit is not known. + * + * @since 3.1.0 + */ + public val maxColumnNameLength: Int init { fun kotlin.Result.orEmpty() = getOrNull().orEmpty() - fun kotlin.Result.orFalse() = getOrDefault(false) runBlocking { useConnection { conn -> @@ -55,6 +134,17 @@ public class Database( productName = metadata.runCatching { databaseProductName }.orEmpty() productVersion = metadata.runCatching { databaseVersion }.orEmpty() keywords = ANSI_SQL_2003_KEYWORDS + dialect.sqlKeywords + identifierQuoteString = dialect.identifierQuoteString + extraNameCharacters = dialect.extraNameCharacters + supportsMixedCaseIdentifiers = dialect.supportsMixedCaseIdentifiers + storesMixedCaseIdentifiers = dialect.storesMixedCaseIdentifiers + storesUpperCaseIdentifiers = dialect.storesUpperCaseIdentifiers + storesLowerCaseIdentifiers = dialect.storesLowerCaseIdentifiers + supportsMixedCaseQuotedIdentifiers = dialect.supportsMixedCaseQuotedIdentifiers + storesMixedCaseQuotedIdentifiers = dialect.storesMixedCaseQuotedIdentifiers + storesUpperCaseQuotedIdentifiers = dialect.storesUpperCaseQuotedIdentifiers + storesLowerCaseQuotedIdentifiers = dialect.storesLowerCaseQuotedIdentifiers + maxColumnNameLength = dialect.maxColumnNameLength } if (logger.isInfoEnabled()) { @@ -248,4 +338,4 @@ public class Database( ) } } -} \ No newline at end of file +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlFormatter.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlFormatter.kt index 2c767df..b0400c7 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlFormatter.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlFormatter.kt @@ -82,7 +82,7 @@ public abstract class SqlFormatter( _builder.append(keyword.lowercase()) } null -> { - if (database.dialect.supportsMixedCaseIdentifiers || !database.dialect.storesLowerCaseIdentifiers) { + if (database.supportsMixedCaseIdentifiers || !database.storesLowerCaseIdentifiers) { _builder.append(keyword.uppercase()) } else { _builder.append(keyword.lowercase()) @@ -104,8 +104,8 @@ public abstract class SqlFormatter( return true } if (identifier.isMixedCase - && !database.dialect.supportsMixedCaseIdentifiers - && database.dialect.supportsMixedCaseQuotedIdentifiers + && !database.supportsMixedCaseIdentifiers + && database.supportsMixedCaseQuotedIdentifiers ) { return true } @@ -113,37 +113,36 @@ public abstract class SqlFormatter( } protected val String.quoted: String get() { - val dialect = database.dialect if (shouldQuote(this)) { - if (dialect.supportsMixedCaseQuotedIdentifiers) { - return "${dialect.identifierQuoteString}${this}${dialect.identifierQuoteString}" + if (database.supportsMixedCaseQuotedIdentifiers) { + return "${database.identifierQuoteString}${this}${database.identifierQuoteString}" } else { - if (dialect.storesUpperCaseQuotedIdentifiers) { - return "${dialect.identifierQuoteString}${this.uppercase()}${dialect.identifierQuoteString}" + if (database.storesUpperCaseQuotedIdentifiers) { + return "${database.identifierQuoteString}${this.uppercase()}${database.identifierQuoteString}" } - if (dialect.storesLowerCaseQuotedIdentifiers) { - return "${dialect.identifierQuoteString}${this.lowercase()}${dialect.identifierQuoteString}" + if (database.storesLowerCaseQuotedIdentifiers) { + return "${database.identifierQuoteString}${this.lowercase()}${database.identifierQuoteString}" } - if (dialect.storesMixedCaseQuotedIdentifiers) { - return "${dialect.identifierQuoteString}${this}${dialect.identifierQuoteString}" + if (database.storesMixedCaseQuotedIdentifiers) { + return "${database.identifierQuoteString}${this}${database.identifierQuoteString}" } - // Should never happen, but it's still needed as some dialect drivers are not implemented correctly. - return "${dialect.identifierQuoteString}${this}${dialect.identifierQuoteString}" + // Should never happen, but it's still needed as some database drivers are not implemented correctly. + return "${database.identifierQuoteString}${this}${database.identifierQuoteString}" } } else { - if (dialect.supportsMixedCaseIdentifiers) { + if (database.supportsMixedCaseIdentifiers) { return this } else { - if (dialect.storesUpperCaseIdentifiers) { + if (database.storesUpperCaseIdentifiers) { return this.uppercase() } - if (dialect.storesLowerCaseIdentifiers) { + if (database.storesLowerCaseIdentifiers) { return this.lowercase() } - if (dialect.storesMixedCaseIdentifiers) { + if (database.storesMixedCaseIdentifiers) { return this } - // Should never happen, but it's still needed as some dialect drivers are not implemented correctly. + // Should never happen, but it's still needed as some database drivers are not implemented correctly. return this } } @@ -171,7 +170,7 @@ public abstract class SqlFormatter( if (this == '_') { return true } - if (this in database.dialect.extraNameCharacters) { + if (this in database.extraNameCharacters) { return true } return false From 962f16af19abe8b28bd785274a77e6504d7f6786 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E9=94=90=E9=9B=84?= <641571835@qq.com> Date: Fri, 18 Feb 2022 10:47:03 +0800 Subject: [PATCH 03/17] update coroutine transaction --- .../database/CoroutinesTransactionManager.kt | 31 ++++++++---- .../org/ktorm/r2dbc/database/Database.kt | 47 +++++++++++-------- .../r2dbc/database/TransactionManager.kt | 12 ++--- .../kotlin/org/ktorm/database/DatabaseTest.kt | 12 +++-- 4 files changed, 60 insertions(+), 42 deletions(-) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt index 915e10f..3c929b0 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt @@ -5,6 +5,10 @@ import io.r2dbc.spi.ConnectionFactory import io.r2dbc.spi.IsolationLevel import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactive.awaitSingle +import kotlinx.coroutines.withContext +import kotlin.coroutines.AbstractCoroutineContextElement +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.coroutineContext /** * Created by vince on Jan 30, 2021. @@ -12,29 +16,38 @@ import kotlinx.coroutines.reactive.awaitSingle public class CoroutinesTransactionManager( public val connectionFactory: ConnectionFactory ) : TransactionManager { - private val currentTransaction = CoroutineLocal() override val defaultIsolation: IsolationLevel? = null override suspend fun getCurrentTransaction(): Transaction? { - return currentTransaction.get() + return coroutineContext[TransactionKey] } - override suspend fun newTransaction(isolation: IsolationLevel?): Transaction { - if (currentTransaction.get() != null) { + public override suspend fun useTransaction( + isolation: IsolationLevel?, + func: suspend (Transaction) -> T + ): T { + val currentTransaction = coroutineContext[TransactionKey] + if (currentTransaction != null) { throw IllegalStateException("There is already a transaction in the current context.") } - val transaction = TransactionImpl(connectionFactory.create().awaitSingle(), isolation) - currentTransaction.set(transaction) transaction.begin() - return transaction + return withContext(coroutineContext + transaction) { + func(transaction) + } } + /** + * Key of [TransactionImpl] in [CoroutineContext]. + */ + public companion object TransactionKey : CoroutineContext.Key + private inner class TransactionImpl( override val connection: Connection, private val desiredIsolation: IsolationLevel? - ) : Transaction { + ) : Transaction, AbstractCoroutineContextElement(TransactionKey) { + private val originIsolation = connection.transactionIsolationLevel private val originAutoCommit = connection.isAutoCommit @@ -75,8 +88,6 @@ public class CoroutinesTransactionManager( try { connection.close().awaitFirstOrNull() } catch (_: Throwable) { - } finally { - currentTransaction.remove() } } } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt index 64ea007..435c3e9 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt @@ -1,6 +1,7 @@ package org.ktorm.r2dbc.database import io.r2dbc.spi.* +import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.reactive.awaitFirst import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactive.awaitSingle @@ -31,6 +32,7 @@ public class Database( * The name of the connected database product, eg. MySQL, H2. */ public val productName: String + /** * The version of the connected database product. */ @@ -124,7 +126,6 @@ public class Database( public val maxColumnNameLength: Int - init { fun kotlin.Result.orEmpty() = getOrNull().orEmpty() @@ -175,30 +176,35 @@ public class Database( } @OptIn(ExperimentalContracts::class) - public suspend inline fun useTransaction(isolation: IsolationLevel? = null, func: (Transaction) -> T): T { + public suspend fun useTransaction( + isolation: IsolationLevel? = null, + func: suspend (Transaction) -> T + ): T { contract { callsInPlace(func, InvocationKind.EXACTLY_ONCE) } val current = transactionManager.getCurrentTransaction() - val isOuter = current == null - val transaction = current ?: transactionManager.newTransaction(isolation) - var throwable: Throwable? = null - try { - return func(transaction) - } catch (e: R2dbcException) { - throwable = exceptionTranslator?.invoke(e) ?: e - throw throwable - } catch (e: Throwable) { - throwable = e - throw throwable - } finally { - if (isOuter) { + if (current != null) { + return func(current) + } else { + return transactionManager.useTransaction(isolation) { + var throwable: Throwable? = null try { - if (throwable == null) transaction.commit() else transaction.rollback() + func(it) + } catch (e: R2dbcException) { + throwable = exceptionTranslator?.invoke(e) ?: e + throw throwable + } catch (e: Throwable) { + throwable = e + throw throwable } finally { - transaction.close() + try { + if (throwable == null) it.commit() else it.rollback() + } finally { + it.close() + } } } } @@ -263,6 +269,7 @@ public class Database( return effects } } + /** * Batch execute the given SQL expressions and return the effected row counts for each expression. * @@ -300,9 +307,9 @@ public class Database( val results = statement.execute().toList() - /* if (logaddBatchger.isDebugEnabled()) { - logger.debug("Effects: ${results?.contentToString()}") - }*/ + /* if (logaddBatchger.isDebugEnabled()) { + logger.debug("Effects: ${results?.contentToString()}") + }*/ return results.map { result -> result.rowsUpdated.awaitFirst() }.toIntArray() } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt index 98bc360..54ea24f 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt @@ -38,14 +38,10 @@ public interface TransactionManager { */ public suspend fun getCurrentTransaction(): Transaction? - /** - * Open a new transaction for the current thread using the specific isolation if there is no transaction opened. - * - * @param isolation the transaction isolation, by default, [defaultIsolation] is used. - * @return the new-created transaction. - * @throws [IllegalStateException] if there is already a transaction opened. - */ - public suspend fun newTransaction(isolation: IsolationLevel? = defaultIsolation): Transaction + public suspend fun useTransaction( + isolation: IsolationLevel? = defaultIsolation, + func: suspend (Transaction) -> T + ): T } /** diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt index 7ec2ca9..a37c0bb 100644 --- a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt @@ -83,9 +83,13 @@ class DatabaseTest : BaseTest() { @Test fun tableTest() = runBlocking { - val employees = database.sequenceOf(Employees) - employees.forEach { - println(it) + database.useTransaction { + database.useTransaction { + for (employee in database.employees) { + println(it) + } + throw RuntimeException() + } } assert(true) } @@ -188,4 +192,4 @@ class DatabaseTest : BaseTest() { assert(UIntArray::class.java.defaultValue !== UIntArray::class.java.defaultValue) assert(ULongArray::class.java.defaultValue !== ULongArray::class.java.defaultValue) }*/ -} \ No newline at end of file +} From 91432501b1c593d7ddc31e39c4ee655c0de35659 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E9=94=90=E9=9B=84?= <641571835@qq.com> Date: Fri, 18 Feb 2022 10:48:53 +0800 Subject: [PATCH 04/17] delete coroutineLocal --- .../ktorm/r2dbc/database/CoroutineLocal.kt | 50 ------------------- 1 file changed, 50 deletions(-) delete mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutineLocal.kt diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutineLocal.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutineLocal.kt deleted file mode 100644 index 7ad53b8..0000000 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutineLocal.kt +++ /dev/null @@ -1,50 +0,0 @@ -package org.ktorm.r2dbc.database - -import kotlinx.coroutines.DisposableHandle -import kotlinx.coroutines.Job -import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock -import java.util.* -import kotlin.coroutines.coroutineContext - -/** - * Created by vince on Jan 31, 2021. - */ -public class CoroutineLocal { - private val mutex = Mutex() - private val locals = IdentityHashMap>() - - private class Entry(var value: T, val disposable: DisposableHandle) { - protected fun finalize() { - disposable.dispose() - } - } - - public suspend fun get(): T? { - mutex.withLock { - val job = coroutineContext[Job] ?: error("Coroutine Job doesn't exist in the current context.") - return locals[job]?.value - } - } - - public suspend fun set(value: T) { - mutex.withLock { - val job = coroutineContext[Job] ?: error("Coroutine Job doesn't exist in the current context.") - - val entry = locals[job] - if (entry != null) { - entry.value = value - } else { - locals[job] = Entry(value, job.invokeOnCompletion { locals.remove(job) }) - } - } - } - - public suspend fun remove() { - mutex.withLock { - val job = coroutineContext[Job] ?: error("Coroutine Job doesn't exist in the current context.") - val existing = locals.remove(job) - existing?.disposable?.dispose() - } - } -} From 4581a0aff51a43502f5edfa9814d3eb373293681 Mon Sep 17 00:00:00 2001 From: htt <641571835@qq.com> Date: Fri, 18 Feb 2022 13:20:29 +0800 Subject: [PATCH 05/17] add kotlin flow support (undone) --- .../org/ktorm/r2dbc/database/Database.kt | 13 +-- .../main/kotlin/org/ktorm/r2dbc/dsl/Query.kt | 105 ++++++++---------- .../org/ktorm/r2dbc/entity/EntitySequence.kt | 25 ++--- 3 files changed, 64 insertions(+), 79 deletions(-) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt index 435c3e9..466c38e 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt @@ -1,7 +1,8 @@ package org.ktorm.r2dbc.database import io.r2dbc.spi.* -import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.reactive.asFlow import kotlinx.coroutines.reactive.awaitFirst import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactive.awaitSingle @@ -246,15 +247,9 @@ public class Database( } } - public suspend fun executeQuery(expression: SqlExpression): List { + public suspend fun executeQuery(expression: SqlExpression): Flow { executeExpression(expression) { result -> - val rows = result.map { row, _ -> row }.toList() - - if (logger.isDebugEnabled()) { - logger.debug("Results: ${rows.size}") - } - - return rows + return result.map { row, _ -> row }.asFlow() } } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt index d1d414a..a917bcb 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt @@ -16,11 +16,13 @@ package org.ktorm.r2dbc.dsl +import kotlinx.coroutines.flow.* import org.ktorm.r2dbc.database.Database import org.ktorm.r2dbc.expression.* import org.ktorm.r2dbc.schema.BooleanSqlType import org.ktorm.r2dbc.schema.Column import org.ktorm.r2dbc.schema.ColumnDeclaring +import org.ktorm.r2dbc.schema.LongSqlType import java.sql.ResultSet /** @@ -77,7 +79,7 @@ public class Query(public val database: Database, public val expression: QueryEx database.formatExpression(expression, beautifySql = true).first } - public suspend fun doQuery(expression: QueryExpression = this.expression): List { + public suspend fun doQuery(expression: QueryExpression = this.expression): Flow { return database.executeQuery(expression).map { QueryRow(this@Query, it) } } @@ -100,13 +102,13 @@ public class Query(public val database: Database, public val expression: QueryEx * it does, return the total record count of the query ignoring the offset and limit parameters. This property * is provided to support pagination, we can calculate the page count through dividing it by our page size. */ - public suspend fun totalRecords(): Int { + public suspend fun totalRecords(): Long { return if (expression.offset == null && expression.limit == null) { - doQuery().size + doQuery().count().toLong() } else { val countExpr = expression.toCountExpression() val count = doQuery(countExpr) - .map { it.get(0, Int::class.java) } + .map { LongSqlType.getResult(it, it.metadata, 0) } .firstOrNull() val (sql, _) = database.formatExpression(countExpr, beautifySql = true) count ?: throw IllegalStateException("No result return for sql: $sql") @@ -121,18 +123,6 @@ public class Query(public val database: Database, public val expression: QueryEx return Query(database, expression) } - /** - * Return an iterator over the rows of this query. - * - * Note that this function is simply implemented as `rowSet.iterator()`, so every element returned by the iterator - * exactly shares the same instance as the [rowSet] property. - * - * @see rowSet - * @see ResultSet.iterator - */ - public suspend operator fun iterator(): Iterator { - return doQuery().iterator() - } } /** @@ -419,8 +409,7 @@ public fun Query.unionAll(right: Query): Query { * @since 3.0.0 */ public suspend fun Query.asIterable(): Iterable { - val iterator = iterator() - return Iterable { iterator } + return doQuery().toList() } /** @@ -428,8 +417,8 @@ public suspend fun Query.asIterable(): Iterable { * * @since 3.0.0 */ -public suspend inline fun Query.forEach(action: (row: QueryRow) -> Unit) { - for (row in this) action(row) +public suspend inline fun Query.forEach(crossinline action: (row: QueryRow) -> Unit) { + doQuery().collect { action(it) } } /** @@ -439,9 +428,9 @@ public suspend inline fun Query.forEach(action: (row: QueryRow) -> Unit) { * * @since 3.0.0 */ -public suspend inline fun Query.forEachIndexed(action: (index: Int, row: QueryRow) -> Unit) { +public suspend inline fun Query.forEachIndexed(crossinline action: (index: Int, row: QueryRow) -> Unit) { var index = 0 - for (row in this) action(index++, row) + doQuery().collect { action(index++, it) } } /** @@ -452,7 +441,7 @@ public suspend inline fun Query.forEachIndexed(action: (index: Int, row: QueryRo */ public suspend fun Query.withIndex(): Iterable> { - val iterator = IndexingIterator(iterator()) + val iterator = IndexingIterator(doQuery().toList().iterator()) return Iterable { iterator } } @@ -479,7 +468,7 @@ internal class IndexingIterator(private val iterator: Iterator) : Iter * @since 3.0.0 */ -public suspend inline fun Query.map(transform: (row: QueryRow) -> R): List { +public suspend inline fun Query.map(crossinline transform: (row: QueryRow) -> R): List { return mapTo(ArrayList(), transform) } @@ -491,9 +480,10 @@ public suspend inline fun Query.map(transform: (row: QueryRow) -> R): List> Query.mapTo( destination: C, - transform: (row: QueryRow) -> R + crossinline transform: (row: QueryRow) -> R ): C { - for (row in this) destination += transform(row) + + doQuery().collect { destination += transform(it) } return destination } @@ -504,7 +494,7 @@ public suspend inline fun > Query.mapTo( * @since 3.0.0 */ -public suspend inline fun Query.mapNotNull(transform: (row: QueryRow) -> R?): List { +public suspend inline fun Query.mapNotNull(crossinline transform: (row: QueryRow) -> R?): List { return mapNotNullTo(ArrayList(), transform) } @@ -517,9 +507,11 @@ public suspend inline fun Query.mapNotNull(transform: (row: QueryRow) public suspend inline fun > Query.mapNotNullTo( destination: C, - transform: (row: QueryRow) -> R? + crossinline transform: (row: QueryRow) -> R? ): C { - forEach { row -> transform(row)?.let { destination += it } } + doQuery().collect { row -> + transform(row)?.let { destination += it } + } return destination } @@ -532,7 +524,7 @@ public suspend inline fun > Query.mapNotNul * @since 3.0.0 */ -public suspend inline fun Query.mapIndexed(transform: (index: Int, row: QueryRow) -> R): List { +public suspend inline fun Query.mapIndexed(crossinline transform: (index: Int, row: QueryRow) -> R): List { return mapIndexedTo(ArrayList(), transform) } @@ -547,7 +539,7 @@ public suspend inline fun Query.mapIndexed(transform: (index: Int, row: Quer public suspend inline fun > Query.mapIndexedTo( destination: C, - transform: (index: Int, row: QueryRow) -> R + crossinline transform: (index: Int, row: QueryRow) -> R ): C { var index = 0 return mapTo(destination) { row -> transform(index++, row) } @@ -563,7 +555,7 @@ public suspend inline fun > Query.mapIndexedTo( * @since 3.0.0 */ -public suspend inline fun Query.mapIndexedNotNull(transform: (index: Int, row: QueryRow) -> R?): List { +public suspend inline fun Query.mapIndexedNotNull(crossinline transform: (index: Int, row: QueryRow) -> R?): List { return mapIndexedNotNullTo(ArrayList(), transform) } @@ -579,7 +571,7 @@ public suspend inline fun Query.mapIndexedNotNull(transform: (index: I public suspend inline fun > Query.mapIndexedNotNullTo( destination: C, - transform: (index: Int, row: QueryRow) -> R? + crossinline transform: (index: Int, row: QueryRow) -> R? ): C { forEachIndexed { index, row -> transform(index, row)?.let { destination += it } } return destination @@ -592,7 +584,7 @@ public suspend inline fun > Query.mapIndexe * @since 3.0.0 */ -public suspend inline fun Query.flatMap(transform: (row: QueryRow) -> Iterable): List { +public suspend inline fun Query.flatMap(crossinline transform: (row: QueryRow) -> Iterable): List { return flatMapTo(ArrayList(), transform) } @@ -605,9 +597,9 @@ public suspend inline fun Query.flatMap(transform: (row: QueryRow) -> Iterab public suspend inline fun > Query.flatMapTo( destination: C, - transform: (row: QueryRow) -> Iterable + crossinline transform: (row: QueryRow) -> Iterable ): C { - for (row in this) destination += transform(row) + doQuery().collect { destination += transform(it) } return destination } @@ -618,7 +610,7 @@ public suspend inline fun > Query.flatMapTo( * @since 3.1.0 */ -public suspend inline fun Query.flatMapIndexed(transform: (index: Int, row: QueryRow) -> Iterable): List { +public suspend inline fun Query.flatMapIndexed(crossinline transform: (index: Int, row: QueryRow) -> Iterable): List { return flatMapIndexedTo(ArrayList(), transform) } @@ -631,7 +623,7 @@ public suspend inline fun Query.flatMapIndexed(transform: (index: Int, row: public suspend inline fun > Query.flatMapIndexedTo( destination: C, - transform: (index: Int, row: QueryRow) -> Iterable + crossinline transform: (index: Int, row: QueryRow) -> Iterable ): C { var index = 0 return flatMapTo(destination) { transform(index++, it) } @@ -647,7 +639,7 @@ public suspend inline fun > Query.flatMapIndexedT * @since 3.0.0 */ -public suspend inline fun Query.associate(transform: (row: QueryRow) -> Pair): Map { +public suspend inline fun Query.associate(crossinline transform: (row: QueryRow) -> Pair): Map { return associateTo(LinkedHashMap(), transform) } @@ -663,8 +655,8 @@ public suspend inline fun Query.associate(transform: (row: QueryRow) -> P */ public suspend inline fun Query.associateBy( - keySelector: (row: QueryRow) -> K, - valueTransform: (row: QueryRow) -> V + crossinline keySelector: (row: QueryRow) -> K, + crossinline valueTransform: (row: QueryRow) -> V ): Map { return associateByTo(LinkedHashMap(), keySelector, valueTransform) } @@ -680,9 +672,9 @@ public suspend inline fun Query.associateBy( public suspend inline fun > Query.associateTo( destination: M, - transform: (row: QueryRow) -> Pair + crossinline transform: (row: QueryRow) -> Pair ): M { - for (row in this) destination += transform(row) + doQuery().collect { destination += transform(it) } return destination } @@ -697,10 +689,10 @@ public suspend inline fun > Query.associateTo( public suspend inline fun > Query.associateByTo( destination: M, - keySelector: (row: QueryRow) -> K, - valueTransform: (row: QueryRow) -> V + crossinline keySelector: (row: QueryRow) -> K, + crossinline valueTransform: (row: QueryRow) -> V ): M { - for (row in this) destination.put(keySelector(row), valueTransform(row)) + doQuery().collect { destination.put(keySelector(it), valueTransform(it)) } return destination } @@ -710,9 +702,9 @@ public suspend inline fun > Query.associateByTo * @since 3.0.0 */ -public suspend inline fun Query.fold(initial: R, operation: (acc: R, row: QueryRow) -> R): R { +public suspend inline fun Query.fold(initial: R, crossinline operation: (acc: R, row: QueryRow) -> R): R { var accumulator = initial - for (row in this) accumulator = operation(accumulator, row) + doQuery().collect { accumulator = operation(accumulator, it) } return accumulator } @@ -726,10 +718,10 @@ public suspend inline fun Query.fold(initial: R, operation: (acc: R, row: Qu * @since 3.0.0 */ -public suspend inline fun Query.foldIndexed(initial: R, operation: (index: Int, acc: R, row: QueryRow) -> R): R { +public suspend inline fun Query.foldIndexed(initial: R, crossinline operation: (index: Int, acc: R, row: QueryRow) -> R): R { var index = 0 var accumulator = initial - for (row in this) accumulator = operation(index++, accumulator, row) + doQuery().collect { accumulator = operation(index++, accumulator, it) } return accumulator } @@ -742,24 +734,23 @@ public suspend inline fun Query.foldIndexed(initial: R, operation: (index: I * @since 3.0.0 */ -public suspend fun Query.joinTo( +public suspend inline fun Query.joinTo( buffer: A, separator: CharSequence = ", ", prefix: CharSequence = "", postfix: CharSequence = "", limit: Int = -1, truncated: CharSequence = "...", - transform: (row: QueryRow) -> CharSequence + crossinline transform: (row: QueryRow) -> CharSequence ): A { buffer.append(prefix) var count = 0 - for (row in this) { + doQuery().collect { if (++count > 1) buffer.append(separator) if (limit < 0 || count <= limit) { - buffer.append(transform(row)) + buffer.append(transform(it)) } else { buffer.append(truncated) - break } } buffer.append(postfix) @@ -775,13 +766,13 @@ public suspend fun Query.joinTo( * @since 3.0.0 */ -public suspend fun Query.joinToString( +public suspend inline fun Query.joinToString( separator: CharSequence = ", ", prefix: CharSequence = "", postfix: CharSequence = "", limit: Int = -1, truncated: CharSequence = "...", - transform: (row: QueryRow) -> CharSequence + crossinline transform: (row: QueryRow) -> CharSequence ): String { return joinTo(StringBuilder(), separator, prefix, postfix, limit, truncated, transform).toString() } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt index 317dd03..73961da 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt @@ -16,6 +16,10 @@ package org.ktorm.r2dbc.entity +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList import org.ktorm.r2dbc.database.Database import org.ktorm.r2dbc.database.DialectFeatureNotSupportedException import org.ktorm.r2dbc.dsl.* @@ -107,14 +111,14 @@ public class EntitySequence>( * * This property is delegated to [Query.rowSet], more details can be found in its documentation. */ - public suspend fun getRowSet(): List = query.doQuery() + public suspend fun getRowSet(): Flow = query.doQuery() /** * The total records count of this query ignoring the pagination params. * * This property is delegated to [Query.totalRecords], more details can be found in its documentation. */ - public suspend fun totalRecords(): Int = query.totalRecords() + public suspend fun totalRecords(): Long = query.totalRecords() /** * Return a copy of this [EntitySequence] with the [expression] modified. @@ -136,17 +140,12 @@ public class EntitySequence>( * Return an iterator over the elements of this sequence. */ @Suppress("IteratorNotThrowingNoSuchElementException") - public suspend operator fun iterator(): Iterator { - val iterator = query.iterator() - return object : Iterator { - override fun hasNext(): Boolean { - return iterator.hasNext() - } - - override fun next(): E { - return entityExtractor(iterator.next()) - } - } + private suspend operator fun iterator(): Iterator { + return flow().toList().iterator() + } + + public suspend fun flow(): Flow { + return getRowSet().map(entityExtractor) } } From 57a29dba516b4df3ad5955959cecf8aa811e5ebd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E9=94=90=E9=9B=84?= <641571835@qq.com> Date: Fri, 18 Feb 2022 14:12:58 +0800 Subject: [PATCH 06/17] add kotlin flow support --- .../kotlin/org/ktorm/r2dbc/database/Utils.kt | 4 + .../org/ktorm/r2dbc/entity/EntitySequence.kt | 180 ++++++++---------- .../kotlin/org/ktorm/database/DatabaseTest.kt | 2 +- 3 files changed, 80 insertions(+), 106 deletions(-) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Utils.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Utils.kt index 38ae593..9ac8678 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Utils.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Utils.kt @@ -2,7 +2,11 @@ package org.ktorm.r2dbc.database import io.r2dbc.spi.Blob import io.r2dbc.spi.Clob +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.collect import kotlinx.coroutines.reactive.collect +import kotlinx.coroutines.sync.Semaphore import org.reactivestreams.Publisher import org.reactivestreams.Subscriber import org.reactivestreams.Subscription diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt index 73961da..75d8e06 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt @@ -16,10 +16,7 @@ package org.ktorm.r2dbc.entity -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.map -import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.flow.* import org.ktorm.r2dbc.database.Database import org.ktorm.r2dbc.database.DialectFeatureNotSupportedException import org.ktorm.r2dbc.dsl.* @@ -132,21 +129,13 @@ public class EntitySequence>( * elements when being iterated. */ public suspend fun asKotlinSequence(): Sequence { - val iterator = iterator() - return Sequence { iterator } + return flow().toList().asSequence() } - /** - * Return an iterator over the elements of this sequence. - */ - @Suppress("IteratorNotThrowingNoSuchElementException") - private suspend operator fun iterator(): Iterator { - return flow().toList().iterator() - } - - public suspend fun flow(): Flow { + public suspend fun flow():Flow { return getRowSet().map(entityExtractor) } + } /** @@ -169,7 +158,7 @@ public fun > Database.sequenceOf( * The operation is terminal. */ public suspend fun > EntitySequence.toCollection(destination: C): C { - for (element in this) destination += element + flow().collect { destination += it } return destination } @@ -316,7 +305,7 @@ public suspend inline fun , C : MutableCollection EntitySequence.map(transform: (E) -> R): List { +public suspend inline fun EntitySequence.map(crossinline transform: (E) -> R): List { return mapTo(ArrayList(), transform) } @@ -328,9 +317,9 @@ public suspend inline fun EntitySequence.map(transform: (E) - */ public suspend inline fun > EntitySequence.mapTo( destination: C, - transform: (E) -> R + crossinline transform: (E) -> R ): C { - for (element in this) destination += transform(element) + flow().collect { destination += transform(it) } return destination } @@ -342,7 +331,7 @@ public suspend inline fun > EntitySequen * * @since 3.0.0 */ -public suspend inline fun EntitySequence.mapNotNull(transform: (E) -> R?): List { +public suspend inline fun EntitySequence.mapNotNull(crossinline transform: (E) -> R?): List { return mapNotNullTo(ArrayList(), transform) } @@ -356,9 +345,9 @@ public suspend inline fun EntitySequence.mapNotNull(tra */ public suspend inline fun > EntitySequence.mapNotNullTo( destination: C, - transform: (E) -> R? + crossinline transform: (E) -> R? ): C { - forEach { element -> transform(element)?.let { destination += it } } + flow().collect { element -> transform(element)?.let { destination += it } } return destination } @@ -371,7 +360,7 @@ public suspend inline fun > Entity * * The operation is terminal. */ -public suspend inline fun EntitySequence.mapIndexed(transform: (index: Int, E) -> R): List { +public suspend inline fun EntitySequence.mapIndexed(crossinline transform: (index: Int, E) -> R): List { return mapIndexedTo(ArrayList(), transform) } @@ -386,7 +375,7 @@ public suspend inline fun EntitySequence.mapIndexed(transform */ public suspend inline fun > EntitySequence.mapIndexedTo( destination: C, - transform: (index: Int, E) -> R + crossinline transform: (index: Int, E) -> R ): C { var index = 0 return mapTo(destination) { transform(index++, it) } @@ -403,7 +392,7 @@ public suspend inline fun > EntitySequen * * @since 3.0.0 */ -public suspend inline fun EntitySequence.mapIndexedNotNull(transform: (index: Int, E) -> R?): List { +public suspend inline fun EntitySequence.mapIndexedNotNull(crossinline transform: (index: Int, E) -> R?): List { return mapIndexedNotNullTo(ArrayList(), transform) } @@ -420,9 +409,9 @@ public suspend inline fun EntitySequence.mapIndexedNotN */ public suspend inline fun > EntitySequence.mapIndexedNotNullTo( destination: C, - transform: (index: Int, E) -> R? + crossinline transform: (index: Int, E) -> R? ): C { - forEachIndexed { index, element -> transform(index, element)?.let { destination += it } } + flow().collectIndexed { index, element -> transform(index, element)?.let { destination += it } } return destination } @@ -435,7 +424,7 @@ public suspend inline fun > Entity * @since 3.0.0 */ public suspend inline fun EntitySequence.flatMap(transform: (E) -> Iterable): List { - return flatMapTo(ArrayList(), transform) + return flow().toList().flatMapTo(ArrayList(), transform) } /** @@ -448,9 +437,9 @@ public suspend inline fun EntitySequence.flatMap(transform: ( */ public suspend inline fun > EntitySequence.flatMapTo( destination: C, - transform: (E) -> Iterable + crossinline transform: (E) -> Iterable ): C { - for (element in this) destination += transform(element) + flow().collect { destination += transform(it) } return destination } @@ -462,7 +451,7 @@ public suspend inline fun > EntitySequen * * @since 3.1.0 */ -public suspend inline fun EntitySequence.flatMapIndexed(transform: (index: Int, E) -> Iterable): List { +public suspend inline fun EntitySequence.flatMapIndexed(crossinline transform: (index: Int, E) -> Iterable): List { return flatMapIndexedTo(ArrayList(), transform) } @@ -476,7 +465,7 @@ public suspend inline fun EntitySequence.flatMapIndexed(trans */ public suspend inline fun > EntitySequence.flatMapIndexedTo( destination: C, - transform: (index: Int, E) -> Iterable + crossinline transform: (index: Int, E) -> Iterable ): C { var index = 0 return flatMapTo(destination) { transform(index++, it) } @@ -719,12 +708,13 @@ public suspend inline fun , reified C : Any> EntitySeq ) val rowSet = Query(database, expr).doQuery() - if (rowSet.size == 1) { + val count = rowSet.count() + if (count == 1) { val row = rowSet.first() return aggregation.sqlType.getResult(row, row.metadata,0) } else { val (sql, _) = database.formatExpression(expr, beautifySql = true) - throw IllegalStateException("Expected 1 row but ${rowSet.size} returned from sql: \n\n$sql") + throw IllegalStateException("Expected 1 row but $count returned from sql: \n\n$sql") } } @@ -875,7 +865,7 @@ public suspend inline fun > EntitySequence.avera * * The operation is terminal. */ -public suspend inline fun EntitySequence.associate(transform: (E) -> Pair): Map { +public suspend inline fun EntitySequence.associate(crossinline transform: (E) -> Pair): Map { return associateTo(LinkedHashMap(), transform) } @@ -890,7 +880,7 @@ public suspend inline fun EntitySequence.associate(transfo * The operation is terminal. */ public suspend inline fun EntitySequence.associateBy(keySelector: (E) -> K): Map { - return associateByTo(LinkedHashMap(), keySelector) + return flow().toList().associateByTo(LinkedHashMap(), keySelector) } /** @@ -907,7 +897,7 @@ public suspend inline fun EntitySequence.associateBy( keySelector: (E) -> K, valueTransform: (E) -> V ): Map { - return associateByTo(LinkedHashMap(), keySelector, valueTransform) + return flow().toList().associateByTo(LinkedHashMap(), keySelector, valueTransform) } /** @@ -921,7 +911,7 @@ public suspend inline fun EntitySequence.associateBy( * The operation is terminal. */ public suspend inline fun , V> EntitySequence.associateWith(valueSelector: (K) -> V): Map { - return associateWithTo(LinkedHashMap(), valueSelector) + return flow().toList().associateWithTo(LinkedHashMap(), valueSelector) } /** @@ -934,9 +924,9 @@ public suspend inline fun , V> EntitySequence.associateWith( */ public suspend inline fun > EntitySequence.associateTo( destination: M, - transform: (E) -> Pair + crossinline transform: (E) -> Pair ): M { - for (element in this) destination += transform(element) + flow().collect { destination += transform(it) } return destination } @@ -950,9 +940,9 @@ public suspend inline fun > EntitySequ */ public suspend inline fun > EntitySequence.associateByTo( destination: M, - keySelector: (E) -> K + crossinline keySelector: (E) -> K ): M { - for (element in this) destination.put(keySelector(element), element) + flow().collect { destination.put(keySelector(it), it) } return destination } @@ -966,10 +956,10 @@ public suspend inline fun > EntitySequenc */ public suspend inline fun > EntitySequence.associateByTo( destination: M, - keySelector: (E) -> K, - valueTransform: (E) -> V + crossinline keySelector: (E) -> K, + crossinline valueTransform: (E) -> V ): M { - for (element in this) destination.put(keySelector(element), valueTransform(element)) + flow().collect { destination.put(keySelector(it), valueTransform(it)) } return destination } @@ -983,9 +973,9 @@ public suspend inline fun > EntitySequ */ public suspend inline fun , V, M : MutableMap> EntitySequence.associateWithTo( destination: M, - valueSelector: (K) -> V + crossinline valueSelector: (K) -> V ): M { - for (element in this) destination.put(element, valueSelector(element)) + flow().collect { destination.put(it, valueSelector(it)) } return destination } @@ -1002,18 +992,11 @@ public suspend inline fun , V, M : MutableMap> EntityS public suspend fun > EntitySequence.elementAtOrNull(index: Int): E? { try { @Suppress("UnconditionalJumpStatementInLoop") - for (element in this.drop(index).take(1)) return element - return null + return this.drop(index).take(1).flow().firstOrNull() } catch (e: DialectFeatureNotSupportedException) { if (database.logger.isTraceEnabled()) { database.logger.trace("Pagination is not supported, retrieving all records instead: ", e) } - - var count = 0 - for (element in this) { - if (index == count++) return element - } - return null } } @@ -1115,9 +1098,7 @@ public suspend inline fun > EntitySequence.first * The operation is terminal. */ public suspend fun EntitySequence.lastOrNull(): E? { - var last: E? = null - for (element in this) last = element - return last + return flow().lastOrNull() } /** @@ -1184,11 +1165,7 @@ public suspend inline fun > EntitySequence.findL * The operation is terminal. */ public suspend fun > EntitySequence.singleOrNull(): E? { - val iterator = iterator() - if (!iterator.hasNext()) return null - - val single = iterator.next() - return if (iterator.hasNext()) null else single + return flow().firstOrNull() } /** @@ -1209,12 +1186,7 @@ public suspend inline fun > EntitySequence.singl * The operation is terminal. */ public suspend fun > EntitySequence.single(): E { - val iterator = iterator() - if (!iterator.hasNext()) throw NoSuchElementException("Sequence is empty.") - - val single = iterator.next() - if (iterator.hasNext()) throw IllegalArgumentException("Sequence has more than one element.") - return single + return flow().first() } /** @@ -1235,9 +1207,9 @@ public suspend inline fun > EntitySequence.singl * * The operation is terminal. */ -public suspend inline fun EntitySequence.fold(initial: R, operation: (acc: R, E) -> R): R { +public suspend inline fun EntitySequence.fold(initial: R, crossinline operation: (acc: R, E) -> R): R { var accumulator = initial - for (element in this) accumulator = operation(accumulator, element) + flow().collect { accumulator = operation(accumulator, it) } return accumulator } @@ -1252,11 +1224,11 @@ public suspend inline fun EntitySequence.fold(initial: R, ope */ public suspend inline fun EntitySequence.foldIndexed( initial: R, - operation: (index: Int, acc: R, E) -> R + crossinline operation: (index: Int, acc: R, E) -> R ): R { var index = 0 var accumulator = initial - for (element in this) accumulator = operation(index++, accumulator, element) + flow().collect { accumulator = operation(index++, accumulator, it) } return accumulator } @@ -1272,7 +1244,7 @@ public suspend inline fun EntitySequence.foldIndexed( * * The operation is terminal. */ -public suspend inline fun EntitySequence.reduce(operation: (acc: E, E) -> E): E { +public suspend inline fun EntitySequence.reduce(crossinline operation: (acc: E, E) -> E): E { return reduceOrNull(operation) ?: throw UnsupportedOperationException("Empty sequence can't be reduced.") } @@ -1288,7 +1260,7 @@ public suspend inline fun EntitySequence.reduce(operation: (acc: * * The operation is terminal. */ -public suspend inline fun EntitySequence.reduceIndexed(operation: (index: Int, acc: E, E) -> E): E { +public suspend inline fun EntitySequence.reduceIndexed(crossinline operation: (index: Int, acc: E, E) -> E): E { return reduceIndexedOrNull(operation) ?: throw UnsupportedOperationException("Empty sequence can't be reduced.") } @@ -1305,15 +1277,15 @@ public suspend inline fun EntitySequence.reduceIndexed(operation * * @since 3.1.0 */ -public suspend inline fun EntitySequence.reduceOrNull(operation: (acc: E, E) -> E): E? { - val iterator = iterator() - if (!iterator.hasNext()) return null - - var accumulator = iterator.next() - while (iterator.hasNext()) { - accumulator = operation(accumulator, iterator.next()) +public suspend inline fun EntitySequence.reduceOrNull(crossinline operation: (acc: E, E) -> E): E? { + var accumulator: E? = null + flow().collect { + if (accumulator == null) { + accumulator = it + } else { + accumulator = operation(accumulator!!,it) + } } - return accumulator } @@ -1330,7 +1302,7 @@ public suspend inline fun EntitySequence.reduceOrNull(operation: * * @since 3.1.0 */ -public suspend inline fun EntitySequence.reduceIndexedOrNull(operation: (index: Int, acc: E, E) -> E): E? { +public suspend inline fun EntitySequence.reduceIndexedOrNull(crossinline operation: (index: Int, acc: E, E) -> E): E? { var index = 1 return reduceOrNull { acc, e -> operation(index++, acc, e) } } @@ -1340,8 +1312,8 @@ public suspend inline fun EntitySequence.reduceIndexedOrNull(ope * * The operation is terminal. */ -public suspend inline fun EntitySequence.forEach(action: (E) -> Unit) { - for (element in this) action(element) +public suspend inline fun EntitySequence.forEach(crossinline action: (E) -> Unit) { + flow().collect { action(it) } } /** @@ -1351,9 +1323,9 @@ public suspend inline fun EntitySequence.forEach(action: (E) -> * * The operation is terminal. */ -public suspend inline fun EntitySequence.forEachIndexed(action: (index: Int, E) -> Unit) { +public suspend inline fun EntitySequence.forEachIndexed(crossinline action: (index: Int, E) -> Unit) { var index = 0 - for (element in this) action(index++, element) + flow().collect { action(index++, it) } } /** @@ -1363,7 +1335,7 @@ public suspend inline fun EntitySequence.forEachIndexed(action: * @since 3.0.0 */ public suspend fun EntitySequence.withIndex(): Sequence> { - val iterator = iterator() + val iterator = flow().toList().iterator() return Sequence { IndexingIterator(iterator) } } @@ -1375,7 +1347,7 @@ public suspend fun EntitySequence.withIndex(): Sequence EntitySequence.groupBy(keySelector: (E) -> K): Map> { +public suspend inline fun EntitySequence.groupBy(crossinline keySelector: (E) -> K): Map> { return groupByTo(LinkedHashMap(), keySelector) } @@ -1389,8 +1361,8 @@ public suspend inline fun EntitySequence.groupBy(keySelector: * The operation is terminal. */ public suspend inline fun EntitySequence.groupBy( - keySelector: (E) -> K, - valueTransform: (E) -> V + crossinline keySelector: (E) -> K, + crossinline valueTransform: (E) -> V ): Map> { return groupByTo(LinkedHashMap(), keySelector, valueTransform) } @@ -1403,14 +1375,13 @@ public suspend inline fun EntitySequence.groupBy( */ public suspend inline fun >> EntitySequence.groupByTo( destination: M, - keySelector: (E) -> K + crossinline keySelector: (E) -> K ): M { - for (element in this) { - val key = keySelector(element) + flow().collect { + val key = keySelector(it) val list = destination.getOrPut(key) { ArrayList() } - list += element + list += it } - return destination } @@ -1423,13 +1394,13 @@ public suspend inline fun >> Ent */ public suspend inline fun >> EntitySequence.groupByTo( destination: M, - keySelector: (E) -> K, - valueTransform: (E) -> V + crossinline keySelector: (E) -> K, + crossinline valueTransform: (E) -> V ): M { - for (element in this) { - val key = keySelector(element) + flow().collect { + val key = keySelector(it) val list = destination.getOrPut(key) { ArrayList() } - list += valueTransform(element) + list += valueTransform(it) } return destination @@ -1469,13 +1440,12 @@ public suspend fun EntitySequence.joinTo( ): A { buffer.append(prefix) var count = 0 - for (element in this) { + flow().collect { if (++count > 1) buffer.append(separator) if (limit < 0 || count <= limit) { - if (transform != null) buffer.append(transform(element)) else buffer.append(element.toString()) + if (transform != null) buffer.append(transform(it)) else buffer.append(it.toString()) } else { buffer.append(truncated) - break } } buffer.append(postfix) diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt index a37c0bb..0e85d25 100644 --- a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt @@ -85,7 +85,7 @@ class DatabaseTest : BaseTest() { fun tableTest() = runBlocking { database.useTransaction { database.useTransaction { - for (employee in database.employees) { + database.employees.forEach { println(it) } throw RuntimeException() From 9617d7fb663529794699890ba404e7af54ca186b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E9=94=90=E9=9B=84?= <641571835@qq.com> Date: Fri, 18 Feb 2022 14:39:55 +0800 Subject: [PATCH 07/17] update query dsl move EntityDml.kt --- build.gradle | 5 +++++ .../main/kotlin/org/ktorm/r2dbc/dsl/Query.kt | 18 ++---------------- .../r2dbc/{schema => entity}/EntityDml.kt | 7 ++----- 3 files changed, 9 insertions(+), 21 deletions(-) rename ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/{schema => entity}/EntityDml.kt (88%) diff --git a/build.gradle b/build.gradle index 0efbe9d..2b00570 100644 --- a/build.gradle +++ b/build.gradle @@ -98,6 +98,11 @@ subprojects { project -> name = "vince" email = "me@liuwj.me" } + developer { + id = "lookup-cat" + name = "夜里的向日葵" + email = "641571835@qq.com" + } } scm { url = "https://github.com/kotlin-orm/ktorm-r2dbc.git" diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt index a917bcb..8f2ea5e 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt @@ -83,18 +83,6 @@ public class Query(public val database: Database, public val expression: QueryEx return database.executeQuery(expression).map { QueryRow(this@Query, it) } } - /** - * The [ResultSet] object of this query, lazy initialized after first access, obtained from the database by - * executing the generated SQL. - * - * Note that the return type of this property is not a normal [ResultSet], but a [QueryRow] instead. That's - * a special implementation provided by Ktorm, different from normal result sets, it is available offline and - * overrides the indexed access operator. More details can be found in the documentation of [QueryRow]. - */ - /* public val rowSet: QueryRow by lazy(LazyThreadSafetyMode.NONE) { - QueryRow(this, database.executeQuery(expression)) - }*/ - /** * The total record count of this query ignoring the pagination params. * @@ -429,8 +417,7 @@ public suspend inline fun Query.forEach(crossinline action: (row: QueryRow) -> U * @since 3.0.0 */ public suspend inline fun Query.forEachIndexed(crossinline action: (index: Int, row: QueryRow) -> Unit) { - var index = 0 - doQuery().collect { action(index++, it) } + doQuery().collectIndexed { index, it -> action(index, it) } } /** @@ -441,8 +428,7 @@ public suspend inline fun Query.forEachIndexed(crossinline action: (index: Int, */ public suspend fun Query.withIndex(): Iterable> { - val iterator = IndexingIterator(doQuery().toList().iterator()) - return Iterable { iterator } + return doQuery().toList().withIndex() } /** diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/EntityDml.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityDml.kt similarity index 88% rename from ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/EntityDml.kt rename to ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityDml.kt index 4a2fa88..f053367 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/EntityDml.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityDml.kt @@ -14,10 +14,7 @@ * limitations under the License. */ -package org.ktorm.r2dbc.schema - -import org.ktorm.r2dbc.entity.Entity -import org.ktorm.r2dbc.entity.implementation +package org.ktorm.r2dbc.entity internal fun Entity<*>.clearChangesRecursively() { implementation.changedProperties.clear() @@ -27,4 +24,4 @@ internal fun Entity<*>.clearChangesRecursively() { value.clearChangesRecursively() } } -} \ No newline at end of file +} From d9085c5ba9d8eb8900569ee736cd0d652d713fac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E9=94=90=E9=9B=84?= <641571835@qq.com> Date: Fri, 18 Feb 2022 16:19:41 +0800 Subject: [PATCH 08/17] update EntityDml update argsBinding method remove SqlType metadata field --- .../org/ktorm/r2dbc/database/Database.kt | 33 +- .../ktorm/r2dbc/database/R2JdbcExtensions.kt | 2 +- .../main/kotlin/org/ktorm/r2dbc/dsl/Query.kt | 2 +- .../kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt | 6 +- .../org/ktorm/r2dbc/entity/EntityDml.kt | 350 ++++++++++++++++++ .../org/ktorm/r2dbc/entity/EntitySequence.kt | 2 +- .../kotlin/org/ktorm/r2dbc/schema/SqlType.kt | 16 +- .../kotlin/org/ktorm/r2dbc/schema/Table.kt | 2 +- .../kotlin/org/ktorm/database/DatabaseTest.kt | 24 +- 9 files changed, 414 insertions(+), 23 deletions(-) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt index 466c38e..ee8379c 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt @@ -2,6 +2,7 @@ package org.ktorm.r2dbc.database import io.r2dbc.spi.* import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map import kotlinx.coroutines.reactive.asFlow import kotlinx.coroutines.reactive.awaitFirst import kotlinx.coroutines.reactive.awaitFirstOrNull @@ -11,7 +12,9 @@ import org.ktorm.r2dbc.expression.ArgumentExpression import org.ktorm.r2dbc.expression.SqlExpression import org.ktorm.r2dbc.logging.Logger import org.ktorm.r2dbc.logging.detectLoggerImplementation +import org.ktorm.r2dbc.schema.IntSqlType import org.ktorm.r2dbc.schema.SqlType +import java.sql.PreparedStatement import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -293,7 +296,7 @@ public class Database( ) } if (logger.isDebugEnabled()) { - logger.debug("Parameters: " + args.map { "${it.value}(${it.sqlType.javaType.simpleName})" }) + logger.debug("Parameters: " + args.map { it.value.toString() }) } statement.bindParameters(args) @@ -301,12 +304,32 @@ public class Database( } val results = statement.execute().toList() + return results.map { result -> result.rowsUpdated.awaitFirst() }.toIntArray() + } + } - /* if (logaddBatchger.isDebugEnabled()) { - logger.debug("Effects: ${results?.contentToString()}") - }*/ + /** + * Format the given [expression] to a SQL string with its execution arguments, execute it via + * [Statement.execute], then return the effected row count along with the generated keys. + * + * @since 2.7 + * @param expression the SQL expression to be executed. + * @return a [Pair] combines the effected row count and the generated keys. + */ + public suspend fun executeUpdateAndRetrieveKeys(expression: SqlExpression): Pair> { + val (sql, args) = formatExpression(expression) - return results.map { result -> result.rowsUpdated.awaitFirst() }.toIntArray() + if (logger.isDebugEnabled()) { + logger.debug("SQL: $sql") + logger.debug("Parameters: " + args.map { "${it.value}(${it.sqlType.javaType.simpleName})" }) + } + + useConnection { + val statement = it.createStatement(sql) + statement.bindParameters(args) + val rowsUpdated = statement.execute().awaitFirst().rowsUpdated.awaitFirst() + val rows = statement.returnGeneratedValues().execute().awaitFirst().map { row, _ -> row }.asFlow() + return Pair(rowsUpdated,rows) } } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/R2JdbcExtensions.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/R2JdbcExtensions.kt index 0f62ba7..b0b0f48 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/R2JdbcExtensions.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/R2JdbcExtensions.kt @@ -31,7 +31,7 @@ public fun Statement.bindParameters(args: List>) { for ((i, expr) in args.withIndex()) { @Suppress("UNCHECKED_CAST") val sqlType = expr.sqlType as SqlType - sqlType.bindParameter(this, i + 1, expr.value) + sqlType.bindParameter(this, i, expr.value) } } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt index 8f2ea5e..f3074d1 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt @@ -96,7 +96,7 @@ public class Query(public val database: Database, public val expression: QueryEx } else { val countExpr = expression.toCountExpression() val count = doQuery(countExpr) - .map { LongSqlType.getResult(it, it.metadata, 0) } + .map { LongSqlType.getResult(it, 0) } .firstOrNull() val (sql, _) = database.formatExpression(countExpr, beautifySql = true) count ?: throw IllegalStateException("No result return for sql: $sql") diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt index ee70b1f..31e8088 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt @@ -31,7 +31,7 @@ public class QueryRow internal constructor(public val query: Query, private val // Try to find the column by label. for (index in metadata.columnMetadatas.indices) { if (metadata.getColumnMetadata(index).name eq column.label) { - return column.sqlType.getResult(row,metadata,index) + return column.sqlType.getResult(row,index) } } // Return null if the column doesn't exist in the result set. @@ -47,7 +47,7 @@ public class QueryRow internal constructor(public val query: Query, private val return when (indices.size) { 0 -> null // Return null if the column doesn't exist in the result set. - 1 -> return column.sqlType.getResult(row,metadata,indices.first()) + 1 -> return column.sqlType.getResult(row,indices.first()) else -> throw IllegalArgumentException(warningConfusedColumnName(column.name)) } } @@ -59,4 +59,4 @@ public class QueryRow internal constructor(public val query: Query, private val private fun warningConfusedColumnName(name: String): String { return "Confused column name, there are more than one column named '$name' in query: \n\n${query.sql}\n" } -} \ No newline at end of file +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityDml.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityDml.kt index f053367..e150d99 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityDml.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityDml.kt @@ -16,6 +16,331 @@ package org.ktorm.r2dbc.entity +import kotlinx.coroutines.flow.firstOrNull +import org.ktorm.r2dbc.dsl.* +import org.ktorm.r2dbc.dsl.AliasRemover +import org.ktorm.r2dbc.expression.* +import org.ktorm.r2dbc.schema.* + +/** + * Insert the given entity into this sequence and return the affected record number. + * + * If we use an auto-increment key in our table, we need to tell Ktorm which is the primary key by calling + * [Table.primaryKey] while registering columns, then this function will obtain the generated key from the + * database and fill it into the corresponding property after the insertion completes. But this requires us + * not to set the primary key’s value beforehand, otherwise, if you do that, the given value will be inserted + * into the database, and no keys generated. + * + * Note that after calling this function, the [entity] will be ATTACHED to the current database. + * + * @see Entity.flushChanges + * @see Entity.delete + * @since 2.7 + */ +@Suppress("UNCHECKED_CAST") +public suspend fun , T : Table> EntitySequence.add(entity: E): Int { + checkIfSequenceModified() + entity.implementation.checkUnexpectedDiscarding(sourceTable) + + val assignments = entity.findInsertColumns(sourceTable).takeIf { it.isNotEmpty() } ?: return 0 + + val expression = AliasRemover.visit( + expr = InsertExpression( + table = sourceTable.asExpression(), + assignments = assignments.map { (col, argument) -> + ColumnAssignmentExpression( + column = col.asExpression() as ColumnExpression, + expression = ArgumentExpression(argument, col.sqlType as SqlType) + ) + } + ) + ) + + val primaryKeys = sourceTable.primaryKeys + + val ignoreGeneratedKeys = primaryKeys.size != 1 + || primaryKeys[0].binding == null + || entity.implementation.hasColumnValue(primaryKeys[0].binding!!) + + if (ignoreGeneratedKeys) { + val effects = database.executeUpdate(expression) + entity.implementation.fromDatabase = database + entity.implementation.fromTable = sourceTable + entity.implementation.doDiscardChanges() + return effects + } else { + val (effects, rowSet) = database.executeUpdateAndRetrieveKeys(expression) + rowSet.firstOrNull()?.let { row -> + val generatedKey = primaryKeys[0].sqlType.getResult(row, 0) + if (generatedKey != null) { + if (database.logger.isDebugEnabled()) { + database.logger.debug("Generated Key: $generatedKey") + } + + entity.implementation.setColumnValue(primaryKeys[0].binding!!, generatedKey) + } + } + + entity.implementation.fromDatabase = database + entity.implementation.fromTable = sourceTable + entity.implementation.doDiscardChanges() + return effects + } +} + +/** + * Update properties of the given entity to the database and return the affected record number. + * + * Note that after calling this function, the [entity] will be ATTACHED to the current database. + * + * @see Entity.flushChanges + * @see Entity.delete + * @since 3.1.0 + */ +@Suppress("UNCHECKED_CAST") +public suspend fun , T : Table> EntitySequence.update(entity: E): Int { + checkIfSequenceModified() + entity.implementation.checkUnexpectedDiscarding(sourceTable) + + val assignments = entity.findUpdateColumns(sourceTable).takeIf { it.isNotEmpty() } ?: return 0 + + val expression = AliasRemover.visit( + expr = UpdateExpression( + table = sourceTable.asExpression(), + assignments = assignments.map { (col, argument) -> + ColumnAssignmentExpression( + column = col.asExpression() as ColumnExpression, + expression = ArgumentExpression(argument, col.sqlType as SqlType) + ) + }, + where = entity.implementation.constructIdentityCondition(sourceTable) + ) + ) + + val effects = database.executeUpdate(expression) + entity.implementation.fromDatabase = database + entity.implementation.fromTable = sourceTable + entity.implementation.doDiscardChanges() + return effects +} + +/** + * Remove all of the elements of this sequence that satisfy the given [predicate]. + * + * @since 2.7 + */ +public suspend fun > EntitySequence.removeIf( + predicate: (T) -> ColumnDeclaring +): Int { + checkIfSequenceModified() + return database.delete(sourceTable, predicate) +} + +/** + * Remove all of the elements of this sequence. The sequence will be empty after this function returns. + * + * @since 2.7 + */ +public suspend fun > EntitySequence.clear(): Int { + checkIfSequenceModified() + return database.deleteAll(sourceTable) +} + +@Suppress("UNCHECKED_CAST") +internal suspend fun EntityImplementation.doFlushChanges(): Int { + check(parent == null) { "The entity is not attached to any database yet." } + + val fromDatabase = fromDatabase ?: error("The entity is not attached to any database yet.") + val fromTable = fromTable ?: error("The entity is not attached to any database yet.") + checkUnexpectedDiscarding(fromTable) + + val assignments = findChangedColumns(fromTable).takeIf { it.isNotEmpty() } ?: return 0 + + val expression = AliasRemover.visit( + expr = UpdateExpression( + table = fromTable.asExpression(), + assignments = assignments.map { (col, argument) -> + ColumnAssignmentExpression( + column = col.asExpression() as ColumnExpression, + expression = ArgumentExpression(argument, col.sqlType as SqlType) + ) + }, + where = constructIdentityCondition(fromTable) + ) + ) + + return fromDatabase.executeUpdate(expression).also { doDiscardChanges() } +} + +@Suppress("UNCHECKED_CAST") +internal suspend fun EntityImplementation.doDelete(): Int { + check(parent == null) { "The entity is not attached to any database yet." } + + val fromDatabase = fromDatabase ?: error("The entity is not attached to any database yet.") + val fromTable = fromTable ?: error("The entity is not attached to any database yet.") + + val expression = AliasRemover.visit( + expr = DeleteExpression( + table = fromTable.asExpression(), + where = constructIdentityCondition(fromTable) + ) + ) + + return fromDatabase.executeUpdate(expression) +} + +private fun EntitySequence<*, *>.checkIfSequenceModified() { + val isModified = expression.where != null + || expression.groupBy.isNotEmpty() + || expression.having != null + || expression.isDistinct + || expression.orderBy.isNotEmpty() + || expression.offset != null + || expression.limit != null + + if (isModified) { + throw UnsupportedOperationException( + "Entity manipulation functions are not supported by this sequence object. " + + "Please call on the origin sequence returned from database.sequenceOf(table)" + ) + } +} + +private fun Entity<*>.findInsertColumns(table: Table<*>): Map, Any?> { + val assignments = LinkedHashMap, Any?>() + + for (column in table.columns) { + if (column.binding != null && implementation.hasColumnValue(column.binding)) { + assignments[column] = implementation.getColumnValue(column.binding) + } + } + + return assignments +} + +private fun Entity<*>.findUpdateColumns(table: Table<*>): Map, Any?> { + val assignments = LinkedHashMap, Any?>() + + for (column in table.columns - table.primaryKeys) { + if (column.binding != null && implementation.hasColumnValue(column.binding)) { + assignments[column] = implementation.getColumnValue(column.binding) + } + } + + return assignments +} + +private fun EntityImplementation.findChangedColumns(fromTable: Table<*>): Map, Any?> { + val assignments = LinkedHashMap, Any?>() + + for (column in fromTable.columns) { + val binding = column.binding ?: continue + + when (binding) { + is ReferenceBinding -> { + if (binding.onProperty.name in changedProperties) { + val child = this.getProperty(binding.onProperty) as Entity<*>? + assignments[column] = child?.implementation?.getPrimaryKeyValue(binding.referenceTable as Table<*>) + } + } + is NestedBinding -> { + var anyChanged = false + var curr: Any? = this + + for (prop in binding.properties) { + if (curr is Entity<*>) { + curr = curr.implementation + } + + check(curr is EntityImplementation?) + + if (curr != null && prop.name in curr.changedProperties) { + anyChanged = true + } + + curr = curr?.getProperty(prop) + } + + if (anyChanged) { + assignments[column] = curr + } + } + } + } + + return assignments +} + +internal fun EntityImplementation.doDiscardChanges() { + check(parent == null) { "The entity is not attached to any database yet." } + val fromTable = fromTable ?: error("The entity is not attached to any database yet.") + + for (column in fromTable.columns) { + val binding = column.binding ?: continue + + when (binding) { + is ReferenceBinding -> { + changedProperties.remove(binding.onProperty.name) + } + is NestedBinding -> { + var curr: Any? = this + + for (prop in binding.properties) { + if (curr == null) { + break + } + if (curr is Entity<*>) { + curr = curr.implementation + } + + check(curr is EntityImplementation) + curr.changedProperties.remove(prop.name) + curr = curr.getProperty(prop) + } + } + } + } +} + +// Add check to avoid bug #10 +private fun EntityImplementation.checkUnexpectedDiscarding(fromTable: Table<*>) { + for (column in fromTable.columns) { + if (column.binding !is NestedBinding) continue + + var curr: Any? = this + for ((i, prop) in column.binding.properties.withIndex()) { + if (curr == null) { + break + } + if (curr is Entity<*>) { + curr = curr.implementation + } + + check(curr is EntityImplementation) + + if (i > 0 && prop.name in curr.changedProperties) { + val isExternalEntity = curr.fromTable != null && curr.getRoot() != this + if (isExternalEntity) { + val propPath = column.binding.properties.subList(0, i + 1).joinToString(separator = ".") { it.name } + val msg = "this.$propPath may be unexpectedly discarded, please save it to database first." + throw IllegalStateException(msg) + } + } + + curr = curr.getProperty(prop) + } + } +} + +private tailrec fun EntityImplementation.getRoot(): EntityImplementation { + val parent = this.parent + if (parent == null) { + return this + } else { + return parent.getRoot() + } +} + internal fun Entity<*>.clearChangesRecursively() { implementation.changedProperties.clear() @@ -25,3 +350,28 @@ internal fun Entity<*>.clearChangesRecursively() { } } } + +@Suppress("UNCHECKED_CAST") +private fun EntityImplementation.constructIdentityCondition(fromTable: Table<*>): ScalarExpression { + val primaryKeys = fromTable.primaryKeys + if (primaryKeys.isEmpty()) { + error("Table '$fromTable' doesn't have a primary key.") + } + + val conditions = primaryKeys.map { pk -> + if (pk.binding == null) { + error("Primary column $pk has no bindings to any entity field.") + } + + val pkValue = getColumnValue(pk.binding) ?: error("The value of primary key column $pk is null.") + + BinaryExpression( + type = BinaryExpressionType.EQUAL, + left = pk.asExpression(), + right = ArgumentExpression(pkValue, pk.sqlType as SqlType), + sqlType = BooleanSqlType + ) + } + + return conditions.combineConditions().asExpression() +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt index 75d8e06..bb4f8e7 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt @@ -711,7 +711,7 @@ public suspend inline fun , reified C : Any> EntitySeq val count = rowSet.count() if (count == 1) { val row = rowSet.first() - return aggregation.sqlType.getResult(row, row.metadata,0) + return aggregation.sqlType.getResult(row,0) } else { val (sql, _) = database.formatExpression(expr, beautifySql = true) throw IllegalStateException("Expected 1 row but $count returned from sql: \n\n$sql") diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt index a2e4a50..d3bde4e 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt @@ -13,9 +13,9 @@ public interface SqlType { public fun bindParameter(statement: Statement, name: String, value: T?) - public fun getResult(row: Row, metadata: RowMetadata, index: Int): T? + public fun getResult(row: Row, index: Int): T? - public fun getResult(row: Row, metadata: RowMetadata, name: String): T? + public fun getResult(row: Row, name: String): T? } @@ -45,11 +45,11 @@ public open class SimpleSqlType(public val kotlinType: KClass) : Sql } } - override fun getResult(row: Row, metadata: RowMetadata, index: Int): T? { + override fun getResult(row: Row, index: Int): T? { return row.get(index, kotlinType.javaObjectType) } - override fun getResult(row: Row, metadata: RowMetadata, name: String): T? { + override fun getResult(row: Row, name: String): T? { return row.get(name, kotlinType.javaObjectType) } @@ -73,11 +73,11 @@ public class TransformedSqlType( underlyingType.bindParameter(statement, name, value?.let(toUnderlyingValue)) } - override fun getResult(row: Row, metadata: RowMetadata, index: Int): R? { - return underlyingType.getResult(row, metadata, index)?.let(fromUnderlyingValue) + override fun getResult(row: Row, index: Int): R? { + return underlyingType.getResult(row, index)?.let(fromUnderlyingValue) } - override fun getResult(row: Row, metadata: RowMetadata, name: String): R? { - return underlyingType.getResult(row, metadata, name)?.let(fromUnderlyingValue) + override fun getResult(row: Row, name: String): R? { + return underlyingType.getResult(row, name)?.let(fromUnderlyingValue) } } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt index 79d0ba6..8382707 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt @@ -17,7 +17,7 @@ package org.ktorm.r2dbc.schema import org.ktorm.r2dbc.dsl.QueryRow -import org.ktorm.r2dbc.entity.Entity +import org.ktorm.r2dbc.entity.* import org.ktorm.r2dbc.entity.EntityImplementation import org.ktorm.r2dbc.entity.implementation import org.ktorm.r2dbc.entity.setColumnValue diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt index 0e85d25..a2f7d8a 100644 --- a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt @@ -1,14 +1,15 @@ package org.ktorm.database +import kotlinx.coroutines.reactive.awaitFirst import kotlinx.coroutines.reactive.awaitFirstOrNull +import kotlinx.coroutines.reactive.awaitSingle import kotlinx.coroutines.runBlocking import org.junit.Test import org.ktorm.BaseTest import org.ktorm.r2dbc.database.toList import org.ktorm.r2dbc.dsl.insert -import org.ktorm.r2dbc.entity.count -import org.ktorm.r2dbc.entity.forEach -import org.ktorm.r2dbc.entity.sequenceOf +import org.ktorm.r2dbc.entity.* +import java.time.LocalDate /** * Created by vince on Dec 02, 2018. @@ -94,6 +95,23 @@ class DatabaseTest : BaseTest() { assert(true) } + @Test + fun insertTest() = runBlocking { + database.useTransaction { + val department = database.departments.toList().first() + val employee = Employee { + this.name = "vince" + this.job = "engineer" + this.manager = null + this.hireDate = LocalDate.now() + this.salary = 100 + this.department = department + } + val add = database.employees.add(employee) + println(employee) + } + } + /*fun BaseTable<*>.ulong(name: String): Column { return registerColumn(name, object : SqlType(Types.BIGINT, "bigint unsigned") { override fun doSetParameter(ps: PreparedStatement, index: Int, parameter: ULong) { From 7def08ba0ca2bf02c2652c5635d85260d1626b69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E9=94=90=E9=9B=84?= <641571835@qq.com> Date: Fri, 18 Feb 2022 16:31:31 +0800 Subject: [PATCH 09/17] update log println add entityGroup update QueryRow update EntitySequence --- .../org/ktorm/r2dbc/database/Database.kt | 2 +- .../main/kotlin/org/ktorm/r2dbc/dsl/Query.kt | 4 + .../org/ktorm/r2dbc/entity/EntityGrouping.kt | 426 ++++++++++++++++++ .../org/ktorm/r2dbc/entity/EntitySequence.kt | 60 ++- 4 files changed, 460 insertions(+), 32 deletions(-) create mode 100644 ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityGrouping.kt diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt index ee8379c..de8a3b6 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt @@ -321,7 +321,7 @@ public class Database( if (logger.isDebugEnabled()) { logger.debug("SQL: $sql") - logger.debug("Parameters: " + args.map { "${it.value}(${it.sqlType.javaType.simpleName})" }) + logger.debug("Parameters: " + args.map { it.value }) } useConnection { diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt index f3074d1..466bb27 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt @@ -83,6 +83,10 @@ public class Query(public val database: Database, public val expression: QueryEx return database.executeQuery(expression).map { QueryRow(this@Query, it) } } + public suspend fun asFlow(): Flow { + return this.doQuery() + } + /** * The total record count of this query ignoring the pagination params. * diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityGrouping.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityGrouping.kt new file mode 100644 index 0000000..e017643 --- /dev/null +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityGrouping.kt @@ -0,0 +1,426 @@ +/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.r2dbc.entity + +import kotlinx.coroutines.flow.collect +import org.ktorm.r2dbc.dsl.* +import org.ktorm.r2dbc.schema.BaseTable +import org.ktorm.r2dbc.schema.ColumnDeclaring +import java.util.* +import kotlin.collections.ArrayList +import kotlin.collections.LinkedHashMap +import kotlin.experimental.ExperimentalTypeInference + +/** + * Wraps an [EntitySequence] with a [keySelector] function, which can be applied to each record to get its key, + * or used as the `group by` clause of the generated SQL. + * + * An [EntityGrouping] structure serves as an intermediate step in group-and-fold operations: they group elements + * by their keys and then fold each group with some aggregation operation. + * + * Entity groups are created by attaching `keySelector: (T) -> ColumnDeclaring` function to an entity sequence. + * To get an instance of [EntityGrouping], use the extension function [EntitySequence.groupingBy]. + * + * For the list of group-and-fold operations available, see the extension functions below. + * + * @property sequence the source entity sequence of this grouping. + * @property keySelector a function used to extract the key of a record in the source table. + */ +public class EntityGrouping, K : Any>( + public val sequence: EntitySequence, + public val keySelector: (T) -> ColumnDeclaring +) { + /** + * Create a [kotlin.collections.Grouping] instance that wraps this original entity grouping returning all the + * elements in the source sequence when being iterated. + */ + public suspend fun asKotlinGrouping(): Grouping { + val keyColumn = keySelector(sequence.sourceTable) + val expr = sequence.expression.copy( + columns = sequence.expression.columns + keyColumn.aliased("_group_key") + ) + val allEntities = ArrayList() + val allEntitiesWithKeys = IdentityHashMap() + Query(sequence.database,expr).asFlow().collect { + val entity = sequence.sourceTable.createEntity(it) + val groupKey = keyColumn.sqlType.getResult(it, expr.columns.size - 1) + allEntities += entity + allEntitiesWithKeys[entity] = groupKey + } + return object : Grouping { + override fun sourceIterator(): Iterator { + return allEntities.iterator() + } + override fun keyOf(element: E): K? { + return allEntitiesWithKeys[element] + } + } + } +} + +/** + * Group elements from the source sequence by key and perform the given aggregation for elements in each group, + * then store the results in a new [Map]. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, aggregation from source group by key`. + * + * Ktorm also supports aggregating two or more columns, we just need to wrap our aggregate expressions by [tupleOf] + * in the closure, then the function’s return type becomes `Map>`. + * + * @param aggregationSelector a function that accepts the source table and returns the aggregate expression. + * @return a [Map] associating the key of each group with the result of aggregation of the group elements. + */ +@OptIn(ExperimentalTypeInference::class) +@OverloadResolutionByLambdaReturnType +public suspend inline fun , K, C> EntityGrouping.aggregateColumns( + aggregationSelector: (T) -> ColumnDeclaring +): Map where K : Any, C : Any { + return aggregateColumnsTo(LinkedHashMap(), aggregationSelector) +} + +/** + * Group elements from the source sequence by key and perform the given aggregation for elements in each group, + * then store the results in the [destination] map. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, aggregation from source group by key`. + * + * Ktorm also supports aggregating two or more columns, we just need to wrap our aggregate expressions by [tupleOf] + * in the closure, then the function’s return type becomes `Map>`. + * + * @param destination a [MutableMap] used to store the results. + * @param aggregationSelector a function that accepts the source table and returns the aggregate expression. + * @return the [destination] map associating the key of each group with the result of aggregation of the group elements. + */ +@OptIn(ExperimentalTypeInference::class) +@OverloadResolutionByLambdaReturnType +public suspend inline fun , K, C, M> EntityGrouping.aggregateColumnsTo( + destination: M, + aggregationSelector: (T) -> ColumnDeclaring +): M where K : Any, C : Any, M : MutableMap { + val keyColumn = keySelector(sequence.sourceTable) + val aggregation = aggregationSelector(sequence.sourceTable) + + val expr = sequence.expression.copy( + columns = listOf(keyColumn, aggregation).map { it.aliased(null) }, + groupBy = listOf(keyColumn.asExpression()) + ) + + Query(sequence.database,expr).asFlow().collect { + val key = keyColumn.sqlType.getResult(it, 0) + val value = aggregation.sqlType.getResult(it, 1) + destination[key] = value + } + return destination +} + +/** + * Group elements from the source sequence by key and count elements in each group. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, count(*) from source group by key`. + * + * @return a [Map] associating the key of each group with the count of elements in the group. + */ +public suspend fun , K> EntityGrouping.eachCount(): Map where K : Any { + return eachCountTo(LinkedHashMap()) +} + +/** + * Group elements from the source sequence by key and count elements in each group, + * then store the results in the [destination] map. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, count(*) from source group by key`. + * + * @param destination a [MutableMap] used to store the results. + * @return the [destination] map associating the key of each group with the count of elements in the group. + */ +@Suppress("UNCHECKED_CAST") +public suspend fun , K, M> EntityGrouping.eachCountTo( + destination: M +): M where K : Any, M : MutableMap { + return aggregateColumnsTo(destination as MutableMap) { count() } as M +} + +/** + * Group elements from the source sequence by key and sum the columns or expressions provided by the [columnSelector] + * function for elements in each group. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, sum(column) from source group by key`. + * + * @param columnSelector a function that accepts the source table and returns the column or expression for summing. + * @return a [Map] associating the key of each group with the summing result in the group. + */ +public suspend inline fun , K, C> EntityGrouping.eachSumBy( + columnSelector: (T) -> ColumnDeclaring +): Map where K : Any, C : Number { + return eachSumByTo(LinkedHashMap(), columnSelector) +} + +/** + * Group elements from the source sequence by key and sum the columns or expressions provided by the [columnSelector] + * function for elements in each group, then store the results in the [destination] map. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, sum(column) from source group by key`. + * + * @param destination a [MutableMap] used to store the results. + * @param columnSelector a function that accepts the source table and returns the column or expression for summing. + * @return the [destination] map associating the key of each group with the summing result in the group. + */ +public suspend inline fun , K, C, M> EntityGrouping.eachSumByTo( + destination: M, + columnSelector: (T) -> ColumnDeclaring +): M where K : Any, C : Number, M : MutableMap { + return aggregateColumnsTo(destination) { sum(columnSelector(it)) } +} + +/** + * Group elements from the source sequence by key and get the max value of the columns or expressions provided by the + * [columnSelector] function for elements in each group. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, max(column) from source group by key`. + * + * @param columnSelector a function that accepts the source table and returns a column or expression. + * @return a [Map] associating the key of each group with the max value in the group. + */ +public suspend inline fun , K, C> EntityGrouping.eachMaxBy( + columnSelector: (T) -> ColumnDeclaring +): Map where K : Any, C : Comparable { + return eachMaxByTo(LinkedHashMap(), columnSelector) +} + +/** + * Group elements from the source sequence by key and get the max value of the columns or expressions provided by the + * [columnSelector] function for elements in each group, then store the results in the [destination] map. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, max(column) from source group by key`. + * + * @param destination a [MutableMap] used to store the results. + * @param columnSelector a function that accepts the source table and returns a column or expression. + * @return a [destination] map associating the key of each group with the max value in the group. + */ +public suspend inline fun , K, C, M> EntityGrouping.eachMaxByTo( + destination: M, + columnSelector: (T) -> ColumnDeclaring +): M where K : Any, C : Comparable, M : MutableMap { + return aggregateColumnsTo(destination) { max(columnSelector(it)) } +} + +/** + * Group elements from the source sequence by key and get the min value of the columns or expressions provided by the + * [columnSelector] function for elements in each group. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, min(column) from source group by key`. + * + * @param columnSelector a function that accepts the source table and returns a column or expression. + * @return a [Map] associating the key of each group with the min value in the group. + */ +public suspend inline fun , K, C> EntityGrouping.eachMinBy( + columnSelector: (T) -> ColumnDeclaring +): Map where K : Any, C : Comparable { + return eachMinByTo(LinkedHashMap(), columnSelector) +} + +/** + * Group elements from the source sequence by key and get the min value of the columns or expressions provided by the + * [columnSelector] function for elements in each group, then store the results in the [destination] map. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, min(column) from source group by key`. + * + * @param destination a [MutableMap] used to store the results. + * @param columnSelector a function that accepts the source table and returns a column or expression. + * @return a [destination] map associating the key of each group with the min value in the group. + */ +public suspend inline fun , K, C, M> EntityGrouping.eachMinByTo( + destination: M, + columnSelector: (T) -> ColumnDeclaring +): M where K : Any, C : Comparable, M : MutableMap { + return aggregateColumnsTo(destination) { min(columnSelector(it)) } +} + +/** + * Group elements from the source sequence by key and average the columns or expressions provided by the + * [columnSelector] function for elements in each group. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, avg(column) from source group by key`. + * + * @param columnSelector a function that accepts the source table and returns the column or expression for averaging. + * @return a [Map] associating the key of each group with the averaging result in the group. + */ +public suspend inline fun , K> EntityGrouping.eachAverageBy( + columnSelector: (T) -> ColumnDeclaring +): Map where K : Any { + return eachAverageByTo(LinkedHashMap(), columnSelector) +} + +/** + * Group elements from the source sequence by key and average the columns or expressions provided by the + * [columnSelector] function for elements in each group, then store the results in the [destination] map. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, avg(column) from source group by key`. + * + * @param destination a [MutableMap] used to store the results. + * @param columnSelector a function that accepts the source table and returns the column or expression for averaging. + * @return the [destination] map associating the key of each group with the averaging result in the group. + */ +public suspend inline fun , K, M> EntityGrouping.eachAverageByTo( + destination: M, + columnSelector: (T) -> ColumnDeclaring +): M where K : Any, M : MutableMap { + return aggregateColumnsTo(destination) { avg(columnSelector(it)) } +} + +/** + * Groups elements from the source sequence by key and applies [operation] to the elements of each group sequentially, + * passing the previously accumulated value and the current element as arguments, and stores the results in a new map. + * + * This function is delegated to [Grouping.aggregate], more details can be found in its documentation. + */ +public suspend inline fun EntityGrouping.aggregate( + operation: (key: K?, accumulator: R?, element: E, first: Boolean) -> R +): Map { + return aggregateTo(LinkedHashMap(), operation) +} + +/** + * Groups elements from the source sequence by key and applies [operation] to the elements of each group sequentially, + * passing the previously accumulated value and the current element as arguments, and stores the results in the given + * [destination] map. + * + * This function is delegated to [Grouping.aggregateTo], more details can be found in its documentation. + */ +public suspend inline fun > EntityGrouping.aggregateTo( + destination: M, + operation: (key: K?, accumulator: R?, element: E, first: Boolean) -> R +): M { + val grouping = asKotlinGrouping() + + for (element in grouping.sourceIterator()) { + val key = grouping.keyOf(element) + val accumulator = destination[key] + destination[key] = operation(key, accumulator, element, accumulator == null && !destination.containsKey(key)) + } + + return destination +} + +/** + * Groups elements from the source sequence by key and applies [operation] to the elements of each group sequentially, + * passing the previously accumulated value and the current element as arguments, and stores the results in a new map. + * An initial value of accumulator is provided by [initialValueSelector] function. + * + * This function is delegated to [Grouping.fold], more details can be found in its documentation. + */ +public suspend inline fun EntityGrouping.fold( + initialValueSelector: (key: K?, element: E) -> R, + operation: (key: K?, accumulator: R, element: E) -> R +): Map { + return foldTo(LinkedHashMap(), initialValueSelector, operation) +} + +/** + * Groups elements from the source sequence by key and applies [operation] to the elements of each group sequentially, + * passing the previously accumulated value and the current element as arguments, and stores the results in the given + * [destination] map. An initial value of accumulator is provided by [initialValueSelector] function. + * + * This function is delegated to [Grouping.foldTo], more details can be found in its documentation. + */ +@Suppress("UNCHECKED_CAST") +public suspend inline fun > EntityGrouping.foldTo( + destination: M, + initialValueSelector: (key: K?, element: E) -> R, + operation: (key: K?, accumulator: R, element: E) -> R +): M { + return aggregateTo(destination) { key, accumulator, element, first -> + val acc = if (first) initialValueSelector(key, element) else accumulator as R + operation(key, acc, element) + } +} + +/** + * Groups elements from the source sequence by key and applies [operation] to the elements of each group sequentially, + * passing the previously accumulated value and the current element as arguments, and stores the results in a new map. + * An initial value of accumulator is the same [initialValue] for each group. + * + * This function is delegated to [Grouping.fold], more details can be found in its documentation. + */ +public suspend inline fun EntityGrouping.fold( + initialValue: R, + operation: (accumulator: R, element: E) -> R +): Map { + return foldTo(LinkedHashMap(), initialValue, operation) +} + +/** + * Groups elements from the source sequence by key and applies [operation] to the elements of each group sequentially, + * passing the previously accumulated value and the current element as arguments, and stores the results in the given + * [destination] map. An initial value of accumulator is the same [initialValue] for each group. + * + * This function is delegated to [Grouping.foldTo], more details can be found in its documentation. + */ +@Suppress("UNCHECKED_CAST") +public suspend inline fun > EntityGrouping.foldTo( + destination: M, + initialValue: R, + operation: (accumulator: R, element: E) -> R +): M { + return aggregateTo(destination) { _, accumulator, element, first -> + val acc = if (first) initialValue else accumulator as R + operation(acc, element) + } +} + +/** + * Groups elements from the source sequence by key and applies the reducing [operation] to the elements of each group + * sequentially starting from the second element of the group, passing the previously accumulated value and the current + * element as arguments, and stores the results in a new map. An initial value of accumulator is the first element of + * the group. + * + * This function is delegated to [Grouping.reduce], more details can be found in its documentation. + */ +public suspend inline fun EntityGrouping.reduce( + operation: (key: K?, accumulator: E, element: E) -> E +): Map { + return reduceTo(LinkedHashMap(), operation) +} + +/** + * Groups elements from the source sequence by key and applies the reducing [operation] to the elements of each group + * sequentially starting from the second element of the group, passing the previously accumulated value and the current + * element as arguments, and stores the results in the given [destination] map. An initial value of accumulator is the + * first element of the group. + * + * This function is delegated to [Grouping.reduceTo], more details can be found in its documentation. + */ +public suspend inline fun > EntityGrouping.reduceTo( + destination: M, + operation: (key: K?, accumulator: E, element: E) -> E +): M { + return aggregateTo(destination) { key, accumulator, element, first -> + if (first) element else operation(key, accumulator as E, element) + } +} diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt index bb4f8e7..8c53e82 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt @@ -129,10 +129,10 @@ public class EntitySequence>( * elements when being iterated. */ public suspend fun asKotlinSequence(): Sequence { - return flow().toList().asSequence() + return asFlow().toList().asSequence() } - public suspend fun flow():Flow { + public suspend fun asFlow():Flow { return getRowSet().map(entityExtractor) } @@ -158,7 +158,7 @@ public fun > Database.sequenceOf( * The operation is terminal. */ public suspend fun > EntitySequence.toCollection(destination: C): C { - flow().collect { destination += it } + asFlow().collect { destination += it } return destination } @@ -319,7 +319,7 @@ public suspend inline fun > EntitySequen destination: C, crossinline transform: (E) -> R ): C { - flow().collect { destination += transform(it) } + asFlow().collect { destination += transform(it) } return destination } @@ -347,7 +347,7 @@ public suspend inline fun > Entity destination: C, crossinline transform: (E) -> R? ): C { - flow().collect { element -> transform(element)?.let { destination += it } } + asFlow().collect { element -> transform(element)?.let { destination += it } } return destination } @@ -411,7 +411,7 @@ public suspend inline fun > Entity destination: C, crossinline transform: (index: Int, E) -> R? ): C { - flow().collectIndexed { index, element -> transform(index, element)?.let { destination += it } } + asFlow().collectIndexed { index, element -> transform(index, element)?.let { destination += it } } return destination } @@ -424,7 +424,7 @@ public suspend inline fun > Entity * @since 3.0.0 */ public suspend inline fun EntitySequence.flatMap(transform: (E) -> Iterable): List { - return flow().toList().flatMapTo(ArrayList(), transform) + return asFlow().toList().flatMapTo(ArrayList(), transform) } /** @@ -439,7 +439,7 @@ public suspend inline fun > EntitySequen destination: C, crossinline transform: (E) -> Iterable ): C { - flow().collect { destination += transform(it) } + asFlow().collect { destination += transform(it) } return destination } @@ -880,7 +880,7 @@ public suspend inline fun EntitySequence.associate(crossin * The operation is terminal. */ public suspend inline fun EntitySequence.associateBy(keySelector: (E) -> K): Map { - return flow().toList().associateByTo(LinkedHashMap(), keySelector) + return asFlow().toList().associateByTo(LinkedHashMap(), keySelector) } /** @@ -897,7 +897,7 @@ public suspend inline fun EntitySequence.associateBy( keySelector: (E) -> K, valueTransform: (E) -> V ): Map { - return flow().toList().associateByTo(LinkedHashMap(), keySelector, valueTransform) + return asFlow().toList().associateByTo(LinkedHashMap(), keySelector, valueTransform) } /** @@ -911,7 +911,7 @@ public suspend inline fun EntitySequence.associateBy( * The operation is terminal. */ public suspend inline fun , V> EntitySequence.associateWith(valueSelector: (K) -> V): Map { - return flow().toList().associateWithTo(LinkedHashMap(), valueSelector) + return asFlow().toList().associateWithTo(LinkedHashMap(), valueSelector) } /** @@ -926,7 +926,7 @@ public suspend inline fun > EntitySequ destination: M, crossinline transform: (E) -> Pair ): M { - flow().collect { destination += transform(it) } + asFlow().collect { destination += transform(it) } return destination } @@ -942,7 +942,7 @@ public suspend inline fun > EntitySequenc destination: M, crossinline keySelector: (E) -> K ): M { - flow().collect { destination.put(keySelector(it), it) } + asFlow().collect { destination.put(keySelector(it), it) } return destination } @@ -959,7 +959,7 @@ public suspend inline fun > EntitySequ crossinline keySelector: (E) -> K, crossinline valueTransform: (E) -> V ): M { - flow().collect { destination.put(keySelector(it), valueTransform(it)) } + asFlow().collect { destination.put(keySelector(it), valueTransform(it)) } return destination } @@ -975,7 +975,7 @@ public suspend inline fun , V, M : MutableMap> EntityS destination: M, crossinline valueSelector: (K) -> V ): M { - flow().collect { destination.put(it, valueSelector(it)) } + asFlow().collect { destination.put(it, valueSelector(it)) } return destination } @@ -992,7 +992,7 @@ public suspend inline fun , V, M : MutableMap> EntityS public suspend fun > EntitySequence.elementAtOrNull(index: Int): E? { try { @Suppress("UnconditionalJumpStatementInLoop") - return this.drop(index).take(1).flow().firstOrNull() + return this.drop(index).take(1).asFlow().firstOrNull() } catch (e: DialectFeatureNotSupportedException) { if (database.logger.isTraceEnabled()) { database.logger.trace("Pagination is not supported, retrieving all records instead: ", e) @@ -1098,7 +1098,7 @@ public suspend inline fun > EntitySequence.first * The operation is terminal. */ public suspend fun EntitySequence.lastOrNull(): E? { - return flow().lastOrNull() + return asFlow().lastOrNull() } /** @@ -1165,7 +1165,7 @@ public suspend inline fun > EntitySequence.findL * The operation is terminal. */ public suspend fun > EntitySequence.singleOrNull(): E? { - return flow().firstOrNull() + return asFlow().firstOrNull() } /** @@ -1186,7 +1186,7 @@ public suspend inline fun > EntitySequence.singl * The operation is terminal. */ public suspend fun > EntitySequence.single(): E { - return flow().first() + return asFlow().first() } /** @@ -1209,7 +1209,7 @@ public suspend inline fun > EntitySequence.singl */ public suspend inline fun EntitySequence.fold(initial: R, crossinline operation: (acc: R, E) -> R): R { var accumulator = initial - flow().collect { accumulator = operation(accumulator, it) } + asFlow().collect { accumulator = operation(accumulator, it) } return accumulator } @@ -1228,7 +1228,7 @@ public suspend inline fun EntitySequence.foldIndexed( ): R { var index = 0 var accumulator = initial - flow().collect { accumulator = operation(index++, accumulator, it) } + asFlow().collect { accumulator = operation(index++, accumulator, it) } return accumulator } @@ -1279,7 +1279,7 @@ public suspend inline fun EntitySequence.reduceIndexed(crossinli */ public suspend inline fun EntitySequence.reduceOrNull(crossinline operation: (acc: E, E) -> E): E? { var accumulator: E? = null - flow().collect { + asFlow().collect { if (accumulator == null) { accumulator = it } else { @@ -1313,7 +1313,7 @@ public suspend inline fun EntitySequence.reduceIndexedOrNull(cro * The operation is terminal. */ public suspend inline fun EntitySequence.forEach(crossinline action: (E) -> Unit) { - flow().collect { action(it) } + asFlow().collect { action(it) } } /** @@ -1325,7 +1325,7 @@ public suspend inline fun EntitySequence.forEach(crossinline act */ public suspend inline fun EntitySequence.forEachIndexed(crossinline action: (index: Int, E) -> Unit) { var index = 0 - flow().collect { action(index++, it) } + asFlow().collect { action(index++, it) } } /** @@ -1335,7 +1335,7 @@ public suspend inline fun EntitySequence.forEachIndexed(crossinl * @since 3.0.0 */ public suspend fun EntitySequence.withIndex(): Sequence> { - val iterator = flow().toList().iterator() + val iterator = asFlow().toList().iterator() return Sequence { IndexingIterator(iterator) } } @@ -1377,7 +1377,7 @@ public suspend inline fun >> Ent destination: M, crossinline keySelector: (E) -> K ): M { - flow().collect { + asFlow().collect { val key = keySelector(it) val list = destination.getOrPut(key) { ArrayList() } list += it @@ -1397,7 +1397,7 @@ public suspend inline fun >> crossinline keySelector: (E) -> K, crossinline valueTransform: (E) -> V ): M { - flow().collect { + asFlow().collect { val key = keySelector(it) val list = destination.getOrPut(key) { ArrayList() } list += valueTransform(it) @@ -1413,13 +1413,11 @@ public suspend inline fun >> * * The operation is intermediate. */ -/* -TODO grouping public suspend fun , K : Any> EntitySequence.groupingBy( keySelector: (T) -> ColumnDeclaring ): EntityGrouping { return EntityGrouping(this, keySelector) -}*/ +} /** * Append the string from all the elements separated using [separator] and using the given [prefix] and [postfix]. @@ -1440,7 +1438,7 @@ public suspend fun EntitySequence.joinTo( ): A { buffer.append(prefix) var count = 0 - flow().collect { + asFlow().collect { if (++count > 1) buffer.append(separator) if (limit < 0 || count <= limit) { if (transform != null) buffer.append(transform(it)) else buffer.append(it.toString()) From 8c29c126144f2db59d94d4840d03f9392f168bc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E9=94=90=E9=9B=84?= <641571835@qq.com> Date: Fri, 18 Feb 2022 18:11:23 +0800 Subject: [PATCH 10/17] update EntityDml add EntityTest (undone) --- build.gradle | 2 +- check-source-header.gradle | 49 ++ ktorm-r2dbc-core/generate-tuples.gradle | 433 +++++++++++++ ktorm-r2dbc-core/ktorm-r2dbc-core.gradle | 2 + .../org/ktorm/r2dbc/database/Database.kt | 16 +- .../kotlin/org/ktorm/r2dbc/entity/Entity.kt | 4 +- .../r2dbc/entity/EntityImplementation.kt | 34 +- .../org/ktorm/r2dbc/entity/EntitySequence.kt | 2 +- .../kotlin/org/ktorm/{ => r2dbc}/BaseTest.kt | 45 +- .../{ => r2dbc}/database/DatabaseTest.kt | 40 +- .../org/ktorm/r2dbc/entity/DataClassTest.kt | 178 ++++++ .../ktorm/r2dbc/entity/EntitySequenceTest.kt | 211 ++++++ .../org/ktorm/r2dbc/entity/EntityTest.kt | 603 ++++++++++++++++++ .../src/test/resources/drop-data.sql | 10 +- .../src/test/resources/init-data.sql | 90 +-- 15 files changed, 1610 insertions(+), 109 deletions(-) create mode 100644 check-source-header.gradle create mode 100644 ktorm-r2dbc-core/generate-tuples.gradle rename ktorm-r2dbc-core/src/test/kotlin/org/ktorm/{ => r2dbc}/BaseTest.kt (68%) rename ktorm-r2dbc-core/src/test/kotlin/org/ktorm/{ => r2dbc}/database/DatabaseTest.kt (86%) create mode 100644 ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/DataClassTest.kt create mode 100644 ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt create mode 100644 ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntityTest.kt diff --git a/build.gradle b/build.gradle index 2b00570..3ea80f5 100644 --- a/build.gradle +++ b/build.gradle @@ -28,7 +28,7 @@ subprojects { project -> apply plugin: "maven-publish" apply plugin: "com.jfrog.bintray" apply plugin: "io.gitlab.arturbosch.detekt" - // apply from: "${project.rootDir}/check-source-header.gradle" + apply from: "${project.rootDir}/check-source-header.gradle" repositories { jcenter() diff --git a/check-source-header.gradle b/check-source-header.gradle new file mode 100644 index 0000000..60d3fd8 --- /dev/null +++ b/check-source-header.gradle @@ -0,0 +1,49 @@ + +project.ext.licenseHeaderText = """/* + * Copyright 2018-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +""" + +task checkCopyrightHeader { + doLast { + def headerLines = project.licenseHeaderText.readLines() + + sourceSets.main.kotlin.srcDirs.each { dir -> + def tree = fileTree(dir) + tree.include("**/*.kt") + + tree.visit { + if (!it.isDirectory()) { + def failed = false + + it.file.withReader { reader -> + for (line in headerLines) { + if (line != reader.readLine()) { + failed = true + break + } + } + } + + if (failed) { + throw new IllegalStateException("Copyright header not found in file: " + it.file) + } + } + } + } + } +} + +check.dependsOn(checkCopyrightHeader) \ No newline at end of file diff --git a/ktorm-r2dbc-core/generate-tuples.gradle b/ktorm-r2dbc-core/generate-tuples.gradle new file mode 100644 index 0000000..45a638d --- /dev/null +++ b/ktorm-r2dbc-core/generate-tuples.gradle @@ -0,0 +1,433 @@ + +def generatedSourceDir = "${project.buildDir.absolutePath}/generated/source/main/kotlin" +def maxTupleNumber = 4 + +def generateTuple(Writer writer, int tupleNumber) { + def typeParams = (1..tupleNumber).collect { "out E$it" }.join(", ") + def propertyDefinitions = (1..tupleNumber).collect { "val element$it: E$it" }.join(",\n ") + def toStringTemplate = (1..tupleNumber).collect { "\$element$it" }.join(", ") + + writer.write(""" + /** + * Represents a tuple of $tupleNumber values. + * + * There is no meaning attached to values in this class, it can be used for any purpose. + * Two tuples are equal if all the components are equal. + */ + public data class Tuple$tupleNumber<$typeParams>( + $propertyDefinitions + ) : Serializable { + + override fun toString(): String { + return \"($toStringTemplate)\" + } + + private companion object { + private const val serialVersionUID = 1L + } + } + """.stripIndent()) +} + +def generateTupleOf(Writer writer, int tupleNumber) { + def typeParams = (1..tupleNumber).collect { "E$it" }.join(", ") + def params = (1..tupleNumber).collect { "element$it: E$it" }.join(",\n ") + def elements = (1..tupleNumber).collect { "element$it" }.join(", ") + + writer.write(""" + /** + * Create a tuple of $tupleNumber values. + * + * @since 2.7 + */ + public fun <$typeParams> tupleOf( + $params + ): Tuple$tupleNumber<$typeParams> { + return Tuple$tupleNumber($elements) + } + """.stripIndent()) +} + +def generateToList(Writer writer, int tupleNumber) { + def typeParams = (1..tupleNumber).collect { "E" }.join(", ") + def elements = (1..tupleNumber).collect { "element$it" }.join(", ") + + writer.write(""" + /** + * Convert this tuple into a list. + * + * @since 2.7 + */ + public fun Tuple$tupleNumber<$typeParams>.toList(): List { + return listOf($elements) + } + """.stripIndent()) +} + +def generateMapColumns(Writer writer, int tupleNumber) { + def typeParams = (1..tupleNumber).collect { "C$it : Any" }.join(", ") + def columnDeclarings = (1..tupleNumber).collect { "ColumnDeclaring" }.join(", ") + def resultTypes = (1..tupleNumber).collect { "C$it?" }.join(", ") + def variableNames = (1..tupleNumber).collect { "c$it" }.join(", ") + def resultExtractors = (1..tupleNumber).collect { "c${it}.sqlType.getResult(row, ${it-1})" }.join(", ") + + writer.write(""" + /** + * Customize the selected columns of the internal query by the given [columnSelector] function, and return a [List] + * containing the query results. + * + * See [EntitySequence.mapColumns] for more details. + * + * The operation is terminal. + * + * @param isDistinct specify if the query is distinct, the generated SQL becomes `select distinct` if it's set to true. + * @param columnSelector a function in which we should return a tuple of columns or expressions to be selected. + * @return a list of the query results. + */ + @Deprecated( + message = "This function will be removed in the future. Please use mapColumns { .. } instead.", + replaceWith = ReplaceWith("mapColumns(isDistinct, columnSelector)") + ) + public suspend inline fun , $typeParams> EntitySequence.mapColumns$tupleNumber( + isDistinct: Boolean = false, + columnSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> + ): List> { + return mapColumns(isDistinct, columnSelector) + } + + /** + * Customize the selected columns of the internal query by the given [columnSelector] function, and append the query + * results to the given [destination]. + * + * See [EntitySequence.mapColumnsTo] for more details. + * + * The operation is terminal. + * + * @param destination a [MutableCollection] used to store the results. + * @param isDistinct specify if the query is distinct, the generated SQL becomes `select distinct` if it's set to true. + * @param columnSelector a function in which we should return a tuple of columns or expressions to be selected. + * @return the [destination] collection of the query results. + */ + @Deprecated( + message = "This function will be removed in the future. Please use mapColumnsTo(destination) { .. } instead.", + replaceWith = ReplaceWith("mapColumnsTo(destination, isDistinct, columnSelector)") + ) + public suspend inline fun , $typeParams, R> EntitySequence.mapColumns${tupleNumber}To( + destination: R, + isDistinct: Boolean = false, + columnSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> + ): R where R : MutableCollection> { + return mapColumnsTo(destination, isDistinct, columnSelector) + } + + /** + * Customize the selected columns of the internal query by the given [columnSelector] function, and return a [List] + * containing the query results. + * + * This function is similar to [EntitySequence.map], but the [columnSelector] closure accepts the current table + * object [T] as the parameter, so what we get in the closure by `it` is the table object instead of an entity + * element. Besides, the function’s return type is a tuple of `ColumnDeclaring`s, and we should return some + * columns or expressions to customize the `select` clause of the generated SQL. + * + * Ktorm supports selecting two or more columns, we just need to wrap our selected columns by [tupleOf] + * in the closure, then the function’s return type becomes `List>`. + * + * The operation is terminal. + * + * @param isDistinct specify if the query is distinct, the generated SQL becomes `select distinct` if it's set to true. + * @param columnSelector a function in which we should return a tuple of columns or expressions to be selected. + * @return a list of the query results. + * @since 3.1.0 + */ + @JvmName("_mapColumns$tupleNumber") + @OptIn(ExperimentalTypeInference::class) + @OverloadResolutionByLambdaReturnType + public suspend inline fun , $typeParams> EntitySequence.mapColumns( + isDistinct: Boolean = false, + columnSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> + ): List> { + return mapColumnsTo(ArrayList(), isDistinct, columnSelector) + } + + /** + * Customize the selected columns of the internal query by the given [columnSelector] function, and append the query + * results to the given [destination]. + * + * This function is similar to [EntitySequence.mapTo], but the [columnSelector] closure accepts the current table + * object [T] as the parameter, so what we get in the closure by `it` is the table object instead of an entity + * element. Besides, the function’s return type is a tuple of `ColumnDeclaring`s, and we should return some + * columns or expressions to customize the `select` clause of the generated SQL. + * + * Ktorm supports selecting two or more columns, we just need to wrap our selected columns by [tupleOf] + * in the closure, then the function’s return type becomes `List>`. + * + * The operation is terminal. + * + * @param destination a [MutableCollection] used to store the results. + * @param isDistinct specify if the query is distinct, the generated SQL becomes `select distinct` if it's set to true. + * @param columnSelector a function in which we should return a tuple of columns or expressions to be selected. + * @return the [destination] collection of the query results. + * @since 3.1.0 + */ + @JvmName("_mapColumns${tupleNumber}To") + @OptIn(ExperimentalTypeInference::class) + @OverloadResolutionByLambdaReturnType + public suspend inline fun , $typeParams, R> EntitySequence.mapColumnsTo( + destination: R, + isDistinct: Boolean = false, + columnSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> + ): R where R : MutableCollection> { + val ($variableNames) = columnSelector(sourceTable) + + val expr = expression.copy( + columns = listOf($variableNames).map { it.aliased(null) }, + isDistinct = isDistinct + ) + + return Query(database, expr).mapTo(destination) { row -> tupleOf($resultExtractors) } + } + """.stripIndent()) +} + +def generateAggregateColumns(Writer writer, int tupleNumber) { + def typeParams = (1..tupleNumber).collect { "C$it : Any" }.join(", ") + def columnDeclarings = (1..tupleNumber).collect { "ColumnDeclaring" }.join(", ") + def resultTypes = (1..tupleNumber).collect { "C$it?" }.join(", ") + def variableNames = (1..tupleNumber).collect { "c$it" }.join(", ") + def resultExtractors = (1..tupleNumber).collect { "c${it}.sqlType.getResult(row, ${it-1})" }.join(", ") + + writer.write(""" + /** + * Perform a tuple of aggregations given by [aggregationSelector] for all elements in the sequence, + * and return the aggregate results. + * + * The operation is terminal. + * + * @param aggregationSelector a function that accepts the source table and returns a tuple of aggregate expressions. + * @return a tuple of the aggregate results. + */ + @Deprecated( + message = "This function will be removed in the future. Please use aggregateColumns { .. } instead.", + replaceWith = ReplaceWith("aggregateColumns(aggregationSelector)") + ) + public suspend inline fun , $typeParams> EntitySequence.aggregateColumns$tupleNumber( + aggregationSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> + ): Tuple$tupleNumber<$resultTypes> { + return aggregateColumns(aggregationSelector) + } + + /** + * Perform a tuple of aggregations given by [aggregationSelector] for all elements in the sequence, + * and return the aggregate results. + * + * Ktorm supports aggregating two or more columns, we just need to wrap our aggregate expressions by + * [tupleOf] in the closure, then the function’s return type becomes `TupleN`. + * + * The operation is terminal. + * + * @param aggregationSelector a function that accepts the source table and returns a tuple of aggregate expressions. + * @return a tuple of the aggregate results. + * @since 3.1.0 + */ + @JvmName("_aggregateColumns$tupleNumber") + @OptIn(ExperimentalTypeInference::class) + @OverloadResolutionByLambdaReturnType + public suspend inline fun , $typeParams> EntitySequence.aggregateColumns( + aggregationSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> + ): Tuple$tupleNumber<$resultTypes> { + val ($variableNames) = aggregationSelector(sourceTable) + + val expr = expression.copy( + columns = listOf($variableNames).map { it.aliased(null) } + ) + + val rowSet = Query(database, expr).asFlow().toList() + + if (rowSet.count() == 1) { + val row = rowSet.first() + return tupleOf($resultExtractors) + } else { + val (sql, _) = database.formatExpression(expr, beautifySql = true) + throw IllegalStateException("Expected 1 row but \${rowSet.count()} returned from sql: \\n\\n\$sql") + } + } + """.stripIndent()) +} + +def generateGroupingAggregateColumns(Writer writer, int tupleNumber) { + def typeParams = (1..tupleNumber).collect { "C$it : Any" }.join(", ") + def columnDeclarings = (1..tupleNumber).collect { "ColumnDeclaring" }.join(", ") + def resultTypes = (1..tupleNumber).collect { "C$it?" }.join(", ") + def variableNames = (1..tupleNumber).collect { "c$it" }.join(", ") + def resultExtractors = (1..tupleNumber).collect { "c${it}.sqlType.getResult(row, ${it})" }.join(", ") + + writer.write(""" + /** + * Group elements from the source sequence by key and perform the given aggregations for elements in each group, + * then store the results in a new [Map]. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, aggregation from source group by key`. + * + * @param aggregationSelector a function that accepts the source table and returns a tuple of aggregate expressions. + * @return a [Map] associating the key of each group with the results of aggregations of the group elements. + */ + @Deprecated( + message = "This function will be removed in the future. Please use aggregateColumns { .. } instead.", + replaceWith = ReplaceWith("aggregateColumns(aggregationSelector)") + ) + public suspend inline fun , K : Any, $typeParams> EntityGrouping.aggregateColumns$tupleNumber( + aggregationSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> + ): Map> { + return aggregateColumns(aggregationSelector) + } + + /** + * Group elements from the source sequence by key and perform the given aggregations for elements in each group, + * then store the results in the [destination] map. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, aggregation from source group by key`. + * + * @param destination a [MutableMap] used to store the results. + * @param aggregationSelector a function that accepts the source table and returns a tuple of aggregate expressions. + * @return the [destination] map associating the key of each group with the result of aggregations of the group elements. + */ + @Deprecated( + message = "This function will be removed in the future. Please use aggregateColumns(destination) { .. } instead.", + replaceWith = ReplaceWith("aggregateColumns(destination, aggregationSelector)") + ) + public suspend inline fun , K : Any, $typeParams, M> EntityGrouping.aggregateColumns${tupleNumber}To( + destination: M, + aggregationSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> + ): M where M : MutableMap> { + return aggregateColumnsTo(destination, aggregationSelector) + } + + /** + * Group elements from the source sequence by key and perform the given aggregations for elements in each group, + * then store the results in a new [Map]. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, aggregation from source group by key`. + * + * Ktorm supports aggregating two or more columns, we just need to wrap our aggregate expressions by [tupleOf] + * in the closure, then the function’s return type becomes `Map>`. + * + * @param aggregationSelector a function that accepts the source table and returns a tuple of aggregate expressions. + * @return a [Map] associating the key of each group with the results of aggregations of the group elements. + * @since 3.1.0 + */ + @JvmName("_aggregateColumns$tupleNumber") + @OptIn(ExperimentalTypeInference::class) + @OverloadResolutionByLambdaReturnType + public suspend inline fun , K : Any, $typeParams> EntityGrouping.aggregateColumns( + aggregationSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> + ): Map> { + return aggregateColumnsTo(LinkedHashMap(), aggregationSelector) + } + + /** + * Group elements from the source sequence by key and perform the given aggregations for elements in each group, + * then store the results in the [destination] map. + * + * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: + * `select key, aggregation from source group by key`. + * + * Ktorm supports aggregating two or more columns, we just need to wrap our aggregate expressions by [tupleOf] + * in the closure, then the function’s return type becomes `Map>`. + * + * @param destination a [MutableMap] used to store the results. + * @param aggregationSelector a function that accepts the source table and returns a tuple of aggregate expressions. + * @return the [destination] map associating the key of each group with the result of aggregations of the group elements. + * @since 3.1.0 + */ + @JvmName("_aggregateColumns${tupleNumber}To") + @OptIn(ExperimentalTypeInference::class) + @OverloadResolutionByLambdaReturnType + public suspend inline fun , K : Any, $typeParams, M> EntityGrouping.aggregateColumnsTo( + destination: M, + aggregationSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> + ): M where M : MutableMap> { + val keyColumn = keySelector(sequence.sourceTable) + val ($variableNames) = aggregationSelector(sequence.sourceTable) + + val expr = sequence.expression.copy( + columns = listOf(keyColumn, $variableNames).map { it.aliased(null) }, + groupBy = listOf(keyColumn.asExpression()) + ) + + Query(sequence.database,expr).forEach { row -> + val key = keyColumn.sqlType.getResult(row, 0) + destination[key] = tupleOf($resultExtractors) + } + return destination + } + """.stripIndent()) +} + +task generateTuples { + doLast { + def outputFile = file("$generatedSourceDir/org/ktorm/r2dbc/entity/Tuples.kt") + outputFile.parentFile.mkdirs() + + outputFile.withWriter { writer -> + writer.write(project.licenseHeaderText) + + writer.write(""" + // This file is auto-generated by generate-tuples.gradle, DO NOT EDIT! + + package org.ktorm.r2dbc.entity + + import org.ktorm.r2dbc.dsl.Query + import kotlinx.coroutines.flow.toList + import org.ktorm.r2dbc.dsl.forEach + import org.ktorm.r2dbc.dsl.mapTo + import org.ktorm.r2dbc.schema.ColumnDeclaring + import org.ktorm.r2dbc.schema.BaseTable + import java.io.Serializable + import kotlin.experimental.ExperimentalTypeInference + + /** + * Set a typealias `Tuple2` for `Pair`. + */ + public typealias Tuple2 = Pair + + /** + * Set a typealias `Tuple3` for `Triple`. + */ + public typealias Tuple3 = Triple + """.stripIndent()) + + (4..maxTupleNumber).each { num -> + generateTuple(writer, num) + } + + (2..maxTupleNumber).each { num -> + generateTupleOf(writer, num) + } + + (4..maxTupleNumber).each { num -> + generateToList(writer, num) + } + + (2..maxTupleNumber).each { num -> + generateMapColumns(writer, num) + } + + (2..maxTupleNumber).each { num -> + generateAggregateColumns(writer, num) + } + + (2..maxTupleNumber).each { num -> + generateGroupingAggregateColumns(writer, num) + } + } + } +} + +sourceSets { + main.kotlin.srcDirs += generatedSourceDir +} + +compileKotlin.dependsOn(generateTuples) diff --git a/ktorm-r2dbc-core/ktorm-r2dbc-core.gradle b/ktorm-r2dbc-core/ktorm-r2dbc-core.gradle index d4d6c80..d10cd3a 100644 --- a/ktorm-r2dbc-core/ktorm-r2dbc-core.gradle +++ b/ktorm-r2dbc-core/ktorm-r2dbc-core.gradle @@ -1,4 +1,6 @@ +apply from: "generate-tuples.gradle" + dependencies { compileOnly "org.slf4j:slf4j-api:1.7.25" compileOnly "commons-logging:commons-logging:1.2" diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt index de8a3b6..8a44044 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt @@ -338,20 +338,26 @@ public class Database( public fun connect( connectionFactory: ConnectionFactory, dialect: SqlDialect = detectDialectImplementation(), - logger: Logger = detectLoggerImplementation() + logger: Logger = detectLoggerImplementation(), + alwaysQuoteIdentifiers: Boolean = false, + generateSqlInUpperCase: Boolean? = null ): Database { return Database( connectionFactory = connectionFactory, transactionManager = CoroutinesTransactionManager(connectionFactory), dialect = dialect, - logger = logger + logger = logger, + alwaysQuoteIdentifiers = alwaysQuoteIdentifiers, + generateSqlInUpperCase = generateSqlInUpperCase ) } public fun connect( url: String, dialect: SqlDialect = detectDialectImplementation(), - logger: Logger = detectLoggerImplementation() + logger: Logger = detectLoggerImplementation(), + alwaysQuoteIdentifiers: Boolean = false, + generateSqlInUpperCase: Boolean? = null ): Database { val connectionFactory = ConnectionFactories.get(url) @@ -359,7 +365,9 @@ public class Database( connectionFactory = connectionFactory, transactionManager = CoroutinesTransactionManager(connectionFactory), dialect = dialect, - logger = logger + logger = logger, + alwaysQuoteIdentifiers = alwaysQuoteIdentifiers, + generateSqlInUpperCase = generateSqlInUpperCase ) } } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt index bfc251d..b3e1e11 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt @@ -163,7 +163,7 @@ public interface Entity> : Serializable { * @see update */ @Throws(SQLException::class) - public fun flushChanges(): Int + public suspend fun flushChanges(): Int /** * Clear the tracked property changes of this entity. @@ -188,7 +188,7 @@ public interface Entity> : Serializable { * @see flushChanges */ @Throws(SQLException::class) - public fun delete(): Int + public suspend fun delete(): Int /** * Obtain a property's value by its name. diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt index d66d918..e85e337 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt @@ -27,8 +27,10 @@ import java.lang.reflect.Method import java.util.* import kotlin.collections.LinkedHashMap import kotlin.collections.LinkedHashSet +import kotlin.coroutines.Continuation import kotlin.reflect.KClass import kotlin.reflect.KProperty1 +import kotlin.reflect.full.functions import kotlin.reflect.jvm.javaGetter import kotlin.reflect.jvm.jvmErasure import kotlin.reflect.jvm.jvmName @@ -42,7 +44,12 @@ internal class EntityImplementation( ) : InvocationHandler, Serializable { var values = LinkedHashMap() - @Transient var changedProperties = LinkedHashSet() + @Transient + var changedProperties = LinkedHashSet() + + private val doDeleteFun = this::doDelete + private val doFlushChangeFun = this::doFlushChanges + companion object { private const val serialVersionUID = 1L @@ -63,9 +70,9 @@ internal class EntityImplementation( when (method.name) { "getEntityClass" -> this.entityClass "getProperties" -> Collections.unmodifiableMap(this.values) - /* "flushChanges" -> this.doFlushChanges() "discardChanges" -> this.doDiscardChanges() - "delete" -> this.doDelete()*/ + "flushChanges" -> this.doFlushChangeFun.call(args!!.first()) + "delete" -> this.doDeleteFun.call(args!!.first()) "get" -> this.values[args!![0] as String] "set" -> this.doSetProperty(args!![0] as String, args[1]) "copy" -> this.copy() @@ -107,17 +114,18 @@ internal class EntityImplementation( } } - private val KProperty1<*, *>.defaultValue: Any get() { - try { - return javaGetter!!.returnType.defaultValue - } catch (e: Throwable) { - val msg = "" + - "The value of non-null property [$this] doesn't exist, " + - "an error occurred while trying to create a default one. " + - "Please ensure its value exists, or you can mark the return type nullable [${this.returnType}?]" - throw IllegalStateException(msg, e) + private val KProperty1<*, *>.defaultValue: Any + get() { + try { + return javaGetter!!.returnType.defaultValue + } catch (e: Throwable) { + val msg = "" + + "The value of non-null property [$this] doesn't exist, " + + "an error occurred while trying to create a default one. " + + "Please ensure its value exists, or you can mark the return type nullable [${this.returnType}?]" + throw IllegalStateException(msg, e) + } } - } private fun cacheDefaultValue(prop: KProperty1<*, *>, value: Any) { val type = prop.javaGetter!!.returnType diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt index 8c53e82..6d014bd 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt @@ -707,7 +707,7 @@ public suspend inline fun , reified C : Any> EntitySeq columns = listOf(aggregation.aliased(null)) ) - val rowSet = Query(database, expr).doQuery() + val rowSet = Query(database, expr).doQuery().toList() val count = rowSet.count() if (count == 1) { val row = rowSet.first() diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/BaseTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/BaseTest.kt similarity index 68% rename from ktorm-r2dbc-core/src/test/kotlin/org/ktorm/BaseTest.kt rename to ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/BaseTest.kt index 424faf6..5af19a9 100644 --- a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/BaseTest.kt +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/BaseTest.kt @@ -1,12 +1,16 @@ -package org.ktorm +package org.ktorm.r2dbc import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.runBlocking import org.junit.After import org.junit.Before import org.ktorm.r2dbc.database.Database +import org.ktorm.r2dbc.database.SqlDialect import org.ktorm.r2dbc.entity.Entity import org.ktorm.r2dbc.entity.sequenceOf +import org.ktorm.r2dbc.expression.ArgumentExpression +import org.ktorm.r2dbc.expression.QueryExpression +import org.ktorm.r2dbc.expression.SqlFormatter import org.ktorm.r2dbc.logging.ConsoleLogger import org.ktorm.r2dbc.logging.LogLevel import org.ktorm.r2dbc.schema.* @@ -25,12 +29,45 @@ open class BaseTest { database = Database.connect( url = "r2dbc:h2:mem:///testdb?DB_CLOSE_DELAY=-1", logger = ConsoleLogger(threshold = LogLevel.TRACE), + dialect = getH2Dialect(), + alwaysQuoteIdentifiers = true ) execSqlScript("init-data.sql") } } + fun getH2Dialect() = object: SqlDialect { + override val identifierQuoteString: String = "\"" + override val extraNameCharacters: String = "" + override val supportsMixedCaseIdentifiers: Boolean = false + override val storesMixedCaseIdentifiers: Boolean = false + override val storesUpperCaseIdentifiers: Boolean = true + override val storesLowerCaseIdentifiers: Boolean = false + override val supportsMixedCaseQuotedIdentifiers: Boolean = true + override val storesMixedCaseQuotedIdentifiers: Boolean = true + override val storesUpperCaseQuotedIdentifiers: Boolean = false + override val storesLowerCaseQuotedIdentifiers: Boolean = false + override val sqlKeywords: Set = emptySet() + override val maxColumnNameLength: Int = 0 + override fun createSqlFormatter(database: Database, beautifySql: Boolean, indentSize: Int): SqlFormatter { + return object:SqlFormatter(database, beautifySql, indentSize) { + override fun writePagination(expr: QueryExpression) { + newLine(Indentation.SAME) + if (expr.limit != null) { + writeKeyword("limit ? ") + _parameters += ArgumentExpression(expr.limit, IntSqlType) + } + if (expr.offset != null) { + writeKeyword("offset ? ") + _parameters += ArgumentExpression(expr.offset, IntSqlType) + } + } + + } + } + } + @After open fun destroy() { runBlocking { @@ -76,8 +113,8 @@ open class BaseTest { var salary: Long var department: Department - val upperName get() = name.toUpperCase() - fun upperName() = name.toUpperCase() + val upperName get() = name.uppercase() + fun upperName() = name.uppercase() } interface Customer : Entity { @@ -131,4 +168,4 @@ open class BaseTest { val Database.employees get() = this.sequenceOf(Employees) val Database.customers get() = this.sequenceOf(Customers) -} \ No newline at end of file +} diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/DatabaseTest.kt similarity index 86% rename from ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt rename to ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/DatabaseTest.kt index a2f7d8a..09a058b 100644 --- a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/database/DatabaseTest.kt +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/DatabaseTest.kt @@ -1,12 +1,9 @@ -package org.ktorm.database +package org.ktorm.r2dbc.database -import kotlinx.coroutines.reactive.awaitFirst import kotlinx.coroutines.reactive.awaitFirstOrNull -import kotlinx.coroutines.reactive.awaitSingle import kotlinx.coroutines.runBlocking import org.junit.Test -import org.ktorm.BaseTest -import org.ktorm.r2dbc.database.toList +import org.ktorm.r2dbc.BaseTest import org.ktorm.r2dbc.dsl.insert import org.ktorm.r2dbc.entity.* import java.time.LocalDate @@ -39,6 +36,7 @@ class DatabaseTest : BaseTest() { database.delete(configs) { it.key eq "test" } }*/ +/* @Test fun testTransaction() = runBlocking { @@ -60,6 +58,8 @@ class DatabaseTest : BaseTest() { assert(database.departments.count() == 2) } } +*/ +/* @Test fun testRawSql() = runBlocking { @@ -81,36 +81,8 @@ class DatabaseTest : BaseTest() { assert(names[0] == "VINCE") assert(names[1] == "MARRY") } +*/ - @Test - fun tableTest() = runBlocking { - database.useTransaction { - database.useTransaction { - database.employees.forEach { - println(it) - } - throw RuntimeException() - } - } - assert(true) - } - - @Test - fun insertTest() = runBlocking { - database.useTransaction { - val department = database.departments.toList().first() - val employee = Employee { - this.name = "vince" - this.job = "engineer" - this.manager = null - this.hireDate = LocalDate.now() - this.salary = 100 - this.department = department - } - val add = database.employees.add(employee) - println(employee) - } - } /*fun BaseTable<*>.ulong(name: String): Column { return registerColumn(name, object : SqlType(Types.BIGINT, "bigint unsigned") { diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/DataClassTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/DataClassTest.kt new file mode 100644 index 0000000..2a4e482 --- /dev/null +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/DataClassTest.kt @@ -0,0 +1,178 @@ +package org.ktorm.r2dbc.entity + +import io.r2dbc.spi.Row +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.ktorm.r2dbc.BaseTest +import org.ktorm.r2dbc.database.Database +import org.ktorm.r2dbc.dsl.* +import org.ktorm.r2dbc.schema.* +import org.ktorm.schema.* +import java.time.LocalDate + +/** + * Created by vince on Aug 10, 2019. + */ +class DataClassTest : BaseTest() { + + data class Section( + val id: Int, + val name: String, + val location: String + ) + + data class Staff( + val id: Int, + val name: String, + val job: String, + val managerId: Int, + val hireDate: LocalDate, + val salary: Long, + val sectionId: Int + ) + + object Sections : BaseTable
("t_department") { + val id = int("id").primaryKey() + val name = varchar("name") + val location = varchar("location") + + override fun doCreateEntity(row: QueryRow, withReferences: Boolean) = Section( + id = row[id] ?: 0, + name = row[name].orEmpty(), + location = row[location].orEmpty() + ) + } + + object Staffs : BaseTable("t_employee") { + val id = int("id").primaryKey() + val name = varchar("name") + val job = varchar("job") + val managerId = int("manager_id") + val hireDate = date("hire_date") + val salary = long("salary") + val sectionId = int("department_id") + + override fun doCreateEntity(row: QueryRow, withReferences: Boolean) = Staff( + id = row[id] ?: 0, + name = row[name].orEmpty(), + job = row[job].orEmpty(), + managerId = row[managerId] ?: 0, + hireDate = row[hireDate] ?: LocalDate.now(), + salary = row[salary] ?: 0, + sectionId = row[sectionId] ?: 0 + ) + } + + val Database.staffs get() = this.sequenceOf(Staffs) + + @Test + fun testFindById() = runBlocking { + val staff = database.staffs.find { it.id eq 1 } ?: throw AssertionError() + assert(staff.name == "vince") + assert(staff.job == "engineer") + } + + @Test + fun testFindList() = runBlocking { + val staffs = database.staffs.filter { it.sectionId eq 1 }.toList() + assert(staffs.size == 2) + assert(staffs.mapTo(HashSet()) { it.name } == setOf("vince", "marry")) + } + + @Test + fun testSelectName() = runBlocking { + val staffs = database + .from(Staffs) + .select(Staffs.name) + .where { Staffs.id eq 1 } + .map { Staffs.createEntity(it) } + assert(staffs[0].name == "vince") + } + + @Test + fun testJoin() = runBlocking { + val staffs = database + .from(Staffs) + .leftJoin(Sections, on = Staffs.sectionId eq Sections.id) + .select(Staffs.columns) + .where { Sections.location like "%Guangzhou%" } + .orderBy(Staffs.id.asc()) + .map { Staffs.createEntity(it) } + + assert(staffs.size == 2) + assert(staffs[0].name == "vince") + assert(staffs[1].name == "marry") + } + + @Test + fun testSequence() = runBlocking { + val staffs = database.staffs + .filter { it.sectionId eq 1 } + .sortedBy { it.id } + .toList() + + assert(staffs.size == 2) + assert(staffs[0].name == "vince") + assert(staffs[1].name == "marry") + } + + @Test + fun testCount() = runBlocking { + assert(database.staffs.count { it.sectionId eq 1 } == 2) + } + + @Test + fun testFold() = runBlocking { + val totalSalary = database.staffs.fold(0L) { acc, staff -> acc + staff.salary } + assert(totalSalary == 450L) + } + + @Test + fun testGroupingBy() = runBlocking { + val salaries = database.staffs + .groupingBy { it.sectionId * 2 } + .fold(0L) { acc, staff -> + acc + staff.salary + } + + println(salaries) + assert(salaries.size == 2) + assert(salaries[1] == 150L) + assert(salaries[3] == 300L) + } + + @Test + fun testEachCount() = runBlocking { + val counts = database.staffs + .filter { it.salary less 100000L } + .groupingBy { it.sectionId } + .eachCount() + + println(counts) + assert(counts.size == 2) + assert(counts[0] == 2L) + assert(counts[1] == 2L) + } + + @Test + fun testMapColumns() = runBlocking { + val (name, job) = database.staffs + .filter { it.sectionId eq 1 } + .filterNot { it.managerId.isNotNull() } + .mapColumns { tupleOf(it.name, it.job) } + .single() + + assert(name == "vince") + assert(job == "engineer") + } + + @Test + fun testGroupingAggregate() = runBlocking { + database.staffs + .groupingBy { it.sectionId } + .aggregateColumns { tupleOf(max(it.salary), min(it.salary)) } + .forEach { sectionId, (max, min) -> + println("$sectionId:$max:$min") + } + } +} diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt new file mode 100644 index 0000000..7953685 --- /dev/null +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt @@ -0,0 +1,211 @@ +package org.ktorm.r2dbc.entity + +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.ktorm.r2dbc.BaseTest +import org.ktorm.r2dbc.dsl.* + +/** + * Created by vince on Mar 22, 2019. + */ +class EntitySequenceTest : BaseTest() { + + @Test + fun testSequenceOf() = runBlocking { + val employee = database + .sequenceOf(Employees, withReferences = false) + .filter { it.name eq "vince" } + .single() + + println(employee) + assert(employee.name == "vince") + assert(employee.department.name.isEmpty()) + } + + @Test + fun testToList() = runBlocking { + val employees = database.employees.toList() + assert(employees.size == 4) + assert(employees[0].name == "vince") + assert(employees[0].department.name == "tech") + } + + @Test + fun testFilter() = runBlocking { + val names = database.employees + .filter { it.departmentId eq 1 } + .filterNot { it.managerId.isNull() } + .toList() + .map { it.name } + + assert(names.size == 1) + assert(names[0] == "marry") + } + + @Test + fun testFilterTo() = runBlocking { + val names = database.employees + .filter { it.departmentId eq 1 } + .filterTo(ArrayList()) { it.managerId.isNull() } + .map { it.name } + + assert(names.size == 1) + assert(names[0] == "vince") + } + + @Test + fun testCount() = runBlocking { + assert(database.employees.filter { it.departmentId eq 1 }.count() == 2) + assert(database.employees.count { it.departmentId eq 1 } == 2) + } + + @Test + fun testAll() = runBlocking { + assert(database.employees.filter { it.departmentId eq 1 }.all { it.salary greater 49L }) + } + + @Test + fun testAssociate() = runBlocking { + val employees = database.employees.filter { it.departmentId eq 1 }.associateBy { it.id } + assert(employees.size == 2) + assert(employees[1]!!.name == "vince") + } + + @Test + fun testDrop() = runBlocking { + try { + val employees = database.employees.drop(3).toList() + assert(employees.size == 1) + assert(employees[0].name == "penny") + } catch (e: UnsupportedOperationException) { + // Expected, pagination should be provided by dialects... + } + } + + @Test + fun testTake() = runBlocking { + try { + val employees = database.employees.take(1).toList() + assert(employees.size == 1) + assert(employees[0].name == "vince") + } catch (e: UnsupportedOperationException) { + // Expected, pagination should be provided by dialects... + } + } + + @Test + fun testFindLast() = runBlocking { + val employee = database.employees.elementAt(3) + assert(employee.name == "penny") + assert(database.employees.elementAtOrNull(4) == null) + } + + @Test + fun testFold() = runBlocking { + val totalSalary = database.employees.fold(0L) { acc, employee -> acc + employee.salary } + assert(totalSalary == 450L) + } + + @Test + fun testSorted() = runBlocking { + val employee = database.employees.sortedByDescending { it.salary }.first() + assert(employee.name == "tom") + } + + @Test + fun testFilterColumns() = runBlocking { + val employee = database.employees + .filterColumns { it.columns + it.department.columns - it.department.location } + .filter { it.department.id eq 1 } + .first() + + assert(employee.department.location.underlying.isEmpty()) + } + + @Test + fun testGroupBy() = runBlocking { + val employees = database.employees.groupBy { it.department.id } + println(employees) + assert(employees.size == 2) + assert(employees[1]!!.sumOf { it.salary.toInt() } == 150) + assert(employees[2]!!.sumOf { it.salary.toInt() } == 300) + } + + @Test + fun testGroupingBy() = runBlocking { + val salaries = database.employees + .groupingBy { it.departmentId * 2 } + .fold(0L) { acc, employee -> + acc + employee.salary + } + + println(salaries) + assert(salaries.size == 2) + assert(salaries[2] == 150L) + assert(salaries[4] == 300L) + } + + @Test + fun testEachCount() = runBlocking { + val counts = database.employees + .filter { it.salary less 100000L } + .groupingBy { it.departmentId } + .eachCount() + + println(counts) + assert(counts.size == 2) + assert(counts[0] == 2L) + assert(counts[1] == 2L) + } + + @Test + fun testEachSum() = runBlocking { + val sums = database.employees + .filter { it.salary lessEq 100000L } + .groupingBy { it.departmentId } + .eachSumBy { it.salary } + + println(sums) + assert(sums.size == 2) + assert(sums[1] == 150L) + assert(sums[2] == 300L) + } + + @Test + fun testJoinToString() = runBlocking { + val salaries = database.employees.joinToString { it.id.toString() } + assert(salaries == "1, 2, 3, 4") + } + + @Test + fun testReduce() = runBlocking { + val emp = database.employees.reduce { acc, employee -> acc.apply { salary += employee.salary } } + assert(emp.salary == 450L) + } + + @Test + fun testSingle() = runBlocking { + val employee = database.employees.singleOrNull { it.departmentId eq 1 } + assert(employee == null) + } + + @Test + fun testMapColumns() = runBlocking { + val names = database.employees.sortedBy { it.id }.mapColumns { it.name } + + println(names) + assert(names.size == 4) + assert(names[0] == "vince") + } + + @Test + fun testFlatMap() = runBlocking { + val names = database.employees + .sortedBy { it.id.asc() } + .flatMapIndexed { index, employee -> listOf("$index:${employee.name}") } + + println(names) + assert(names.size == 4) + assert(names[0] == "0:vince") + } +} diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntityTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntityTest.kt new file mode 100644 index 0000000..8eb3a16 --- /dev/null +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntityTest.kt @@ -0,0 +1,603 @@ +package org.ktorm.r2dbc.entity + +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.ktorm.r2dbc.BaseTest +import org.ktorm.r2dbc.database.Database +import org.ktorm.r2dbc.dsl.* +import org.ktorm.r2dbc.schema.* +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.io.ObjectInputStream +import java.io.ObjectOutputStream +import java.time.LocalDate +import java.util.* +import kotlin.reflect.jvm.jvmErasure + +/** + * Created by vince on Dec 09, 2018. + */ +class EntityTest : BaseTest() { + + @Test + fun testTypeReference() { + println(Employee) + println(Employee.referencedKotlinType) + assert(Employee.referencedKotlinType.jvmErasure == Employee::class) + + println(Employees) + println(Employees.entityClass) + assert(Employees.entityClass == Employee::class) + + println(Employees.aliased("t")) + println(Employees.aliased("t").entityClass) + assert(Employees.aliased("t").entityClass == Employee::class) + } + + @Test + fun testEntityProperties() { + val employee = Employee { + name = "vince" + } + + println(employee) + + assert(employee["name"] == "vince") + assert(employee.name == "vince") + assert(employee.upperName == "VINCE") + assert(employee.upperName() == "VINCE") + + assert(employee["job"] == null) + assert(employee.job == "") + } + + @Test + fun testSerialize() = runBlocking { + val employee = Employee { + name = "jerry" + job = "trainee" + manager = database.employees.find { it.name eq "vince" } + hireDate = LocalDate.now() + salary = 50 + department = database.departments.find { it.name eq "tech" } ?: throw AssertionError() + } + + val bytes = serialize(employee) + println(Base64.getEncoder().encodeToString(bytes)) + } + + @Test + fun testDeserialize() = runBlocking { + Department { + name = "test" + println(this.javaClass) + println(this) + } + + Employee { + name = "test" + println(this.javaClass) + println(this) + } + + val str = + "rO0ABXN9AAAAAQAbb3JnLmt0b3JtLkJhc2VUZXN0JEVtcGxveWVleHIAF2phdmEubGFuZy5yZWZsZWN0LlByb3h54SfaIMwQQ8sCAAFMAAFodAAlTGphdmEvbGFuZy9yZWZsZWN0L0ludm9jYXRpb25IYW5kbGVyO3hwc3IAJW9yZy5rdG9ybS5lbnRpdHkuRW50aXR5SW1wbGVtZW50YXRpb24AAAAAAAAAAQMAAkwAC2VudGl0eUNsYXNzdAAXTGtvdGxpbi9yZWZsZWN0L0tDbGFzcztMAAZ2YWx1ZXN0ABlMamF2YS91dGlsL0xpbmtlZEhhc2hNYXA7eHB3HQAbb3JnLmt0b3JtLkJhc2VUZXN0JEVtcGxveWVlc3IAF2phdmEudXRpbC5MaW5rZWRIYXNoTWFwNMBOXBBswPsCAAFaAAthY2Nlc3NPcmRlcnhyABFqYXZhLnV0aWwuSGFzaE1hcAUH2sHDFmDRAwACRgAKbG9hZEZhY3RvckkACXRocmVzaG9sZHhwP0AAAAAAAAx3CAAAABAAAAAGdAAEbmFtZXQABWplcnJ5dAADam9idAAHdHJhaW5lZXQAB21hbmFnZXJzcQB+AABzcQB+AAR3HQAbb3JnLmt0b3JtLkJhc2VUZXN0JEVtcGxveWVlc3EAfgAIP0AAAAAAAAx3CAAAABAAAAAGdAACaWRzcgARamF2YS5sYW5nLkludGVnZXIS4qCk94GHOAIAAUkABXZhbHVleHIAEGphdmEubGFuZy5OdW1iZXKGrJUdC5TgiwIAAHhwAAAAAXEAfgALdAAFdmluY2VxAH4ADXQACGVuZ2luZWVydAAIaGlyZURhdGVzcgANamF2YS50aW1lLlNlcpVdhLobIkiyDAAAeHB3BwMAAAfiAQF4dAAGc2FsYXJ5c3IADmphdmEubGFuZy5Mb25nO4vkkMyPI98CAAFKAAV2YWx1ZXhxAH4AFQAAAAAAAABkdAAKZGVwYXJ0bWVudHN9AAAAAQAdb3JnLmt0b3JtLkJhc2VUZXN0JERlcGFydG1lbnR4cQB+AAFzcQB+AAR3HwAdb3JnLmt0b3JtLkJhc2VUZXN0JERlcGFydG1lbnRzcQB+AAg/QAAAAAAADHcIAAAAEAAAAAN0AAJpZHEAfgAWdAAEbmFtZXQABHRlY2h0AAhsb2NhdGlvbnNyACJvcmcua3Rvcm0uQmFzZVRlc3QkTG9jYXRpb25XcmFwcGVykiiIyygSeecCAAFMAAp1bmRlcmx5aW5ndAASTGphdmEvbGFuZy9TdHJpbmc7eHB0AAlHdWFuZ3pob3V4AHh4AHhxAH4AGXNxAH4AGncHAwAAB+QJGnhxAH4AHHNxAH4AHQAAAAAAAAAycQB+AB9zcQB+ACBzcQB+AAR3HwAdb3JnLmt0b3JtLkJhc2VUZXN0JERlcGFydG1lbnRzcQB+AAg/QAAAAAAADHcIAAAAEAAAAANxAH4AJHEAfgAWcQB+ACVxAH4AJnEAfgAnc3EAfgAocQB+ACt4AHh4AHg=" + val bytes = Base64.getDecoder().decode(str) + + val employee = deserialize(bytes) as Employee + println(employee.javaClass) + println(employee) + + assert(employee.name == "jerry") + assert(employee.job == "trainee") + assert(employee.manager?.name == "vince") + assert(employee.salary == 50L) + assert(employee.department.name == "tech") + } + + private fun serialize(obj: Any): ByteArray { + ByteArrayOutputStream().use { buffer -> + ObjectOutputStream(buffer).use { output -> + output.writeObject(obj) + output.flush() + return buffer.toByteArray() + } + } + } + + private fun deserialize(bytes: ByteArray): Any { + ByteArrayInputStream(bytes).use { buffer -> + ObjectInputStream(buffer).use { input -> + return input.readObject() + } + } + } + + @Test + fun testFind() = runBlocking { + val employee = database.employees.find { it.id eq 1 } ?: throw AssertionError() + println(employee) + + assert(employee.name == "vince") + assert(employee.job == "engineer") + } + + @Test + fun testFindWithReference() = runBlocking { + val employees = database.employees + .filter { it.department.location like "%Guangzhou%" } + .sortedBy { it.id } + .toList() + + assert(employees.size == 2) + assert(employees[0].name == "vince") + assert(employees[1].name == "marry") + } + + @Test + fun testCreateEntity() = runBlocking { + val employees = database + .from(Employees) + .joinReferencesAndSelect() + .where { + val dept = Employees.departmentId.referenceTable as Departments + dept.location like "%Guangzhou%" + } + .orderBy(Employees.id.asc()) + .map { Employees.createEntity(it) } + + assert(employees.size == 2) + assert(employees[0].name == "vince") + assert(employees[1].name == "marry") + } + + @Test + fun testUpdate() = runBlocking { + var employee = Employee() + employee.id = 2 + employee.job = "engineer" + employee.salary = 100 + // employee.manager = null + database.employees.update(employee) + + employee = database.employees.find { it.id eq 2 } ?: throw AssertionError() + assert(employee.job == "engineer") + assert(employee.salary == 100L) + assert(employee.manager?.id == 1) + } + + @Test + fun testFlushChanges() = runBlocking { + var employee = database.employees.find { it.id eq 2 } ?: throw AssertionError() + employee.job = "engineer" + employee.salary = 100 + employee.manager = null + employee.flushChanges() + employee.flushChanges() + + employee = database.employees.find { it.id eq 2 } ?: throw AssertionError() + assert(employee.job == "engineer") + assert(employee.salary == 100L) + assert(employee.manager == null) + } + + @Test + fun testDeleteEntity() = runBlocking { + val employee = database.employees.find { it.id eq 2 } ?: throw AssertionError() + employee.delete() + + assert(database.employees.count() == 3) + } + + @Test + fun testSaveEntity() = runBlocking { + var employee = Employee { + name = "jerry" + job = "trainee" + manager = null + hireDate = LocalDate.now() + salary = 50 + department = database.departments.find { it.name eq "tech" } ?: throw AssertionError() + } + + database.employees.add(employee) + println(employee) + + employee = database.employees.find { it.id eq 5 } ?: throw AssertionError() + assert(employee.name == "jerry") + assert(employee.department.name == "tech") + + employee.job = "engineer" + employee.salary = 100 + employee.flushChanges() + + employee = database.employees.find { it.id eq 5 } ?: throw AssertionError() + assert(employee.job == "engineer") + assert(employee.salary == 100L) + + employee.delete() + assert(database.employees.count() == 4) + } + + @Test + fun testFindMapById() = runBlocking { + val employees = database.employees.filter { it.id.inList(1, 2) }.associateBy { it.id } + assert(employees.size == 2) + assert(employees[1]?.name == "vince") + assert(employees[2]?.name == "marry") + } + + interface Parent : Entity { + companion object : Entity.Factory() + + var child: Child? + } + + interface Child : Entity { + companion object : Entity.Factory() + + var grandChild: GrandChild? + } + + interface GrandChild : Entity { + companion object : Entity.Factory() + + var id: Int? + } + + object Parents : Table("t_employee") { + val id = int("id").primaryKey().bindTo { it.child?.grandChild?.id } + } + + @Test + fun testHasColumnValue() { + val p1 = Parent() + assert(!p1.implementation.hasColumnValue(Parents.id.binding!!)) + assert(p1.implementation.getColumnValue(Parents.id.binding!!) == null) + + val p2 = Parent { + child = null + } + assert(p2.implementation.hasColumnValue(Parents.id.binding!!)) + assert(p2.implementation.getColumnValue(Parents.id.binding!!) == null) + + val p3 = Parent { + child = Child() + } + assert(!p3.implementation.hasColumnValue(Parents.id.binding!!)) + assert(p3.implementation.getColumnValue(Parents.id.binding!!) == null) + + val p4 = Parent { + child = Child { + grandChild = null + } + } + assert(p4.implementation.hasColumnValue(Parents.id.binding!!)) + assert(p4.implementation.getColumnValue(Parents.id.binding!!) == null) + + val p5 = Parent { + child = Child { + grandChild = GrandChild() + } + } + assert(!p5.implementation.hasColumnValue(Parents.id.binding!!)) + assert(p5.implementation.getColumnValue(Parents.id.binding!!) == null) + + val p6 = Parent { + child = Child { + grandChild = GrandChild { + id = null + } + } + } + assert(p6.implementation.hasColumnValue(Parents.id.binding!!)) + assert(p6.implementation.getColumnValue(Parents.id.binding!!) == null) + + val p7 = Parent { + child = Child { + grandChild = GrandChild { + id = 6 + } + } + } + assert(p7.implementation.hasColumnValue(Parents.id.binding!!)) + assert(p7.implementation.getColumnValue(Parents.id.binding!!) == 6) + } + + @Test + fun testUpdatePrimaryKey() = runBlocking { + try { + val parent = database.sequenceOf(Parents).find { it.id eq 1 } ?: throw AssertionError() + assert(parent.child?.grandChild?.id == 1) + + parent.child?.grandChild?.id = 2 + throw AssertionError() + + } catch (e: UnsupportedOperationException) { + // expected + println(e.message) + } + } + + interface EmployeeTestForReferencePrimaryKey : Entity { + var employee: Employee + var manager: EmployeeManagerTestForReferencePrimaryKey + } + + interface EmployeeManagerTestForReferencePrimaryKey : Entity { + var employee: Employee + } + + object EmployeeTestForReferencePrimaryKeys : Table("t_employee0") { + val id = int("id").primaryKey().references(Employees) { it.employee } + val managerId = int("manager_id").bindTo { it.manager.employee.id } + } + + @Test + fun testUpdateReferencesPrimaryKey() = runBlocking { + val e = database.sequenceOf(EmployeeTestForReferencePrimaryKeys).find { it.id eq 2 } ?: return@runBlocking + e.manager.employee = database.sequenceOf(Employees).find { it.id eq 1 } ?: return@runBlocking + + try { + e.employee = database.sequenceOf(Employees).find { it.id eq 1 } ?: return@runBlocking + throw AssertionError() + } catch (e: UnsupportedOperationException) { + // expected + println(e.message) + } + + e.flushChanges() + } + + @Test + fun testForeignKeyValue() = runBlocking { + val employees = database + .from(Employees) + .select() + .orderBy(Employees.id.asc()) + .map { Employees.createEntity(it) } + + val vince = employees[0] + assert(vince.manager == null) + assert(vince.department.id == 1) + + val marry = employees[1] + assert(marry.manager?.id == 1) + assert(marry.department.id == 1) + + val tom = employees[2] + assert(tom.manager == null) + assert(tom.department.id == 2) + + val penny = employees[3] + assert(penny.manager?.id == 3) + assert(penny.department.id == 2) + } + + @Test + fun testCreateEntityWithoutReferences() = runBlocking { + val employees = database + .from(Employees) + .leftJoin(Departments, on = Employees.departmentId eq Departments.id) + .select(Employees.columns + Departments.columns) + .map { Employees.createEntity(it, withReferences = false) } + + employees.forEach { println(it) } + + assert(employees.size == 4) + assert(employees[0].department.id == 1) + assert(employees[1].department.id == 1) + assert(employees[2].department.id == 2) + assert(employees[3].department.id == 2) + } + + @Test + fun testAutoDiscardChanges() = runBlocking { + var department = database.departments.find { it.id eq 2 } ?: return@runBlocking + department.name = "tech" + + val employee = Employee() + employee.department = department + employee.name = "jerry" + employee.job = "trainee" + employee.manager = database.employees.find { it.name eq "vince" } + employee.hireDate = LocalDate.now() + employee.salary = 50 + database.employees.add(employee) + + department.location = LocationWrapper("Guangzhou") + department.flushChanges() + + department = database.departments.find { it.id eq 2 } ?: return@runBlocking + assert(department.name == "tech") + assert(department.location.underlying == "Guangzhou") + } + + interface Emp : Entity { + companion object : Entity.Factory() + + val id: Int + var employee: Employee + var manager: Employee + var hireDate: LocalDate + var salary: Long + var departmentId: Int + } + + object Emps : Table("t_employee") { + val id = int("id").primaryKey().bindTo { it.id } + val name = varchar("name").bindTo { it.employee.name } + val job = varchar("job").bindTo { it.employee.job } + val managerId = int("manager_id").bindTo { it.manager.id } + val hireDate = date("hire_date").bindTo { it.hireDate } + val salary = long("salary").bindTo { it.salary } + val departmentId = int("department_id").bindTo { it.departmentId } + } + + val Database.emps get() = this.sequenceOf(Emps) + + @Test + fun testCheckUnexpectedFlush() = runBlocking { + val emp1 = database.emps.find { it.id eq 1 } ?: return@runBlocking + emp1.employee.name = "jerry" + // emp1.flushChanges() + + val emp2 = Emp { + employee = emp1.employee + hireDate = LocalDate.now() + salary = 100 + departmentId = 1 + } + + try { + database.emps.add(emp2) + throw AssertionError("failed") + + } catch (e: IllegalStateException) { + assert(e.message == "this.employee.name may be unexpectedly discarded, please save it to database first.") + } + } + + @Test + fun testCheckUnexpectedFlush0() = runBlocking { + val emp1 = database.emps.find { it.id eq 1 } ?: return@runBlocking + emp1.employee.name = "jerry" + // emp1.flushChanges() + + val emp2 = database.emps.find { it.id eq 2 } ?: return@runBlocking + emp2.employee = emp1.employee + + try { + emp2.flushChanges() + throw AssertionError("failed") + + } catch (e: IllegalStateException) { + assert(e.message == "this.employee.name may be unexpectedly discarded, please save it to database first.") + } + } + + @Test + fun testCheckUnexpectedFlush1() = runBlocking { + val employee = database.employees.find { it.id eq 1 } ?: return@runBlocking + employee.name = "jerry" + // employee.flushChanges() + + val emp = database.emps.find { it.id eq 2 } ?: return@runBlocking + emp.employee = employee + + try { + emp.flushChanges() + throw AssertionError("failed") + + } catch (e: IllegalStateException) { + assert(e.message == "this.employee.name may be unexpectedly discarded, please save it to database first.") + } + } + + @Test + fun testFlushChangesForDefaultValues() = runBlocking { + var emp = database.emps.find { it.id eq 1 } ?: return@runBlocking + emp.manager.id = 2 + emp.flushChanges() + + emp = database.emps.find { it.id eq 1 } ?: return@runBlocking + assert(emp.manager.id == 2) + } + + @Test + fun testDefaultValuesCache() = runBlocking { + val department = Department() + assert(department.id == 0) + assert(department["id"] == null) + } + + @Test + fun testCopyStatus() = runBlocking { + var employee = database.employees.find { it.id eq 2 }?.copy() ?: return@runBlocking + employee.name = "jerry" + employee.manager?.id = 3 + employee.flushChanges() + + employee = database.employees.find { it.id eq 2 } ?: return@runBlocking + assert(employee.name == "jerry") + assert(employee.manager?.id == 3) + } + + @Test + fun testDeepCopy() = runBlocking { + val employee = database.employees.find { it.id eq 2 } ?: return@runBlocking + val copy = employee.copy() + + assert(employee == copy) + assert(employee !== copy) + assert(employee.hireDate !== copy.hireDate) // should not be the same instance because of deep copy. + assert(copy.manager?.implementation?.parent === copy.implementation) // should keep the parent relationship. + } + + @Test + fun testRemoveIf() = runBlocking { + database.employees.removeIf { it.departmentId eq 1 } + assert(database.employees.count() == 2) + } + + @Test + fun testClear() = runBlocking { + database.employees.clear() + assert(database.employees.isEmpty()) + } + + @Test + fun testAddAndFlushChanges() = runBlocking { + var employee = Employee { + name = "jerry" + job = "trainee" + manager = database.employees.find { it.name eq "vince" } + hireDate = LocalDate.now() + salary = 50 + department = database.departments.find { it.name eq "tech" } ?: throw AssertionError() + } + + database.employees.add(employee) + + employee.job = "engineer" + employee.flushChanges() + + employee = database.employees.find { it.id eq employee.id } ?: throw AssertionError() + assert(employee.job == "engineer") + } + + @Test + fun testValueEquality() = runBlocking { + val now = LocalDate.now() + val employee1 = Employee { + id = 1 + name = "Eric" + job = "contributor" + hireDate = now + salary = 50 + } + + val employee2 = Employee { + id = 1 + name = "Eric" + job = "contributor" + hireDate = now + salary = 50 + } + + assert(employee1 == employee2) + } + + @Test + fun testDifferentClassesSameValuesNotEqual() { + val employee = Employee { + name = "name" + } + + val department = Department { + name = "name" + } + + assert(employee != department) + } +} diff --git a/ktorm-r2dbc-core/src/test/resources/drop-data.sql b/ktorm-r2dbc-core/src/test/resources/drop-data.sql index dcfef45..942a78d 100644 --- a/ktorm-r2dbc-core/src/test/resources/drop-data.sql +++ b/ktorm-r2dbc-core/src/test/resources/drop-data.sql @@ -1,6 +1,6 @@ -DROP TABLE IF EXISTS "T_DEPARTMENT"; -DROP TABLE IF EXISTS "T_EMPLOYEE"; -DROP TABLE IF EXISTS "T_EMPLOYEE0"; -DROP TABLE IF EXISTS "COMPANY"."T_CUSTOMER"; -DROP SCHEMA IF EXISTS "COMPANY"; +drop table if exists "t_department"; +drop table if exists "t_employee"; +drop table if exists "t_employee0"; +drop table if exists "company"."t_customer"; +drop schema if exists "company"; diff --git a/ktorm-r2dbc-core/src/test/resources/init-data.sql b/ktorm-r2dbc-core/src/test/resources/init-data.sql index 9d1cc8f..6a16aed 100644 --- a/ktorm-r2dbc-core/src/test/resources/init-data.sql +++ b/ktorm-r2dbc-core/src/test/resources/init-data.sql @@ -1,60 +1,60 @@ -CREATE TABLE "T_DEPARTMENT"( - "ID" INT NOT NULL PRIMARY KEY AUTO_INCREMENT, - "NAME" VARCHAR(128) NOT NULL, - "LOCATION" VARCHAR(128) NOT NULL, - "MIXEDCASE" VARCHAR(128) +create table "t_department"( + "id" int not null primary key auto_increment, + "name" varchar(128) not null, + "location" varchar(128) not null, + "mixedCase" varchar(128) ); -CREATE TABLE "T_EMPLOYEE"( - "ID" INT NOT NULL PRIMARY KEY AUTO_INCREMENT, - "NAME" VARCHAR(128) NOT NULL, - "JOB" VARCHAR(128) NOT NULL, - "MANAGER_ID" INT NULL, - "HIRE_DATE" DATE NOT NULL, - "SALARY" BIGINT NOT NULL, - "DEPARTMENT_ID" INT NOT NULL +create table "t_employee"( + "id" int not null primary key auto_increment, + "name" varchar(128) not null, + "job" varchar(128) not null, + "manager_id" int null, + "hire_date" date not null, + "salary" bigint not null, + "department_id" int not null ); -CREATE SCHEMA "COMPANY"; -CREATE TABLE "COMPANY"."T_CUSTOMER" ( - "ID" INT NOT NULL PRIMARY KEY AUTO_INCREMENT, - "NAME" VARCHAR(128) NOT NULL, - "EMAIL" VARCHAR(128) NOT NULL, - "PHONE_NUMBER" VARCHAR(128) NOT NULL +create schema "company"; +create table "company"."t_customer" ( + "id" int not null primary key auto_increment, + "name" varchar(128) not null, + "email" varchar(128) not null, + "phone_number" varchar(128) not null ); -INSERT INTO "T_DEPARTMENT"("NAME", "LOCATION") VALUES ('TECH', 'GUANGZHOU'); -INSERT INTO "T_DEPARTMENT"("NAME", "LOCATION") VALUES ('FINANCE', 'BEIJING'); +insert into "t_department"("name", "location") values ('tech', 'Guangzhou'); +insert into "t_department"("name", "location") values ('finance', 'Beijing'); -INSERT INTO "T_EMPLOYEE"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") - VALUES ('VINCE', 'ENGINEER', NULL, '2018-01-01', 100, 1); -INSERT INTO "T_EMPLOYEE"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") - VALUES ('MARRY', 'TRAINEE', 1, '2019-01-01', 50, 1); +insert into "t_employee"("name", "job", "manager_id", "hire_date", "salary", "department_id") + values ('vince', 'engineer', null, '2018-01-01', 100, 1); +insert into "t_employee"("name", "job", "manager_id", "hire_date", "salary", "department_id") + values ('marry', 'trainee', 1, '2019-01-01', 50, 1); -INSERT INTO "T_EMPLOYEE"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") - VALUES ('TOM', 'DIRECTOR', NULL, '2018-01-01', 200, 2); -INSERT INTO "T_EMPLOYEE"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") - VALUES ('PENNY', 'ASSISTANT', 3, '2019-01-01', 100, 2); +insert into "t_employee"("name", "job", "manager_id", "hire_date", "salary", "department_id") + values ('tom', 'director', null, '2018-01-01', 200, 2); +insert into "t_employee"("name", "job", "manager_id", "hire_date", "salary", "department_id") + values ('penny', 'assistant', 3, '2019-01-01', 100, 2); -CREATE TABLE "T_EMPLOYEE0"( - "ID" INT NOT NULL PRIMARY KEY AUTO_INCREMENT, - "NAME" VARCHAR(128) NOT NULL, - "JOB" VARCHAR(128) NOT NULL, - "MANAGER_ID" INT NULL, - "HIRE_DATE" DATE NOT NULL, - "SALARY" BIGINT NOT NULL, - "DEPARTMENT_ID" INT NOT NULL +create table "t_employee0"( + "id" int not null primary key auto_increment, + "name" varchar(128) not null, + "job" varchar(128) not null, + "manager_id" int null, + "hire_date" date not null, + "salary" bigint not null, + "department_id" int not null ); -INSERT INTO "T_EMPLOYEE0"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") - VALUES ('VINCE', 'ENGINEER', NULL, '2018-01-01', 100, 1); -INSERT INTO "T_EMPLOYEE0"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") - VALUES ('MARRY', 'TRAINEE', 1, '2019-01-01', 50, 1); +insert into "t_employee0"("name", "job", "manager_id", "hire_date", "salary", "department_id") + values ('vince', 'engineer', null, '2018-01-01', 100, 1); +insert into "t_employee0"("name", "job", "manager_id", "hire_date", "salary", "department_id") + values ('marry', 'trainee', 1, '2019-01-01', 50, 1); -INSERT INTO "T_EMPLOYEE0"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") - VALUES ('TOM', 'DIRECTOR', NULL, '2018-01-01', 200, 2); -INSERT INTO "T_EMPLOYEE0"("NAME", "JOB", "MANAGER_ID", "HIRE_DATE", "SALARY", "DEPARTMENT_ID") - VALUES ('PENNY', 'ASSISTANT', 3, '2019-01-01', 100, 2); \ No newline at end of file +insert into "t_employee0"("name", "job", "manager_id", "hire_date", "salary", "department_id") + values ('tom', 'director', null, '2018-01-01', 200, 2); +insert into "t_employee0"("name", "job", "manager_id", "hire_date", "salary", "department_id") + values ('penny', 'assistant', 3, '2019-01-01', 100, 2); From 3eb9f0864161bd45e6aa73a4114351cdf910e805 Mon Sep 17 00:00:00 2001 From: htt <641571835@qq.com> Date: Fri, 18 Feb 2022 19:06:43 +0800 Subject: [PATCH 11/17] add sqlType convert support --- .../org/ktorm/r2dbc/entity/EntitySequence.kt | 2 +- .../kotlin/org/ktorm/r2dbc/schema/SqlType.kt | 20 ++++++- .../kotlin/org/ktorm/r2dbc/schema/SqlTypes.kt | 60 +++++++++++++++++-- .../ktorm/r2dbc/entity/EntitySequenceTest.kt | 2 +- 4 files changed, 76 insertions(+), 8 deletions(-) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt index 6d014bd..d7e72c5 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt @@ -1413,7 +1413,7 @@ public suspend inline fun >> * * The operation is intermediate. */ -public suspend fun , K : Any> EntitySequence.groupingBy( +public fun , K : Any> EntitySequence.groupingBy( keySelector: (T) -> ColumnDeclaring ): EntityGrouping { return EntityGrouping(this, keySelector) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt index d3bde4e..10859f8 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt @@ -1,7 +1,6 @@ package org.ktorm.r2dbc.schema import io.r2dbc.spi.Row -import io.r2dbc.spi.RowMetadata import io.r2dbc.spi.Statement import kotlin.reflect.KClass @@ -58,6 +57,25 @@ public open class SimpleSqlType(public val kotlinType: KClass) : Sql } +public abstract class ConvertibleSqlType(kotlinType: KClass) : SimpleSqlType(kotlinType) { + + override val javaType: Class = kotlinType.javaObjectType + + public abstract fun convert(value: Any): R + + override fun getResult(row: Row, index: Int): R? { + val metadata = row.metadata.getColumnMetadata(index) + val value = row.get(index, metadata.javaType) ?: return null + return convert(value) + } + + override fun getResult(row: Row, name: String): R? { + val metadata = row.metadata.getColumnMetadata(name) + val value = row.get(name, metadata.javaType) ?: return null + return convert(value) + } +} + public class TransformedSqlType( public val underlyingType: SqlType, public val fromUnderlyingValue: (T) -> R, diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlTypes.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlTypes.kt index 9659eb4..9c1765d 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlTypes.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlTypes.kt @@ -43,7 +43,15 @@ public fun BaseTable<*>.int(name: String): Column { /** * [SqlType] implementation represents `int` SQL type. */ -public object IntSqlType : SimpleSqlType(Int::class) +public object IntSqlType : ConvertibleSqlType(Int::class) { + override fun convert(value: Any): Int { + return when (value) { + is Number -> value.toInt() + is String -> value.toInt() + else -> throw IllegalStateException("Converting type is not supported from value:$value") + } + } +} /** * Define a column typed of [ShortSqlType]. @@ -71,7 +79,16 @@ public fun BaseTable<*>.long(name: String): Column { /** * [SqlType] implementation represents `long` SQL type. */ -public object LongSqlType : SimpleSqlType(Long::class) +public object LongSqlType : ConvertibleSqlType(Long::class) { + override fun convert(value: Any): Long { + return when (value) { + is Number -> value.toLong() + is String -> value.toLong() + else -> throw IllegalStateException("Converting type is not supported from value:$value") + } + } +} + /** * Define a column typed of [FloatSqlType]. */ @@ -82,7 +99,15 @@ public fun BaseTable<*>.float(name: String): Column { /** * [SqlType] implementation represents `float` SQL type. */ -public object FloatSqlType : SimpleSqlType(Float::class) +public object FloatSqlType : ConvertibleSqlType(Float::class) { + override fun convert(value: Any): Float { + return when (value) { + is Number -> value.toFloat() + is String -> value.toFloat() + else -> throw IllegalStateException("Converting type is not supported from value:$value") + } + } +} /** * Define a column typed of [DoubleSqlType]. @@ -94,7 +119,15 @@ public fun BaseTable<*>.double(name: String): Column { /** * [SqlType] implementation represents `double` SQL type. */ -public object DoubleSqlType : SimpleSqlType(Double::class) +public object DoubleSqlType : ConvertibleSqlType(Double::class) { + override fun convert(value: Any): Double { + return when (value) { + is Number -> value.toDouble() + is String -> value.toDouble() + else -> throw IllegalStateException("Converting type is not supported from value:$value") + } + } +} /** * Define a column typed of [DecimalSqlType]. @@ -106,7 +139,20 @@ public fun BaseTable<*>.decimal(name: String): Column { /** * [SqlType] implementation represents `decimal` SQL type. */ -public object DecimalSqlType : SimpleSqlType(BigDecimal::class) +public object DecimalSqlType : ConvertibleSqlType(BigDecimal::class) { + override fun convert(value: Any): BigDecimal { + return when (value) { + is BigDecimal -> value + is Int -> BigDecimal(value) + is Long -> BigDecimal(value) + is Double -> BigDecimal(value) + is Float -> BigDecimal(value.toDouble()) + is String -> BigDecimal(value) + else -> throw IllegalStateException("Converting type is not supported from value:$value") + } + } + +} /** * Define a column typed of [VarcharSqlType]. @@ -143,6 +189,7 @@ public fun BaseTable<*>.blob(name: String): Column { * [SqlType] implementation represents `blob` SQL type. */ public object BlobSqlType : SimpleSqlType(ByteArray::class) + /** * Define a column typed of [BytesSqlType]. */ @@ -154,6 +201,7 @@ public fun BaseTable<*>.bytes(name: String): Column { * [SqlType] implementation represents `bytes` SQL type. */ public object BytesSqlType : SimpleSqlType(ByteArray::class) + /** * Define a column typed of [TimestampSqlType]. */ @@ -177,6 +225,7 @@ public fun BaseTable<*>.timestamp(name: String): Column { * [SqlType] implementation represents `timestamp` SQL type. */ public object InstantSqlType : SimpleSqlType(Instant::class) + /** * Define a column typed of [LocalDateTimeSqlType]. */ @@ -212,6 +261,7 @@ public fun BaseTable<*>.time(name: String): Column { * [SqlType] implementation represents `time` SQL type. */ public object LocalTimeSqlType : SimpleSqlType(LocalTime::class) + /** * Define a column typed of [MonthDaySqlType], instances of [MonthDay] are saved as strings in format `MM-dd`. */ diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt index 7953685..2b64990 100644 --- a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt @@ -154,8 +154,8 @@ class EntitySequenceTest : BaseTest() { println(counts) assert(counts.size == 2) - assert(counts[0] == 2L) assert(counts[1] == 2L) + assert(counts[2] == 2L) } @Test From 4e0bdaa9fae942eeb77a6d3d675cf5b67940f393 Mon Sep 17 00:00:00 2001 From: htt <641571835@qq.com> Date: Fri, 18 Feb 2022 21:20:41 +0800 Subject: [PATCH 12/17] add test update SqlType update queryRow fix batchUpdate fix entity create remove deprecated function --- ktorm-r2dbc-core/generate-tuples.gradle | 109 ------- .../org/ktorm/r2dbc/database/Database.kt | 22 +- .../kotlin/org/ktorm/r2dbc/dsl/Aggregation.kt | 6 +- .../main/kotlin/org/ktorm/r2dbc/dsl/Dml.kt | 28 +- .../kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt | 20 ++ .../r2dbc/entity/EntityImplementation.kt | 11 +- .../ktorm/r2dbc/expression/SqlExpressions.kt | 8 +- .../org/ktorm/r2dbc/schema/BaseTable.kt | 29 +- .../kotlin/org/ktorm/r2dbc/schema/SqlType.kt | 11 +- .../kotlin/org/ktorm/r2dbc/schema/Table.kt | 5 +- .../r2dbc/database/CircularReferenceTest.kt | 79 +++++ .../ktorm/r2dbc/database/CompoundKeysTest.kt | 111 +++++++ .../org/ktorm/r2dbc/database/DatabaseTest.kt | 89 +++--- .../org/ktorm/r2dbc/dsl/AggregationTest.kt | 87 ++++++ .../kotlin/org/ktorm/r2dbc/dsl/DmlTest.kt | 157 ++++++++++ .../kotlin/org/ktorm/r2dbc/dsl/JoinTest.kt | 86 ++++++ .../kotlin/org/ktorm/r2dbc/dsl/QueryTest.kt | 277 ++++++++++++++++++ .../org/ktorm/r2dbc/entity/DataClassTest.kt | 8 +- .../ktorm/r2dbc/entity/EntitySequenceTest.kt | 2 +- .../org/ktorm/r2dbc/entity/EntityTest.kt | 4 +- 20 files changed, 907 insertions(+), 242 deletions(-) create mode 100644 ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/CircularReferenceTest.kt create mode 100644 ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/CompoundKeysTest.kt create mode 100644 ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/AggregationTest.kt create mode 100644 ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/DmlTest.kt create mode 100644 ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/JoinTest.kt create mode 100644 ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/QueryTest.kt diff --git a/ktorm-r2dbc-core/generate-tuples.gradle b/ktorm-r2dbc-core/generate-tuples.gradle index 45a638d..abaf011 100644 --- a/ktorm-r2dbc-core/generate-tuples.gradle +++ b/ktorm-r2dbc-core/generate-tuples.gradle @@ -72,54 +72,6 @@ def generateMapColumns(Writer writer, int tupleNumber) { def resultExtractors = (1..tupleNumber).collect { "c${it}.sqlType.getResult(row, ${it-1})" }.join(", ") writer.write(""" - /** - * Customize the selected columns of the internal query by the given [columnSelector] function, and return a [List] - * containing the query results. - * - * See [EntitySequence.mapColumns] for more details. - * - * The operation is terminal. - * - * @param isDistinct specify if the query is distinct, the generated SQL becomes `select distinct` if it's set to true. - * @param columnSelector a function in which we should return a tuple of columns or expressions to be selected. - * @return a list of the query results. - */ - @Deprecated( - message = "This function will be removed in the future. Please use mapColumns { .. } instead.", - replaceWith = ReplaceWith("mapColumns(isDistinct, columnSelector)") - ) - public suspend inline fun , $typeParams> EntitySequence.mapColumns$tupleNumber( - isDistinct: Boolean = false, - columnSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> - ): List> { - return mapColumns(isDistinct, columnSelector) - } - - /** - * Customize the selected columns of the internal query by the given [columnSelector] function, and append the query - * results to the given [destination]. - * - * See [EntitySequence.mapColumnsTo] for more details. - * - * The operation is terminal. - * - * @param destination a [MutableCollection] used to store the results. - * @param isDistinct specify if the query is distinct, the generated SQL becomes `select distinct` if it's set to true. - * @param columnSelector a function in which we should return a tuple of columns or expressions to be selected. - * @return the [destination] collection of the query results. - */ - @Deprecated( - message = "This function will be removed in the future. Please use mapColumnsTo(destination) { .. } instead.", - replaceWith = ReplaceWith("mapColumnsTo(destination, isDistinct, columnSelector)") - ) - public suspend inline fun , $typeParams, R> EntitySequence.mapColumns${tupleNumber}To( - destination: R, - isDistinct: Boolean = false, - columnSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> - ): R where R : MutableCollection> { - return mapColumnsTo(destination, isDistinct, columnSelector) - } - /** * Customize the selected columns of the internal query by the given [columnSelector] function, and return a [List] * containing the query results. @@ -197,25 +149,6 @@ def generateAggregateColumns(Writer writer, int tupleNumber) { def resultExtractors = (1..tupleNumber).collect { "c${it}.sqlType.getResult(row, ${it-1})" }.join(", ") writer.write(""" - /** - * Perform a tuple of aggregations given by [aggregationSelector] for all elements in the sequence, - * and return the aggregate results. - * - * The operation is terminal. - * - * @param aggregationSelector a function that accepts the source table and returns a tuple of aggregate expressions. - * @return a tuple of the aggregate results. - */ - @Deprecated( - message = "This function will be removed in the future. Please use aggregateColumns { .. } instead.", - replaceWith = ReplaceWith("aggregateColumns(aggregationSelector)") - ) - public suspend inline fun , $typeParams> EntitySequence.aggregateColumns$tupleNumber( - aggregationSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> - ): Tuple$tupleNumber<$resultTypes> { - return aggregateColumns(aggregationSelector) - } - /** * Perform a tuple of aggregations given by [aggregationSelector] for all elements in the sequence, * and return the aggregate results. @@ -262,48 +195,6 @@ def generateGroupingAggregateColumns(Writer writer, int tupleNumber) { def resultExtractors = (1..tupleNumber).collect { "c${it}.sqlType.getResult(row, ${it})" }.join(", ") writer.write(""" - /** - * Group elements from the source sequence by key and perform the given aggregations for elements in each group, - * then store the results in a new [Map]. - * - * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: - * `select key, aggregation from source group by key`. - * - * @param aggregationSelector a function that accepts the source table and returns a tuple of aggregate expressions. - * @return a [Map] associating the key of each group with the results of aggregations of the group elements. - */ - @Deprecated( - message = "This function will be removed in the future. Please use aggregateColumns { .. } instead.", - replaceWith = ReplaceWith("aggregateColumns(aggregationSelector)") - ) - public suspend inline fun , K : Any, $typeParams> EntityGrouping.aggregateColumns$tupleNumber( - aggregationSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> - ): Map> { - return aggregateColumns(aggregationSelector) - } - - /** - * Group elements from the source sequence by key and perform the given aggregations for elements in each group, - * then store the results in the [destination] map. - * - * The key for each group is provided by the [EntityGrouping.keySelector] function, and the generated SQL is like: - * `select key, aggregation from source group by key`. - * - * @param destination a [MutableMap] used to store the results. - * @param aggregationSelector a function that accepts the source table and returns a tuple of aggregate expressions. - * @return the [destination] map associating the key of each group with the result of aggregations of the group elements. - */ - @Deprecated( - message = "This function will be removed in the future. Please use aggregateColumns(destination) { .. } instead.", - replaceWith = ReplaceWith("aggregateColumns(destination, aggregationSelector)") - ) - public suspend inline fun , K : Any, $typeParams, M> EntityGrouping.aggregateColumns${tupleNumber}To( - destination: M, - aggregationSelector: (T) -> Tuple$tupleNumber<$columnDeclarings> - ): M where M : MutableMap> { - return aggregateColumnsTo(destination, aggregationSelector) - } - /** * Group elements from the source sequence by key and perform the given aggregations for elements in each group, * then store the results in a new [Map]. diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt index 8a44044..f507261 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt @@ -2,7 +2,6 @@ package org.ktorm.r2dbc.database import io.r2dbc.spi.* import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.map import kotlinx.coroutines.reactive.asFlow import kotlinx.coroutines.reactive.awaitFirst import kotlinx.coroutines.reactive.awaitFirstOrNull @@ -12,9 +11,7 @@ import org.ktorm.r2dbc.expression.ArgumentExpression import org.ktorm.r2dbc.expression.SqlExpression import org.ktorm.r2dbc.logging.Logger import org.ktorm.r2dbc.logging.detectLoggerImplementation -import org.ktorm.r2dbc.schema.IntSqlType import org.ktorm.r2dbc.schema.SqlType -import java.sql.PreparedStatement import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -287,7 +284,8 @@ public class Database( useConnection { conn -> val statement = conn.createStatement(sql) - for (expr in expressions) { + val size = expressions.size + expressions.forEachIndexed { index, expr -> val (subSql, args) = formatExpression(expr) if (subSql != sql) { @@ -300,9 +298,10 @@ public class Database( } statement.bindParameters(args) - statement.add() + if (index < size - 1) { + statement.add() + } } - val results = statement.execute().toList() return results.map { result -> result.rowsUpdated.awaitFirst() }.toIntArray() } @@ -327,9 +326,10 @@ public class Database( useConnection { val statement = it.createStatement(sql) statement.bindParameters(args) - val rowsUpdated = statement.execute().awaitFirst().rowsUpdated.awaitFirst() - val rows = statement.returnGeneratedValues().execute().awaitFirst().map { row, _ -> row }.asFlow() - return Pair(rowsUpdated,rows) + val result = statement.returnGeneratedValues().execute().awaitFirst() + val rowsUpdated = result.rowsUpdated.awaitFirst() + val rows = result.map { row, _ -> row }.asFlow() + return Pair(rowsUpdated, rows) } } @@ -348,7 +348,7 @@ public class Database( dialect = dialect, logger = logger, alwaysQuoteIdentifiers = alwaysQuoteIdentifiers, - generateSqlInUpperCase = generateSqlInUpperCase + generateSqlInUpperCase = generateSqlInUpperCase ) } @@ -367,7 +367,7 @@ public class Database( dialect = dialect, logger = logger, alwaysQuoteIdentifiers = alwaysQuoteIdentifiers, - generateSqlInUpperCase = generateSqlInUpperCase + generateSqlInUpperCase = generateSqlInUpperCase ) } } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Aggregation.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Aggregation.kt index 30f06c6..d304844 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Aggregation.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Aggregation.kt @@ -19,9 +19,9 @@ package org.ktorm.r2dbc.dsl import org.ktorm.r2dbc.expression.AggregateExpression import org.ktorm.r2dbc.expression.AggregateType import org.ktorm.r2dbc.schema.ColumnDeclaring +import org.ktorm.r2dbc.schema.DoubleSqlType import org.ktorm.r2dbc.schema.IntSqlType import org.ktorm.r2dbc.schema.LongSqlType -import org.ktorm.r2dbc.schema.SimpleSqlType /** * The min function, translated to `min(column)` in SQL. @@ -55,14 +55,14 @@ public fun > maxDistinct(column: ColumnDeclaring): Aggregat * The avg function, translated to `avg(column)` in SQL. */ public fun avg(column: ColumnDeclaring): AggregateExpression { - return AggregateExpression(AggregateType.AVG, column.asExpression(), false, SimpleSqlType(Double::class)) + return AggregateExpression(AggregateType.AVG, column.asExpression(), false, DoubleSqlType) } /** * The avg function with distinct, translated to `avg(distinct column)` in SQL. */ public fun avgDistinct(column: ColumnDeclaring): AggregateExpression { - return AggregateExpression(AggregateType.AVG, column.asExpression(), true, SimpleSqlType(Double::class)) + return AggregateExpression(AggregateType.AVG, column.asExpression(), true, DoubleSqlType) } /** diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Dml.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Dml.kt index 986fdc6..f368e60 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Dml.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Dml.kt @@ -17,6 +17,7 @@ package org.ktorm.r2dbc.dsl import io.r2dbc.spi.Statement +import kotlinx.coroutines.flow.firstOrNull import org.ktorm.r2dbc.database.Database import org.ktorm.r2dbc.expression.* import org.ktorm.r2dbc.schema.BaseTable @@ -149,26 +150,23 @@ public suspend fun > Database.insert(table: T, block: Assignmen * @param block the DSL block, an extension function of [AssignmentsBuilder], used to construct the expression. * @return the first auto-generated key. */ -/* -TODO -public fun > Database.insertAndGenerateKey(table: T, block: AssignmentsBuilder.(T) -> Unit): Any { +public suspend fun > Database.insertAndGenerateKey( + table: T, + block: AssignmentsBuilder.(T) -> Unit +): Any { val builder = AssignmentsBuilder().apply { block(table) } val expression = AliasRemover.visit(InsertExpression(table.asExpression(), builder.assignments)) val (_, rowSet) = executeUpdateAndRetrieveKeys(expression) + val row = rowSet.firstOrNull() ?: error("No generated key returns by database.") + val pk = table.singlePrimaryKey { "Key retrieval is not supported for compound primary keys." } + val generatedKey = pk.sqlType.getResult(row, 0) ?: error("Generated key is null.") - if (rowSet.next()) { - val pk = table.singlePrimaryKey { "Key retrieval is not supported for compound primary keys." } - val generatedKey = pk.sqlType.getResult(rowSet, 1) ?: error("Generated key is null.") - - if (logger.isDebugEnabled()) { - logger.debug("Generated Key: $generatedKey") - } - - return generatedKey - } else { - error("No generated key returns by database.") + if (logger.isDebugEnabled()) { + logger.debug("Generated Key: $generatedKey") } -}*/ + + return generatedKey +} /** * Construct insert expressions in the given closure, then batch execute them and return the effected diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt index 31e8088..aedf43c 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt @@ -53,6 +53,26 @@ public class QueryRow internal constructor(public val query: Query, private val } } + /** + * Obtain the value of the specific [ColumnDeclaringExpression] instance. + * + * Note that if the column doesn't exist in the result set, this function will return null rather than + * throwing an exception. + */ + public operator fun get(column: ColumnDeclaringExpression): C? { + if (column.declaredName.isNullOrBlank()) { + throw IllegalArgumentException("Label of the specified column cannot be null or blank.") + } + + for (index in row.metadata.columnMetadatas.indices) { + if (row.metadata.columnMetadatas[index].name eq column.declaredName) { + return column.sqlType.getResult(row,index) + } + } + + // Return null if the column doesn't exist in the result set. + return null + } private infix fun String?.eq(other: String?) = this.equals(other, ignoreCase = true) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt index e85e337..5b6758a 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityImplementation.kt @@ -25,12 +25,8 @@ import java.lang.reflect.InvocationHandler import java.lang.reflect.InvocationTargetException import java.lang.reflect.Method import java.util.* -import kotlin.collections.LinkedHashMap -import kotlin.collections.LinkedHashSet -import kotlin.coroutines.Continuation import kotlin.reflect.KClass import kotlin.reflect.KProperty1 -import kotlin.reflect.full.functions import kotlin.reflect.jvm.javaGetter import kotlin.reflect.jvm.jvmErasure import kotlin.reflect.jvm.jvmName @@ -44,6 +40,7 @@ internal class EntityImplementation( ) : InvocationHandler, Serializable { var values = LinkedHashMap() + @Transient var changedProperties = LinkedHashSet() @@ -71,8 +68,10 @@ internal class EntityImplementation( "getEntityClass" -> this.entityClass "getProperties" -> Collections.unmodifiableMap(this.values) "discardChanges" -> this.doDiscardChanges() - "flushChanges" -> this.doFlushChangeFun.call(args!!.first()) - "delete" -> this.doDeleteFun.call(args!!.first()) + "flushChanges" -> kotlin.runCatching { this.doFlushChangeFun.call(args!!.first()) } + .onFailure { throw it.cause!! } + "delete" -> kotlin.runCatching { this.doDeleteFun.call(args!!.first()) } + .onFailure { throw it.cause!! } "get" -> this.values[args!![0] as String] "set" -> this.doSetProperty(args!![0] as String, args[1]) "copy" -> this.copy() diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlExpressions.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlExpressions.kt index 22c47c6..861b005 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlExpressions.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/expression/SqlExpressions.kt @@ -16,8 +16,8 @@ package org.ktorm.r2dbc.expression +import org.ktorm.r2dbc.schema.BooleanSqlType import org.ktorm.r2dbc.schema.ColumnDeclaring -import org.ktorm.r2dbc.schema.SimpleSqlType import org.ktorm.r2dbc.schema.SqlType /** @@ -450,7 +450,7 @@ public data class InListExpression( val query: QueryExpression? = null, val values: List>? = null, val notInList: Boolean = false, - override val sqlType: SqlType = SimpleSqlType(Boolean::class), + override val sqlType: SqlType = BooleanSqlType, override val isLeafNode: Boolean = false, override val extraProperties: Map = emptyMap() ) : ScalarExpression() @@ -464,7 +464,7 @@ public data class InListExpression( public data class ExistsExpression( val query: QueryExpression, val notExists: Boolean = false, - override val sqlType: SqlType = SimpleSqlType(Boolean::class), + override val sqlType: SqlType = BooleanSqlType, override val isLeafNode: Boolean = false, override val extraProperties: Map = emptyMap() ) : ScalarExpression() @@ -533,7 +533,7 @@ public data class BetweenExpression( val lower: ScalarExpression, val upper: ScalarExpression, val notBetween: Boolean = false, - override val sqlType: SqlType = SimpleSqlType(Boolean::class), + override val sqlType: SqlType = BooleanSqlType, override val isLeafNode: Boolean = false, override val extraProperties: Map = emptyMap() ) : ScalarExpression() diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt index a5906bd..103ba60 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt @@ -175,32 +175,6 @@ public abstract class BaseTable( } } - /** - * Transform the registered column's [SqlType] to another. The transformed [SqlType] has the same `typeCode` and - * `typeName` as the underlying one, and performs the specific transformations on column values. - * - * This enables a user-friendly syntax to extend more data types. For example, the following code defines a column - * of type `Column`, based on the existing column definition function [int]: - * - * ```kotlin - * val role = int("role").transform({ UserRole.fromCode(it) }, { it.code }) - * ``` - * - * Note: Since [Column] is immutable, this function will create a new [Column] instance and replace the origin - * registered one. - * - * @param fromUnderlyingValue a function that transforms a value of underlying type to the user's type. - * @param toUnderlyingValue a function that transforms a value of user's type the to the underlying type. - * @return the new [Column] instance with its type changed to [R]. - * @see SqlType.transform - */ - public inline fun Column.transform( - noinline fromUnderlyingValue: (C) -> R, - noinline toUnderlyingValue: (R) -> C, - ): Column { - return transform(fromUnderlyingValue, toUnderlyingValue, R::class.java) - } - /** * Transform the registered column's [SqlType] to another. The transformed [SqlType] has the same `typeCode` and * `typeName` as the underlying one, and performs the specific transformations on column values. @@ -223,12 +197,11 @@ public abstract class BaseTable( public fun Column.transform( fromUnderlyingValue: (C) -> R, toUnderlyingValue: (R) -> C, - javaType: Class ): Column { checkRegistered() checkTransformable() - val result = Column(table, name, sqlType = sqlType.transform(fromUnderlyingValue, toUnderlyingValue, javaType)) + val result = Column(table, name, sqlType = sqlType.transform(fromUnderlyingValue, toUnderlyingValue)) _columns[name] = result return result } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt index 10859f8..a1182d6 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt @@ -6,8 +6,6 @@ import kotlin.reflect.KClass public interface SqlType { - public val javaType: Class - public fun bindParameter(statement: Statement, index: Int, value: T?) public fun bindParameter(statement: Statement, name: String, value: T?) @@ -21,9 +19,8 @@ public interface SqlType { public fun SqlType.transform( fromUnderlyingValue: (T) -> R, toUnderlyingValue: (R) -> T, - javaType: Class ): SqlType { - return TransformedSqlType(this, fromUnderlyingValue, toUnderlyingValue, javaType) + return TransformedSqlType(this, fromUnderlyingValue, toUnderlyingValue) } public open class SimpleSqlType(public val kotlinType: KClass) : SqlType { @@ -52,15 +49,10 @@ public open class SimpleSqlType(public val kotlinType: KClass) : Sql return row.get(name, kotlinType.javaObjectType) } - override val javaType: Class - get() = kotlinType.javaObjectType - } public abstract class ConvertibleSqlType(kotlinType: KClass) : SimpleSqlType(kotlinType) { - override val javaType: Class = kotlinType.javaObjectType - public abstract fun convert(value: Any): R override fun getResult(row: Row, index: Int): R? { @@ -80,7 +72,6 @@ public class TransformedSqlType( public val underlyingType: SqlType, public val fromUnderlyingValue: (T) -> R, public val toUnderlyingValue: (R) -> T, - public override val javaType: Class ) : SqlType { override fun bindParameter(statement: Statement, index: Int, value: R?) { diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt index 8382707..7133b05 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/Table.kt @@ -18,9 +18,6 @@ package org.ktorm.r2dbc.schema import org.ktorm.r2dbc.dsl.QueryRow import org.ktorm.r2dbc.entity.* -import org.ktorm.r2dbc.entity.EntityImplementation -import org.ktorm.r2dbc.entity.implementation -import org.ktorm.r2dbc.entity.setColumnValue import kotlin.reflect.KClass import kotlin.reflect.KProperty1 import kotlin.reflect.jvm.jvmErasure @@ -144,7 +141,7 @@ public open class Table>( } private fun QueryRow.retrieveColumn(column: Column<*>, intoEntity: E, withReferences: Boolean) { - val columnValue = this[column] + val columnValue = this[column] ?: return for (binding in column.allBindings) { when (binding) { diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/CircularReferenceTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/CircularReferenceTest.kt new file mode 100644 index 0000000..a4ae58a --- /dev/null +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/CircularReferenceTest.kt @@ -0,0 +1,79 @@ +package org.ktorm.r2dbc.database + +import org.junit.Test +import org.ktorm.r2dbc.BaseTest +import org.ktorm.r2dbc.dsl.from +import org.ktorm.r2dbc.dsl.joinReferencesAndSelect +import org.ktorm.r2dbc.entity.Entity +import org.ktorm.r2dbc.schema.Table +import org.ktorm.r2dbc.schema.int + +/** + * Created by vince on Dec 19, 2018. + */ +class CircularReferenceTest : BaseTest() { + + interface Foo1 : Entity { + val id: Int + val foo2: Foo2 + } + + interface Foo2 : Entity { + val id: Int + val foo3: Foo3 + } + + interface Foo3 : Entity { + val id: Int + val foo1: Foo1 + } + + object Foos1 : Table("foo1") { + val id = int("id").primaryKey().bindTo { it.id } + val r1 = int("r1").references(Foos2) { it.foo2 } + } + + object Foos2 : Table("foo2") { + val id = int("id").primaryKey().bindTo { it.id } + val r2 = int("r2").references(Foos3) { it.foo3 } + } + + object Foos3 : Table("foo3") { + val id = int("id").primaryKey().bindTo { it.id } + val r3 = int("r3").references(Foos1) { it.foo1 } + } + + @Test + fun testCircularReference() { + try { + database.from(Foos1).joinReferencesAndSelect() + throw AssertionError("unexpected") + + } catch (e: ExceptionInInitializerError) { + val ex = e.cause as IllegalStateException + println(ex.message) + } + } + + interface Bar : Entity { + val id: Int + val bar: Bar + } + + object Bars : Table("bar") { + val id = int("id").primaryKey().bindTo { it.id } + val r = int("r").references(Bars) { it.bar } + } + + @Test + fun test() { + try { + database.from(Bars).joinReferencesAndSelect() + throw AssertionError("unexpected") + + } catch (e: ExceptionInInitializerError) { + val ex = e.cause as IllegalStateException + println(ex.message) + } + } +} \ No newline at end of file diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/CompoundKeysTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/CompoundKeysTest.kt new file mode 100644 index 0000000..e1858a3 --- /dev/null +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/CompoundKeysTest.kt @@ -0,0 +1,111 @@ +package org.ktorm.r2dbc.database + +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.ktorm.r2dbc.BaseTest +import org.ktorm.r2dbc.dsl.eq +import org.ktorm.r2dbc.dsl.from +import org.ktorm.r2dbc.dsl.joinReferencesAndSelect +import org.ktorm.r2dbc.entity.* +import org.ktorm.r2dbc.schema.* +import java.time.LocalDate + +/** + * Created by vince at Apr 07, 2020. + */ +class CompoundKeysTest : BaseTest() { + + interface Staff : Entity { + var id: Int + var departmentId: Int + var name: String + var job: String + var managerId: Int + var hireDate: LocalDate + var salary: Long + } + + object Staffs : Table("t_employee") { + val id = int("id").primaryKey().bindTo { it.id } + val departmentId = int("department_id").primaryKey().bindTo { it.departmentId } + val name = varchar("name").bindTo { it.name } + val job = varchar("job").bindTo { it.job } + val managerId = int("manager_id").bindTo { it.managerId } + val hireDate = date("hire_date").bindTo { it.hireDate } + val salary = long("salary").bindTo { it.salary } + } + + interface StaffRef : Entity { + val id: Int + val staff: Staff + } + + object StaffRefs : Table("t_staff_ref") { + val id = int("id").primaryKey().bindTo { it.id } + val staffId = int("staff_id").references(Staffs) { it.staff } + } + + val Database.staffs get() = this.sequenceOf(Staffs) + + val Database.staffRefs get() = this.sequenceOf(StaffRefs) + + @Test + fun testAdd() = runBlocking { + val staff = Entity.create() + staff.departmentId = 1 + staff.name = "jerry" + staff.job = "engineer" + staff.managerId = 1 + staff.hireDate = LocalDate.now() + staff.salary = 100 + database.staffs.add(staff) + println(staff) + assert(staff.id == 0) + } + + @Test + fun testFlushChanges() = runBlocking { + var staff = database.staffs.find { it.id eq 2 } ?: throw AssertionError() + staff.job = "engineer" + staff.salary = 100 + staff.flushChanges() + staff.flushChanges() + + staff = database.staffs.find { it.id eq 2 } ?: throw AssertionError() + assert(staff.job == "engineer") + assert(staff.salary == 100L) + } + + @Test + fun testDeleteEntity() = runBlocking { + val staff = database.staffs.find { it.id eq 2 } ?: throw AssertionError() + staff.delete() + + assert(database.staffs.count() == 3) + } + + @Test + fun testUpdatePrimaryKey() = runBlocking { + try { + val staff = database.staffs.find { it.id eq 1 } ?: return@runBlocking + staff.departmentId = 2 + throw AssertionError() + + } catch (e: UnsupportedOperationException) { + // expected + println(e.message) + } + } + + @Test + fun testReferenceTableWithCompoundKeys() = runBlocking { + try { + database.from(StaffRefs).joinReferencesAndSelect() + throw AssertionError("unexpected") + + } catch (e: ExceptionInInitializerError) { + val ex = e.cause as IllegalStateException + println(ex.message) + } + } +} \ No newline at end of file diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/DatabaseTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/DatabaseTest.kt index 09a058b..beaf9c7 100644 --- a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/DatabaseTest.kt +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/database/DatabaseTest.kt @@ -1,30 +1,43 @@ package org.ktorm.r2dbc.database +import kotlinx.coroutines.reactive.awaitFirst import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.runBlocking import org.junit.Test import org.ktorm.r2dbc.BaseTest -import org.ktorm.r2dbc.dsl.insert +import org.ktorm.r2dbc.dsl.* import org.ktorm.r2dbc.entity.* -import java.time.LocalDate +import org.ktorm.r2dbc.schema.* +import java.lang.IllegalStateException /** * Created by vince on Dec 02, 2018. */ + @ExperimentalUnsignedTypes class DatabaseTest : BaseTest() { + @Test + fun testMetadata() { + with(database) { + println(productName) + println(productVersion) + println(keywords.toString()) + println(identifierQuoteString) + println(extraNameCharacters) + } + } - /*@Test - fun testKeywordWrapping(): Unit = runBlocking { - val configs = object : Table("t_config") { - val key = varchar("key").primaryKey() - val value = varchar("value") + @Test + fun testKeywordWrapping() = runBlocking { + val configs = object : Table("T_CONFIG") { + val key = varchar("KEY").primaryKey() + val value = varchar("VALUE") } - database.useConnection { - val sql = """CREATE TABLE T_CONFIG(KEY VARCHAR(128) PRIMARY KEY, VALUE VARCHAR(128))""" - it.createStatement(sql).execute().awaitFirstOrNull() + database.useConnection { conn -> + val sql = """CREATE TABLE T_CONFIG("KEY" VARCHAR(128) PRIMARY KEY, "VALUE" VARCHAR(128))""" + conn.createStatement(sql).execute().awaitFirst() } database.insert(configs) { @@ -35,8 +48,8 @@ class DatabaseTest : BaseTest() { assert(database.sequenceOf(configs).count { it.key eq "test" } == 1) database.delete(configs) { it.key eq "test" } - }*/ -/* + Unit + } @Test fun testTransaction() = runBlocking { @@ -58,60 +71,46 @@ class DatabaseTest : BaseTest() { assert(database.departments.count() == 2) } } -*/ -/* @Test fun testRawSql() = runBlocking { val names = database.useConnection { conn -> val sql = """ - SELECT "NAME" FROM "T_EMPLOYEE" - WHERE "DEPARTMENT_ID" = ? - ORDER BY "ID" + select "name" from "t_employee" + where "department_id" = ? + order by "id" """ val statement = conn.createStatement(sql) statement.bind(0, 1) - statement.execute().awaitFirstOrNull()?.map { row, _ -> - row.get(0) - }?.toList() ?: emptyList() + statement.execute().awaitFirst().map { row, _ -> row[0, String::class.java] }.toList() } assert(names.size == 2) - assert(names[0] == "VINCE") - assert(names[1] == "MARRY") + assert(names[0] == "vince") + assert(names[1] == "marry") } -*/ - - - /*fun BaseTable<*>.ulong(name: String): Column { - return registerColumn(name, object : SqlType(Types.BIGINT, "bigint unsigned") { - override fun doSetParameter(ps: PreparedStatement, index: Int, parameter: ULong) { - ps.setLong(index, parameter.toLong()) - } - override fun doGetResult(rs: ResultSet, index: Int): ULong? { - return rs.getLong(index).toULong() - } - }) + fun BaseTable<*>.ulong(name: String): Column { + return registerColumn(name, LongSqlType.transform({ it.toULong() }, { it.toLong() })) } interface TestUnsigned : Entity { companion object : Entity.Factory() + var id: ULong } @Test - fun testUnsigned() { + fun testUnsigned() = runBlocking { val t = object : Table("T_TEST_UNSIGNED") { val id = ulong("ID").primaryKey().bindTo { it.id } } database.useConnection { conn -> - conn.createStatement().use { statement -> - val sql = """CREATE TABLE T_TEST_UNSIGNED(ID BIGINT UNSIGNED NOT NULL PRIMARY KEY)""" - statement.executeUpdate(sql) - } + val sql = """CREATE TABLE T_TEST_UNSIGNED(ID BIGINT NOT NULL PRIMARY KEY)""" + val statement = conn.createStatement(sql) + statement.execute().awaitFirst() } val unsigned = TestUnsigned { id = 5UL } @@ -135,20 +134,20 @@ class DatabaseTest : BaseTest() { interface TestUnsignedNullable : Entity { companion object : Entity.Factory() + var id: ULong? } @Test - fun testUnsignedNullable() { + fun testUnsignedNullable() = runBlocking { val t = object : Table("T_TEST_UNSIGNED_NULLABLE") { val id = ulong("ID").primaryKey().bindTo { it.id } } database.useConnection { conn -> - conn.createStatement().use { statement -> - val sql = """CREATE TABLE T_TEST_UNSIGNED_NULLABLE(ID BIGINT UNSIGNED NOT NULL PRIMARY KEY)""" - statement.executeUpdate(sql) - } + val sql = """CREATE TABLE T_TEST_UNSIGNED_NULLABLE(ID BIGINT NOT NULL PRIMARY KEY)""" + val statement = conn.createStatement(sql) + statement.execute().awaitFirst() } val unsigned = TestUnsignedNullable { id = 5UL } @@ -181,5 +180,5 @@ class DatabaseTest : BaseTest() { assert(UShortArray::class.java.defaultValue !== UShortArray::class.java.defaultValue) assert(UIntArray::class.java.defaultValue !== UIntArray::class.java.defaultValue) assert(ULongArray::class.java.defaultValue !== ULongArray::class.java.defaultValue) - }*/ + } } diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/AggregationTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/AggregationTest.kt new file mode 100644 index 0000000..69b00c6 --- /dev/null +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/AggregationTest.kt @@ -0,0 +1,87 @@ +package org.ktorm.r2dbc.dsl + +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.ktorm.r2dbc.BaseTest +import org.ktorm.r2dbc.entity.* + +/** + * Created by vince on Dec 09, 2018. + */ +class AggregationTest : BaseTest() { + + @Test + fun testCount() = runBlocking { + val count = database.employees.count { it.departmentId eq 1 } + assert(count == 2) + } + + @Test + fun testCountAll() = runBlocking { + val count = database.employees.count() + assert(count == 4) + } + + @Test + fun testSum() = runBlocking { + val sum = database.employees.sumBy { it.salary + 1 } + assert(sum == 454L) + } + + @Test + fun testMax() = runBlocking { + val max = database.employees.maxBy { it.salary - 1 } + assert(max == 199L) + } + + @Test + fun testMin() = runBlocking { + val min = database.employees.minBy { it.salary } + assert(min == 50L) + } + + @Test + fun testAvg() = runBlocking { + val avg = database.employees.averageBy { it.salary } + println(avg) + } + + @Test + fun testNone() = runBlocking { + assert(database.employees.none { it.salary greater 200L }) + } + + @Test + fun testAny() = runBlocking { + assert(!database.employees.any { it.salary greater 200L }) + } + + @Test + fun testAll() = runBlocking { + assert(database.employees.all { it.salary greater 0L }) + } + + @Test + fun testAggregate() = runBlocking { + val result = database.employees.aggregateColumns { max(it.salary) - min(it.salary) } + println(result) + assert(result == 150L) + } + + @Test + fun testAggregate2() = runBlocking { + val (max, min) = database.employees.aggregateColumns { tupleOf(max(it.salary), min(it.salary)) } + assert(max == 200L) + assert(min == 50L) + } + + @Test + fun testGroupAggregate3() = runBlocking { + database.employees + .groupingBy { it.departmentId } + .aggregateColumns { tupleOf(max(it.salary), min(it.salary)) } + .forEach { departmentId, (max, min) -> + println("$departmentId:$max:$min") + } + } +} \ No newline at end of file diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/DmlTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/DmlTest.kt new file mode 100644 index 0000000..bdc403b --- /dev/null +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/DmlTest.kt @@ -0,0 +1,157 @@ +package org.ktorm.r2dbc.dsl + +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.ktorm.r2dbc.BaseTest +import org.ktorm.r2dbc.entity.count +import org.ktorm.r2dbc.entity.find +import org.ktorm.r2dbc.entity.toList +import org.ktorm.r2dbc.schema.LongSqlType +import java.time.LocalDate + +/** + * Created by vince on Dec 08, 2018. + */ +class DmlTest : BaseTest() { + + @Test + fun testUpdate() = runBlocking { + database.update(Employees) { + set(it.job, "engineer") + set(it.managerId, null) + set(it.salary, 100) + + where { + it.id eq 2 + } + } + + val employee = database.employees.find { it.id eq 2 } ?: throw AssertionError() + assert(employee.name == "marry") + assert(employee.job == "engineer") + assert(employee.manager == null) + assert(employee.salary == 100L) + } + + @Test + fun testBatchUpdate() = runBlocking { + database.batchUpdate(Departments) { + for (i in 1..2) { + item { + set(it.location, LocationWrapper("Hong Kong")) + where { + it.id eq i + } + } + } + } + + val departments = database.departments.toList() + assert(departments.size == 2) + + for (dept in departments) { + assert(dept.location.underlying == "Hong Kong") + } + } + + @Test + fun testSelfIncrement() = runBlocking { + database.update(Employees) { + set(it.salary, it.salary + 1) + where { it.id eq 1 } + } + + val salary = database + .from(Employees) + .select(Employees.salary) + .where(Employees.id.eq(1)) + .map { LongSqlType.getResult(it, 0) } + .first() + + assert(salary == 101L) + } + + @Test + fun testInsert() = runBlocking { + database.insert(Employees) { + set(it.name, "jerry") + set(it.job, "trainee") + set(it.managerId, 1) + set(it.hireDate, LocalDate.now()) + set(it.salary, 50) + set(it.departmentId, 1) + } + + assert(database.employees.count() == 5) + } + + @Test + fun testInsertWithSchema() = runBlocking { + database.insert(Customers) { + set(it.name, "steve") + set(it.email, "steve@job.com") + set(it.phoneNumber, "0123456789") + } + + assert(database.customers.count() == 1) + } + + @Test + fun testBatchInsert() = runBlocking { + database.batchInsert(Employees) { + item { + set(it.name, "jerry") + set(it.job, "trainee") + set(it.managerId, 1) + set(it.hireDate, LocalDate.now()) + set(it.salary, 50) + set(it.departmentId, 1) + } + item { + set(it.name, "linda") + set(it.job, "assistant") + set(it.managerId, 3) + set(it.hireDate, LocalDate.now()) + set(it.salary, 100) + set(it.departmentId, 2) + } + } + + assert(database.employees.count() == 6) + } + + @Test + fun testInsertAndGenerateKey() = runBlocking { + val today = LocalDate.now() + + val id = database.insertAndGenerateKey(Employees) { + set(it.name, "jerry") + set(it.job, "trainee") + set(it.managerId, 1) + set(it.hireDate, today) + set(it.salary, 50) + set(it.departmentId, 1) + } + + val employee = database.employees.find { it.id eq (id as Int) } ?: throw AssertionError() + assert(employee.name == "jerry") + assert(employee.hireDate == today) + } + + @Test + fun testInsertFromSelect() = runBlocking { + database + .from(Departments) + .select(Departments.name, Departments.location) + .where { Departments.id eq 1 } + .insertTo(Departments, Departments.name, Departments.location) + + assert(database.departments.count() == 3) + } + + @Test + fun testDelete() = runBlocking { + database.delete(Employees) { it.id eq 4 } + assert(database.employees.count() == 3) + } +} \ No newline at end of file diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/JoinTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/JoinTest.kt new file mode 100644 index 0000000..3c26bc8 --- /dev/null +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/JoinTest.kt @@ -0,0 +1,86 @@ +package org.ktorm.r2dbc.dsl + +import kotlinx.coroutines.flow.count +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.ktorm.r2dbc.BaseTest +import org.ktorm.r2dbc.schema.VarcharSqlType + +/** + * Created by vince on Dec 08, 2018. + */ +class JoinTest : BaseTest() { + + @Test + fun testCrossJoin() = runBlocking { + val query = database.from(Employees).crossJoin(Departments).select() + assert(query.asFlow().count() == 8) + } + + @Test + fun testJoinWithConditions() = runBlocking { + val names = database + .from(Employees) + .leftJoin(Departments, on = Employees.departmentId eq Departments.id) + .select(Employees.name, Departments.name) + .where { Employees.managerId.isNull() } + .associate { VarcharSqlType.getResult(it, 0) to VarcharSqlType.getResult(it, 1) } + + assert(names.size == 2) + assert(names["vince"] == "tech") + assert(names["tom"] == "finance") + } + + @Test + fun testMultiJoin() = runBlocking { + data class Names(val name: String, val managerName: String, val departmentName: String) + + val emp = Employees.aliased("emp") + val mgr = Employees.aliased("mgr") + val dept = Departments.aliased("dept") + + val results = database + .from(emp) + .leftJoin(dept, on = emp.departmentId eq dept.id) + .leftJoin(mgr, on = emp.managerId eq mgr.id) + .select(emp.name, mgr.name, dept.name) + .orderBy(emp.id.asc()) + .map { row -> + Names( + name = row[emp.name].orEmpty(), + managerName = row[mgr.name].orEmpty(), + departmentName = row[dept.name].orEmpty() + ) + } + + assert(results.size == 4) + assert(results[0] == Names(name = "vince", managerName = "", departmentName = "tech")) + assert(results[1] == Names(name = "marry", managerName = "vince", departmentName = "tech")) + assert(results[2] == Names(name = "tom", managerName = "", departmentName = "finance")) + assert(results[3] == Names(name = "penny", managerName = "tom", departmentName = "finance")) + } + + @Test + fun testHasColumn() = runBlocking { + data class Names(val name: String, val managerName: String, val departmentName: String) + + val emp = Employees.aliased("emp") + val mgr = Employees.aliased("mgr") + val dept = Departments.aliased("dept") + + val results = database + .from(emp) + .select(emp.name) + .map { + Names( + name = it[emp.name].orEmpty(), + managerName = it[mgr.name].orEmpty(), + departmentName = it[dept.name].orEmpty() + ) + } + + results.forEach(::println) + assert(results.all { it.managerName == "" }) + assert(results.all { it.departmentName == "" }) + } +} diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/QueryTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/QueryTest.kt new file mode 100644 index 0000000..81fb242 --- /dev/null +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/dsl/QueryTest.kt @@ -0,0 +1,277 @@ +package org.ktorm.r2dbc.dsl + +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.ktorm.r2dbc.BaseTest +import org.ktorm.r2dbc.expression.ScalarExpression +import org.ktorm.r2dbc.schema.DoubleSqlType +import org.ktorm.r2dbc.schema.IntSqlType +import org.ktorm.r2dbc.schema.LongSqlType +import org.ktorm.r2dbc.schema.VarcharSqlType + +/** + * Created by vince on Dec 07, 2018. + */ +class QueryTest : BaseTest() { + + @Test + fun testSelect() = runBlocking { + val query = database.from(Departments).select() + val rows = query.asFlow().toList() + assert(rows.size == 2) + + for (row in rows) { + println(row[Departments.name] + ": " + row[Departments.location]) + } + } + + @Test + fun testSelectDistinct() = runBlocking { + val ids = database + .from(Employees) + .selectDistinct(Employees.departmentId) + .map { + it[Employees.id] + IntSqlType.getResult(it,0)!! + } + .sortedDescending() + + assert(ids.size == 2) + assert(ids[0] == 2) + assert(ids[1] == 1) + } + + @Test + fun testWhere() = runBlocking { + val name = database + .from(Employees) + .select(Employees.name) + .where { Employees.managerId.isNull() and (Employees.departmentId eq 1) } + .map { VarcharSqlType.getResult(it,0) } + .first() + + assert(name == "vince") + } + + @Test + fun testWhereWithConditions() = runBlocking { + val t = Employees.aliased("t") + + val name = database + .from(t) + .select(t.name) + .whereWithConditions { + it += t.managerId.isNull() + it += t.departmentId eq 1 + } + .map { VarcharSqlType.getResult(it,0) } + .first() + + assert(name == "vince") + } + + @Test + fun testCombineConditions() = runBlocking { + val t = Employees.aliased("t") + + val names = database + .from(t) + .select(t.name) + .where { emptyList>().combineConditions() } + .orderBy(t.id.asc()) + .map { it.get(0, String::class.java) } + + assert(names.size == 4) + assert(names[0] == "vince") + assert(names[1] == "marry") + } + + @Test + fun testOrderBy() = runBlocking { + val names = database + .from(Employees) + .select(Employees.name) + .where { Employees.departmentId eq 1 } + .orderBy(Employees.salary.desc()) + .map { VarcharSqlType.getResult(it,0) } + + assert(names.size == 2) + assert(names[0] == "vince") + assert(names[1] == "marry") + } + + @Test + fun testAggregation() = runBlocking { + val t = Employees + + val salaries = database + .from(t) + .select(t.departmentId, sum(t.salary)) + .groupBy(t.departmentId) + .associate { IntSqlType.getResult(it,0) to LongSqlType.getResult(it,1) } + + assert(salaries.size == 2) + assert(salaries[1]!! == 150L) + assert(salaries[2]!! == 300L) + } + + @Test + fun testHaving() = runBlocking { + val t = Employees + + val salaries = database + .from(t) + .select(t.departmentId, avg(t.salary)) + .groupBy(t.departmentId) + .having(avg(t.salary).greater(100.0)) + .associate { IntSqlType.getResult(it,0) to DoubleSqlType.getResult(it,1) } + + println(salaries) + assert(salaries.size == 1) + assert(salaries.keys.first() == 2) + } + + @Test + fun testColumnAlias() = runBlocking { + val deptId = Employees.departmentId.aliased("dept_id") + val salaryAvg = avg(Employees.salary).aliased("salary_avg") + + val salaries = database + .from(Employees) + .select(deptId, salaryAvg) + .groupBy(deptId) + .having { salaryAvg greater 100.0 } + .associate { row -> + row[deptId] to row[salaryAvg] + } + + println(salaries) + assert(salaries.size == 1) + assert(salaries.keys.first() == 2) + assert(salaries.values.first() == 150.0) + } + + @Test + fun testColumnAlias1() = runBlocking { + val salary = (Employees.salary + 100).aliased(null) + + val salaries = database + .from(Employees) + .select(salary) + .where { salary greater 200L } + .map { LongSqlType.getResult(it,0) } + + println(salaries) + assert(salaries.size == 1) + assert(salaries.first() == 300L) + } + + @Test + fun testLimit() = runBlocking { + try { + val query = database.from(Employees).select().orderBy(Employees.id.desc()).limit(0, 2) + assert(query.totalRecords() == 4L) + val records = query.asFlow().toList() + val ids = records.map { it[Employees.id] } + assert(ids[0] == 4) + assert(ids[1] == 3) + + } catch (e: UnsupportedOperationException) { + // Expected, pagination should be provided by dialects... + } + } + + @Test + fun testBetween() = runBlocking { + val names = database + .from(Employees) + .select(Employees.name) + .where { Employees.salary between 100L..200L } + .map { VarcharSqlType.getResult(it,0) } + + assert(names.size == 3) + println(names) + } + + @Test + fun testInList() = runBlocking { + val query = database + .from(Employees) + .select() + .where { Employees.id.inList(1, 2, 3) } + + assert(query.totalRecords() == 3L) + } + + @Test + fun testInNestedQuery() = runBlocking { + val departmentIds = database.from(Departments).selectDistinct(Departments.id) + + val query = database + .from(Employees) + .select() + .where { Employees.departmentId inList departmentIds } + + assert(query.totalRecords() == 4L) + + println(query.sql) + } + + @Test + fun testExists() = runBlocking { + val query = database + .from(Employees) + .select() + .where { + Employees.id.isNotNull() and exists( + database + .from(Departments) + .select() + .where { Departments.id eq Employees.departmentId } + ) + } + + assert(query.totalRecords() == 4L) + println(query.sql) + } + + @Test + fun testUnion() = runBlocking { + val query = database + .from(Employees) + .select(Employees.id) + .unionAll( + database.from(Departments).select(Departments.id) + ) + .unionAll( + database.from(Departments).select(Departments.id) + ) + .orderBy(Employees.id.desc()) + + assert(query.totalRecords() == 8L) + + println(query.sql) + } + + @Test + fun testMod() = runBlocking { + val query = database.from(Employees).select().where { Employees.id % 2 eq 1 } + assert(query.totalRecords() == 2L) + println(query.sql) + } + + @Test + fun testFlatMap()= runBlocking { + val names = database + .from(Employees) + .select(Employees.name) + .where { Employees.departmentId eq 1 } + .orderBy(Employees.salary.desc()) + .flatMapIndexed { index, row -> listOf("$index:${VarcharSqlType.getResult(row,0)}") } + + assert(names.size == 2) + assert(names[0] == "0:vince") + assert(names[1] == "1:marry") + } +} \ No newline at end of file diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/DataClassTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/DataClassTest.kt index 2a4e482..c7242c7 100644 --- a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/DataClassTest.kt +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/DataClassTest.kt @@ -1,13 +1,11 @@ package org.ktorm.r2dbc.entity -import io.r2dbc.spi.Row import kotlinx.coroutines.runBlocking import org.junit.Test import org.ktorm.r2dbc.BaseTest import org.ktorm.r2dbc.database.Database import org.ktorm.r2dbc.dsl.* import org.ktorm.r2dbc.schema.* -import org.ktorm.schema.* import java.time.LocalDate /** @@ -137,8 +135,8 @@ class DataClassTest : BaseTest() { println(salaries) assert(salaries.size == 2) - assert(salaries[1] == 150L) - assert(salaries[3] == 300L) + assert(salaries[2] == 150L) + assert(salaries[4] == 300L) } @Test @@ -150,8 +148,8 @@ class DataClassTest : BaseTest() { println(counts) assert(counts.size == 2) - assert(counts[0] == 2L) assert(counts[1] == 2L) + assert(counts[2] == 2L) } @Test diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt index 2b64990..4545866 100644 --- a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntitySequenceTest.kt @@ -185,7 +185,7 @@ class EntitySequenceTest : BaseTest() { @Test fun testSingle() = runBlocking { - val employee = database.employees.singleOrNull { it.departmentId eq 1 } + val employee = database.employees.singleOrNull { it.departmentId eq -1 } assert(employee == null) } diff --git a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntityTest.kt b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntityTest.kt index 8eb3a16..75650f1 100644 --- a/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntityTest.kt +++ b/ktorm-r2dbc-core/src/test/kotlin/org/ktorm/r2dbc/entity/EntityTest.kt @@ -81,7 +81,7 @@ class EntityTest : BaseTest() { } val str = - "rO0ABXN9AAAAAQAbb3JnLmt0b3JtLkJhc2VUZXN0JEVtcGxveWVleHIAF2phdmEubGFuZy5yZWZsZWN0LlByb3h54SfaIMwQQ8sCAAFMAAFodAAlTGphdmEvbGFuZy9yZWZsZWN0L0ludm9jYXRpb25IYW5kbGVyO3hwc3IAJW9yZy5rdG9ybS5lbnRpdHkuRW50aXR5SW1wbGVtZW50YXRpb24AAAAAAAAAAQMAAkwAC2VudGl0eUNsYXNzdAAXTGtvdGxpbi9yZWZsZWN0L0tDbGFzcztMAAZ2YWx1ZXN0ABlMamF2YS91dGlsL0xpbmtlZEhhc2hNYXA7eHB3HQAbb3JnLmt0b3JtLkJhc2VUZXN0JEVtcGxveWVlc3IAF2phdmEudXRpbC5MaW5rZWRIYXNoTWFwNMBOXBBswPsCAAFaAAthY2Nlc3NPcmRlcnhyABFqYXZhLnV0aWwuSGFzaE1hcAUH2sHDFmDRAwACRgAKbG9hZEZhY3RvckkACXRocmVzaG9sZHhwP0AAAAAAAAx3CAAAABAAAAAGdAAEbmFtZXQABWplcnJ5dAADam9idAAHdHJhaW5lZXQAB21hbmFnZXJzcQB+AABzcQB+AAR3HQAbb3JnLmt0b3JtLkJhc2VUZXN0JEVtcGxveWVlc3EAfgAIP0AAAAAAAAx3CAAAABAAAAAGdAACaWRzcgARamF2YS5sYW5nLkludGVnZXIS4qCk94GHOAIAAUkABXZhbHVleHIAEGphdmEubGFuZy5OdW1iZXKGrJUdC5TgiwIAAHhwAAAAAXEAfgALdAAFdmluY2VxAH4ADXQACGVuZ2luZWVydAAIaGlyZURhdGVzcgANamF2YS50aW1lLlNlcpVdhLobIkiyDAAAeHB3BwMAAAfiAQF4dAAGc2FsYXJ5c3IADmphdmEubGFuZy5Mb25nO4vkkMyPI98CAAFKAAV2YWx1ZXhxAH4AFQAAAAAAAABkdAAKZGVwYXJ0bWVudHN9AAAAAQAdb3JnLmt0b3JtLkJhc2VUZXN0JERlcGFydG1lbnR4cQB+AAFzcQB+AAR3HwAdb3JnLmt0b3JtLkJhc2VUZXN0JERlcGFydG1lbnRzcQB+AAg/QAAAAAAADHcIAAAAEAAAAAN0AAJpZHEAfgAWdAAEbmFtZXQABHRlY2h0AAhsb2NhdGlvbnNyACJvcmcua3Rvcm0uQmFzZVRlc3QkTG9jYXRpb25XcmFwcGVykiiIyygSeecCAAFMAAp1bmRlcmx5aW5ndAASTGphdmEvbGFuZy9TdHJpbmc7eHB0AAlHdWFuZ3pob3V4AHh4AHhxAH4AGXNxAH4AGncHAwAAB+QJGnhxAH4AHHNxAH4AHQAAAAAAAAAycQB+AB9zcQB+ACBzcQB+AAR3HwAdb3JnLmt0b3JtLkJhc2VUZXN0JERlcGFydG1lbnRzcQB+AAg/QAAAAAAADHcIAAAAEAAAAANxAH4AJHEAfgAWcQB+ACVxAH4AJnEAfgAnc3EAfgAocQB+ACt4AHh4AHg=" + "rO0ABXN9AAAAAQAhb3JnLmt0b3JtLnIyZGJjLkJhc2VUZXN0JEVtcGxveWVleHIAF2phdmEubGFuZy5yZWZsZWN0LlByb3h54SfaIMwQQ8sCAAFMAAFodAAlTGphdmEvbGFuZy9yZWZsZWN0L0ludm9jYXRpb25IYW5kbGVyO3hwc3IAK29yZy5rdG9ybS5yMmRiYy5lbnRpdHkuRW50aXR5SW1wbGVtZW50YXRpb24AAAAAAAAAAQMABEwAC2RvRGVsZXRlRnVudAAaTGtvdGxpbi9yZWZsZWN0L0tGdW5jdGlvbjtMABBkb0ZsdXNoQ2hhbmdlRnVucQB+AAVMAAtlbnRpdHlDbGFzc3QAF0xrb3RsaW4vcmVmbGVjdC9LQ2xhc3M7TAAGdmFsdWVzdAAZTGphdmEvdXRpbC9MaW5rZWRIYXNoTWFwO3hwdyMAIW9yZy5rdG9ybS5yMmRiYy5CYXNlVGVzdCRFbXBsb3llZXNyABdqYXZhLnV0aWwuTGlua2VkSGFzaE1hcDTATlwQbMD7AgABWgALYWNjZXNzT3JkZXJ4cgARamF2YS51dGlsLkhhc2hNYXAFB9rBwxZg0QMAAkYACmxvYWRGYWN0b3JJAAl0aHJlc2hvbGR4cD9AAAAAAAAMdwgAAAAQAAAABnQABG5hbWV0AAVqZXJyeXQAA2pvYnQAB3RyYWluZWV0AAdtYW5hZ2Vyc3EAfgAAc3EAfgAEdyMAIW9yZy5rdG9ybS5yMmRiYy5CYXNlVGVzdCRFbXBsb3llZXNxAH4ACT9AAAAAAAAMdwgAAAAQAAAAB3QAAmlkc3IAEWphdmEubGFuZy5JbnRlZ2VyEuKgpPeBhzgCAAFJAAV2YWx1ZXhyABBqYXZhLmxhbmcuTnVtYmVyhqyVHQuU4IsCAAB4cAAAAAFxAH4ADHQABXZpbmNlcQB+AA50AAhlbmdpbmVlcnEAfgAQc3EAfgAAc3EAfgAEdyMAIW9yZy5rdG9ybS5yMmRiYy5CYXNlVGVzdCRFbXBsb3llZXNxAH4ACT9AAAAAAAAMdwgAAAAQAAAAAXEAfgAUcHgAeHQACGhpcmVEYXRlc3IADWphdmEudGltZS5TZXKVXYS6GyJIsgwAAHhwdwcDAAAH4gEBeHQABnNhbGFyeXNyAA5qYXZhLmxhbmcuTG9uZzuL5JDMjyPfAgABSgAFdmFsdWV4cQB+ABYAAAAAAAAAZHQACmRlcGFydG1lbnRzfQAAAAEAI29yZy5rdG9ybS5yMmRiYy5CYXNlVGVzdCREZXBhcnRtZW50eHEAfgABc3EAfgAEdyUAI29yZy5rdG9ybS5yMmRiYy5CYXNlVGVzdCREZXBhcnRtZW50c3EAfgAJP0AAAAAAAAx3CAAAABAAAAAEdAACaWRxAH4AF3QABG5hbWV0AAR0ZWNodAAIbG9jYXRpb25zcgAob3JnLmt0b3JtLnIyZGJjLkJhc2VUZXN0JExvY2F0aW9uV3JhcHBlcuw3KJ5eUyi8AgABTAAKdW5kZXJseWluZ3QAEkxqYXZhL2xhbmcvU3RyaW5nO3hwdAAJR3Vhbmd6aG91dAAJbWl4ZWRDYXNlcHgAeHgAeHEAfgAdc3EAfgAedwcDAAAH5gISeHEAfgAgc3EAfgAhAAAAAAAAADJxAH4AI3NxAH4AJHNxAH4ABHclACNvcmcua3Rvcm0ucjJkYmMuQmFzZVRlc3QkRGVwYXJ0bWVudHNxAH4ACT9AAAAAAAAMdwgAAAAQAAAABHEAfgAocQB+ABdxAH4AKXEAfgAqcQB+ACtzcQB+ACxxAH4AL3EAfgAwcHgAeHgAeA==" val bytes = Base64.getDecoder().decode(str) val employee = deserialize(bytes) as Employee @@ -191,6 +191,8 @@ class EntityTest : BaseTest() { @Test fun testSaveEntity() = runBlocking { + val employees = database.employees.toList() + println(employees) var employee = Employee { name = "jerry" job = "trainee" From b9ddd2ef7b615ce9798d58c6979f3d1eef1a3c20 Mon Sep 17 00:00:00 2001 From: htt <641571835@qq.com> Date: Fri, 18 Feb 2022 21:36:13 +0800 Subject: [PATCH 13/17] update import update package name --- .../src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt | 2 +- .../src/main/kotlin/org/ktorm/r2dbc/entity/EntityGrouping.kt | 3 --- .../src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt | 2 +- .../src/main/kotlin/org/ktorm/r2dbc/schema/TypeReference.kt | 2 +- 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt index b3e1e11..05a32c3 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/Entity.kt @@ -18,7 +18,7 @@ package org.ktorm.r2dbc.entity import org.ktorm.r2dbc.database.Database import org.ktorm.r2dbc.schema.Table -import org.ktorm.schema.TypeReference +import org.ktorm.r2dbc.schema.TypeReference import java.io.ObjectInputStream import java.io.ObjectOutputStream import java.io.Serializable diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityGrouping.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityGrouping.kt index e017643..d09f13c 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityGrouping.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntityGrouping.kt @@ -16,13 +16,10 @@ package org.ktorm.r2dbc.entity -import kotlinx.coroutines.flow.collect import org.ktorm.r2dbc.dsl.* import org.ktorm.r2dbc.schema.BaseTable import org.ktorm.r2dbc.schema.ColumnDeclaring import java.util.* -import kotlin.collections.ArrayList -import kotlin.collections.LinkedHashMap import kotlin.experimental.ExperimentalTypeInference /** diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt index 103ba60..d46e133 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/BaseTable.kt @@ -18,7 +18,7 @@ package org.ktorm.r2dbc.schema import org.ktorm.r2dbc.dsl.QueryRow import org.ktorm.r2dbc.expression.TableExpression -import org.ktorm.schema.* +import org.ktorm.r2dbc.schema.* import java.util.* import kotlin.reflect.KClass import kotlin.reflect.jvm.jvmErasure diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/TypeReference.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/TypeReference.kt index 8cdf291..25a75c9 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/TypeReference.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/TypeReference.kt @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.ktorm.schema +package org.ktorm.r2dbc.schema import java.lang.reflect.ParameterizedType import java.lang.reflect.Type From a86326ddc25272d4641af99a889e8c36f62544e6 Mon Sep 17 00:00:00 2001 From: htt <641571835@qq.com> Date: Tue, 22 Feb 2022 19:19:24 +0800 Subject: [PATCH 14/17] add doc --- .../database/CoroutinesTransactionManager.kt | 11 +++- .../org/ktorm/r2dbc/database/Database.kt | 51 +++++++++++++++++++ .../org/ktorm/r2dbc/database/SqlDialect.kt | 15 +++++- .../r2dbc/database/TransactionManager.kt | 12 +++-- .../kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt | 44 ++++++++-------- .../org/ktorm/r2dbc/entity/EntitySequence.kt | 1 + 6 files changed, 108 insertions(+), 26 deletions(-) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt index 3c929b0..e6c7295 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/CoroutinesTransactionManager.kt @@ -6,12 +6,21 @@ import io.r2dbc.spi.IsolationLevel import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactive.awaitSingle import kotlinx.coroutines.withContext +import java.sql.DriverManager +import javax.sql.DataSource import kotlin.coroutines.AbstractCoroutineContextElement import kotlin.coroutines.CoroutineContext import kotlin.coroutines.coroutineContext /** - * Created by vince on Jan 30, 2021. + * [TransactionManager] implementation based on R2DBC. + * + * This class is capable of working in any environment with any R2DBC driver. It accepts a [connectionFactory] + * used to obtain SQL connections. + * + * [Database] instances created by [Database.connect] functions use this implementation by default. + * + * @property connectionFactory A factory for creating [Connection] */ public class CoroutinesTransactionManager( public val connectionFactory: ConnectionFactory diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt index f507261..c2eb418 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/Database.kt @@ -12,6 +12,7 @@ import org.ktorm.r2dbc.expression.SqlExpression import org.ktorm.r2dbc.logging.Logger import org.ktorm.r2dbc.logging.detectLoggerImplementation import org.ktorm.r2dbc.schema.SqlType +import java.sql.PreparedStatement import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -156,6 +157,16 @@ public class Database( } } + /** + * Obtain a connection and invoke the callback function with it. + * + * If the current thread has opened a transaction, then this transaction's connection will be used. + * Otherwise, Ktorm will pass a new-created connection to the function and auto close it after it's + * not useful anymore. + * + * @param func the executed callback function. + * @return the result of the callback function. + */ @OptIn(ExperimentalContracts::class) public suspend inline fun useConnection(func: (Connection) -> T): T { contract { @@ -176,6 +187,20 @@ public class Database( } } + /** + * Execute the specific callback function in a transaction and returns its result if the execution succeeds, + * otherwise, if the execution fails, the transaction will be rollback. + * + * Note: + * + * - Any exceptions thrown in the callback function can trigger a rollback. + * - This function is reentrant, so it can be called nested. However, the inner calls don’t open new transactions + * but share the same ones with outers. + * + * @param isolation transaction isolation, null for the default isolation level of the underlying datastore. + * @param func the executed callback function. + * @return the result of the callback function. + */ @OptIn(ExperimentalContracts::class) public suspend fun useTransaction( isolation: IsolationLevel? = null, @@ -211,6 +236,14 @@ public class Database( } } + /** + * Format the specific [SqlExpression] to an executable SQL string with execution arguments. + * + * @param expression the expression to be formatted. + * @param beautifySql output beautiful SQL strings with line-wrapping and indentation, default to `false`. + * @param indentSize the indent size, default to 2. + * @return a [Pair] combines the SQL string and its execution arguments. + */ public fun formatExpression( expression: SqlExpression, beautifySql: Boolean = false, @@ -221,6 +254,16 @@ public class Database( return Pair(formatter.sql, formatter.parameters) } + /** + * Format the given [expression] to a SQL string with its execution arguments, then create + * a [Statement] from the this database using the SQL string and execute the specific + * callback function with it. After the callback function completes. + * + * @since 2.7 + * @param expression the SQL expression to be executed. + * @param func the callback function. + * @return the result of the callback function. + */ @OptIn(ExperimentalContracts::class) public suspend inline fun executeExpression(expression: SqlExpression, func: (Result) -> T): T { contract { @@ -247,6 +290,14 @@ public class Database( } } + /** + * Format the given [expression] to a SQL string with its execution arguments, then execute it via + * [Statement.execute] and return the result [Flow]. + * + * @since 2.7 + * @param expression the SQL expression to be executed. + * @return the result [Flow]. + */ public suspend fun executeQuery(expression: SqlExpression): Flow { executeExpression(expression) { result -> return result.map { row, _ -> row }.asFlow() diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/SqlDialect.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/SqlDialect.kt index 2c51d6f..26ded3f 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/SqlDialect.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/SqlDialect.kt @@ -5,7 +5,20 @@ import org.ktorm.r2dbc.expression.SqlFormatter import java.util.* /** - * Created by vince on Feb 08, 2021. + * Representation of a SQL dialect. + * + * It's known that there is a uniform standard for SQL language, but beyond the standard, many databases still have + * their special features. The interface provides an extension mechanism for Ktorm and its extension modules to support + * those dialect-specific SQL features. + * + * Implementations of this interface are recommended to be published as separated modules independent of ktorm-core. + * + * To enable a dialect, applications should add the dialect module to the classpath first, then configure the `dialect` + * parameter to the dialect implementation while creating database instances via [Database.connect] functions. + * + * Ktorm's dialect modules start following the convention of JDK [ServiceLoader] SPI, so we don't + * need to specify the `dialect` parameter explicitly anymore while creating [Database] instances. Ktorm auto detects + * one for us from the classpath. We just need to insure the dialect module exists in the dependencies. */ public interface SqlDialect { diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt index 54ea24f..1ec888e 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt @@ -23,8 +23,7 @@ import io.r2dbc.spi.IsolationLevel * Transaction manager abstraction used to manage database connections and transactions. * * Applications can use this interface directly, but it is not primary meant as API: - * Typically, transactions are used by calling the [Database.useTransaction] function or - * Spring's [Transactional] annotation if the Spring support is enabled. + * Typically, transactions are used by calling the [Database.useTransaction] function. */ public interface TransactionManager { @@ -34,10 +33,17 @@ public interface TransactionManager { public val defaultIsolation: IsolationLevel? /** - * The opened transaction of the current thread, null if there is no transaction opened. + * The opened transaction of the current [CoroutineContext], null if there is no transaction opened. */ public suspend fun getCurrentTransaction(): Transaction? + /** + * Open a new transaction for the [CoroutineContext] using the specific isolation. + * + * @param isolation the transaction isolation, by default, [defaultIsolation] is used. + * @return the result of the callback function. + * @throws [IllegalStateException] if there is already a transaction opened. + */ public suspend fun useTransaction( isolation: IsolationLevel? = defaultIsolation, func: suspend (Transaction) -> T diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt index aedf43c..4501f64 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/QueryRow.kt @@ -3,26 +3,31 @@ package org.ktorm.r2dbc.dsl import io.r2dbc.spi.Row import org.ktorm.r2dbc.expression.ColumnDeclaringExpression import org.ktorm.r2dbc.schema.Column +import java.sql.ResultSet +/** + * Special implementation of [Row], used to hold the [Query] results for Ktorm. + * + * Different from normal rows, this class provides additional features: + * + * - **Indexed access operator:** It overloads the indexed access operator, so we can use square brackets `[]` to + * obtain the value by giving a specific [Column] instance. It’s less error prone by the benefit of the compiler’s + * static checking. Also, we can still use get functions in the [Row] to obtain our results by labels or + * column indices. + * + * ```kotlin + * val query = database.from(Employees).select() + * for (row in query.rowSet) { + * println(row[Employees.name]) + * } + * ``` + */ public class QueryRow internal constructor(public val query: Query, private val row: Row) : Row by row { - public operator fun get(column: ColumnDeclaringExpression, columnClass: Class): C? { - if (column.declaredName.isNullOrBlank()) { - throw IllegalArgumentException("Label of the specified column cannot be null or blank.") - } - val metadata = row.metadata - for (index in metadata.columnMetadatas.indices) { - if (metadata.getColumnMetadata(index).name eq column.declaredName) { - return row.get(index, columnClass) - } - } - return null - } - /** * Obtain the value of the specific [Column] instance. * - * Note that if the column doesn't exist in the result set, this function will return null rather than + * Note that if the column doesn't exist in the row, this function will return null rather than * throwing an exception. */ public operator fun get(column: Column): C? { @@ -34,19 +39,16 @@ public class QueryRow internal constructor(public val query: Query, private val return column.sqlType.getResult(row,index) } } - // Return null if the column doesn't exist in the result set. + // Return null if the column doesn't exist in the row. return null } else { // Try to find the column by name and its table name (happens when we are using `select *`). val indices = metadata.columnMetadatas.indices.filter { index -> - /*val tableName = metadata.getTableName(index) - val tableNameMatched = tableName.isBlank() || tableName eq table.alias || tableName eq table.tableName - val columnName = metaData.getColumnName(index)*/ metadata.columnMetadatas[index].name eq column.name/* && tableNameMatched*/ } return when (indices.size) { - 0 -> null // Return null if the column doesn't exist in the result set. + 0 -> null // Return null if the column doesn't exist in the row. 1 -> return column.sqlType.getResult(row,indices.first()) else -> throw IllegalArgumentException(warningConfusedColumnName(column.name)) } @@ -56,7 +58,7 @@ public class QueryRow internal constructor(public val query: Query, private val /** * Obtain the value of the specific [ColumnDeclaringExpression] instance. * - * Note that if the column doesn't exist in the result set, this function will return null rather than + * Note that if the column doesn't exist in the row, this function will return null rather than * throwing an exception. */ public operator fun get(column: ColumnDeclaringExpression): C? { @@ -70,7 +72,7 @@ public class QueryRow internal constructor(public val query: Query, private val } } - // Return null if the column doesn't exist in the result set. + // Return null if the column doesn't exist in the row. return null } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt index d7e72c5..41888a3 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt @@ -16,6 +16,7 @@ package org.ktorm.r2dbc.entity +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.* import org.ktorm.r2dbc.database.Database import org.ktorm.r2dbc.database.DialectFeatureNotSupportedException From 7e15b6f9d539745b3e5859be58739e5d133a1e04 Mon Sep 17 00:00:00 2001 From: htt <641571835@qq.com> Date: Tue, 22 Feb 2022 19:19:43 +0800 Subject: [PATCH 15/17] fix error --- .../src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt index 41888a3..4590a86 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt @@ -134,7 +134,7 @@ public class EntitySequence>( } public suspend fun asFlow():Flow { - return getRowSet().map(entityExtractor) + return getRowSet().map { entityExtractor(it) } } } From a181abd107947aa478fdf0520f0550eed16e33d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E9=94=90=E9=9B=84?= <641571835@qq.com> Date: Thu, 24 Feb 2022 18:13:15 +0800 Subject: [PATCH 16/17] update comment --- .../r2dbc/database/TransactionManager.kt | 1 + .../main/kotlin/org/ktorm/r2dbc/dsl/Query.kt | 7 ++++++- .../org/ktorm/r2dbc/entity/EntitySequence.kt | 20 ++++++++++--------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt index 1ec888e..1e1fffa 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt @@ -18,6 +18,7 @@ package org.ktorm.r2dbc.database import io.r2dbc.spi.Connection import io.r2dbc.spi.IsolationLevel +import kotlin.coroutines.CoroutineContext /** * Transaction manager abstraction used to manage database connections and transactions. diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt index 466bb27..c713c43 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/dsl/Query.kt @@ -23,7 +23,6 @@ import org.ktorm.r2dbc.schema.BooleanSqlType import org.ktorm.r2dbc.schema.Column import org.ktorm.r2dbc.schema.ColumnDeclaring import org.ktorm.r2dbc.schema.LongSqlType -import java.sql.ResultSet /** * [Query] is an abstraction of query operations and the core class of Ktorm's query DSL. @@ -79,10 +78,16 @@ public class Query(public val database: Database, public val expression: QueryEx database.formatExpression(expression, beautifySql = true).first } + /** + * The [QueryRow] object flow of this query + */ public suspend fun doQuery(expression: QueryExpression = this.expression): Flow { return database.executeQuery(expression).map { QueryRow(this@Query, it) } } + /** + * The [QueryRow] object flow of this query + */ public suspend fun asFlow(): Flow { return this.doQuery() } diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt index 4590a86..2635757 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/entity/EntitySequence.kt @@ -16,12 +16,12 @@ package org.ktorm.r2dbc.entity -import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.* import org.ktorm.r2dbc.database.Database import org.ktorm.r2dbc.database.DialectFeatureNotSupportedException import org.ktorm.r2dbc.dsl.* -import org.ktorm.r2dbc.expression.* +import org.ktorm.r2dbc.expression.OrderByExpression +import org.ktorm.r2dbc.expression.SelectExpression import org.ktorm.r2dbc.schema.BaseTable import org.ktorm.r2dbc.schema.Column import org.ktorm.r2dbc.schema.ColumnDeclaring @@ -41,18 +41,18 @@ import kotlin.math.min * ``` * * Now we got a default sequence, which can obtain all employees from the table. Please know that Ktorm doesn't execute - * the query right now. The sequence provides an iterator of type `Iterator`, only when we iterate the - * sequence using the iterator, the query is executed. The following code prints all employees using a for-each loop: + * the query right now. The sequence provides an flow of type `Flow`, only when we collect the + * sequence using the flow, the query is executed. The following code prints all employees using a forEach loop: * * ```kotlin - * for (employee in sequence) { + * sequence.forEach { employee -> * println(employee) * } * ``` * * This class wraps a [Query] object, and it’s iterator exactly wraps the query’s iterator. While an entity sequence is * iterated, its internal query is executed, and the [entityExtractor] function is applied to create an entity object - * for each row. As for other properties in sequences (such as [sql], [rowSet], [totalRecords], etc), all of them + * for each row. As for other properties in sequences (such as [sql], [getRowSet], [totalRecords], etc), all of them * delegates the callings to their internal query objects, and their usages are totally the same as the corresponding * properties in [Query] class. * @@ -104,10 +104,9 @@ public class EntitySequence>( public val sql: String get() = query.sql /** - * The [ResultSet] object of the internal query, lazy initialized after first access, obtained from the database by - * executing the generated SQL. + * The [QueryRow] object flow of the internal query * - * This property is delegated to [Query.rowSet], more details can be found in its documentation. + * This function is invoke [Query.doQuery], more details can be found in its documentation. */ public suspend fun getRowSet(): Flow = query.doQuery() @@ -133,6 +132,9 @@ public class EntitySequence>( return asFlow().toList().asSequence() } + /** + * Return an flow over the elements of this sequence. + */ public suspend fun asFlow():Flow { return getRowSet().map { entityExtractor(it) } } From 6ff3cc675492d0e9be7cdb2a842e68fc5a006db4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E9=94=90=E9=9B=84?= <641571835@qq.com> Date: Fri, 25 Feb 2022 18:07:29 +0800 Subject: [PATCH 17/17] update comment --- .../r2dbc/database/TransactionManager.kt | 1 + .../kotlin/org/ktorm/r2dbc/schema/SqlType.kt | 82 +++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt index 1e1fffa..bf94b79 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/database/TransactionManager.kt @@ -42,6 +42,7 @@ public interface TransactionManager { * Open a new transaction for the [CoroutineContext] using the specific isolation. * * @param isolation the transaction isolation, by default, [defaultIsolation] is used. + * @param func the executed callback function. * @return the result of the callback function. * @throws [IllegalStateException] if there is already a transaction opened. */ diff --git a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt index a1182d6..62b305f 100644 --- a/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt +++ b/ktorm-r2dbc-core/src/main/kotlin/org/ktorm/r2dbc/schema/SqlType.kt @@ -1,21 +1,55 @@ package org.ktorm.r2dbc.schema +import io.r2dbc.spi.ColumnMetadata import io.r2dbc.spi.Row import io.r2dbc.spi.Statement import kotlin.reflect.KClass +/** + * SQL data type interface. + * + * Based on R2DBC, [SqlType] and its subclasses encapsulate the common operations of obtaining data from a [Row] + * and setting parameters to a [Statement]. + * + */ public interface SqlType { + /** + * Binding the [value] to a given [Statement] + */ public fun bindParameter(statement: Statement, index: Int, value: T?) + /** + * Binding the [value] to a given [Statement] + */ public fun bindParameter(statement: Statement, name: String, value: T?) + /** + * Obtain a result from a given [Row] by [index], the result may be null. + */ public fun getResult(row: Row, index: Int): T? + /** + * Obtain a result from a given [Row] by [name], the result may be null. + */ public fun getResult(row: Row, name: String): T? } +/** + * Transform this [SqlType] to another. The returned [SqlType] performs a specific conversion on the column value. + * + * This function enables a user-friendly syntax to extend more data types. For example, the following code defines + * a column of type `Column`, based on the existing [IntSqlType]: + * + * ```kotlin + * val role by registerColumn("role", IntSqlType.transform({ UserRole.fromCode(it) }, { it.code })) + * ``` + * + * @param fromUnderlyingValue a function that transforms a value of underlying type to the user's type. + * @param toUnderlyingValue a function that transforms a value of user's type the to the underlying type. + * @return a [SqlType] instance based on this underlying type with specific transformations. + */ public fun SqlType.transform( fromUnderlyingValue: (T) -> R, toUnderlyingValue: (R) -> T, @@ -23,6 +57,12 @@ public fun SqlType.transform( return TransformedSqlType(this, fromUnderlyingValue, toUnderlyingValue) } +/** + * Simple [SqlType] implementation, pass the specified type [kotlinType] to [Statement] for binding or acquisition, + * and r2dbc driver parses the specified Java type + * + * @param kotlinType Specify the associated kotlin type + */ public open class SimpleSqlType(public val kotlinType: KClass) : SqlType { override fun bindParameter(statement: Statement, index: Int, value: T?) { @@ -51,8 +91,35 @@ public open class SimpleSqlType(public val kotlinType: KClass) : Sql } +/** + * Convertible SqlType type abstraction + * + * In the r2dbc query result, the metadata information [ColumnMetadata] of the data result is included, which includes + * the Java type corresponding to each column. However, what Java type corresponds to the specific SQL type is + * determined by R2DBC driver implementation, maybe the type is exactly what we want, maybe not. [ConvertibleSqlType] + * and its subclasses can convert the object returned by the R2DBC driver to the specified type, E.g: + * + * ```kotlin + * public object IntSqlType : ConvertibleSqlType(Int::class) { + * override fun convert(value: Any): Int { + * return when (value) { + * is Number -> value.toInt() + * is String -> value.toInt() + * else -> throw IllegalStateException("Converting type is not supported from value:$value") + * } + * } + * } + * + * ``` + * + * @param kotlinType Specify the kotlin type + */ public abstract class ConvertibleSqlType(kotlinType: KClass) : SimpleSqlType(kotlinType) { + /** + * Convert the object returned by the R2DBC driver query to the specified kotlinType + * @param value Value returned from R2DBC query + */ public abstract fun convert(value: Any): R override fun getResult(row: Row, index: Int): R? { @@ -68,6 +135,21 @@ public abstract class ConvertibleSqlType(kotlinType: KClass) : Simpl } } +/** + * Transform [underlyingType] to another. this [SqlType] performs a specific conversion on the column value. + * + * This function enables a user-friendly syntax to extend more data types. For example, the following code defines + * a column of type `Column`, based on the existing [IntSqlType]: + * + * ```kotlin + * val role by registerColumn("role", IntSqlType.transform({ UserRole.fromCode(it) }, { it.code })) + * ``` + * + * @see [transform] + * @param underlyingType [SqlType] to be converted + * @param fromUnderlyingValue a function that transforms a value of underlying type to the user's type. + * @param toUnderlyingValue a function that transforms a value of user's type the to the underlying type. + */ public class TransformedSqlType( public val underlyingType: SqlType, public val fromUnderlyingValue: (T) -> R,