diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index 3e28218..2fb5882 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -17,7 +17,7 @@ package com.databricks.spark.csv import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import com.databricks.spark.csv.util.{ParserLibs, ParseModes, TextFile} /** @@ -36,6 +36,7 @@ class CsvParser { private var parserLib: String = ParserLibs.DEFAULT private var charset: String = TextFile.DEFAULT_CHARSET.name() private var inferSchema: Boolean = false + private var columnsTypeMap: Map[String, DataType] = Map.empty def withUseHeader(flag: Boolean): CsvParser = { this.useHeader = flag @@ -82,6 +83,11 @@ class CsvParser { this } + def withTypedFields(columnsTypeMap: Map[String, DataType]): CsvParser = { + this.columnsTypeMap = columnsTypeMap + this + } + def withCharset(charset: String): CsvParser = { this.charset = charset this @@ -107,7 +113,8 @@ class CsvParser { ignoreTrailingWhiteSpace, schema, charset, - inferSchema)(sqlContext) + inferSchema, + columnsTypeMap)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 12d1400..e554be7 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -27,6 +27,8 @@ import org.slf4j.LoggerFactory import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan} +import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType} +import com.databricks.spark.csv.util.{ParserLibs, ParseModes, TypeCast} import org.apache.spark.sql.types._ import com.databricks.spark.csv.util._ import com.databricks.spark.sql.readers._ @@ -43,7 +45,8 @@ case class CsvRelation protected[spark] ( ignoreTrailingWhiteSpace: Boolean, userSchema: StructType = null, charset: String = TextFile.DEFAULT_CHARSET.name(), - inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext) + inferCsvSchema: Boolean, + columnsTypeMap: Map[String, DataType] = Map.empty)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with InsertableRelation { private val logger = LoggerFactory.getLogger(CsvRelation.getClass) @@ -148,7 +151,7 @@ case class CsvRelation protected[spark] ( } else{ // By default fields are assumed to be StringType val schemaFields = header.map { fieldName => - StructField(fieldName.toString, StringType, nullable = true) + StructField(fieldName.toString, columnsTypeMap.getOrElse(fieldName, StringType), nullable = true) } StructType(schemaFields) } diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index 87965bd..8067e11 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -17,6 +17,7 @@ package com.databricks.spark import org.apache.commons.csv.CSVFormat import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{DataFrame, SQLContext} import com.databricks.spark.csv.util.TextFile @@ -27,7 +28,9 @@ package object csv { * Adds a method, `csvFile`, to SQLContext that allows reading CSV data. */ implicit class CsvContext(sqlContext: SQLContext) { + def csvFile(filePath: String, + columnsTypeMap: Map[String, DataType] = Map.empty, useHeader: Boolean = true, delimiter: Char = ',', quote: Char = '"', @@ -49,7 +52,8 @@ package object csv { ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace, charset = charset, - inferCsvSchema = inferSchema)(sqlContext) + inferCsvSchema = inferSchema, + columnsTypeMap = columnsTypeMap)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } diff --git a/src/test/resources/cars-typed-fail.csv b/src/test/resources/cars-typed-fail.csv new file mode 100644 index 0000000..da00696 --- /dev/null +++ b/src/test/resources/cars-typed-fail.csv @@ -0,0 +1,6 @@ +year,make,model,comment,price,new,blank +"2012","Tesla","S","No comment",900X00.00,false + +1997,Ford,E350,"Go get one now they are going fast",23000,true, +2015,Chevy,Volt,,40000X.6767,false + diff --git a/src/test/resources/cars-typed.csv b/src/test/resources/cars-typed.csv new file mode 100644 index 0000000..516a3a5 --- /dev/null +++ b/src/test/resources/cars-typed.csv @@ -0,0 +1,6 @@ +year,make,model,comment,price,new,blank +"2012","Tesla","S","No comment",90000.00,false + +1997,Ford,E350,"Go get one now they are going fast",23000,true, +2015,Chevy,Volt,,40000.6767,false + diff --git a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala index eeea378..804a7f6 100644 --- a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala @@ -360,6 +360,23 @@ class CsvFastSuite extends FunSuite { } + test("DSL test custom schema") { + + val results = TestSQLContext + .csvFile(carsFile, columnsTypeMap = Map("year" -> IntegerType)) + + assert(results.schema == StructType(List( + StructField("year", IntegerType, true), + StructField("make", StringType, true), + StructField("model", StringType, true), + StructField("comment", StringType, true), + StructField("blank", StringType, true)) + )) + + assert(results.collect().size === numCars) + + } + test("DSL test inferred schema passed through") { val dataFrame = TestSQLContext diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 0acb654..bff3de4 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -29,6 +29,8 @@ import TestSQLContext._ class CsvSuite extends FunSuite { val carsFile = "src/test/resources/cars.csv" + val carsTypedColumnsFile = "src/test/resources/cars-typed.csv" + val carsTypedColumnsFailFile = "src/test/resources/cars-typed-fail.csv" val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv" val carsTsvFile = "src/test/resources/cars.tsv" val carsAltFile = "src/test/resources/cars-alternative.csv" @@ -159,6 +161,37 @@ class CsvSuite extends FunSuite { assert(results.size === numCars) } + test("DSL test typed columns using sparkContext.csvFile") { + val typedColumnsMap = Map("price" -> DoubleType, "new" -> BooleanType) + + val expectedColumnNamesAndTheirTypes = + Array("year" -> StringType.toString, + "make" -> StringType.toString, + "model" -> StringType.toString, + "comment" -> StringType.toString, + "price" -> DoubleType.toString, + "new" -> BooleanType.toString, + "blank" -> StringType.toString) + + val results = TestSQLContext.csvFile(carsTypedColumnsFile, columnsTypeMap = typedColumnsMap) + assume(results.dtypes containsSlice expectedColumnNamesAndTheirTypes) + } + + test("DSL test typed values using sparkContext.csvFile") { + val typedColumnsMap = Map("price" -> DoubleType, "new" -> BooleanType) + + val results = TestSQLContext.csvFile(carsTypedColumnsFile, columnsTypeMap = typedColumnsMap) + assert(results.collect().map(_.getDouble(4)) === Seq(90000.00d, 23000d, 40000.6767d)) + assert(results.collect().map(_.getBoolean(5)) === Seq(false, true, false)) + } + + test("Expect parsing error with wrong type for FailFast mode using sparkContext.csvFile") { + val typedColumnsMap = Map("price" -> DoubleType, "new" -> BooleanType) + + intercept[SparkException] { + TestSQLContext.csvFile(carsTypedColumnsFailFile, columnsTypeMap = typedColumnsMap, mode = "FAILFAST").collect() + } + } test("Expect parsing error with wrong delimiter setting using sparkContext.csvFile") { intercept[ org.apache.spark.sql.AnalysisException] {