@@ -2,6 +2,8 @@ import NIOCore
22import Logging
33
44final class PSQLRowStream {
5+ private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer < DataRow , Error , AdaptiveRowBuffer , PSQLRowStream > . Source
6+
57 enum RowSource {
68 case stream( PSQLRowsDataSource )
79 case noRows( Result < String , Error > )
@@ -23,7 +25,7 @@ final class PSQLRowStream {
2325 case consumed( Result < String , Error > )
2426
2527 #if canImport(_Concurrency)
26- case asyncSequence( AsyncStreamConsumer , PSQLRowsDataSource )
28+ case asyncSequence( AsyncSequenceSource , PSQLRowsDataSource )
2729 #endif
2830 }
2931
@@ -71,26 +73,35 @@ final class PSQLRowStream {
7173 preconditionFailure ( " Invalid state: \( self . downstreamState) " )
7274 }
7375
74- let consumer = AsyncStreamConsumer (
75- lookupTable: self . lookupTable,
76- columns: self . rowDescription
76+ let producer = NIOThrowingAsyncSequenceProducer . makeSequence (
77+ elementType: DataRow . self,
78+ failureType: Error . self,
79+ backPressureStrategy: AdaptiveRowBuffer ( ) ,
80+ delegate: self
7781 )
82+
83+ let source = producer. source
7884
7985 switch bufferState {
8086 case . streaming( let bufferedRows, let dataSource) :
81- consumer. startStreaming ( bufferedRows, upstream: self )
82- self . downstreamState = . asyncSequence( consumer, dataSource)
87+ let yieldResult = source. yield ( contentsOf: bufferedRows)
88+ self . downstreamState = . asyncSequence( source, dataSource)
89+
90+ self . eventLoop. execute {
91+ self . executeActionBasedOnYieldResult ( yieldResult, source: dataSource)
92+ }
8393
8494 case . finished( let buffer, let commandTag) :
85- consumer. startCompleted ( buffer, commandTag: commandTag)
95+ _ = source. yield ( contentsOf: buffer)
96+ source. finish ( )
8697 self . downstreamState = . consumed( . success( commandTag) )
8798
8899 case . failure( let error) :
89- consumer . startFailed ( error)
100+ source . finish ( error)
90101 self . downstreamState = . consumed( . failure( error) )
91102 }
92103
93- return PostgresRowSequence ( consumer )
104+ return PostgresRowSequence ( producer . sequence , lookupTable : self . lookupTable , columns : self . rowDescription )
94105 }
95106
96107 func demand( ) {
@@ -128,10 +139,8 @@ final class PSQLRowStream {
128139
129140 private func cancel0( ) {
130141 switch self . downstreamState {
131- case . asyncSequence( let consumer, let dataSource) :
132- let error = PSQLError . connectionClosed
133- self . downstreamState = . consumed( . failure( error) )
134- consumer. receive ( completion: . failure( error) )
142+ case . asyncSequence( _, let dataSource) :
143+ self . downstreamState = . consumed( . failure( CancellationError ( ) ) )
135144 dataSource. cancel ( for: self )
136145
137146 case . consumed:
@@ -305,8 +314,9 @@ final class PSQLRowStream {
305314 dataSource. request ( for: self )
306315
307316 #if canImport(_Concurrency)
308- case . asyncSequence( let consumer, _) :
309- consumer. receive ( newRows)
317+ case . asyncSequence( let consumer, let source) :
318+ let yieldResult = consumer. yield ( contentsOf: newRows)
319+ self . executeActionBasedOnYieldResult ( yieldResult, source: source)
310320 #endif
311321
312322 case . consumed( . success) :
@@ -345,8 +355,8 @@ final class PSQLRowStream {
345355 promise. succeed ( rows)
346356
347357 #if canImport(_Concurrency)
348- case . asyncSequence( let consumer , _) :
349- consumer . receive ( completion : . success ( commandTag ) )
358+ case . asyncSequence( let source , _) :
359+ source . finish ( )
350360 self . downstreamState = . consumed( . success( commandTag) )
351361 #endif
352362
@@ -373,14 +383,30 @@ final class PSQLRowStream {
373383
374384 #if canImport(_Concurrency)
375385 case . asyncSequence( let consumer, _) :
376- consumer. receive ( completion : . failure ( error) )
386+ consumer. finish ( error)
377387 self . downstreamState = . consumed( . failure( error) )
378388 #endif
379389
380390 case . consumed:
381391 break
382392 }
383393 }
394+
395+ private func executeActionBasedOnYieldResult( _ yieldResult: AsyncSequenceSource . YieldResult , source: PSQLRowsDataSource ) {
396+ self . eventLoop. preconditionInEventLoop ( )
397+ switch yieldResult {
398+ case . dropped:
399+ // ignore
400+ break
401+
402+ case . produceMore:
403+ source. request ( for: self )
404+
405+ case . stopProducing:
406+ // ignore
407+ break
408+ }
409+ }
384410
385411 var commandTag : String {
386412 guard case . consumed( . success( let commandTag) ) = self . downstreamState else {
@@ -390,14 +416,24 @@ final class PSQLRowStream {
390416 }
391417}
392418
419+ extension PSQLRowStream : NIOAsyncSequenceProducerDelegate {
420+ func produceMore( ) {
421+ self . demand ( )
422+ }
423+
424+ func didTerminate( ) {
425+ self . cancel ( )
426+ }
427+ }
428+
393429protocol PSQLRowsDataSource {
394430
395431 func request( for stream: PSQLRowStream )
396432 func cancel( for stream: PSQLRowStream )
397433
398434}
399435
400- #if swift(>=5.6 )
436+ #if swift(>=5.5 )
401437// Thread safety is guaranteed in the RowStream through dispatching onto the NIO EventLoop.
402438extension PSQLRowStream : @unchecked Sendable { }
403439#endif
0 commit comments