From 7566572079ee3639b79c50a49fad0c5fcecb0043 Mon Sep 17 00:00:00 2001 From: Daniel K Date: Tue, 14 Dec 2021 14:43:31 +0100 Subject: [PATCH] Feature/4 std config (#9) * #4 StandardizationConfig for Standardization, internally contains means to read configs. Tests adjusted, too. Co-authored-by: Sasa Zejnilovic --- .../co/absa/standardization/Standardization.scala | 9 +++------ .../standardization/StandardizationConfig.scala | 13 +++++++++++++ .../StandardizationParquetSuite.scala | 12 ++++++------ 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/main/scala/za/co/absa/standardization/Standardization.scala b/src/main/scala/za/co/absa/standardization/Standardization.scala index 140c282..8452074 100644 --- a/src/main/scala/za/co/absa/standardization/Standardization.scala +++ b/src/main/scala/za/co/absa/standardization/Standardization.scala @@ -16,14 +16,12 @@ package za.co.absa.standardization -import com.typesafe.config.{Config, ConfigFactory} import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, SparkSession} import org.apache.spark.sql.types.StructType import org.slf4j.{Logger, LoggerFactory} -import za.co.absa.standardization.RecordIdGeneration.getRecordIdGenerationType import za.co.absa.standardization.schema.{SchemaUtils, SparkUtils} import za.co.absa.standardization.stages.{SchemaChecker, TypeParser} import za.co.absa.standardization.types.{Defaults, GlobalDefaults, ParseOutput} @@ -33,7 +31,7 @@ object Standardization { private implicit val defaults: Defaults = GlobalDefaults private val logger: Logger = LoggerFactory.getLogger(this.getClass) - def standardize(df: DataFrame, schema: StructType, generalConfig: Config = ConfigFactory.load()) + def standardize(df: DataFrame, schema: StructType, standardizationConfig: StandardizationConfig = StandardizationConfig.fromConfig()) (implicit sparkSession: SparkSession): DataFrame = { implicit val udfLib: UDFLibrary = new UDFLibrary implicit val hadoopConf: Configuration = sparkSession.sparkContext.hadoopConfiguration @@ -42,7 +40,7 @@ object Standardization { validateSchemaAgainstSelfInconsistencies(schema) logger.info(s"Step 2: Standardization") - val std = standardizeDataset(df, schema, generalConfig.getBoolean("standardization.failOnInputNotPerSchema")) + val std = standardizeDataset(df, schema, standardizationConfig.failOnInputNotPerSchema) logger.info(s"Step 3: Clean the final error column") val cleanedStd = cleanTheFinalErrorColumn(std) @@ -50,8 +48,7 @@ object Standardization { val idedStd = if (SchemaUtils.fieldExists(Constants.EnceladusRecordId, cleanedStd.schema)) { cleanedStd // no new id regeneration } else { - val recordIdGenerationStrategy = getRecordIdGenerationType(generalConfig.getString("standardization.recordId.generation.strategy")) - RecordIdGeneration.addRecordIdColumnByStrategy(cleanedStd, Constants.EnceladusRecordId, recordIdGenerationStrategy) + RecordIdGeneration.addRecordIdColumnByStrategy(cleanedStd, Constants.EnceladusRecordId, standardizationConfig.recordIdGenerationStrategy) } logger.info(s"Standardization process finished, returning to the application...") diff --git a/src/main/scala/za/co/absa/standardization/StandardizationConfig.scala b/src/main/scala/za/co/absa/standardization/StandardizationConfig.scala index 82db673..5017f28 100644 --- a/src/main/scala/za/co/absa/standardization/StandardizationConfig.scala +++ b/src/main/scala/za/co/absa/standardization/StandardizationConfig.scala @@ -16,7 +16,20 @@ package za.co.absa.standardization +import com.typesafe.config.{Config, ConfigFactory} +import za.co.absa.standardization.RecordIdGeneration.getRecordIdGenerationType + case class StandardizationConfig(recordIdGenerationStrategy: RecordIdGeneration.IdType, failOnInputNotPerSchema: Boolean) { } + +object StandardizationConfig { + def fromConfig(generalConfig: Config = ConfigFactory.load()): StandardizationConfig = { + StandardizationConfig( + getRecordIdGenerationType(generalConfig.getString("standardization.recordId.generation.strategy")), + generalConfig.getBoolean("standardization.failOnInputNotPerSchema") + ) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/StandardizationParquetSuite.scala b/src/test/scala/za/co/absa/standardization/StandardizationParquetSuite.scala index bda61aa..42e7ffc 100644 --- a/src/test/scala/za/co/absa/standardization/StandardizationParquetSuite.scala +++ b/src/test/scala/za/co/absa/standardization/StandardizationParquetSuite.scala @@ -216,7 +216,7 @@ class StandardizationParquetSuite extends AnyFunSuite with SparkTestBase { val schema = StructType(seq) val exception = intercept[TypeParserException] { - Standardization.standardize(sourceDataDF, schema, configWithSchemaValidation) + Standardization.standardize(sourceDataDF, schema, StandardizationConfig.fromConfig(configWithSchemaValidation)) } assert(exception.getMessage == "Cannot standardize field 'id' from type integer into array") } @@ -233,7 +233,7 @@ class StandardizationParquetSuite extends AnyFunSuite with SparkTestBase { val schema = StructType(seq) val exception = intercept[TypeParserException] { - Standardization.standardize(sourceDataDF, schema, configWithSchemaValidation) + Standardization.standardize(sourceDataDF, schema, StandardizationConfig.fromConfig(configWithSchemaValidation)) } assert(exception.getMessage == "Cannot standardize field 'id' from type integer into struct") } @@ -247,7 +247,7 @@ class StandardizationParquetSuite extends AnyFunSuite with SparkTestBase { val schema = StructType(seq) val exception = intercept[TypeParserException] { - Standardization.standardize(sourceDataDF, schema, configWithSchemaValidation) + Standardization.standardize(sourceDataDF, schema, StandardizationConfig.fromConfig(configWithSchemaValidation)) } assert(exception.getMessage == "Cannot standardize field 'letters' from type array into struct") } @@ -270,7 +270,7 @@ class StandardizationParquetSuite extends AnyFunSuite with SparkTestBase { ) val schema = StructType(seq) // stableHashId will always yield the same ids - val destDF = Standardization.standardize(sourceDataDF, schema, stableIdConfig) + val destDF = Standardization.standardize(sourceDataDF, schema, StandardizationConfig.fromConfig(stableIdConfig)) val actual = destDF.dataAsString(truncate = false) assert(actual == expected) @@ -293,7 +293,7 @@ class StandardizationParquetSuite extends AnyFunSuite with SparkTestBase { StructField("struct", StructType(Seq(StructField("bar", BooleanType))), nullable = false) ) val schema = StructType(seq) - val destDF = Standardization.standardize(sourceDataDF, schema, uuidConfig) + val destDF = Standardization.standardize(sourceDataDF, schema, StandardizationConfig.fromConfig(uuidConfig)) // same except for the record id val actual = destDF.drop("enceladus_record_id").dataAsString(truncate = false) @@ -326,7 +326,7 @@ class StandardizationParquetSuite extends AnyFunSuite with SparkTestBase { StructField("enceladus_record_id", StringType, nullable = false) ) val schema = StructType(seq) - val destDF = Standardization.standardize(sourceDfWithExistingIds, schema, uuidConfig) + val destDF = Standardization.standardize(sourceDfWithExistingIds, schema, StandardizationConfig.fromConfig(uuidConfig)) // The TrueUuids strategy does not override the existing values val actual = destDF.dataAsString(truncate = false)