Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#314: Add support for sum of truncated values #324

Merged
merged 4 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ final case class UnknownMeasure(measureName: String, measuredColumns: Seq[String

object AtumMeasure {

val supportedMeasureNames: Seq[String] = Seq(
RecordCount.measureName,
DistinctRecordCount.measureName,
SumOfValuesOfColumn.measureName,
AbsSumOfValuesOfColumn.measureName,
SumOfHashesOfColumn.measureName
)

case class RecordCount private (measureName: String) extends AtumMeasure {
private val columnExpression = count("*")

Expand Down Expand Up @@ -117,6 +109,42 @@ object AtumMeasure {
def apply(measuredCol: String): AbsSumOfValuesOfColumn = AbsSumOfValuesOfColumn(measureName, measuredCol)
}

case class SumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why have you decided to use double casting approach instead of using standard functions from package org.apache.spark.sql to round values before summing them?

  /**
   * Returns the value of the column `e` rounded to 0 decimal places with HALF_UP round mode.
   *
   * @group math_funcs
   * @since 1.5.0
   */
  def round(e: Column): Column = round(e, 0)

  /**
   * Round the value of `e` to `scale` decimal places with HALF_UP round mode
   * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
   *
   * @group math_funcs
   * @since 1.5.0
   */
  def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) }

  /**
   * Returns the value of the column `e` rounded to 0 decimal places with HALF_EVEN round mode.
   *
   * @group math_funcs
   * @since 2.0.0
   */
  def bround(e: Column): Column = bround(e, 0)

  /**
   * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode
   * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
   *
   * @group math_funcs
   * @since 2.0.0
   */
  def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, Literal(scale)) }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. It was mentioned in the issue that this method was used in ATUM but let me change it accordingly. Since I think this method you have mentioned is more correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The round function does not work with negative numbers, unfortunately. We need a truncation function that will simply remove the decimals. But I have found another way that also works to ensure proper functionality without resorting to a double cast.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, combination of floor and ceil should work fine

scala> df.show
+-----+
|value|
+-----+
| -1.1|
|  1.1|
| -1.7|
|  1.7|
| -0.8|
|  0.8|
+-----+


scala> df.select(when(col("value") >= 0, floor(col("value"))).otherwise(ceil(col("value")))).show
+-------------------------------------------------------------+
|CASE WHEN (value >= 0) THEN FLOOR(value) ELSE CEIL(value) END|
+-------------------------------------------------------------+
|                                                           -1|
|                                                            1|
|                                                           -1|
|                                                            1|
|                                                            0|
|                                                            0|
+-------------------------------------------------------------+


private val columnAggFn: Column => Column = column => sum(when(column >= 0, floor(column)).otherwise(ceil(column)))

override def function: MeasurementFunction = (ds: DataFrame) => {
val dataType = ds.select(measuredCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect()
MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
}

override def measuredColumns: Seq[String] = Seq(measuredCol)
override val resultValueType: ResultValueType = ResultValueType.LongValue
}
object SumOfTruncatedValuesOfColumn {
private[agent] val measureName: String = "aggregatedTruncTotal"
def apply(measuredCol: String): SumOfTruncatedValuesOfColumn = SumOfTruncatedValuesOfColumn(measureName, measuredCol)
}

case class AbsSumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {

private val columnAggFn: Column => Column = column => sum(abs(when(column >= 0, floor(column)).otherwise(ceil(column))))

override def function: MeasurementFunction = (ds: DataFrame) => {
val dataType = ds.select(measuredCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect()
MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
}

override def measuredColumns: Seq[String] = Seq(measuredCol)
override val resultValueType: ResultValueType = ResultValueType.LongValue
}
object AbsSumOfTruncatedValuesOfColumn {
private[agent] val measureName: String = "absAggregatedTruncTotal"
def apply(measuredCol: String): AbsSumOfTruncatedValuesOfColumn = AbsSumOfTruncatedValuesOfColumn(measureName, measuredCol)
}

case class SumOfHashesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {
private val columnExpression: Column = sum(crc32(col(measuredCol).cast("String")))
override def function: MeasurementFunction = (ds: DataFrame) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ private [agent] object MeasuresBuilder extends Logging {
case DistinctRecordCount.measureName => DistinctRecordCount(measuredColumns)
case SumOfValuesOfColumn.measureName => SumOfValuesOfColumn(measuredColumns.head)
case AbsSumOfValuesOfColumn.measureName => AbsSumOfValuesOfColumn(measuredColumns.head)
case SumOfTruncatedValuesOfColumn.measureName => SumOfTruncatedValuesOfColumn(measuredColumns.head)
case AbsSumOfTruncatedValuesOfColumn.measureName => AbsSumOfTruncatedValuesOfColumn(measuredColumns.head)
case SumOfHashesOfColumn.measureName => SumOfHashesOfColumn(measuredColumns.head)
}
}.toOption
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
"Measure" should "be based on the dataframe" in {

// Measures
val measureIds: AtumMeasure = RecordCount()
val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn(
measuredCol = "salary"
)
val salarySum = SumOfValuesOfColumn(measuredCol = "salary")
val sumOfHashes: AtumMeasure = SumOfHashesOfColumn(measuredCol = "id")
val measureIds: AtumMeasure = RecordCount()
val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn(measuredCol = "salary")
val salarySum = SumOfValuesOfColumn(measuredCol = "salary")
val salaryTruncSum = SumOfTruncatedValuesOfColumn(measuredCol = "salary")
val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn(measuredCol = "salary")
val sumOfHashes: AtumMeasure = SumOfHashesOfColumn(measuredCol = "id")

// AtumContext contains `Measurement`
val atumContextInstanceWithRecordCount = AtumAgent
Expand Down Expand Up @@ -86,12 +86,34 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
.removeMeasure(salaryAbsSum)
)

val dfExtraPersonWithDecimalSalary = spark
.createDataFrame(
Seq(
("id", "firstName", "lastName", "email", "email2", "profession", "3000.98"),
("id", "firstName", "lastName", "email", "email2", "profession", "-1000.76")
)
)
.toDF("id", "firstName", "lastName", "email", "email2", "profession", "salary")

val dfExtraDecimalPerson = dfExtraPersonWithDecimalSalary.union(dfPersons)

dfExtraDecimalPerson.createCheckpoint("a checkpoint name")(
atumContextWithSalaryAbsMeasure
.removeMeasure(measureIds)
.removeMeasure(salaryAbsSum)
)


val dfPersonCntResult = measureIds.function(dfPersons)
val dfFullCntResult = measureIds.function(dfFull)
val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull)
val dfFullHashResult = sumOfHashes.function(dfFull)
val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson)
val dfFullSalarySumResult = salarySum.function(dfFull)
val dfExtraPersonSalarySumTruncResult = salaryTruncSum.function(dfExtraDecimalPerson)
val dfFullSalarySumTruncResult = salaryTruncSum.function(dfFull)
val dfExtraPersonSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfExtraDecimalPerson)
val dfFullSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfFull)

// Assertions
assert(dfPersonCntResult.resultValue == "1000")
Expand All @@ -106,6 +128,14 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
assert(dfExtraPersonSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
assert(dfFullSalarySumResult.resultValue == "2987144")
assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144")
assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.LongValue)
assert(dfFullSalarySumTruncResult.resultValue == "2987144")
assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.LongValue)
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144")
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue)
assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144")
assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue)
}

"AbsSumOfValuesOfColumn" should "return expected value" in {
Expand Down Expand Up @@ -187,4 +217,33 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
assert(result.resultValueType == ResultValueType.BigDecimalValue)
}

"SumTruncOfValuesOfColumn" should "return expected value" in {
val distinctCount = SumOfTruncatedValuesOfColumn("colA")

val data = List(Row("1.98", "b1"), Row("-1.76", "b2"), Row("1.54", "b2"), Row("1.32", "b2"))
val rdd = spark.sparkContext.parallelize(data)

val schema = StructType(Array(StructField("colA", StringType), StructField("colB", StringType)))
val df = spark.createDataFrame(rdd, schema)

val result = distinctCount.function(df)

assert(result.resultValue == "2")
assert(result.resultValueType == ResultValueType.LongValue)
}

"AbsSumTruncOfValuesOfColumn" should "return expected value" in {
val distinctCount = AbsSumOfTruncatedValuesOfColumn("colA")

val data = List(Row("1.98", "b1"), Row("-1.76", "b2"), Row("1.54", "b2"), Row("-1.32", "b2"))
val rdd = spark.sparkContext.parallelize(data)

val schema = StructType(Array(StructField("colA", StringType), StructField("colB", StringType)))
val df = spark.createDataFrame(rdd, schema)

val result = distinctCount.function(df)

assert(result.resultValue == "4")
assert(result.resultValueType == ResultValueType.LongValue)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package za.co.absa.atum.agent.model
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import za.co.absa.atum.agent.AtumAgent
import za.co.absa.atum.agent.model.AtumMeasure.{AbsSumOfValuesOfColumn, RecordCount, SumOfHashesOfColumn, SumOfValuesOfColumn}
import za.co.absa.atum.agent.model.AtumMeasure.{AbsSumOfValuesOfColumn, RecordCount, SumOfHashesOfColumn, SumOfValuesOfColumn, SumOfTruncatedValuesOfColumn, AbsSumOfTruncatedValuesOfColumn}
import za.co.absa.spark.commons.test.SparkTestBase
import za.co.absa.atum.agent.AtumContext._
import za.co.absa.atum.model.ResultValueType
Expand All @@ -30,11 +30,13 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se
"Measure" should "be based on the dataframe" in {

// Measures
val measureIds: AtumMeasure = RecordCount()
val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn("salary")
val measureIds: AtumMeasure = RecordCount()
val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn("salary")
val sumOfHashes: AtumMeasure = SumOfHashesOfColumn("id")

val salarySum = SumOfValuesOfColumn("salary")
val sumOfHashes: AtumMeasure = SumOfHashesOfColumn("id")
val salarySum = SumOfValuesOfColumn("salary")
val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn("salary")
val salaryTruncSum = SumOfTruncatedValuesOfColumn("salary")

// AtumContext contains `Measurement`
val atumContextInstanceWithRecordCount = AtumAgent
Expand Down Expand Up @@ -83,12 +85,33 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se
.removeMeasure(salaryAbsSum)
)

val dfPersonCntResult = measureIds.function(dfPersons)
val dfFullCntResult = measureIds.function(dfFull)
val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull)
val dfFullHashResult = sumOfHashes.function(dfFull)
val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson)
val dfFullSalarySumResult = salarySum.function(dfFull)
val dfExtraPersonWithDecimalSalary = spark
.createDataFrame(
Seq(
("id", "firstName", "lastName", "email", "email2", "profession", "3000.98"),
("id", "firstName", "lastName", "email", "email2", "profession", "-1000.76")
)
)
.toDF("id", "firstName", "lastName", "email", "email2", "profession", "salary")

val dfExtraDecimalPerson = dfExtraPersonWithDecimalSalary.union(dfPersons)

dfExtraDecimalPerson.createCheckpoint("a checkpoint name")(
atumContextWithSalaryAbsMeasure
.removeMeasure(measureIds)
.removeMeasure(salaryAbsSum)
)

val dfPersonCntResult = measureIds.function(dfPersons)
val dfFullCntResult = measureIds.function(dfFull)
val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull)
val dfFullHashResult = sumOfHashes.function(dfFull)
val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson)
val dfFullSalarySumResult = salarySum.function(dfFull)
val dfExtraPersonSalarySumTruncResult = salaryTruncSum.function(dfExtraDecimalPerson)
val dfFullSalarySumTruncResult = salaryTruncSum.function(dfFull)
val dfExtraPersonSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfExtraDecimalPerson)
val dfFullSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfFull)

// Assertions
assert(dfPersonCntResult.resultValue == "1000")
Expand All @@ -103,6 +126,14 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se
assert(dfExtraPersonSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
assert(dfFullSalarySumResult.resultValue == "2987144")
assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144")
assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.LongValue)
assert(dfFullSalarySumTruncResult.resultValue == "2987144")
assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.LongValue)
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144")
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue)
assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144")
assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class MeasuresBuilderUnitTests extends AnyFlatSpecLike {
MeasureDTO("distinctCount", Seq("distinctCountCol")),
MeasureDTO("aggregatedTotal", Seq("aggregatedTotalCol")),
MeasureDTO("absAggregatedTotal", Seq("absAggregatedTotalCol")),
MeasureDTO("aggregatedTruncTotal", Seq("aggregatedTruncTotalCol")),
MeasureDTO("absAggregatedTruncTotal", Seq("absAggregatedTruncTotalCol")),
MeasureDTO("hashCrc32", Seq("hashCrc32Col"))
)

Expand All @@ -36,6 +38,8 @@ class MeasuresBuilderUnitTests extends AnyFlatSpecLike {
DistinctRecordCount(Seq("distinctCountCol")),
SumOfValuesOfColumn("aggregatedTotalCol"),
AbsSumOfValuesOfColumn("absAggregatedTotalCol"),
SumOfTruncatedValuesOfColumn("aggregatedTruncTotalCol"),
AbsSumOfTruncatedValuesOfColumn("absAggregatedTruncTotalCol"),
SumOfHashesOfColumn("hashCrc32Col")
)

Expand Down
Loading