Skip to content

Commit e2e6fa8

Browse files
authored
implement parallel finalizers mask (zio#8204)
1 parent 2763016 commit e2e6fa8

File tree

6 files changed

+176
-42
lines changed

6 files changed

+176
-42
lines changed

core-tests/shared/src/test/scala/zio/ZIOSpec.scala

+19-1
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,16 @@ object ZIOSpec extends ZIOBaseSpec {
345345
right2 = (promise1.succeed(()) *> ZIO.never).ensuring(promise2.interrupt *> ZIO.never.interruptible)
346346
exit <- ZIO.collectAllPar(List(left, ZIO.collectAllPar(List(right1, right2)))).exit
347347
} yield assert(exit)(failsCause(containsCause(Cause.fail("fail"))))
348-
} @@ nonFlaky
348+
} @@ nonFlaky,
349+
test("runs finalizers in parallel") {
350+
for {
351+
promise1 <- Promise.make[Nothing, Unit]
352+
promise2 <- Promise.make[Nothing, Unit]
353+
left = ZIO.addFinalizer(promise1.succeed(())) *> promise2.succeed(())
354+
right = promise2.await *> ZIO.addFinalizer(promise1.await)
355+
_ <- ZIO.collectAllPar(List(left, right))
356+
} yield assertCompletes
357+
}
349358
),
350359
suite("collectAllParN")(
351360
test("returns results in the same order") {
@@ -4203,6 +4212,15 @@ object ZIOSpec extends ZIOBaseSpec {
42034212
_ <- fiber.join
42044213
value <- fiberRef.get
42054214
} yield assertTrue(value == 10)
4215+
},
4216+
test("runs finalizers in parallel") {
4217+
for {
4218+
promise1 <- Promise.make[Nothing, Unit]
4219+
promise2 <- Promise.make[Nothing, Unit]
4220+
left = ZIO.addFinalizer(promise1.succeed(())) *> promise2.succeed(())
4221+
right = promise2.await *> ZIO.addFinalizer(promise1.await)
4222+
_ <- left.zipPar(right)
4223+
} yield assertCompletes
42064224
}
42074225
),
42084226
suite("toFuture")(

core-tests/shared/src/test/scala/zio/ZLayerSpec.scala

+10
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,16 @@ object ZLayerSpec extends ZIOBaseSpec {
555555
_ <- layer.build
556556
_ <- layer.build
557557
} yield assertTrue(n == 2)
558+
},
559+
test("layers acquired in parallel are released in parallel") {
560+
for {
561+
promise1 <- Promise.make[Nothing, Unit]
562+
promise2 <- Promise.make[Nothing, Unit]
563+
layer1 = ZLayer.scoped(ZIO.addFinalizer(promise1.succeed(())) *> promise2.succeed(()))
564+
layer2 = ZLayer.scoped(promise2.await *> ZIO.addFinalizer(promise1.await))
565+
layer3 = layer1 ++ layer2
566+
_ <- layer3.build
567+
} yield assertCompletes
558568
}
559569
)
560570
}

core/shared/src/main/scala/zio/Scope.scala

+11-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ trait Scope extends Serializable { self =>
4848
final def addFinalizer(finalizer: => UIO[Any])(implicit trace: Trace): UIO[Unit] =
4949
addFinalizerExit(_ => finalizer)
5050

51+
/**
52+
* The execution strategy finalizers associated with this scope will be run
53+
* with.
54+
*/
55+
def executionStrategy: ExecutionStrategy =
56+
ExecutionStrategy.Sequential
57+
5158
/**
5259
* Extends the scope of a `ZIO` workflow that needs a scope into this scope by
5360
* providing it to the workflow but not closing the scope when the workflow
@@ -64,7 +71,7 @@ trait Scope extends Serializable { self =>
6471
* closed when this scope is closed.
6572
*/
6673
final def fork(implicit trace: Trace): UIO[Scope.Closeable] =
67-
forkWith(ExecutionStrategy.Sequential)
74+
forkWith(executionStrategy)
6875
}
6976

7077
object Scope {
@@ -142,13 +149,15 @@ object Scope {
142149
* Makes a scope. Finalizers added to this scope will be run according to the
143150
* specified `ExecutionStrategy`.
144151
*/
145-
def makeWith(executionStrategy: => ExecutionStrategy)(implicit trace: Trace): UIO[Scope.Closeable] =
152+
def makeWith(executionStrategy0: => ExecutionStrategy)(implicit trace: Trace): UIO[Scope.Closeable] =
146153
ReleaseMap.make.map { releaseMap =>
147154
new Scope.Closeable { self =>
148155
def addFinalizerExit(finalizer: Exit[Any, Any] => UIO[Any])(implicit trace: Trace): UIO[Unit] =
149156
releaseMap.add(finalizer).unit
150157
def close(exit: => Exit[Any, Any])(implicit trace: Trace): UIO[Unit] =
151158
ZIO.suspendSucceed(releaseMap.releaseAll(exit, executionStrategy).unit)
159+
override val executionStrategy: ExecutionStrategy =
160+
executionStrategy0
152161
def forkWith(executionStrategy: => ExecutionStrategy)(implicit trace: Trace): UIO[Scope.Closeable] =
153162
ZIO.uninterruptible {
154163
for {

core/shared/src/main/scala/zio/ZEnvironment.scala

+16-3
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ final class ZEnvironment[+R] private (
5353
def getAt[K, V](k: K)(implicit ev: R <:< Map[K, V], tagged: EnvironmentTag[Map[K, V]]): Option[V] =
5454
unsafe.get[Map[K, V]](taggedTagType(tagged))(Unsafe.unsafe).get(k)
5555

56+
/**
57+
* Retrieves a service from the environment if it exists in the environment.
58+
*/
59+
def getDynamic[A](implicit tag: Tag[A]): Option[A] =
60+
Option(unsafe.getOrElse(tag.tag, null.asInstanceOf[A])(Unsafe.unsafe))
61+
5662
override def hashCode: Int =
5763
map.hashCode
5864

@@ -143,14 +149,21 @@ final class ZEnvironment[+R] private (
143149
): ZEnvironment[R]
144150
}
145151

146-
val unsafe: UnsafeAPI =
147-
new UnsafeAPI {
152+
trait UnsafeAPI2 {
153+
private[ZEnvironment] def getOrElse[A](tag: LightTypeTag, default: => A)(implicit unsafe: Unsafe): A
154+
}
155+
156+
val unsafe: UnsafeAPI with UnsafeAPI2 =
157+
new UnsafeAPI with UnsafeAPI2 {
148158
private[ZEnvironment] def add[A](tag: LightTypeTag, a: A)(implicit unsafe: Unsafe): ZEnvironment[R with A] = {
149159
val self0 = if (index == Int.MaxValue) self.clean else self
150160
new ZEnvironment(self0.map.updated(tag, a -> self0.index), self0.index + 1)
151161
}
152162

153163
def get[A](tag: LightTypeTag)(implicit unsafe: Unsafe): A =
164+
getOrElse(tag, throw new Error(s"Defect in zio.ZEnvironment: Could not find ${tag} inside ${self}"))
165+
166+
private[ZEnvironment] def getOrElse[A](tag: LightTypeTag, default: => A)(implicit unsafe: Unsafe): A =
154167
self.cache.get(tag) match {
155168
case Some(a) => a.asInstanceOf[A]
156169
case None =>
@@ -164,7 +177,7 @@ final class ZEnvironment[+R] private (
164177
service = curService.asInstanceOf[A]
165178
}
166179
}
167-
if (service == null) throw new Error(s"Defect in zio.ZEnvironment: Could not find ${tag} inside ${self}")
180+
if (service == null) default
168181
else {
169182
self.cache = self.cache.updated(tag, service)
170183
service

core/shared/src/main/scala/zio/ZIO.scala

+112-33
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ sealed trait ZIO[-R, +E, +A]
12281228
* Returns a new scoped workflow that runs finalizers added to the scope of
12291229
* this workflow in parallel.
12301230
*/
1231-
final def parallelFinalizers(implicit trace: Trace): ZIO[R with Scope, E, A] =
1231+
final def parallelFinalizers(implicit trace: Trace): ZIO[R, E, A] =
12321232
ZIO.parallelFinalizers(self)
12331233

12341234
/**
@@ -1858,7 +1858,7 @@ sealed trait ZIO[-R, +E, +A]
18581858
* has meaning if used within a scope where finalizers are being run in
18591859
* parallel.
18601860
*/
1861-
final def sequentialFinalizers(implicit trace: Trace): ZIO[R with Scope, E, A] =
1861+
final def sequentialFinalizers(implicit trace: Trace): ZIO[R, E, A] =
18621862
ZIO.sequentialFinalizers(self)
18631863

18641864
/**
@@ -2538,20 +2538,21 @@ sealed trait ZIO[-R, +E, +A]
25382538
)
25392539
.forkDaemon
25402540

2541-
fork(self, false).zip(fork(that, true)).flatMap { case (left, right) =>
2542-
restore(promise.await).foldCauseZIO(
2543-
cause =>
2544-
left.interruptFork *> right.interruptFork *>
2545-
left.await.zip(right.await).flatMap { case (left, right) =>
2546-
left.zipWith(right)(f, _ && _) match {
2547-
case Exit.Failure(causes) => ZIO.refailCause(cause.stripFailures && causes)
2548-
case _ => ZIO.refailCause(cause.stripFailures)
2549-
}
2550-
},
2551-
leftWins =>
2552-
if (leftWins) left.join.zipWith(right.join)((a, b) => f(a, b))
2553-
else right.join.zipWith(left.join)((b, a) => f(a, b))
2554-
)
2541+
ZIO.parallelFinalizersMask(restore => fork(restore(self), false).zip(fork(restore(that), true))).flatMap {
2542+
case (left, right) =>
2543+
restore(promise.await).foldCauseZIO(
2544+
cause =>
2545+
left.interruptFork *> right.interruptFork *>
2546+
left.await.zip(right.await).flatMap { case (left, right) =>
2547+
left.zipWith(right)(f, _ && _) match {
2548+
case Exit.Failure(causes) => ZIO.refailCause(cause.stripFailures && causes)
2549+
case _ => ZIO.refailCause(cause.stripFailures)
2550+
}
2551+
},
2552+
leftWins =>
2553+
if (leftWins) left.join.zipWith(right.join)((a, b) => f(a, b))
2554+
else right.join.zipWith(left.join)((b, a) => f(a, b))
2555+
)
25552556
}
25562557
}
25572558
}
@@ -4254,8 +4255,30 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific
42544255
* Returns a new scoped workflow that runs finalizers added to the scope of
42554256
* this workflow in parallel.
42564257
*/
4257-
def parallelFinalizers[R, E, A](zio: => ZIO[R, E, A])(implicit trace: Trace): ZIO[R with Scope, E, A] =
4258-
ZIO.scopeWith(_.forkWith(ExecutionStrategy.Parallel).flatMap(_.extend[R](zio)))
4258+
def parallelFinalizers[R, E, A](zio: => ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
4259+
ZIO.environmentWithZIO[R] { environment =>
4260+
environment.getDynamic[Scope] match {
4261+
case None => zio
4262+
case Some(scope) =>
4263+
scope.executionStrategy match {
4264+
case ExecutionStrategy.Parallel => zio
4265+
case _ => scope.forkWith(ExecutionStrategy.Parallel).flatMap(_.extend[R](zio))
4266+
}
4267+
}
4268+
}
4269+
4270+
def parallelFinalizersMask[R, E, A](f: ZIO.FinalizersRestorer => ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
4271+
ZIO.environmentWithZIO[R] { environment =>
4272+
environment.getDynamic[Scope] match {
4273+
case None => f(ZIO.FinalizersRestorer.Identity)
4274+
case Some(scope) =>
4275+
scope.executionStrategy match {
4276+
case ExecutionStrategy.Parallel => ZIO.parallelFinalizers(f(ZIO.FinalizersRestorer.MakeParallel))
4277+
case ExecutionStrategy.ParallelN(n) => ZIO.parallelFinalizers(f(ZIO.FinalizersRestorer.MakeParallelN(n)))
4278+
case ExecutionStrategy.Sequential => ZIO.parallelFinalizers(f(ZIO.FinalizersRestorer.MakeSequential))
4279+
}
4280+
}
4281+
}
42594282

42604283
/**
42614284
* Retrieves the maximum number of fibers for parallel operators or `None` if
@@ -4459,8 +4482,17 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific
44594482
* has meaning if used within a scope where finalizers are being run in
44604483
* parallel.
44614484
*/
4462-
def sequentialFinalizers[R, E, A](zio: => ZIO[R, E, A])(implicit trace: Trace): ZIO[R with Scope, E, A] =
4463-
ZIO.scopeWith(_.fork.flatMap(_.extend[R](zio)))
4485+
def sequentialFinalizers[R, E, A](zio: => ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
4486+
ZIO.environmentWithZIO[R] { environment =>
4487+
environment.getDynamic[Scope] match {
4488+
case None => zio
4489+
case Some(scope) =>
4490+
scope.executionStrategy match {
4491+
case ExecutionStrategy.Sequential => zio
4492+
case _ => scope.forkWith(ExecutionStrategy.Sequential).flatMap(_.extend[R](zio))
4493+
}
4494+
}
4495+
}
44644496

44654497
/**
44664498
* Sets the `FiberRef` values for the fiber running this effect to the values
@@ -5898,6 +5930,51 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific
58985930
}
58995931
}
59005932

5933+
sealed trait FinalizersRestorer {
5934+
def apply[R, E, A](zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A]
5935+
}
5936+
5937+
object FinalizersRestorer {
5938+
def apply(executionStrategy: ExecutionStrategy): FinalizersRestorer =
5939+
executionStrategy match {
5940+
case ExecutionStrategy.Sequential =>
5941+
FinalizersRestorer.MakeSequential
5942+
case ExecutionStrategy.Parallel =>
5943+
FinalizersRestorer.MakeParallel
5944+
case ExecutionStrategy.ParallelN(n) =>
5945+
FinalizersRestorer.MakeParallelN(n)
5946+
}
5947+
5948+
case object Identity extends FinalizersRestorer {
5949+
def apply[R, E, A](zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
5950+
zio
5951+
}
5952+
5953+
case object MakeParallel extends FinalizersRestorer {
5954+
def apply[R, E, A](zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
5955+
zio.parallelFinalizers
5956+
}
5957+
5958+
final case class MakeParallelN(n: Int) extends FinalizersRestorer {
5959+
def apply[R, E, A](zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
5960+
ZIO.environmentWithZIO[R] { environment =>
5961+
environment.getDynamic[Scope] match {
5962+
case None => zio
5963+
case Some(scope) =>
5964+
scope.executionStrategy match {
5965+
case ExecutionStrategy.ParallelN(n) => zio
5966+
case _ => scope.forkWith(ExecutionStrategy.ParallelN(n)).flatMap(_.extend[R](zio))
5967+
}
5968+
}
5969+
}
5970+
}
5971+
5972+
case object MakeSequential extends FinalizersRestorer {
5973+
def apply[R, E, A](zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
5974+
zio.sequentialFinalizers
5975+
}
5976+
}
5977+
59015978
private[zio] val someFatal = Some(LogLevel.Fatal)
59025979
private[zio] val someError = Some(LogLevel.Error)
59035980
private[zio] val someWarning = Some(LogLevel.Warning)
@@ -5979,19 +6056,21 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific
59796056
val promise = Promise.unsafe.make[Unit, Unit](FiberId.None)(Unsafe.unsafe)
59806057
val ref = new java.util.concurrent.atomic.AtomicInteger(0)
59816058
ZIO.transplant { graft =>
5982-
ZIO.foreach(as) { a =>
5983-
graft {
5984-
restore(ZIO.suspendSucceed(f(a))).foldCauseZIO(
5985-
cause => promise.fail(()) *> ZIO.refailCause(cause),
5986-
_ =>
5987-
if (ref.incrementAndGet == size) {
5988-
promise.unsafe.done(ZIO.unit)(Unsafe.unsafe)
5989-
ZIO.unit
5990-
} else {
5991-
ZIO.unit
5992-
}
5993-
)
5994-
}.forkDaemon
6059+
ZIO.parallelFinalizersMask { restoreFinalizers =>
6060+
ZIO.foreach(as) { a =>
6061+
graft {
6062+
restore(restoreFinalizers(ZIO.suspendSucceed(f(a)))).foldCauseZIO(
6063+
cause => promise.fail(()) *> ZIO.refailCause(cause),
6064+
_ =>
6065+
if (ref.incrementAndGet == size) {
6066+
promise.unsafe.done(ZIO.unit)(Unsafe.unsafe)
6067+
ZIO.unit
6068+
} else {
6069+
ZIO.unit
6070+
}
6071+
)
6072+
}.forkDaemon
6073+
}
59956074
}
59966075
}.flatMap { fibers =>
59976076
restore(promise.await).foldCauseZIO(

core/shared/src/main/scala/zio/ZLayer.scala

+8-3
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,14 @@ sealed abstract class ZLayer[-RIn, +E, +ROut] { self =>
419419
case ZLayer.ZipWith(self, that, f) =>
420420
ZIO.succeed(memoMap => memoMap.getOrElseMemoize(scope)(self).zipWith(memoMap.getOrElseMemoize(scope)(that))(f))
421421
case ZLayer.ZipWithPar(self, that, f) =>
422-
ZIO.succeed(memoMap =>
423-
memoMap.getOrElseMemoize(scope)(self).zipWithPar(memoMap.getOrElseMemoize(scope)(that))(f)
424-
)
422+
ZIO.succeed { memoMap =>
423+
for {
424+
parallel <- scope.forkWith(ExecutionStrategy.Parallel)
425+
left <- parallel.forkWith(scope.executionStrategy)
426+
right <- parallel.forkWith(scope.executionStrategy)
427+
out <- memoMap.getOrElseMemoize(left)(self).zipWithPar(memoMap.getOrElseMemoize(right)(that))(f)
428+
} yield out
429+
}
425430
}
426431
}
427432

0 commit comments

Comments
 (0)