diff --git a/README.md b/README.md index e2aba30..e21752a 100755 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ When reading files the API accepts several options: * `inferSchema`: automatically infers column types. It requires one extra pass over the data and is false by default * `comment`: skip lines beginning with this character. Default is `"#"`. Disable comments by setting this to `null`. * `codec`: compression codec to use when saving to file. Should be the fully qualified name of a class implementing `org.apache.hadoop.io.compress.CompressionCodec`. Defaults to no compression when a codec is not specified. +* `nullValue`: specificy a string that indicates a null value, any fields matching this string will be set as nulls in the DataFrame The package also support saving simple (non-nested) DataFrame. When saving you can specify the delimiter and whether we should generate a header row for the table. See following examples for more details. @@ -109,7 +110,7 @@ import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerT val sqlContext = new SQLContext(sc) val customSchema = StructType( - StructField("year", IntegerType, true), + StructField("year", IntegerType, true), StructField("make", StringType, true), StructField("model", StringType, true), StructField("comment", StringType, true), @@ -155,7 +156,7 @@ import org.apache.spark.sql.SQLContext val sqlContext = new SQLContext(sc) val df = sqlContext.load( - "com.databricks.spark.csv", + "com.databricks.spark.csv", Map("path" -> "cars.csv", "header" -> "true", "inferSchema" -> "true")) val selectedData = df.select("year", "model") selectedData.save("newcars.csv", "com.databricks.spark.csv") @@ -168,14 +169,14 @@ import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerT val sqlContext = new SQLContext(sc) val customSchema = StructType( - StructField("year", IntegerType, true), + StructField("year", IntegerType, true), StructField("make", StringType, true), StructField("model", StringType, true), StructField("comment", StringType, true), StructField("blank", StringType, true)) val df = sqlContext.load( - "com.databricks.spark.csv", + "com.databricks.spark.csv", schema = customSchema, Map("path" -> "cars.csv", "header" -> "true")) @@ -210,7 +211,7 @@ import org.apache.spark.sql.types.*; SQLContext sqlContext = new SQLContext(sc); StructType customSchema = new StructType(new StructField[] { - new StructField("year", DataTypes.IntegerType, true, Metadata.empty()), + new StructField("year", DataTypes.IntegerType, true, Metadata.empty()), new StructField("make", DataTypes.StringType, true, Metadata.empty()), new StructField("model", DataTypes.StringType, true, Metadata.empty()), new StructField("comment", DataTypes.StringType, true, Metadata.empty()), diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index 0a2f914..370ab3c 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -40,6 +40,7 @@ class CsvParser extends Serializable { private var charset: String = TextFile.DEFAULT_CHARSET.name() private var inferSchema: Boolean = false private var codec: String = null + private var nullValue: String = "" def withUseHeader(flag: Boolean): CsvParser = { this.useHeader = flag @@ -111,6 +112,11 @@ class CsvParser extends Serializable { this } + def withNullValue(nullValue: String): CsvParser = { + this.nullValue = nullValue + this + } + /** Returns a Schema RDD for the given CSV path. */ @throws[RuntimeException] def csvFile(sqlContext: SQLContext, path: String): DataFrame = { @@ -129,7 +135,8 @@ class CsvParser extends Serializable { treatEmptyValuesAsNulls, schema, inferSchema, - codec)(sqlContext) + codec, + nullValue)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } @@ -149,7 +156,8 @@ class CsvParser extends Serializable { treatEmptyValuesAsNulls, schema, inferSchema, - codec)(sqlContext) + codec, + nullValue)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } } diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index dcab9c8..5a09176 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -46,7 +46,8 @@ case class CsvRelation protected[spark] ( treatEmptyValuesAsNulls: Boolean, userSchema: StructType = null, inferCsvSchema: Boolean, - codec: String = null)(@transient val sqlContext: SQLContext) + codec: String = null, + nullValue: String = "")(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with PrunedScan with InsertableRelation { /** @@ -116,7 +117,7 @@ case class CsvRelation protected[spark] ( while (index < schemaFields.length) { val field = schemaFields(index) rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable, - treatEmptyValuesAsNulls) + treatEmptyValuesAsNulls, nullValue) index = index + 1 } Some(Row.fromSeq(rowArray)) @@ -189,7 +190,9 @@ case class CsvRelation protected[spark] ( indexSafeTokens(index), field.dataType, field.nullable, - treatEmptyValuesAsNulls) + treatEmptyValuesAsNulls, + nullValue + ) subIndex = subIndex + 1 } Some(Row.fromSeq(rowArray.take(requiredSize))) @@ -235,7 +238,7 @@ case class CsvRelation protected[spark] ( firstRow.zipWithIndex.map { case (value, index) => s"C$index"} } if (this.inferCsvSchema) { - InferSchema(tokenRdd(header), header) + InferSchema(tokenRdd(header), header, nullValue) } else { // By default fields are assumed to be StringType val schemaFields = header.map { fieldName => diff --git a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala index 13abf04..c2e1481 100755 --- a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala @@ -136,6 +136,7 @@ class DefaultSource } else { throw new Exception("Infer schema flag can be true or false") } + val nullValue = parameters.getOrElse("nullValue", "") val codec = parameters.getOrElse("codec", null) @@ -154,7 +155,8 @@ class DefaultSource treatEmptyValuesAsNullsFlag, schema, inferSchemaFlag, - codec)(sqlContext) + codec, + nullValue)(sqlContext) } override def createRelation( diff --git a/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala b/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala index 0fef690..cc88f51 100644 --- a/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala +++ b/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala @@ -30,10 +30,15 @@ private[csv] object InferSchema { * 2. Merge row types to find common type * 3. Replace any null types with string type */ - def apply(tokenRdd: RDD[Array[String]], header: Array[String]): StructType = { + def apply( + tokenRdd: RDD[Array[String]], + header: Array[String], + nullValue: String = ""): StructType = { val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) - val rootTypes: Array[DataType] = tokenRdd.aggregate(startType)(inferRowType, mergeRowTypes) + val rootTypes: Array[DataType] = tokenRdd.aggregate(startType)( + inferRowType(nullValue), + mergeRowTypes) val stuctFields = header.zip(rootTypes).map { case (thisHeader, rootType) => StructField(thisHeader, rootType, nullable = true) @@ -42,10 +47,11 @@ private[csv] object InferSchema { StructType(stuctFields) } - private def inferRowType(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { + private def inferRowType(nullValue: String) + (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. - rowSoFar(i) = inferField(rowSoFar(i), next(i)) + rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue) i+=1 } rowSoFar @@ -67,8 +73,10 @@ private[csv] object InferSchema { * Infer type of string field. Given known type Double, and a string "1", there is no * point checking if it is an Int, as the final type must be Double or higher. */ - private[csv] def inferField(typeSoFar: DataType, field: String): DataType = { - if (field == null || field.isEmpty) { + private[csv] def inferField(typeSoFar: DataType, + field: String, + nullValue: String = ""): DataType = { + if (field == null || field.isEmpty || field == nullValue) { typeSoFar } else { typeSoFar match { diff --git a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala index 226eafd..edecf97 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala @@ -43,8 +43,15 @@ object TypeCast { datum: String, castType: DataType, nullable: Boolean = true, - treatEmptyValuesAsNulls: Boolean = false): Any = { - if (datum == "" && nullable && (!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls)){ + treatEmptyValuesAsNulls: Boolean = false, + nullValue: String = ""): Any = { + // if nullValue is not an empty string, don't require treatEmptyValuesAsNulls + // to be set to true + val nullValueIsNotEmpty = nullValue != "" + if (datum == nullValue && + nullable && + (!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls || nullValueIsNotEmpty) + ){ null } else { castType match { diff --git a/src/test/resources/null_null_numbers.csv b/src/test/resources/null_null_numbers.csv new file mode 100644 index 0000000..d020d9f --- /dev/null +++ b/src/test/resources/null_null_numbers.csv @@ -0,0 +1,4 @@ +name,age +alice,35 +bob,null +null,24 diff --git a/src/test/resources/null_slashn_numbers.csv b/src/test/resources/null_slashn_numbers.csv new file mode 100644 index 0000000..4068ca8 --- /dev/null +++ b/src/test/resources/null_slashn_numbers.csv @@ -0,0 +1,4 @@ +name,age +alice,35 +bob,\N +\N,24 diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 9acc7f3..6fddab5 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -33,6 +33,8 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { val carsAltFile = "src/test/resources/cars-alternative.csv" val carsUnbalancedQuotesFile = "src/test/resources/cars-unbalanced-quotes.csv" val nullNumbersFile = "src/test/resources/null-numbers.csv" + val nullNullNumbersFile = "src/test/resources/null_null_numbers.csv" + val nullSlashNNumbersFile = "src/test/resources/null_slashn_numbers.csv" val emptyFile = "src/test/resources/empty.csv" val ageFile = "src/test/resources/ages.csv" val escapeFile = "src/test/resources/escape.csv" @@ -572,6 +574,36 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(results(2).toSeq === Seq("", 24)) } + test("DSL test nullable fields with user defined null value of \"null\"") { + val results = new CsvParser() + .withSchema(StructType(List(StructField("name", StringType, false), + StructField("age", IntegerType, true)))) + .withUseHeader(true) + .withParserLib(parserLib) + .withNullValue("null") + .csvFile(sqlContext, nullNullNumbersFile) + .collect() + + assert(results.head.toSeq === Seq("alice", 35)) + assert(results(1).toSeq === Seq("bob", null)) + assert(results(2).toSeq === Seq("null", 24)) + } + + test("DSL test nullable fields with user defined null value of \"\\N\"") { + val results = new CsvParser() + .withSchema(StructType(List(StructField("name", StringType, false), + StructField("age", IntegerType, true)))) + .withUseHeader(true) + .withParserLib(parserLib) + .withNullValue("\\N") + .csvFile(sqlContext, nullSlashNNumbersFile) + .collect() + + assert(results.head.toSeq === Seq("alice", 35)) + assert(results(1).toSeq === Seq("bob", null)) + assert(results(2).toSeq === Seq("\\N", 24)) + } + test("Commented lines in CSV data") { val results: Array[Row] = new CsvParser() .withDelimiter(',') diff --git a/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala b/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala index 1d43c06..d713649 100644 --- a/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala @@ -15,6 +15,15 @@ class InferSchemaSuite extends FunSuite { assert(InferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) } + test("Null fields are handled properly when a nullValue is specified") { + assert(InferSchema.inferField(NullType, "null", "null") == NullType) + assert(InferSchema.inferField(StringType, "null", "null") == StringType) + assert(InferSchema.inferField(LongType, "null", "null") == LongType) + assert(InferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType) + assert(InferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) + assert(InferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) + } + test("String fields types are inferred correctly from other types") { assert(InferSchema.inferField(LongType, "1.0") == DoubleType) assert(InferSchema.inferField(LongType, "test") == StringType) diff --git a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala index f2e93fc..b8e6e71 100644 --- a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala @@ -94,4 +94,13 @@ class TypeCastSuite extends FunSuite { assert(TypeCast.castTo("1,00", FloatType) == 1.0) assert(TypeCast.castTo("1,00", DoubleType) == 1.0) } + + test("Can handle mapping user specified nullValues") { + assert(TypeCast.castTo("null", StringType, true, false, "null") == null) + assert(TypeCast.castTo("\\N", ByteType, true, false, "\\N") == null) + assert(TypeCast.castTo("", ShortType, true, false) == null) + assert(TypeCast.castTo("null", StringType, true, true, "null") == null) + assert(TypeCast.castTo("", StringType, true, false, "") == "") + assert(TypeCast.castTo("", StringType, true, true, "") == null) + } }