From 007f45c7ce0541835d4738a86a55553a53bd5323 Mon Sep 17 00:00:00 2001 From: ABLL526 Date: Tue, 4 Mar 2025 18:17:05 +0200 Subject: [PATCH 1/3] Changes Made: - Added the aggregatedTruncTotal Measure and the absAggregatedTruncTotal Measure. - Added the tests for these Measures. --- .../za/co/absa/atum/agent/model/Measure.scala | 38 ++++++++++++ .../atum/agent/model/MeasuresBuilder.scala | 2 + .../agent/model/AtumMeasureUnitTests.scala | 61 +++++++++++++++++++ .../atum/agent/model/MeasureUnitTests.scala | 53 ++++++++++++---- .../model/MeasuresBuilderUnitTests.scala | 4 ++ 5 files changed, 147 insertions(+), 11 deletions(-) diff --git a/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala b/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala index 1a5e634e2..90e330537 100644 --- a/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala +++ b/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala @@ -44,6 +44,8 @@ object AtumMeasure { DistinctRecordCount.measureName, SumOfValuesOfColumn.measureName, AbsSumOfValuesOfColumn.measureName, + SumOfTruncatedValuesOfColumn.measureName, + AbsSumOfTruncatedValuesOfColumn.measureName, SumOfHashesOfColumn.measureName ) @@ -117,6 +119,42 @@ object AtumMeasure { def apply(measuredCol: String): AbsSumOfValuesOfColumn = AbsSumOfValuesOfColumn(measureName, measuredCol) } + case class SumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure { + //Cast to LongType to remove decimal points then cast back to decimal to ensure compatibility + private val columnAggFn: Column => Column = column => sum(column.cast(LongType).cast(DecimalType(38, 0))) + + override def function: MeasurementFunction = (ds: DataFrame) => { + val dataType = ds.select(measuredCol).schema.fields(0).dataType + val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect() + MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType) + } + + override def measuredColumns: Seq[String] = Seq(measuredCol) + override val resultValueType: ResultValueType = ResultValueType.BigDecimalValue + } + object SumOfTruncatedValuesOfColumn { + private[agent] val measureName: String = "aggregatedTruncTotal" + def apply(measuredCol: String): SumOfTruncatedValuesOfColumn = SumOfTruncatedValuesOfColumn(measureName, measuredCol) + } + + case class AbsSumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure { + //Cast to LongType to remove decimal points then cast back to decimal to ensure compatibility + private val columnAggFn: Column => Column = column => sum(abs(column.cast(LongType).cast(DecimalType(38, 0)))) + + override def function: MeasurementFunction = (ds: DataFrame) => { + val dataType = ds.select(measuredCol).schema.fields(0).dataType + val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect() + MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType) + } + + override def measuredColumns: Seq[String] = Seq(measuredCol) + override val resultValueType: ResultValueType = ResultValueType.BigDecimalValue + } + object AbsSumOfTruncatedValuesOfColumn { + private[agent] val measureName: String = "absAggregatedTruncTotal" + def apply(measuredCol: String): AbsSumOfTruncatedValuesOfColumn = AbsSumOfTruncatedValuesOfColumn(measureName, measuredCol) + } + case class SumOfHashesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure { private val columnExpression: Column = sum(crc32(col(measuredCol).cast("String"))) override def function: MeasurementFunction = (ds: DataFrame) => { diff --git a/agent/src/main/scala/za/co/absa/atum/agent/model/MeasuresBuilder.scala b/agent/src/main/scala/za/co/absa/atum/agent/model/MeasuresBuilder.scala index 2fefe36e2..739757c23 100644 --- a/agent/src/main/scala/za/co/absa/atum/agent/model/MeasuresBuilder.scala +++ b/agent/src/main/scala/za/co/absa/atum/agent/model/MeasuresBuilder.scala @@ -49,6 +49,8 @@ private [agent] object MeasuresBuilder extends Logging { case DistinctRecordCount.measureName => DistinctRecordCount(measuredColumns) case SumOfValuesOfColumn.measureName => SumOfValuesOfColumn(measuredColumns.head) case AbsSumOfValuesOfColumn.measureName => AbsSumOfValuesOfColumn(measuredColumns.head) + case SumOfTruncatedValuesOfColumn.measureName => SumOfTruncatedValuesOfColumn(measuredColumns.head) + case AbsSumOfTruncatedValuesOfColumn.measureName => AbsSumOfTruncatedValuesOfColumn(measuredColumns.head) case SumOfHashesOfColumn.measureName => SumOfHashesOfColumn(measuredColumns.head) } }.toOption diff --git a/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala b/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala index 5c3ff2b88..fea3ae1ed 100644 --- a/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala +++ b/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala @@ -37,6 +37,8 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase measuredCol = "salary" ) val salarySum = SumOfValuesOfColumn(measuredCol = "salary") + val salaryTruncSum = SumOfTruncatedValuesOfColumn(measuredCol = "salary") + val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn(measuredCol = "salary") val sumOfHashes: AtumMeasure = SumOfHashesOfColumn(measuredCol = "id") // AtumContext contains `Measurement` @@ -86,12 +88,34 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase .removeMeasure(salaryAbsSum) ) + val dfExtraPersonWithDecimalSalary = spark + .createDataFrame( + Seq( + ("id", "firstName", "lastName", "email", "email2", "profession", "3000.98"), + ("id", "firstName", "lastName", "email", "email2", "profession", "-1000.76") + ) + ) + .toDF("id", "firstName", "lastName", "email", "email2", "profession", "salary") + + val dfExtraDecimalPerson = dfExtraPersonWithDecimalSalary.union(dfPersons) + + dfExtraDecimalPerson.createCheckpoint("a checkpoint name")( + atumContextWithSalaryAbsMeasure + .removeMeasure(measureIds) + .removeMeasure(salaryAbsSum) + ) + + val dfPersonCntResult = measureIds.function(dfPersons) val dfFullCntResult = measureIds.function(dfFull) val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull) val dfFullHashResult = sumOfHashes.function(dfFull) val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson) val dfFullSalarySumResult = salarySum.function(dfFull) + val dfExtraPersonSalarySumTruncResult = salaryTruncSum.function(dfExtraDecimalPerson) + val dfFullSalarySumTruncResult = salaryTruncSum.function(dfFull) + val dfExtraPersonSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfExtraDecimalPerson) + val dfFullSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfFull) // Assertions assert(dfPersonCntResult.resultValue == "1000") @@ -106,6 +130,14 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase assert(dfExtraPersonSalarySumResult.resultValueType == ResultValueType.BigDecimalValue) assert(dfFullSalarySumResult.resultValue == "2987144") assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144") + assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfFullSalarySumTruncResult.resultValue == "2987144") + assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144") + assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144") + assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue) } "AbsSumOfValuesOfColumn" should "return expected value" in { @@ -187,4 +219,33 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase assert(result.resultValueType == ResultValueType.BigDecimalValue) } + "SumTruncOfValuesOfColumn" should "return expected value" in { + val distinctCount = SumOfTruncatedValuesOfColumn("colA") + + val data = List(Row("1.98", "b1"), Row("-1.76", "b2"), Row("1.54", "b2"), Row("1.32", "b2")) + val rdd = spark.sparkContext.parallelize(data) + + val schema = StructType(Array(StructField("colA", StringType), StructField("colB", StringType))) + val df = spark.createDataFrame(rdd, schema) + + val result = distinctCount.function(df) + + assert(result.resultValue == "2") + assert(result.resultValueType == ResultValueType.BigDecimalValue) + } + + "AbsSumTruncOfValuesOfColumn" should "return expected value" in { + val distinctCount = AbsSumOfTruncatedValuesOfColumn("colA") + + val data = List(Row("1.98", "b1"), Row("-1.76", "b2"), Row("1.54", "b2"), Row("-1.32", "b2")) + val rdd = spark.sparkContext.parallelize(data) + + val schema = StructType(Array(StructField("colA", StringType), StructField("colB", StringType))) + val df = spark.createDataFrame(rdd, schema) + + val result = distinctCount.function(df) + + assert(result.resultValue == "4") + assert(result.resultValueType == ResultValueType.BigDecimalValue) + } } diff --git a/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala b/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala index fea11c9f9..c7f3d3b85 100644 --- a/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala +++ b/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala @@ -19,7 +19,7 @@ package za.co.absa.atum.agent.model import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import za.co.absa.atum.agent.AtumAgent -import za.co.absa.atum.agent.model.AtumMeasure.{AbsSumOfValuesOfColumn, RecordCount, SumOfHashesOfColumn, SumOfValuesOfColumn} +import za.co.absa.atum.agent.model.AtumMeasure.{AbsSumOfValuesOfColumn, RecordCount, SumOfHashesOfColumn, SumOfValuesOfColumn, SumOfTruncatedValuesOfColumn, AbsSumOfTruncatedValuesOfColumn} import za.co.absa.spark.commons.test.SparkTestBase import za.co.absa.atum.agent.AtumContext._ import za.co.absa.atum.model.ResultValueType @@ -30,11 +30,13 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se "Measure" should "be based on the dataframe" in { // Measures - val measureIds: AtumMeasure = RecordCount() - val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn("salary") + val measureIds: AtumMeasure = RecordCount() + val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn("salary") + val sumOfHashes: AtumMeasure = SumOfHashesOfColumn("id") - val salarySum = SumOfValuesOfColumn("salary") - val sumOfHashes: AtumMeasure = SumOfHashesOfColumn("id") + val salarySum = SumOfValuesOfColumn("salary") + val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn("salary") + val salaryTruncSum = SumOfTruncatedValuesOfColumn("salary") // AtumContext contains `Measurement` val atumContextInstanceWithRecordCount = AtumAgent @@ -83,12 +85,33 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se .removeMeasure(salaryAbsSum) ) - val dfPersonCntResult = measureIds.function(dfPersons) - val dfFullCntResult = measureIds.function(dfFull) - val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull) - val dfFullHashResult = sumOfHashes.function(dfFull) - val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson) - val dfFullSalarySumResult = salarySum.function(dfFull) + val dfExtraPersonWithDecimalSalary = spark + .createDataFrame( + Seq( + ("id", "firstName", "lastName", "email", "email2", "profession", "3000.98"), + ("id", "firstName", "lastName", "email", "email2", "profession", "-1000.76") + ) + ) + .toDF("id", "firstName", "lastName", "email", "email2", "profession", "salary") + + val dfExtraDecimalPerson = dfExtraPersonWithDecimalSalary.union(dfPersons) + + dfExtraDecimalPerson.createCheckpoint("a checkpoint name")( + atumContextWithSalaryAbsMeasure + .removeMeasure(measureIds) + .removeMeasure(salaryAbsSum) + ) + + val dfPersonCntResult = measureIds.function(dfPersons) + val dfFullCntResult = measureIds.function(dfFull) + val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull) + val dfFullHashResult = sumOfHashes.function(dfFull) + val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson) + val dfFullSalarySumResult = salarySum.function(dfFull) + val dfExtraPersonSalarySumTruncResult = salaryTruncSum.function(dfExtraDecimalPerson) + val dfFullSalarySumTruncResult = salaryTruncSum.function(dfFull) + val dfExtraPersonSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfExtraDecimalPerson) + val dfFullSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfFull) // Assertions assert(dfPersonCntResult.resultValue == "1000") @@ -103,6 +126,14 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se assert(dfExtraPersonSalarySumResult.resultValueType == ResultValueType.BigDecimalValue) assert(dfFullSalarySumResult.resultValue == "2987144") assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144") + assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfFullSalarySumTruncResult.resultValue == "2987144") + assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144") + assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144") + assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue) } } diff --git a/agent/src/test/scala/za/co/absa/atum/agent/model/MeasuresBuilderUnitTests.scala b/agent/src/test/scala/za/co/absa/atum/agent/model/MeasuresBuilderUnitTests.scala index da74c9ba3..bb03cff15 100644 --- a/agent/src/test/scala/za/co/absa/atum/agent/model/MeasuresBuilderUnitTests.scala +++ b/agent/src/test/scala/za/co/absa/atum/agent/model/MeasuresBuilderUnitTests.scala @@ -28,6 +28,8 @@ class MeasuresBuilderUnitTests extends AnyFlatSpecLike { MeasureDTO("distinctCount", Seq("distinctCountCol")), MeasureDTO("aggregatedTotal", Seq("aggregatedTotalCol")), MeasureDTO("absAggregatedTotal", Seq("absAggregatedTotalCol")), + MeasureDTO("aggregatedTruncTotal", Seq("aggregatedTruncTotalCol")), + MeasureDTO("absAggregatedTruncTotal", Seq("absAggregatedTruncTotalCol")), MeasureDTO("hashCrc32", Seq("hashCrc32Col")) ) @@ -36,6 +38,8 @@ class MeasuresBuilderUnitTests extends AnyFlatSpecLike { DistinctRecordCount(Seq("distinctCountCol")), SumOfValuesOfColumn("aggregatedTotalCol"), AbsSumOfValuesOfColumn("absAggregatedTotalCol"), + SumOfTruncatedValuesOfColumn("aggregatedTruncTotalCol"), + AbsSumOfTruncatedValuesOfColumn("absAggregatedTruncTotalCol"), SumOfHashesOfColumn("hashCrc32Col") ) From ff0ed63a2d2318ef15e5aba48b01a0c03a8cb463 Mon Sep 17 00:00:00 2001 From: ABLL526 Date: Wed, 5 Mar 2025 18:46:35 +0200 Subject: [PATCH 2/3] Changes Made: - Added the aggregatedTruncTotal Measure and the absAggregatedTruncTotal Measure. - Added the tests for these Measures. - Made amendments to the function to not include double casts. --- .../za/co/absa/atum/agent/model/Measure.scala | 24 ++++++++++--------- .../agent/model/AtumMeasureUnitTests.scala | 14 +++++------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala b/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala index 90e330537..805623fbd 100644 --- a/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala +++ b/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala @@ -39,15 +39,15 @@ final case class UnknownMeasure(measureName: String, measuredColumns: Seq[String object AtumMeasure { - val supportedMeasureNames: Seq[String] = Seq( - RecordCount.measureName, - DistinctRecordCount.measureName, - SumOfValuesOfColumn.measureName, - AbsSumOfValuesOfColumn.measureName, - SumOfTruncatedValuesOfColumn.measureName, - AbsSumOfTruncatedValuesOfColumn.measureName, - SumOfHashesOfColumn.measureName - ) +// val supportedMeasureNames: Seq[String] = Seq( +// RecordCount.measureName, +// DistinctRecordCount.measureName, +// SumOfValuesOfColumn.measureName, +// AbsSumOfValuesOfColumn.measureName, +// SumOfTruncatedValuesOfColumn.measureName, +// AbsSumOfTruncatedValuesOfColumn.measureName, +// SumOfHashesOfColumn.measureName +// ) case class RecordCount private (measureName: String) extends AtumMeasure { private val columnExpression = count("*") @@ -121,7 +121,8 @@ object AtumMeasure { case class SumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure { //Cast to LongType to remove decimal points then cast back to decimal to ensure compatibility - private val columnAggFn: Column => Column = column => sum(column.cast(LongType).cast(DecimalType(38, 0))) + //private val columnAggFn: Column => Column = column => sum(column.cast(LongType).cast(DecimalType(38, 0))) + private val columnAggFn: Column => Column = column => sum(when(column >= 0, floor(column)).otherwise(ceil(column))) override def function: MeasurementFunction = (ds: DataFrame) => { val dataType = ds.select(measuredCol).schema.fields(0).dataType @@ -139,7 +140,8 @@ object AtumMeasure { case class AbsSumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure { //Cast to LongType to remove decimal points then cast back to decimal to ensure compatibility - private val columnAggFn: Column => Column = column => sum(abs(column.cast(LongType).cast(DecimalType(38, 0)))) + //private val columnAggFn: Column => Column = column => sum(abs(column.cast(LongType).cast(DecimalType(38, 0)))) + private val columnAggFn: Column => Column = column => sum(abs(when(column >= 0, floor(column)).otherwise(ceil(column)))) override def function: MeasurementFunction = (ds: DataFrame) => { val dataType = ds.select(measuredCol).schema.fields(0).dataType diff --git a/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala b/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala index fea3ae1ed..aee121883 100644 --- a/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala +++ b/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala @@ -32,14 +32,12 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase "Measure" should "be based on the dataframe" in { // Measures - val measureIds: AtumMeasure = RecordCount() - val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn( - measuredCol = "salary" - ) - val salarySum = SumOfValuesOfColumn(measuredCol = "salary") - val salaryTruncSum = SumOfTruncatedValuesOfColumn(measuredCol = "salary") - val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn(measuredCol = "salary") - val sumOfHashes: AtumMeasure = SumOfHashesOfColumn(measuredCol = "id") + val measureIds: AtumMeasure = RecordCount() + val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn(measuredCol = "salary") + val salarySum = SumOfValuesOfColumn(measuredCol = "salary") + val salaryTruncSum = SumOfTruncatedValuesOfColumn(measuredCol = "salary") + val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn(measuredCol = "salary") + val sumOfHashes: AtumMeasure = SumOfHashesOfColumn(measuredCol = "id") // AtumContext contains `Measurement` val atumContextInstanceWithRecordCount = AtumAgent From 3a5a907bfbf6880437a2f58b560e95c753c1a248 Mon Sep 17 00:00:00 2001 From: ABLL526 Date: Fri, 7 Mar 2025 10:26:28 +0200 Subject: [PATCH 3/3] Changes Made: - Added the aggregatedTruncTotal Measure and the absAggregatedTruncTotal Measure. - Added the tests for these Measures. - Made amendments to the function to not include double casts. - Changed the result from BigDecimal to LongValue. --- .../za/co/absa/atum/agent/model/Measure.scala | 20 ++++--------------- .../agent/model/AtumMeasureUnitTests.scala | 12 +++++------ .../atum/agent/model/MeasureUnitTests.scala | 8 ++++---- 3 files changed, 14 insertions(+), 26 deletions(-) diff --git a/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala b/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala index 805623fbd..1fc157240 100644 --- a/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala +++ b/agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala @@ -39,16 +39,6 @@ final case class UnknownMeasure(measureName: String, measuredColumns: Seq[String object AtumMeasure { -// val supportedMeasureNames: Seq[String] = Seq( -// RecordCount.measureName, -// DistinctRecordCount.measureName, -// SumOfValuesOfColumn.measureName, -// AbsSumOfValuesOfColumn.measureName, -// SumOfTruncatedValuesOfColumn.measureName, -// AbsSumOfTruncatedValuesOfColumn.measureName, -// SumOfHashesOfColumn.measureName -// ) - case class RecordCount private (measureName: String) extends AtumMeasure { private val columnExpression = count("*") @@ -120,8 +110,7 @@ object AtumMeasure { } case class SumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure { - //Cast to LongType to remove decimal points then cast back to decimal to ensure compatibility - //private val columnAggFn: Column => Column = column => sum(column.cast(LongType).cast(DecimalType(38, 0))) + private val columnAggFn: Column => Column = column => sum(when(column >= 0, floor(column)).otherwise(ceil(column))) override def function: MeasurementFunction = (ds: DataFrame) => { @@ -131,7 +120,7 @@ object AtumMeasure { } override def measuredColumns: Seq[String] = Seq(measuredCol) - override val resultValueType: ResultValueType = ResultValueType.BigDecimalValue + override val resultValueType: ResultValueType = ResultValueType.LongValue } object SumOfTruncatedValuesOfColumn { private[agent] val measureName: String = "aggregatedTruncTotal" @@ -139,8 +128,7 @@ object AtumMeasure { } case class AbsSumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure { - //Cast to LongType to remove decimal points then cast back to decimal to ensure compatibility - //private val columnAggFn: Column => Column = column => sum(abs(column.cast(LongType).cast(DecimalType(38, 0)))) + private val columnAggFn: Column => Column = column => sum(abs(when(column >= 0, floor(column)).otherwise(ceil(column)))) override def function: MeasurementFunction = (ds: DataFrame) => { @@ -150,7 +138,7 @@ object AtumMeasure { } override def measuredColumns: Seq[String] = Seq(measuredCol) - override val resultValueType: ResultValueType = ResultValueType.BigDecimalValue + override val resultValueType: ResultValueType = ResultValueType.LongValue } object AbsSumOfTruncatedValuesOfColumn { private[agent] val measureName: String = "absAggregatedTruncTotal" diff --git a/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala b/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala index aee121883..685e05de6 100644 --- a/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala +++ b/agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala @@ -129,13 +129,13 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase assert(dfFullSalarySumResult.resultValue == "2987144") assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue) assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144") - assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.LongValue) assert(dfFullSalarySumTruncResult.resultValue == "2987144") - assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.LongValue) assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144") - assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue) assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144") - assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue) } "AbsSumOfValuesOfColumn" should "return expected value" in { @@ -229,7 +229,7 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase val result = distinctCount.function(df) assert(result.resultValue == "2") - assert(result.resultValueType == ResultValueType.BigDecimalValue) + assert(result.resultValueType == ResultValueType.LongValue) } "AbsSumTruncOfValuesOfColumn" should "return expected value" in { @@ -244,6 +244,6 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase val result = distinctCount.function(df) assert(result.resultValue == "4") - assert(result.resultValueType == ResultValueType.BigDecimalValue) + assert(result.resultValueType == ResultValueType.LongValue) } } diff --git a/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala b/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala index c7f3d3b85..f00f5c299 100644 --- a/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala +++ b/agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala @@ -127,13 +127,13 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se assert(dfFullSalarySumResult.resultValue == "2987144") assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue) assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144") - assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.LongValue) assert(dfFullSalarySumTruncResult.resultValue == "2987144") - assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.LongValue) assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144") - assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue) assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144") - assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue) + assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue) } }