Skip to content

Commit

Permalink
Merge pull request #16 from apache/scala_compat
Browse files Browse the repository at this point in the history
Scala version compatibility
  • Loading branch information
jmalkin authored Jan 24, 2025
2 parents d8da04a + 28de4b1 commit eb7b4b2
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 42 deletions.
16 changes: 10 additions & 6 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@ on:

jobs:
build:
name: Build and Test
name: JDK ${{ matrix.jdk }} - Scala ${{ matrix.scala }} - Spark ${{ matrix.spark }}
runs-on: ubuntu-latest

strategy:
fail-fast: false
matrix:
jdk: [ 8, 11, 17 ]
scala: [ 2.12.20, 2.13.16 ]
spark: [ 3.4.4, 3.5.4 ]

env:
JDK_VERSION: ${{ matrix.jdk }}
SCALA_VERSION: ${{ matrix.scala }}
SPARK_VERSION: ${{ matrix.spark }}

steps:
Expand All @@ -38,8 +40,8 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.m2/repository
key: build-${{ runner.os }}-jdk-${{ matrix.jdk }}-spark-${{ matrix.spark }}-${{ hashFiles('**/pom.xml') }}
restore-keys: build-${{ runner.os }}-jdk-${{matrix.jdk}}-spark-${{ matrix.spark }}-maven-
key: build-${{ runner.os }}-jdk-${{ matrix.jdk }}-scala-${{ matrix.scala }}-spark-${{ matrix.spark }}-${{ hashFiles('**/pom.xml') }}
restore-keys: build-${{ runner.os }}-jdk-${{matrix.jdk}}-scala-${{ matrix.scala }}-spark-${{ matrix.spark }}-maven-

- name: Setup JDK
uses: actions/setup-java@v4
Expand All @@ -53,13 +55,15 @@ jobs:
- name: Setup SBT
uses: sbt/setup-sbt@v1

- name: Echo Java Version
run: >
- name: Echo config versions
run: |
java -version
echo Scala version: $SCALA_VERSION
echo Spark version: $SPARK_VERSION
- name: Build and test
run: >
sbt --batch clean test
sbt ++$SCALA_VEERSION --batch clean test
# Architecture options: x86, x64, armv7, aarch64, ppc64le
# setup-java@v4 has a "with cache" option
Expand Down
16 changes: 11 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import scala.xml.dtd.DEFAULT
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
Expand All @@ -17,15 +18,20 @@

name := "datasketches-spark"
version := "1.0-SNAPSHOT"
scalaVersion := "2.12.20"

val DEFAULT_SCALA_VERSION = "2.12.20"
val DEFAULT_SPARK_VERSION = "3.5.4"
val DEFAULT_JDK_VERSION = "11"

organization := "org.apache.datasketches"
description := "The Apache DataSketches package for Spark"

licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0"))

scalaVersion := sys.env.getOrElse("SCALA_VERSION", DEFAULT_SCALA_VERSION)

val sparkVersion = settingKey[String]("The version of Spark")
sparkVersion := sys.env.getOrElse("SPARK_VERSION", "3.5.4")
sparkVersion := sys.env.getOrElse("SPARK_VERSION", DEFAULT_SPARK_VERSION)

// determine our java version
val jvmVersionString = settingKey[String]("The JVM version")
Expand All @@ -45,7 +51,7 @@ val jvmVersionMap = Map(
val jvmVersion = settingKey[String]("The JVM major version")
jvmVersion := jvmVersionMap.collectFirst {
case (prefix, (major, _)) if jvmVersionString.value.startsWith(prefix) => major
}.getOrElse("11")
}.getOrElse(DEFAULT_JDK_VERSION)

// look up the associated datasketches-java version
val dsJavaVersion = settingKey[String]("The DataSketches Java version")
Expand All @@ -59,8 +65,8 @@ Test / scalacOptions ++= Seq("-encoding", "UTF-8", "-release", jvmVersion.value)

libraryDependencies ++= Seq(
"org.apache.datasketches" % "datasketches-java" % dsJavaVersion.value % "compile",
"org.scala-lang" % "scala-library" % "2.12.6",
"org.apache.spark" %% "spark-sql" % sparkVersion.value % "provided",
"org.scala-lang" % "scala-library" % scalaVersion.value, // scala3-library may need to use %%
("org.apache.spark" %% "spark-sql" % sparkVersion.value % "provided").cross(CrossVersion.for3Use2_13),
"org.scalatest" %% "scalatest" % "3.2.19" % "test",
"org.scalatestplus" %% "junit-4-13" % "3.2.19.0" % "test"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ case class KllDoublesSketchGetPmfCdf(sketchExpr: Expression,
if (!isInclusiveExpr.foldable) {
return TypeCheckResult.TypeCheckFailure(s"isInclusiveExpr must be foldable, but got: ${isInclusiveExpr}")
}
if (splitPointsExpr.eval().asInstanceOf[GenericArrayData].numElements == 0) {
if (splitPointsExpr.eval().asInstanceOf[GenericArrayData].numElements() == 0) {
return TypeCheckResult.TypeCheckFailure(s"splitPointsExpr must not be empty")
}

Expand All @@ -269,7 +269,7 @@ case class KllDoublesSketchGetPmfCdf(sketchExpr: Expression,

override def nullSafeEval(sketchInput: Any, splitPointsInput: Any, isInclusiveInput: Any): Any = {
val sketchBytes = sketchInput.asInstanceOf[Array[Byte]]
val splitPoints = splitPointsInput.asInstanceOf[GenericArrayData].toDoubleArray
val splitPoints = splitPointsInput.asInstanceOf[GenericArrayData].toDoubleArray()
val sketch = KllDoublesSketch.wrap(Memory.wrap(sketchBytes))

val result: Array[Double] =
Expand Down
45 changes: 22 additions & 23 deletions src/test/scala/org/apache/spark/sql/KllTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql

import scala.util.Random
import org.apache.spark.sql.functions._
import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.types.{StructType, StructField, IntegerType, BinaryType}

import org.apache.spark.sql.functions_datasketches_kll._
Expand All @@ -31,7 +30,7 @@ class KllTest extends SparkSessionManager {
import spark.implicits._

// helper method to check if two arrays are equal
private def compareArrays(ref: Array[Double], tst: WrappedArray[Double]) {
private def compareArrays(ref: Array[Double], tst: Array[Double]): Unit = {
val tstArr = tst.toArray
if (ref.length != tstArr.length)
throw new AssertionError("Array lengths do not match: " + ref.length + " != " + tstArr.length)
Expand All @@ -48,7 +47,7 @@ class KllTest extends SparkSessionManager {
// produce a List[Row] of (id, sk)
for (i <- 1 to numClass) yield {
val sk = KllDoublesSketch.newHeapInstance(200)
for (j <- 0 until numSamples) sk.update(Random.nextDouble)
for (j <- 0 until numSamples) sk.update(Random.nextDouble())
dataList.add(Row(i, sk))
}

Expand All @@ -68,7 +67,7 @@ class KllTest extends SparkSessionManager {
// produce a Seq(Array(id, sk))
val data = for (i <- 1 to numClass) yield {
val sk = KllDoublesSketch.newHeapInstance(200)
for (j <- 0 until numSamples) sk.update(Random.nextDouble)
for (j <- 0 until numSamples) sk.update(Random.nextDouble())
Row(i, sk.toByteArray)
}

Expand All @@ -90,7 +89,7 @@ class KllTest extends SparkSessionManager {
val sketchDf = data.agg(kll_sketch_double_agg_build("value").as("sketch"))
val result: Row = sketchDf.select(kll_sketch_double_get_min($"sketch").as("min"),
kll_sketch_double_get_max($"sketch").as("max")
).head
).head()

val minValue = result.getAs[Double]("min")
val maxValue = result.getAs[Double]("max")
Expand All @@ -103,19 +102,19 @@ class KllTest extends SparkSessionManager {
kll_sketch_double_get_pmf($"sketch", splitPoints, false).as("pmf_exclusive"),
kll_sketch_double_get_cdf($"sketch", splitPoints).as("cdf_inclusive"),
kll_sketch_double_get_cdf($"sketch", splitPoints, false).as("cdf_exclusive")
).head
).head()

val pmf_incl = Array[Double](0.2, 0.3, 0.5, 0.0)
compareArrays(pmf_incl, pmfCdfResult.getAs[WrappedArray[Double]]("pmf_inclusive"))
compareArrays(pmf_incl, pmfCdfResult.getAs[Seq[Double]]("pmf_inclusive").toArray)

val pmf_excl = Array[Double](0.2, 0.29, 0.51, 0.0)
compareArrays(pmf_excl, pmfCdfResult.getAs[WrappedArray[Double]]("pmf_exclusive"))
compareArrays(pmf_excl, pmfCdfResult.getAs[Seq[Double]]("pmf_exclusive").toArray)

val cdf_incl = Array[Double](0.2, 0.5, 1.0, 1.0)
compareArrays(cdf_incl, pmfCdfResult.getAs[WrappedArray[Double]]("cdf_inclusive"))
compareArrays(cdf_incl, pmfCdfResult.getAs[Seq[Double]]("cdf_inclusive").toArray)

val cdf_excl = Array[Double](0.2, 0.49, 1.0, 1.0)
compareArrays(cdf_excl, pmfCdfResult.getAs[WrappedArray[Double]]("cdf_exclusive"))
compareArrays(cdf_excl, pmfCdfResult.getAs[Seq[Double]]("cdf_exclusive").toArray)
}

test("Kll Doubles Sketch via SQL") {
Expand All @@ -135,8 +134,8 @@ class KllTest extends SparkSessionManager {
| data_table
""".stripMargin
)
val minValue = kllDf.head.getAs[Double]("min")
val maxValue = kllDf.head.getAs[Double]("max")
val minValue = kllDf.head().getAs[Double]("min")
val maxValue = kllDf.head().getAs[Double]("max")
assert(minValue == 1.0)
assert(maxValue == n.toDouble)

Expand All @@ -154,26 +153,26 @@ class KllTest extends SparkSessionManager {
| FROM
| data_table) t
""".stripMargin
).head
).head()

val pmf_incl = Array[Double](0.2, 0.3, 0.5, 0.0)
compareArrays(pmf_incl, pmfCdfResult.getAs[WrappedArray[Double]]("pmf_inclusive"))
compareArrays(pmf_incl, pmfCdfResult.getAs[Seq[Double]]("pmf_inclusive").toArray)

val pmf_excl = Array[Double](0.2, 0.29, 0.51, 0.0)
compareArrays(pmf_excl, pmfCdfResult.getAs[WrappedArray[Double]]("pmf_exclusive"))
compareArrays(pmf_excl, pmfCdfResult.getAs[Seq[Double]]("pmf_exclusive").toArray)

val cdf_incl = Array[Double](0.2, 0.5, 1.0, 1.0)
compareArrays(cdf_incl, pmfCdfResult.getAs[WrappedArray[Double]]("cdf_inclusive"))
compareArrays(cdf_incl, pmfCdfResult.getAs[Seq[Double]]("cdf_inclusive").toArray)

val cdf_excl = Array[Double](0.2, 0.49, 1.0, 1.0)
compareArrays(cdf_excl, pmfCdfResult.getAs[WrappedArray[Double]]("cdf_exclusive"))
compareArrays(cdf_excl, pmfCdfResult.getAs[Seq[Double]]("cdf_exclusive").toArray)
}

test("KLL Doubles Merge via Scala") {
val data = generateData().toDF("id", "value")

// compute global min and max
val minMax: Row = data.agg(min("value").as("min"), max("value").as("max")).collect.head
val minMax: Row = data.agg(min("value").as("min"), max("value").as("max")).collect().head
val globalMin = minMax.getAs[Double]("min")
val globalMax = minMax.getAs[Double]("max")

Expand All @@ -187,7 +186,7 @@ class KllTest extends SparkSessionManager {
// check min and max
var result: Row = mergedSketchDf.select(kll_sketch_double_get_min($"sketch").as("min"),
kll_sketch_double_get_max($"sketch").as("max"))
.head
.head()

var sketchMin = result.getAs[Double]("min")
var sketchMax = result.getAs[Double]("max")
Expand All @@ -202,7 +201,7 @@ class KllTest extends SparkSessionManager {
// check min and max
result = mergedSketchDf.select(kll_sketch_double_get_min($"sketch").as("min"),
kll_sketch_double_get_max($"sketch").as("max"))
.head
.head()

sketchMin = result.getAs[Double]("min")
sketchMax = result.getAs[Double]("max")
Expand All @@ -219,7 +218,7 @@ class KllTest extends SparkSessionManager {
data.createOrReplaceTempView("data_table")

// compute global min and max from dataframe
val minMax: Row = data.agg(min("value").as("min"), max("value").as("max")).head
val minMax: Row = data.agg(min("value").as("min"), max("value").as("max")).head()
val globalMin = minMax.getAs[Double]("min")
val globalMax = minMax.getAs[Double]("max")

Expand Down Expand Up @@ -255,7 +254,7 @@ class KllTest extends SparkSessionManager {
)

// check min and max
var result: Row = mergedSketchDf.head
var result: Row = mergedSketchDf.head()
var sketchMin = result.getAs[Double]("min")
var sketchMax = result.getAs[Double]("max")

Expand All @@ -279,7 +278,7 @@ class KllTest extends SparkSessionManager {
)

// check min and max
result = mergedSketchDf.head
result = mergedSketchDf.head()
sketchMin = result.getAs[Double]("min")
sketchMax = result.getAs[Double]("max")

Expand Down
12 changes: 6 additions & 6 deletions src/test/scala/org/apache/spark/sql/ThetaTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ThetaTest extends SparkSessionManager {
val data = (for (i <- 1 to n) yield i).toDF("value")

val sketchDf = data.agg(theta_sketch_agg_build("value").as("sketch"))
val result: Row = sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head
val result: Row = sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head()

assert(result.getAs[Double]("estimate") == 100.0)
}
Expand All @@ -46,7 +46,7 @@ class ThetaTest extends SparkSessionManager {
FROM
theta_input_table
""")
assert(df.head.getAs[Double]("estimate") == 100.0)
assert(df.head().getAs[Double]("estimate") == 100.0)
}

test("Theta Sketch build via SQL with lgk") {
Expand All @@ -62,7 +62,7 @@ class ThetaTest extends SparkSessionManager {
FROM
theta_input_table
""")
assert(df.head.getAs[Double]("estimate") == 100.0)
assert(df.head().getAs[Double]("estimate") == 100.0)
}

test("Theta Union via Scala") {
Expand All @@ -72,7 +72,7 @@ class ThetaTest extends SparkSessionManager {

val groupedDf = data.groupBy("group").agg(theta_sketch_agg_build("value").as("sketch"))
val mergedDf = groupedDf.agg(theta_sketch_agg_union("sketch").as("merged"))
val result: Row = mergedDf.select(theta_sketch_get_estimate("merged").as("estimate")).head
val result: Row = mergedDf.select(theta_sketch_get_estimate("merged").as("estimate")).head()
assert(result.getAs[Double]("estimate") == numDistinct)
}

Expand Down Expand Up @@ -100,7 +100,7 @@ class ThetaTest extends SparkSessionManager {
FROM
theta_sketch_table
""")
assert(mergedDf.head.getAs[Double]("estimate") == numDistinct)
assert(mergedDf.head().getAs[Double]("estimate") == numDistinct)
}

test("Theta Union via SQL with lgk") {
Expand All @@ -126,7 +126,7 @@ class ThetaTest extends SparkSessionManager {
FROM
theta_sketch_table
""")
assert(mergedDf.head.getAs[Double]("estimate") == numDistinct)
assert(mergedDf.head().getAs[Double]("estimate") == numDistinct)
}

}

0 comments on commit eb7b4b2

Please sign in to comment.