Skip to content

Commit

Permalink
Fix for OkHttp
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Jan 24, 2025
1 parent b95564d commit e363385
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ abstract class HttpClientAsyncBackend[F[_], S <: Streams[S], BH, B](
readResponse(jResponse, Left(limitedBody), request)
}
} {
// the request might have been interrupted during sending (no publisher is available then), or any time
// after that, including right after the sending effect completed, but before the response was read
val llb = lowLevelBody.get()
if (llb != null) monad.eval(cancelLowLevelBody(llb)) else monad.unit(())
monad.eval {
// the request might have been interrupted during sending (no publisher is available then), or any time
// after that, including right after the sending effect completed, but before the response was read
val llb = lowLevelBody.get()
if (llb != null) cancelLowLevelBody(llb)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.concurrent.Future
import sttp.client4.compression.CompressionHandlers
import sttp.client4.compression.Compressor
import sttp.client4.compression.Decompressor
import cats.effect.ExitCase

class OkHttpMonixBackend private (
client: OkHttpClient,
Expand Down Expand Up @@ -112,6 +113,11 @@ class OkHttpMonixBackend private (

override protected def createSimpleQueue[T]: Task[SimpleQueue[Task, T]] =
Task.eval(new MonixSimpleQueue[T](webSocketBufferCapacity))

override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] =
effect.guaranteeCase { exit =>
if (exit == ExitCase.Completed) Task.unit else finalizer.onErrorHandleWith(t => Task.eval(t.printStackTrace()))
}
}

object OkHttpMonixBackend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import sttp.client4.{ignore, GenericRequest, Response}
import sttp.monad.{Canceler, MonadAsyncError}
import sttp.client4.compression.CompressionHandlers
import java.io.InputStream
import java.util.concurrent.atomic.AtomicReference

abstract class OkHttpAsyncBackend[F[_], S <: Streams[S], P](
client: OkHttpClient,
Expand All @@ -19,51 +20,72 @@ abstract class OkHttpAsyncBackend[F[_], S <: Streams[S], P](
compressionHandlers: CompressionHandlers[P, InputStream]
) extends OkHttpBackend[F, S, P](client, closeClient, compressionHandlers) {

// #1987: see the comments in HttpClientAsyncBackend
protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T]

override protected def sendRegular[T](request: GenericRequest[T, R]): F[Response[T]] = {
val nativeRequest = convertRequest(request)
monad.flatten(monad.async[F[Response[T]]] { cb =>
def success(r: F[Response[T]]): Unit = cb(Right(r))
val okHttpResponse = new AtomicReference[OkHttpResponse]()
ensureOnAbnormal {
monad.flatten(monad.async[F[Response[T]]] { cb =>
def success(r: F[Response[T]]): Unit = cb(Right(r))

def error(t: Throwable): Unit = cb(Left(t))
def error(t: Throwable): Unit = cb(Left(t))

val call = OkHttpBackend
.updateClientIfCustomReadTimeout(request, client)
.newCall(nativeRequest)
val call = OkHttpBackend
.updateClientIfCustomReadTimeout(request, client)
.newCall(nativeRequest)

call.enqueue(new Callback {
override def onFailure(call: Call, e: IOException): Unit =
error(e)
call.enqueue(new Callback {
override def onFailure(call: Call, e: IOException): Unit =
error(e)

override def onResponse(call: Call, response: OkHttpResponse): Unit =
try success(readResponse(response, request, request.response))
catch {
case e: Exception =>
response.close()
error(e)
override def onResponse(call: Call, response: OkHttpResponse): Unit = {
okHttpResponse.set(response)
try success(readResponse(response, request, request.response))
catch {
case e: Exception =>
try response.close()
finally error(e)
}
}
})
})

Canceler(() => call.cancel())
})
Canceler(() => call.cancel())
})
} {
monad.eval {
val response = okHttpResponse.get()
if (response != null) response.close()
}
}
}

override protected def sendWebSocket[T](
request: GenericRequest[T, R]
): F[Response[T]] = {
val nativeRequest = convertRequest(request)
monad.flatten(
createSimpleQueue[WebSocketEvent]
.flatMap { queue =>
monad.async[F[Response[T]]] { cb =>
val listener = createListener(queue, cb, request)
val ws = OkHttpBackend
.updateClientIfCustomReadTimeout(request, client)
.newWebSocket(nativeRequest, listener)

Canceler(() => ws.cancel())
val okHttpWS = new AtomicReference[okhttp3.WebSocket]()
ensureOnAbnormal {
monad.flatten(
createSimpleQueue[WebSocketEvent]
.flatMap { queue =>
monad.async[F[Response[T]]] { cb =>
val listener = createListener(queue, cb, request)
val ws = OkHttpBackend
.updateClientIfCustomReadTimeout(request, client)
.newWebSocket(nativeRequest, listener)

Canceler(() => ws.cancel())
}
}
}
)
)
} {
monad.eval {
val ws = okHttpWS.get()
if (ws != null) ws.cancel()
}
}
}

private def createListener[T](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class OkHttpFutureBackend private (
override val streams: NoStreams = NoStreams
override def streamToRequestBody(stream: Nothing, mt: MediaType, cl: Option[Long]): OkHttpRequestBody = stream
}

override protected def ensureOnAbnormal[T](effect: Future[T])(finalizer: => Future[Unit]): Future[T] =
effect.recoverWith { case e =>
finalizer.recoverWith { case e2 => e.addSuppressed(e2); Future.failed(e) }.flatMap(_ => Future.failed(e))
}
}

object OkHttpFutureBackend {
Expand Down

0 comments on commit e363385

Please sign in to comment.