Skip to content

Commit 8b4fbb8

Browse files
committed
Lift assertions out of withContext calls
1 parent d2d3b05 commit 8b4fbb8

File tree

6 files changed

+58
-48
lines changed

6 files changed

+58
-48
lines changed

src/main/kotlin/com/github/michaelbull/jdbc/Transaction.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package com.github.michaelbull.jdbc
22

33
import com.github.michaelbull.jdbc.context.CoroutineTransaction
44
import com.github.michaelbull.jdbc.context.connection
5+
import com.github.michaelbull.jdbc.context.transaction
56
import kotlinx.coroutines.CoroutineScope
67
import kotlinx.coroutines.currentCoroutineContext
78
import kotlinx.coroutines.withContext
@@ -28,7 +29,7 @@ suspend inline fun <T> transaction(crossinline block: suspend CoroutineScope.()
2829
}
2930

3031
val ctx = currentCoroutineContext()
31-
val existingTransaction = ctx[CoroutineTransaction]
32+
val existingTransaction = ctx.transaction
3233

3334
return when {
3435
existingTransaction == null -> {

src/main/kotlin/com/github/michaelbull/jdbc/context/CoroutineTransaction.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@ package com.github.michaelbull.jdbc.context
33
import kotlin.coroutines.AbstractCoroutineContextElement
44
import kotlin.coroutines.CoroutineContext
55

6+
@PublishedApi
7+
internal val CoroutineContext.transaction: CoroutineTransaction?
8+
get() = get(CoroutineTransaction)
9+
610
@PublishedApi
711
internal class CoroutineTransaction(
8-
private var completed: Boolean = false
12+
private var completed: Boolean = false,
913
) : AbstractCoroutineContextElement(CoroutineTransaction) {
1014

1115
companion object Key : CoroutineContext.Key<CoroutineTransaction>

src/test/kotlin/com/github/michaelbull/jdbc/ConnectionTest.kt

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,18 @@ class ConnectionTest {
3939
every { connection } returns newConnection
4040
}
4141

42-
withContext(CoroutineDataSource(dataSource)) {
43-
val actual = withConnection {
42+
val actual = withContext(CoroutineDataSource(dataSource)) {
43+
withConnection {
4444
coroutineContext.connection
4545
}
46-
47-
assertEquals(newConnection, actual)
4846
}
47+
48+
assertEquals(newConnection, actual)
4949
}
5050

5151
@Test
5252
fun `withConnection adds new connection to context if existing connection isClosed returns true`() = runTest {
53-
val existingConnection = mockk<Connection> {
53+
val closedConnection = mockk<Connection> {
5454
every { isClosed } returns true
5555
}
5656

@@ -62,18 +62,18 @@ class ConnectionTest {
6262
every { connection } returns newConnection
6363
}
6464

65-
withContext(CoroutineDataSource(dataSource) + CoroutineConnection(existingConnection)) {
66-
val actual = withConnection {
65+
val actual = withContext(CoroutineDataSource(dataSource) + CoroutineConnection(closedConnection)) {
66+
withConnection {
6767
coroutineContext.connection
6868
}
69-
70-
assertEquals(newConnection, actual)
7169
}
70+
71+
assertEquals(newConnection, actual)
7272
}
7373

7474
@Test
7575
fun `withConnection adds new connection to context if existing connection isClosed throws exception`() = runTest {
76-
val existingConnection = mockk<Connection> {
76+
val brokenConnection = mockk<Connection> {
7777
every { isClosed } throws SQLException()
7878
}
7979

@@ -85,30 +85,30 @@ class ConnectionTest {
8585
every { connection } returns newConnection
8686
}
8787

88-
withContext(CoroutineDataSource(dataSource) + CoroutineConnection(existingConnection)) {
89-
val actual = withConnection {
88+
val actual = withContext(CoroutineDataSource(dataSource) + CoroutineConnection(brokenConnection)) {
89+
withConnection {
9090
coroutineContext.connection
9191
}
92-
93-
assertEquals(newConnection, actual)
9492
}
93+
94+
assertEquals(newConnection, actual)
9595
}
9696

9797
@Test
98-
fun `withConnection reuses existing connection in context if not closed`() = runTest {
99-
val existing = mockk<Connection> {
98+
fun `withConnection reuses open connection`() = runTest {
99+
val openConnection = mockk<Connection> {
100100
every { isClosed } returns false
101101
}
102102

103103
val dataSource = mockk<DataSource>()
104104

105-
withContext(CoroutineDataSource(dataSource) + CoroutineConnection(existing)) {
106-
val actual = withConnection {
105+
val actual = withContext(CoroutineDataSource(dataSource) + CoroutineConnection(openConnection)) {
106+
withConnection {
107107
coroutineContext.connection
108108
}
109-
110-
assertEquals(existing, actual)
111109
}
110+
111+
assertEquals(openConnection, actual)
112112
}
113113

114114
@Test
@@ -153,20 +153,20 @@ class ConnectionTest {
153153

154154
@Test
155155
fun `withConnection does not close connection if connection was not added to context`() = runTest {
156-
val existing = mockk<Connection> {
156+
val openConnection = mockk<Connection> {
157157
every { isClosed } returns false
158158
}
159159

160160
val dataSource = mockk<DataSource> {
161-
every { connection } returns existing
161+
every { connection } returns openConnection
162162
}
163163

164-
withContext(CoroutineDataSource(dataSource) + CoroutineConnection(existing)) {
164+
withContext(CoroutineDataSource(dataSource) + CoroutineConnection(openConnection)) {
165165
withConnection {
166166
/* empty */
167167
}
168168
}
169169

170-
verify(exactly = 0) { existing.close() }
170+
verify(exactly = 0) { openConnection.close() }
171171
}
172172
}

src/test/kotlin/com/github/michaelbull/jdbc/TransactionTest.kt

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package com.github.michaelbull.jdbc
22

33
import com.github.michaelbull.jdbc.context.CoroutineConnection
44
import com.github.michaelbull.jdbc.context.CoroutineTransaction
5+
import com.github.michaelbull.jdbc.context.transaction
56
import io.mockk.every
67
import io.mockk.just
78
import io.mockk.mockk
@@ -66,13 +67,13 @@ class TransactionTest {
6667
fun `transaction reuses existing transaction in context if incomplete`() = runTest {
6768
val incompleteTransaction = CoroutineTransaction(completed = false)
6869

69-
withContext(incompleteTransaction) {
70-
val actual = transaction {
71-
coroutineContext[CoroutineTransaction]
70+
val actual = withContext(incompleteTransaction) {
71+
transaction {
72+
coroutineContext.transaction
7273
}
73-
74-
assertEquals(incompleteTransaction, actual)
7574
}
75+
76+
assertEquals(incompleteTransaction, actual)
7677
}
7778

7879
@Test
@@ -104,13 +105,13 @@ class TransactionTest {
104105
every { autoCommit } returns true
105106
}
106107

107-
withContext(CoroutineConnection(connection)) {
108-
val transaction = transaction {
109-
coroutineContext[CoroutineTransaction]
108+
val actual = withContext(CoroutineConnection(connection)) {
109+
transaction {
110+
coroutineContext.transaction
110111
}
111-
112-
assertNotNull(transaction)
113112
}
113+
114+
assertNotNull(actual)
114115
}
115116

116117
@Test
@@ -119,13 +120,13 @@ class TransactionTest {
119120
every { autoCommit } returns true
120121
}
121122

122-
withContext(CoroutineConnection(connection)) {
123-
val transaction = runTransactionally {
124-
coroutineContext[CoroutineTransaction]
123+
val actual = withContext(CoroutineConnection(connection)) {
124+
runTransactionally {
125+
coroutineContext.transaction
125126
}
126-
127-
assertNotNull(transaction)
128127
}
128+
129+
assertNotNull(actual)
129130
}
130131

131132
@Test
@@ -135,7 +136,9 @@ class TransactionTest {
135136
}
136137

137138
withContext(CoroutineConnection(connection)) {
138-
runTransactionally {}
139+
runTransactionally {
140+
/* empty */
141+
}
139142
}
140143

141144
verify(exactly = 1) { connection.commit() }

src/test/kotlin/com/github/michaelbull/jdbc/context/CoroutineConnectionTest.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ class CoroutineConnectionTest {
2121
fun `connection returns connection if in context`() = runTest {
2222
val expected = mockk<Connection>()
2323

24-
withContext(CoroutineConnection(expected)) {
25-
val actual = coroutineContext.connection
26-
assertEquals(expected, actual)
24+
val actual = withContext(CoroutineConnection(expected)) {
25+
coroutineContext.connection
2726
}
27+
28+
assertEquals(expected, actual)
2829
}
2930
}

src/test/kotlin/com/github/michaelbull/jdbc/context/CoroutineDataSourceTest.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ class CoroutineDataSourceTest {
2121
fun `dataSource returns connection if in context`() = runTest {
2222
val expected = mockk<DataSource>()
2323

24-
withContext(CoroutineDataSource(expected)) {
25-
val actual = coroutineContext.dataSource
26-
assertEquals(expected, actual)
24+
val actual = withContext(CoroutineDataSource(expected)) {
25+
coroutineContext.dataSource
2726
}
27+
28+
assertEquals(expected, actual)
2829
}
2930
}

0 commit comments

Comments
 (0)