diff --git a/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala b/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala index 1a5e634e2..1fc157240 100644 --- a/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala +++ b/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala @@ -39,14 +39,6 @@ final case class UnknownMeasure(measureName: String, measuredColumns: Seq[String object AtumMeasure { - val supportedMeasureNames: Seq[String] = Seq( - RecordCount.measureName, - DistinctRecordCount.measureName, - SumOfValuesOfColumn.measureName, - AbsSumOfValuesOfColumn.measureName, - SumOfHashesOfColumn.measureName - ) - case class RecordCount private (measureName: String) extends AtumMeasure { private val columnExpression = count("*") @@ -117,6 +109,42 @@ object AtumMeasure { def apply(measuredCol: String): AbsSumOfValuesOfColumn = AbsSumOfValuesOfColumn(measureName, measuredCol) } + case class SumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure { + + private val columnAggFn: Column => Column = column => sum(when(column >= 0, floor(column)).otherwise(ceil(column))) + + override def function: MeasurementFunction = (ds: DataFrame) => { + val dataType = ds.select(measuredCol).schema.fields(0).dataType + val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect() + MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType) + } + + override def measuredColumns: Seq[String] = Seq(measuredCol) + override val resultValueType: ResultValueType = ResultValueType.LongValue + } + object SumOfTruncatedValuesOfColumn { + private[agent] val measureName: String = "aggregatedTruncTotal" + def apply(measuredCol: String): SumOfTruncatedValuesOfColumn = SumOfTruncatedValuesOfColumn(measureName, measuredCol) + } + + case class AbsSumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure { + + private val columnAggFn: Column => Column = column => sum(abs(when(column >= 0, floor(column)).otherwise(ceil(column)))) + + override def function: MeasurementFunction = (ds: DataFrame) => { + val dataType = ds.select(measuredCol).schema.fields(0).dataType + val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect() + MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType) + } + + override def measuredColumns: Seq[String] = Seq(measuredCol) + override val resultValueType: ResultValueType = ResultValueType.LongValue + } + object AbsSumOfTruncatedValuesOfColumn { + private[agent] val measureName: String = "absAggregatedTruncTotal" + def apply(measuredCol: String): AbsSumOfTruncatedValuesOfColumn = AbsSumOfTruncatedValuesOfColumn(measureName, measuredCol) + } + case class SumOfHashesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure { private val columnExpression: Column = sum(crc32(col(measuredCol).cast("String"))) override def function: MeasurementFunction = (ds: DataFrame) => { diff --git a/agent/src/main/scala/za/co/absa/atum/agent/model/MeasuresBuilder.scala b/agent/src/main/scala/za/co/absa/atum/agent/model/MeasuresBuilder.scala index 2fefe36e2..739757c23 100644 --- a/agent/src/main/scala/za/co/absa/atum/agent/model/MeasuresBuilder.scala +++ b/agent/src/main/scala/za/co/absa/atum/agent/model/MeasuresBuilder.scala @@ -49,6 +49,8 @@ private [agent] object MeasuresBuilder extends Logging { case DistinctRecordCount.measureName => DistinctRecordCount(measuredColumns) case SumOfValuesOfColumn.measureName => SumOfValuesOfColumn(measuredColumns.head) case AbsSumOfValuesOfColumn.measureName => AbsSumOfValuesOfColumn(measuredColumns.head) + case SumOfTruncatedValuesOfColumn.measureName => SumOfTruncatedValuesOfColumn(measuredColumns.head) + case AbsSumOfTruncatedValuesOfColumn.measureName => AbsSumOfTruncatedValuesOfColumn(measuredColumns.head) case SumOfHashesOfColumn.measureName => SumOfHashesOfColumn(measuredColumns.head) } }.toOption diff --git a/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala b/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala index 5c3ff2b88..685e05de6 100644 --- a/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala +++ b/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala @@ -32,12 +32,12 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase "Measure" should "be based on the dataframe" in { // Measures - val measureIds: AtumMeasure = RecordCount() - val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn( - measuredCol = "salary" - ) - val salarySum = SumOfValuesOfColumn(measuredCol = "salary") - val sumOfHashes: AtumMeasure = SumOfHashesOfColumn(measuredCol = "id") + val measureIds: AtumMeasure = RecordCount() + val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn(measuredCol = "salary") + val salarySum = SumOfValuesOfColumn(measuredCol = "salary") + val salaryTruncSum = SumOfTruncatedValuesOfColumn(measuredCol = "salary") + val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn(measuredCol = "salary") + val sumOfHashes: AtumMeasure = SumOfHashesOfColumn(measuredCol = "id") // AtumContext contains `Measurement` val atumContextInstanceWithRecordCount = AtumAgent @@ -86,12 +86,34 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase .removeMeasure(salaryAbsSum) ) + val dfExtraPersonWithDecimalSalary = spark + .createDataFrame( + Seq( + ("id", "firstName", "lastName", "email", "email2", "profession", "3000.98"), + ("id", "firstName", "lastName", "email", "email2", "profession", "-1000.76") + ) + ) + .toDF("id", "firstName", "lastName", "email", "email2", "profession", "salary") + + val dfExtraDecimalPerson = dfExtraPersonWithDecimalSalary.union(dfPersons) + + dfExtraDecimalPerson.createCheckpoint("a checkpoint name")( + atumContextWithSalaryAbsMeasure + .removeMeasure(measureIds) + .removeMeasure(salaryAbsSum) + ) + + val dfPersonCntResult = measureIds.function(dfPersons) val dfFullCntResult = measureIds.function(dfFull) val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull) val dfFullHashResult = sumOfHashes.function(dfFull) val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson) val dfFullSalarySumResult = salarySum.function(dfFull) + val dfExtraPersonSalarySumTruncResult = salaryTruncSum.function(dfExtraDecimalPerson) + val dfFullSalarySumTruncResult = salaryTruncSum.function(dfFull) + val dfExtraPersonSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfExtraDecimalPerson) + val dfFullSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfFull) // Assertions assert(dfPersonCntResult.resultValue == "1000") @@ -106,6 +128,14 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase assert(dfExtraPersonSalarySumResult.resultValueType == ResultValueType.BigDecimalValue) assert(dfFullSalarySumResult.resultValue == "2987144") assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144") + assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.LongValue) + assert(dfFullSalarySumTruncResult.resultValue == "2987144") + assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.LongValue) + assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144") + assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue) + assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144") + assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue) } "AbsSumOfValuesOfColumn" should "return expected value" in { @@ -187,4 +217,33 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase assert(result.resultValueType == ResultValueType.BigDecimalValue) } + "SumTruncOfValuesOfColumn" should "return expected value" in { + val distinctCount = SumOfTruncatedValuesOfColumn("colA") + + val data = List(Row("1.98", "b1"), Row("-1.76", "b2"), Row("1.54", "b2"), Row("1.32", "b2")) + val rdd = spark.sparkContext.parallelize(data) + + val schema = StructType(Array(StructField("colA", StringType), StructField("colB", StringType))) + val df = spark.createDataFrame(rdd, schema) + + val result = distinctCount.function(df) + + assert(result.resultValue == "2") + assert(result.resultValueType == ResultValueType.LongValue) + } + + "AbsSumTruncOfValuesOfColumn" should "return expected value" in { + val distinctCount = AbsSumOfTruncatedValuesOfColumn("colA") + + val data = List(Row("1.98", "b1"), Row("-1.76", "b2"), Row("1.54", "b2"), Row("-1.32", "b2")) + val rdd = spark.sparkContext.parallelize(data) + + val schema = StructType(Array(StructField("colA", StringType), StructField("colB", StringType))) + val df = spark.createDataFrame(rdd, schema) + + val result = distinctCount.function(df) + + assert(result.resultValue == "4") + assert(result.resultValueType == ResultValueType.LongValue) + } } diff --git a/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala b/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala index fea11c9f9..f00f5c299 100644 --- a/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala +++ b/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala @@ -19,7 +19,7 @@ package za.co.absa.atum.agent.model import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import za.co.absa.atum.agent.AtumAgent -import za.co.absa.atum.agent.model.AtumMeasure.{AbsSumOfValuesOfColumn, RecordCount, SumOfHashesOfColumn, SumOfValuesOfColumn} +import za.co.absa.atum.agent.model.AtumMeasure.{AbsSumOfValuesOfColumn, RecordCount, SumOfHashesOfColumn, SumOfValuesOfColumn, SumOfTruncatedValuesOfColumn, AbsSumOfTruncatedValuesOfColumn} import za.co.absa.spark.commons.test.SparkTestBase import za.co.absa.atum.agent.AtumContext._ import za.co.absa.atum.model.ResultValueType @@ -30,11 +30,13 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se "Measure" should "be based on the dataframe" in { // Measures - val measureIds: AtumMeasure = RecordCount() - val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn("salary") + val measureIds: AtumMeasure = RecordCount() + val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn("salary") + val sumOfHashes: AtumMeasure = SumOfHashesOfColumn("id") - val salarySum = SumOfValuesOfColumn("salary") - val sumOfHashes: AtumMeasure = SumOfHashesOfColumn("id") + val salarySum = SumOfValuesOfColumn("salary") + val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn("salary") + val salaryTruncSum = SumOfTruncatedValuesOfColumn("salary") // AtumContext contains `Measurement` val atumContextInstanceWithRecordCount = AtumAgent @@ -83,12 +85,33 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se .removeMeasure(salaryAbsSum) ) - val dfPersonCntResult = measureIds.function(dfPersons) - val dfFullCntResult = measureIds.function(dfFull) - val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull) - val dfFullHashResult = sumOfHashes.function(dfFull) - val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson) - val dfFullSalarySumResult = salarySum.function(dfFull) + val dfExtraPersonWithDecimalSalary = spark + .createDataFrame( + Seq( + ("id", "firstName", "lastName", "email", "email2", "profession", "3000.98"), + ("id", "firstName", "lastName", "email", "email2", "profession", "-1000.76") + ) + ) + .toDF("id", "firstName", "lastName", "email", "email2", "profession", "salary") + + val dfExtraDecimalPerson = dfExtraPersonWithDecimalSalary.union(dfPersons) + + dfExtraDecimalPerson.createCheckpoint("a checkpoint name")( + atumContextWithSalaryAbsMeasure + .removeMeasure(measureIds) + .removeMeasure(salaryAbsSum) + ) + + val dfPersonCntResult = measureIds.function(dfPersons) + val dfFullCntResult = measureIds.function(dfFull) + val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull) + val dfFullHashResult = sumOfHashes.function(dfFull) + val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson) + val dfFullSalarySumResult = salarySum.function(dfFull) + val dfExtraPersonSalarySumTruncResult = salaryTruncSum.function(dfExtraDecimalPerson) + val dfFullSalarySumTruncResult = salaryTruncSum.function(dfFull) + val dfExtraPersonSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfExtraDecimalPerson) + val dfFullSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfFull) // Assertions assert(dfPersonCntResult.resultValue == "1000") @@ -103,6 +126,14 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se assert(dfExtraPersonSalarySumResult.resultValueType == ResultValueType.BigDecimalValue) assert(dfFullSalarySumResult.resultValue == "2987144") assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144") + assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.LongValue) + assert(dfFullSalarySumTruncResult.resultValue == "2987144") + assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.LongValue) + assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144") + assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue) + assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144") + assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue) } } diff --git a/agent/src/test/scala/za/co/absa/atum/agent/model/MeasuresBuilderUnitTests.scala b/agent/src/test/scala/za/co/absa/atum/agent/model/MeasuresBuilderUnitTests.scala index da74c9ba3..bb03cff15 100644 --- a/agent/src/test/scala/za/co/absa/atum/agent/model/MeasuresBuilderUnitTests.scala +++ b/agent/src/test/scala/za/co/absa/atum/agent/model/MeasuresBuilderUnitTests.scala @@ -28,6 +28,8 @@ class MeasuresBuilderUnitTests extends AnyFlatSpecLike { MeasureDTO("distinctCount", Seq("distinctCountCol")), MeasureDTO("aggregatedTotal", Seq("aggregatedTotalCol")), MeasureDTO("absAggregatedTotal", Seq("absAggregatedTotalCol")), + MeasureDTO("aggregatedTruncTotal", Seq("aggregatedTruncTotalCol")), + MeasureDTO("absAggregatedTruncTotal", Seq("absAggregatedTruncTotalCol")), MeasureDTO("hashCrc32", Seq("hashCrc32Col")) ) @@ -36,6 +38,8 @@ class MeasuresBuilderUnitTests extends AnyFlatSpecLike { DistinctRecordCount(Seq("distinctCountCol")), SumOfValuesOfColumn("aggregatedTotalCol"), AbsSumOfValuesOfColumn("absAggregatedTotalCol"), + SumOfTruncatedValuesOfColumn("aggregatedTruncTotalCol"), + AbsSumOfTruncatedValuesOfColumn("absAggregatedTruncTotalCol"), SumOfHashesOfColumn("hashCrc32Col") )