diff --git a/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala b/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala index d9991a2..f138264 100644 --- a/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala +++ b/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala @@ -42,7 +42,11 @@ private[csv] object InferSchema { mergeRowTypes) val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => - StructField(thisHeader, rootType, nullable = true) + val dType = rootType match { + case z: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) } StructType(structFields) @@ -62,11 +66,7 @@ private[csv] object InferSchema { first: Array[DataType], second: Array[DataType]): Array[DataType] = { first.zipAll(second, NullType, NullType).map { case ((a, b)) => - val tpe = findTightestCommonType(a, b).getOrElse(StringType) - tpe match { - case _: NullType => StringType - case other => other - } + findTightestCommonType(a, b).getOrElse(NullType) } } @@ -93,7 +93,6 @@ private[csv] object InferSchema { } } - private def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) { IntegerType } else { @@ -152,6 +151,8 @@ private[csv] object InferSchema { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) + case (StringType, t2) => Some(StringType) + case (t1, StringType) => Some(StringType) // Promote numeric types to the highest of the two and all numeric types to unlimited decimal case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => diff --git a/src/test/resources/simple.csv b/src/test/resources/simple.csv new file mode 100644 index 0000000..02d29ca --- /dev/null +++ b/src/test/resources/simple.csv @@ -0,0 +1,5 @@ +A,B,C,D +1,,, +,1,, +,,1, +,,,1 diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 00eb846..2494be7 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -41,6 +41,7 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { val tempEmptyDir = "target/test/empty/" val commentsFile = "src/test/resources/comments.csv" val disableCommentsFile = "src/test/resources/disable_comments.csv" + private val simpleDatasetFile = "src/test/resources/simple.csv" val numCars = 3 @@ -658,7 +659,6 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(results.toSeq.map(_.toSeq) === expected) } - test("Setting comment to null disables comment support") { val results: Array[Row] = new CsvParser() .withDelimiter(',') @@ -717,6 +717,17 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(results.size === numCars) } + + test("Type/Schema inference works as expected for the simple sparse dataset.") { + val df = new CsvParser() + .withUseHeader(true) + .withInferSchema(true) + .csvFile(sqlContext, simpleDatasetFile) + + assert( + df.schema.fields.map(_.dataType).deep == + Array(IntegerType, IntegerType, IntegerType, IntegerType).deep) + } } class CsvSuite extends AbstractCsvSuite { diff --git a/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala b/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala index 41f3ba8..4bf578c 100644 --- a/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala @@ -55,6 +55,12 @@ class InferSchemaSuite extends FunSuite { assert(InferSchema.inferField(LongType, "2015-08 14:49:00") == StringType) } + test("Merging Nulltypes should yeild Nulltype.") { + assert( + InferSchema.mergeRowTypes(Array(NullType), + Array(NullType)).deep == Array(NullType).deep) + } + test("Type arrays are merged to highest common type") { assert( InferSchema.mergeRowTypes(Array(StringType),