Skip to content
This repository has been archived by the owner on Aug 11, 2023. It is now read-only.

Commit

Permalink
Added join support for akka streams
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliihonta committed Feb 1, 2019
1 parent 3fb4dcf commit 4b49cd6
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 7 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
Project: Trembita
Current version: 0.8.4-SNAPSHOT
Current version: 0.8.5-SNAPSHOT
Scala version: 2.11.12, 2.12.8
---

Expand All @@ -18,7 +18,7 @@ Trembita allows you to make complicated transformation pipelines where some of t
```scala
resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots"
libraryDependencies ++= {
val trembitaV = "0.8.4-SNAPSHOT"
val trembitaV = "0.8.5-SNAPSHOT"
Seq(
"ua.pp.itkpi" %% "trembita-kernel" % trembitaV, // kernel,

Expand Down
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import Dependencies._

lazy val snapshot: Boolean = true
lazy val v: String = {
val vv = "0.8.4"
val vv = "0.8.5"
if (!snapshot) vv
else vv + "-SNAPSHOT"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ class CanPause2AkkaF[F[_]: Effect: Timer: RunAkka, Mat](implicit mat: ActorMater
val out = Outlet[A]("CanPause2.out")
override val shape: FlowShape[A, A] = FlowShape(in, out)

var prevOpt: Option[A] = None
val inBuffer: mutable.Queue[A] = mutable.Queue.empty[A]
var paused: Boolean = false

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with StageLogging {
var prevOpt: Option[A] = None
val inBuffer: mutable.Queue[A] = mutable.Queue.empty[A]
var paused: Boolean = false

setHandlers(
in,
out,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package trembita.akka_streams

import akka.NotUsed
import akka.stream.{Attributes, FlowShape, Inlet, Outlet}
import akka.stream.scaladsl._
import akka.stream.stage._

import scala.language.higherKinds

class JoinFlow[A, B](right: Stream[B], on: (A, B) => Boolean) extends GraphStage[FlowShape[A, (A, B)]] {
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogicWithLogging(shape) {
setHandlers(
in,
out,
new InHandler with OutHandler {
def onPush(): Unit = {
val curr = grab(in)
log.debug(s"grabed $curr\n\tright=$right")
right.find(on(curr, _)) match {
case None =>
log.debug(s"Haven't found pair for $curr")
case Some(found) =>
push(out, (curr, found))
log.debug(s"Pushed pair: $curr -> $found")
}
}
def onPull(): Unit = {
log.debug("On pull...")
pull(in)
}
}
)
}
private val in = Inlet[A]("join.in")
private val out = Outlet[(A, B)]("join.out")
val shape: FlowShape[A, (A, B)] = FlowShape(in, out)
}

class JoinLeftFlow[A, B](right: Stream[B], on: (A, B) => Boolean) extends GraphStage[FlowShape[A, (A, Option[B])]] {
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogicWithLogging(shape) {
setHandlers(
in,
out,
new InHandler with OutHandler {
def onPush(): Unit = {
val curr = grab(in)
log.debug(s"grabed $curr")
right.find(on(curr, _)) match {
case None =>
log.debug(s"Haven't found pair for $curr")
push(out, (curr, None))
case res @ Some(found) =>
push(out, (curr, res))
log.debug(s"Pushed pair: $curr -> $found")
}
}
def onPull(): Unit = {
log.debug("Pulling...")
pull(in)
}
}
)
}
private val in = Inlet[A]("joinLeft.in")
private val out = Outlet[(A, Option[B])]("joinLeft.out")
val shape: FlowShape[A, (A, Option[B])] = FlowShape(in, out)
}
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,34 @@ trait operations {
fa.runWith(Sink.foldAsync(zero)((b, a) => Effect[F].toIO(f(b, a)).unsafeToFuture()))
}))
}

implicit def canJoinSource(implicit mat: ActorMaterializer, ec: ExecutionContext): CanJoin[Source[?, NotUsed]] =
new CanJoin[Source[?, NotUsed]] {
def join[A, B](fa: Source[A, NotUsed], fb: Source[B, NotUsed])(
on: (A, B) => Boolean
): Source[(A, B), NotUsed] =
Source
.fromFutureSource(
fb.runWith(Sink.collection[B, Stream[B]]).map { allFb =>
val joinFlow = new JoinFlow[A, B](allFb, on)
fa.via(joinFlow)
}
)
.mapMaterializedValue(_ => NotUsed)

@inline def joinLeft[A, B](fa: Source[A, NotUsed], fb: Source[B, NotUsed])(on: (A, B) => Boolean): Source[(A, Option[B]), NotUsed] =
Source
.fromFutureSource(
fb.runWith(Sink.collection[B, Stream[B]]).map { allFb =>
val joinFlow = new JoinLeftFlow[A, B](allFb, on)
fa.via(joinFlow)
}
)
.mapMaterializedValue(_ => NotUsed)

@inline def joinRight[A, B](fa: Source[A, NotUsed], fb: Source[B, NotUsed])(on: (A, B) => Boolean): Source[(Option[A], B), NotUsed] =
joinLeft(fb, fa)((b, a) => on(a, b)).map(_.swap)
}
}

class AkkaCollectionOutput[Col[x] <: Iterable[x], F[_], Mat](implicit async: Async[F], mat: ActorMaterializer)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package trembita.akka_streams

import akka.actor.ActorSystem
import akka.stream.scaladsl.Source
import akka.stream.{ActorMaterializer, DelayOverflowStrategy}
import akka.testkit.TestKit
import cats.effect.{IO, Timer}
import org.scalatest.{BeforeAndAfterAll, FlatSpecLike}
import trembita._
import scala.concurrent.ExecutionContext

class JoinSpec extends TestKit(ActorSystem("trembita-akka-join")) with FlatSpecLike with BeforeAndAfterAll {
implicit val _system: ActorSystem = system
implicit val mat: ActorMaterializer = ActorMaterializer()(system)
implicit val parallelism: Parallelism = Parallelism(4, ordered = false)
implicit val ec: ExecutionContext = system.dispatcher
implicit val delayOverflowStrategy: DelayOverflowStrategy = DelayOverflowStrategy.dropHead
implicit val ioTimer: Timer[IO] = IO.timer(ec)

override def afterAll(): Unit = {
mat.shutdown()
system.terminate()
}

"Akka pipelines" should "support join" in {
val ppln1 = Input.fromSourceF[IO](Source(1 to 10))
val ppln2 = Input.fromSourceF[IO](Source(0 :: 1 :: Nil))

val result = ppln1.join(ppln2)(on = _ % 2 == _).into(Output.vector).run.unsafeRunSync().sortBy(_._1)

assert(result == (1 to 10).toVector.map(x => x -> (x % 2)))
}

it should "support join left" in {
val ppln1 = Input.fromSourceF[IO](Source(1 to 10))
val ppln2 = Input.fromSourceF[IO](Source(0 :: 1 :: Nil))

val result = ppln1.joinLeft(ppln2)(on = _ % 3 == _).into(Output.vector).run.unsafeRunSync().sortBy(_._1)

assert(result == (1 to 10).toVector.map { x =>
x -> (x % 3 match {
case 0 => Some(0)
case 1 => Some(1)
case _ => None
})
})
}

it should "support join right" in {
val ppln1 = Input.fromSourceF[IO](Source(1 to 10))
val ppln2 = Input.fromSourceF[IO](Source(0 :: 1 :: Nil))

val result = ppln2.joinRight(ppln1)(on = _ == _ % 3).into(Output.vector).run.unsafeRunSync().sortBy(_._2)

assert(
result == (1 to 10).toVector
.map { x =>
x -> (x % 3 match {
case 0 => Some(0)
case 1 => Some(1)
case _ => None
})
}
.map(_.swap)
)
}
}

0 comments on commit 4b49cd6

Please sign in to comment.