Skip to content

Commit dd3b545

Browse files
Rahul Tanwanirxin
authored andcommitted
[SPARK-13309][SQL] Fix type inference issue with CSV data
Fix type inference issue for sparse CSV data - https://issues.apache.org/jira/browse/SPARK-13309 Author: Rahul Tanwani <[email protected]> Closes #11194 from tanwanirahul/master.
1 parent 6dfc4a7 commit dd3b545

File tree

4 files changed

+32
-10
lines changed

4 files changed

+32
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD
2929
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
3030
import org.apache.spark.sql.types._
3131

32-
3332
private[csv] object CSVInferSchema {
3433

3534
/**
@@ -48,7 +47,11 @@ private[csv] object CSVInferSchema {
4847
tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes)
4948

5049
val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
51-
StructField(thisHeader, rootType, nullable = true)
50+
val dType = rootType match {
51+
case _: NullType => StringType
52+
case other => other
53+
}
54+
StructField(thisHeader, dType, nullable = true)
5255
}
5356

5457
StructType(structFields)
@@ -65,12 +68,8 @@ private[csv] object CSVInferSchema {
6568
}
6669

6770
def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = {
68-
first.zipAll(second, NullType, NullType).map { case ((a, b)) =>
69-
val tpe = findTightestCommonType(a, b).getOrElse(StringType)
70-
tpe match {
71-
case _: NullType => StringType
72-
case other => other
73-
}
71+
first.zipAll(second, NullType, NullType).map { case (a, b) =>
72+
findTightestCommonType(a, b).getOrElse(NullType)
7473
}
7574
}
7675

@@ -140,6 +139,8 @@ private[csv] object CSVInferSchema {
140139
case (t1, t2) if t1 == t2 => Some(t1)
141140
case (NullType, t1) => Some(t1)
142141
case (t1, NullType) => Some(t1)
142+
case (StringType, t2) => Some(StringType)
143+
case (t1, StringType) => Some(StringType)
143144

144145
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
145146
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
@@ -150,7 +151,6 @@ private[csv] object CSVInferSchema {
150151
}
151152
}
152153

153-
154154
private[csv] object CSVTypeCast {
155155

156156
/**
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
A,B,C,D
2+
1,,,
3+
,1,,
4+
,,1,
5+
,,,1

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,9 @@ class InferSchemaSuite extends SparkFunSuite {
6868
assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
6969
assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
7070
}
71+
72+
test("Merging Nulltypes should yeild Nulltype.") {
73+
val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType))
74+
assert(mergedNullTypes.deep == Array(NullType).deep)
75+
}
7176
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
3737
private val emptyFile = "empty.csv"
3838
private val commentsFile = "comments.csv"
3939
private val disableCommentsFile = "disable_comments.csv"
40+
private val simpleSparseFile = "simple_sparse.csv"
4041

4142
private def testFile(fileName: String): String = {
4243
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
@@ -233,7 +234,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
233234
assert(result.schema.fieldNames.size === 1)
234235
}
235236

236-
237237
test("DDL test with empty file") {
238238
sqlContext.sql(s"""
239239
|CREATE TEMPORARY TABLE carsTable
@@ -396,4 +396,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
396396
verifyCars(carsCopy, withHeader = true)
397397
}
398398
}
399+
400+
test("Schema inference correctly identifies the datatype when data is sparse.") {
401+
val df = sqlContext.read
402+
.format("csv")
403+
.option("header", "true")
404+
.option("inferSchema", "true")
405+
.load(testFile(simpleSparseFile))
406+
407+
assert(
408+
df.schema.fields.map(field => field.dataType).deep ==
409+
Array(IntegerType, IntegerType, IntegerType, IntegerType).deep)
410+
}
399411
}

0 commit comments

Comments
 (0)