diff --git a/core/src/main/scala/za/co/absa/db/fadb/utils/ClassFieldNamesExtractor.scala b/core/src/main/scala/za/co/absa/db/fadb/utils/ClassFieldNamesExtractor.scala new file mode 100644 index 00000000..6d71656d --- /dev/null +++ b/core/src/main/scala/za/co/absa/db/fadb/utils/ClassFieldNamesExtractor.scala @@ -0,0 +1,53 @@ +package za.co.absa.db.fadb.utils + +import za.co.absa.db.fadb.naming.NamingConvention +import za.co.absa.db.fadb.naming.implementations.SnakeCaseNaming + +import java.lang +import scala.reflect.runtime.universe._ + +object ClassFieldNamesExtractor { + + private def doExtract[T: TypeTag](namingConvention: NamingConvention): Seq[String] = { + val tpe = typeOf[T] + if (tpe.typeSymbol.isClass) { + val cl = tpe.typeSymbol.asClass + if (cl.isPrimitive) { + throw new IllegalArgumentException(s"${tpe.typeSymbol} is a primitive type, extraction is not supported") + } + if (cl.isTrait) { + throw new IllegalArgumentException(s"${tpe.typeSymbol} is a trait, extraction is not supported") + } + if (cl.isCaseClass || cl.isClass) { + tpe + .decl(termNames.CONSTRUCTOR) + .asMethod + .paramLists + .flatten + .map(_.name.decodedName.toString) + .map(namingConvention.stringPerConvention) + } else { + throw new IllegalArgumentException(s"${tpe.typeSymbol} is not a case class nor a class") + } + } else { + throw new IllegalArgumentException(s"${tpe.typeSymbol} is not a case class nor a class") + } + } + + /** + * Extracts constructor field names from case class or regular class, and converts them according to naming convention. + * @param namingConvention - the naming convention to use when converting the constructor parameters names into field name + * @tparam T - type to investigate and extract field names from + * @return - list of field names + */ + def extract[T: TypeTag]()( + implicit namingConvention: NamingConvention = SnakeCaseNaming.Implicits.namingConvention + ): Seq[String] = { + doExtract[T](namingConvention) + } + + def extract[T: TypeTag](namingConvention: NamingConvention): Seq[String] = { + doExtract[T](namingConvention) + } + +} diff --git a/core/src/test/scala/za/co/absa/db/fadb/utils/ClassFieldNamesExtractorUnitTests.scala b/core/src/test/scala/za/co/absa/db/fadb/utils/ClassFieldNamesExtractorUnitTests.scala new file mode 100644 index 00000000..ca6f8c27 --- /dev/null +++ b/core/src/test/scala/za/co/absa/db/fadb/utils/ClassFieldNamesExtractorUnitTests.scala @@ -0,0 +1,82 @@ +/* + * Copyright 2025 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.db.fadb.utils + +import org.scalatest.funsuite.AnyFunSuiteLike +import za.co.absa.db.fadb.naming.LettersCase +import za.co.absa.db.fadb.naming.implementations.AsIsNaming +import za.co.absa.db.fadb.utils.ClassFieldNamesExtractorUnitTests._ + +class ClassFieldNamesExtractorUnitTests extends AnyFunSuiteLike { + test("Extract from case class returns its fields") { + val expected = Seq( + "int_field", + "string_field" + ) + val fieldNames = ClassFieldNamesExtractor.extract[TestCaseClass]() + assert(fieldNames == expected) + } + + test("Extract from class constructor fields returns its constructor fields, explicit naming convention") { + val expected = Seq( + "XFIELD", + "YFIELD" + ) + val fieldNames = ClassFieldNamesExtractor.extract[TestClass](new AsIsNaming(LettersCase.UpperCase)) + assert(fieldNames == expected) + } + + test("Extract from class constructor fields returns its constructor fields, implicit naming convention") { + implicit val namingConvention: AsIsNaming = new AsIsNaming(LettersCase.LowerCase) + val expected = Seq( + "xfield", + "yfield" + ) + val fieldNames = ClassFieldNamesExtractor.extract[TestClass]() + assert(fieldNames == expected) + } + + test("Extract from trait fails") { + intercept[IllegalArgumentException] { + ClassFieldNamesExtractor.extract[TestTrait] + } + } + + test("Extract fails on simple type") { + intercept[IllegalArgumentException] { + ClassFieldNamesExtractor.extract[Boolean] + } + } + +} + +object ClassFieldNamesExtractorUnitTests { + case class TestCaseClass(intField: Int, stringField: String) { + def stringFunction: String = intField.toString + stringField + } + + class TestClass(val xField: Int, val yField: String) { + val w: String = xField.toString + yField + + def z: String = xField.toString + yField + } + + trait TestTrait { + def foo: String + } + +} diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 1ad183b8..c667bb41 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -19,6 +19,7 @@ import sbt.* object Dependencies { private def commonDependencies(scalaVersion: String): Seq[ModuleID] = Seq( + "org.scala-lang" % "scala-reflect" % scalaVersion, "org.typelevel" %% "cats-core" % "2.9.0", "org.typelevel" %% "cats-effect" % "3.5.0", "org.scalatest" %% "scalatest" % "3.1.0" % Test,