Skip to content

#60 provide type for custom columns #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.databricks.spark.csv


import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}
import com.databricks.spark.csv.util.{ParserLibs, ParseModes, TextFile}

/**
Expand All @@ -36,6 +36,7 @@ class CsvParser {
private var parserLib: String = ParserLibs.DEFAULT
private var charset: String = TextFile.DEFAULT_CHARSET.name()
private var inferSchema: Boolean = false
private var columnsTypeMap: Map[String, DataType] = Map.empty

def withUseHeader(flag: Boolean): CsvParser = {
this.useHeader = flag
Expand Down Expand Up @@ -82,6 +83,11 @@ class CsvParser {
this
}

def withTypedFields(columnsTypeMap: Map[String, DataType]): CsvParser = {
this.columnsTypeMap = columnsTypeMap
this
}

def withCharset(charset: String): CsvParser = {
this.charset = charset
this
Expand All @@ -107,7 +113,8 @@ class CsvParser {
ignoreTrailingWhiteSpace,
schema,
charset,
inferSchema)(sqlContext)
inferSchema,
columnsTypeMap)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}

Expand Down
7 changes: 5 additions & 2 deletions src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.slf4j.LoggerFactory
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}
import com.databricks.spark.csv.util.{ParserLibs, ParseModes, TypeCast}
import org.apache.spark.sql.types._
import com.databricks.spark.csv.util._
import com.databricks.spark.sql.readers._
Expand All @@ -43,7 +45,8 @@ case class CsvRelation protected[spark] (
ignoreTrailingWhiteSpace: Boolean,
userSchema: StructType = null,
charset: String = TextFile.DEFAULT_CHARSET.name(),
inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext)
inferCsvSchema: Boolean,
columnsTypeMap: Map[String, DataType] = Map.empty)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with InsertableRelation {

private val logger = LoggerFactory.getLogger(CsvRelation.getClass)
Expand Down Expand Up @@ -148,7 +151,7 @@ case class CsvRelation protected[spark] (
} else{
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
StructField(fieldName.toString, StringType, nullable = true)
StructField(fieldName.toString, columnsTypeMap.getOrElse(fieldName, StringType), nullable = true)
}
StructType(schemaFields)
}
Expand Down
6 changes: 5 additions & 1 deletion src/main/scala/com/databricks/spark/csv/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package com.databricks.spark

import org.apache.commons.csv.CSVFormat
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.spark.sql.types.DataType

import org.apache.spark.sql.{DataFrame, SQLContext}
import com.databricks.spark.csv.util.TextFile
Expand All @@ -27,7 +28,9 @@ package object csv {
* Adds a method, `csvFile`, to SQLContext that allows reading CSV data.
*/
implicit class CsvContext(sqlContext: SQLContext) {

def csvFile(filePath: String,
columnsTypeMap: Map[String, DataType] = Map.empty,
useHeader: Boolean = true,
delimiter: Char = ',',
quote: Char = '"',
Expand All @@ -49,7 +52,8 @@ package object csv {
ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
charset = charset,
inferCsvSchema = inferSchema)(sqlContext)
inferCsvSchema = inferSchema,
columnsTypeMap = columnsTypeMap)(sqlContext)
sqlContext.baseRelationToDataFrame(csvRelation)
}

Expand Down
6 changes: 6 additions & 0 deletions src/test/resources/cars-typed-fail.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
year,make,model,comment,price,new,blank
"2012","Tesla","S","No comment",900X00.00,false

1997,Ford,E350,"Go get one now they are going fast",23000,true,
2015,Chevy,Volt,,40000X.6767,false

6 changes: 6 additions & 0 deletions src/test/resources/cars-typed.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
year,make,model,comment,price,new,blank
"2012","Tesla","S","No comment",90000.00,false

1997,Ford,E350,"Go get one now they are going fast",23000,true,
2015,Chevy,Volt,,40000.6767,false

17 changes: 17 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,23 @@ class CsvFastSuite extends FunSuite {

}

test("DSL test custom schema") {

val results = TestSQLContext
.csvFile(carsFile, columnsTypeMap = Map("year" -> IntegerType))

assert(results.schema == StructType(List(
StructField("year", IntegerType, true),
StructField("make", StringType, true),
StructField("model", StringType, true),
StructField("comment", StringType, true),
StructField("blank", StringType, true))
))

assert(results.collect().size === numCars)

}

test("DSL test inferred schema passed through") {

val dataFrame = TestSQLContext
Expand Down
33 changes: 33 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import TestSQLContext._

class CsvSuite extends FunSuite {
val carsFile = "src/test/resources/cars.csv"
val carsTypedColumnsFile = "src/test/resources/cars-typed.csv"
val carsTypedColumnsFailFile = "src/test/resources/cars-typed-fail.csv"
val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv"
val carsTsvFile = "src/test/resources/cars.tsv"
val carsAltFile = "src/test/resources/cars-alternative.csv"
Expand Down Expand Up @@ -159,6 +161,37 @@ class CsvSuite extends FunSuite {
assert(results.size === numCars)
}

test("DSL test typed columns using sparkContext.csvFile") {
val typedColumnsMap = Map("price" -> DoubleType, "new" -> BooleanType)

val expectedColumnNamesAndTheirTypes =
Array("year" -> StringType.toString,
"make" -> StringType.toString,
"model" -> StringType.toString,
"comment" -> StringType.toString,
"price" -> DoubleType.toString,
"new" -> BooleanType.toString,
"blank" -> StringType.toString)

val results = TestSQLContext.csvFile(carsTypedColumnsFile, columnsTypeMap = typedColumnsMap)
assume(results.dtypes containsSlice expectedColumnNamesAndTheirTypes)
}

test("DSL test typed values using sparkContext.csvFile") {
val typedColumnsMap = Map("price" -> DoubleType, "new" -> BooleanType)

val results = TestSQLContext.csvFile(carsTypedColumnsFile, columnsTypeMap = typedColumnsMap)
assert(results.collect().map(_.getDouble(4)) === Seq(90000.00d, 23000d, 40000.6767d))
assert(results.collect().map(_.getBoolean(5)) === Seq(false, true, false))
}

test("Expect parsing error with wrong type for FailFast mode using sparkContext.csvFile") {
val typedColumnsMap = Map("price" -> DoubleType, "new" -> BooleanType)

intercept[SparkException] {
TestSQLContext.csvFile(carsTypedColumnsFailFile, columnsTypeMap = typedColumnsMap, mode = "FAILFAST").collect()
}
}

test("Expect parsing error with wrong delimiter setting using sparkContext.csvFile") {
intercept[ org.apache.spark.sql.AnalysisException] {
Expand Down