diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index 7d71195..da2c599 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -34,6 +34,7 @@ class CsvParser { private var ignoreLeadingWhiteSpace: Boolean = false private var ignoreTrailingWhiteSpace: Boolean = false private var parserLib: String = ParserLibs.DEFAULT + private var nullValues: Seq[String] = Seq("") def withUseHeader(flag: Boolean): CsvParser = { @@ -81,6 +82,11 @@ class CsvParser { this } + def withNullValues(nullValues: Seq[String]): CsvParser = { + this.nullValues = nullValues + this + } + /** Returns a Schema RDD for the given CSV path. */ @throws[RuntimeException] def csvFile(sqlContext: SQLContext, path: String): DataFrame = { @@ -94,7 +100,8 @@ class CsvParser { parserLib, ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace, - schema)(sqlContext) + schema, + nullValues.toSet)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 2c9f30a..84242f6 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -41,7 +41,8 @@ case class CsvRelation protected[spark] ( parserLib: String, ignoreLeadingWhiteSpace: Boolean, ignoreTrailingWhiteSpace: Boolean, - userSchema: StructType = null)(@transient val sqlContext: SQLContext) + userSchema: StructType = null, + nullValues: Set[String] = Set(""))(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with InsertableRelation { private val logger = LoggerFactory.getLogger(CsvRelation.getClass) @@ -153,7 +154,8 @@ case class CsvRelation protected[spark] ( try { index = 0 while (index < schemaFields.length) { - rowArray(index) = TypeCast.castTo(tokens(index), schemaFields(index).dataType) + val token = if (nullValues.contains(tokens(index))) "" else tokens(index) + rowArray(index) = TypeCast.castTo(token, schemaFields(index).dataType) index = index + 1 } Some(Row.fromSeq(rowArray)) @@ -195,7 +197,8 @@ case class CsvRelation protected[spark] ( throw new RuntimeException(s"Malformed line in FAILFAST mode: $line") } else { while (index < schemaFields.length) { - rowArray(index) = TypeCast.castTo(tokens.get(index), schemaFields(index).dataType) + val token = if (nullValues.contains(tokens.get(index))) "" else tokens.get(index) + rowArray(index) = TypeCast.castTo(token, schemaFields(index).dataType) index = index + 1 } Some(Row.fromSeq(rowArray)) diff --git a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala index 62c7b17..b881e5d 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala @@ -33,21 +33,25 @@ object TypeCast { * @param castType SparkSQL type */ private[csv] def castTo(datum: String, castType: DataType): Any = { - castType match { - case _: ByteType => datum.toByte - case _: ShortType => datum.toShort - case _: IntegerType => datum.toInt - case _: LongType => datum.toLong - case _: FloatType => datum.toFloat - case _: DoubleType => datum.toDouble - case _: BooleanType => datum.toBoolean - case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) - // TODO(hossein): would be good to support other common timestamp formats - case _: TimestampType => Timestamp.valueOf(datum) - // TODO(hossein): would be good to support other common date formats - case _: DateType => Date.valueOf(datum) - case _: StringType => datum - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + if (datum.isEmpty && castType != StringType) { + null + } else { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => datum.toFloat + case _: DoubleType => datum.toDouble + case _: BooleanType => datum.toBoolean + case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) + // TODO(hossein): would be good to support other common timestamp formats + case _: TimestampType => Timestamp.valueOf(datum) + // TODO(hossein): would be good to support other common date formats + case _: DateType => Date.valueOf(datum) + case _: StringType => datum + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + } } } diff --git a/src/test/resources/missing-values.csv b/src/test/resources/missing-values.csv new file mode 100644 index 0000000..71c953b --- /dev/null +++ b/src/test/resources/missing-values.csv @@ -0,0 +1,5 @@ +year,make,model,comment,blank +"2012","Tesla","S","No comment", +1997,Ford,E350,"Go get one now they are going fast", +2015,Chevy,Volt +NA,NULL,"T","Comment" diff --git a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala index e2bab82..3b6f440 100644 --- a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala @@ -32,6 +32,7 @@ class CsvFastSuite extends FunSuite { val carsAltFile = "src/test/resources/cars-alternative.csv" val emptyFile = "src/test/resources/empty.csv" val escapeFile = "src/test/resources/escape.csv" + val carsWithNAs = "src/test/resources/missing-values.csv" val tempEmptyDir = "target/test/empty2/" val numCars = 3 @@ -93,6 +94,25 @@ class CsvFastSuite extends FunSuite { assert(results.size === numCars - 1) } + test("DSL test for handling NULL values") { + val results = new CsvParser() + .withUseHeader(true) + .withParserLib("univocity") + .withNullValues(Seq("NULL", "NA")) + .csvFile(TestSQLContext, carsWithNAs) + .collect() + + assert(results.size === numCars + 1) + + val results2 = new CsvParser() + .withUseHeader(true) + .withNullValues(Seq("NULL", "NA", "NaN")) + .csvFile(TestSQLContext, carsWithNAs) + .collect() + + assert(results2.size === numCars + 1) + } + test("DSL test for FAILFAST parsing mode") { val parser = new CsvParser() .withParseMode("FAILFAST")