Skip to content

Commit 98997d7

Browse files
author
Kostyantyn Spitsyn
committed
#60 provide type for custom columns
1 parent 451cceb commit 98997d7

File tree

7 files changed

+81
-5
lines changed

7 files changed

+81
-5
lines changed

src/main/scala/com/databricks/spark/csv/CsvParser.scala

+9-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ package com.databricks.spark.csv
1717

1818

1919
import org.apache.spark.sql.{DataFrame, SQLContext}
20-
import org.apache.spark.sql.types.StructType
20+
import org.apache.spark.sql.types.{DataType, StructType}
2121
import com.databricks.spark.csv.util.{ParserLibs, ParseModes, TextFile}
2222

2323
/**
@@ -36,6 +36,7 @@ class CsvParser {
3636
private var parserLib: String = ParserLibs.DEFAULT
3737
private var charset: String = TextFile.DEFAULT_CHARSET.name()
3838
private var inferSchema: Boolean = false
39+
private var columnsTypeMap: Map[String, DataType] = Map.empty
3940

4041
def withUseHeader(flag: Boolean): CsvParser = {
4142
this.useHeader = flag
@@ -82,6 +83,11 @@ class CsvParser {
8283
this
8384
}
8485

86+
def withTypedFields(columnsTypeMap: Map[String, DataType]): CsvParser = {
87+
this.columnsTypeMap = columnsTypeMap
88+
this
89+
}
90+
8591
def withCharset(charset: String): CsvParser = {
8692
this.charset = charset
8793
this
@@ -107,7 +113,8 @@ class CsvParser {
107113
ignoreTrailingWhiteSpace,
108114
schema,
109115
charset,
110-
inferSchema)(sqlContext)
116+
inferSchema,
117+
columnsTypeMap)(sqlContext)
111118
sqlContext.baseRelationToDataFrame(relation)
112119
}
113120

src/main/scala/com/databricks/spark/csv/CsvRelation.scala

+5-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import org.slf4j.LoggerFactory
2727
import org.apache.spark.rdd.RDD
2828
import org.apache.spark.sql._
2929
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
30+
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}
31+
import com.databricks.spark.csv.util.{ParserLibs, ParseModes, TypeCast}
3032
import org.apache.spark.sql.types._
3133
import com.databricks.spark.csv.util._
3234
import com.databricks.spark.sql.readers._
@@ -43,7 +45,8 @@ case class CsvRelation protected[spark] (
4345
ignoreTrailingWhiteSpace: Boolean,
4446
userSchema: StructType = null,
4547
charset: String = TextFile.DEFAULT_CHARSET.name(),
46-
inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext)
48+
inferCsvSchema: Boolean,
49+
columnsTypeMap: Map[String, DataType] = Map.empty)(@transient val sqlContext: SQLContext)
4750
extends BaseRelation with TableScan with InsertableRelation {
4851

4952
private val logger = LoggerFactory.getLogger(CsvRelation.getClass)
@@ -148,7 +151,7 @@ case class CsvRelation protected[spark] (
148151
} else{
149152
// By default fields are assumed to be StringType
150153
val schemaFields = header.map { fieldName =>
151-
StructField(fieldName.toString, StringType, nullable = true)
154+
StructField(fieldName.toString, columnsTypeMap.getOrElse(fieldName, StringType), nullable = true)
152155
}
153156
StructType(schemaFields)
154157
}

src/main/scala/com/databricks/spark/csv/package.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package com.databricks.spark
1717

1818
import org.apache.commons.csv.CSVFormat
1919
import org.apache.hadoop.io.compress.CompressionCodec
20+
import org.apache.spark.sql.types.DataType
2021

2122
import org.apache.spark.sql.{DataFrame, SQLContext}
2223
import com.databricks.spark.csv.util.TextFile
@@ -27,7 +28,9 @@ package object csv {
2728
* Adds a method, `csvFile`, to SQLContext that allows reading CSV data.
2829
*/
2930
implicit class CsvContext(sqlContext: SQLContext) {
31+
3032
def csvFile(filePath: String,
33+
columnsTypeMap: Map[String, DataType] = Map.empty,
3134
useHeader: Boolean = true,
3235
delimiter: Char = ',',
3336
quote: Char = '"',
@@ -49,7 +52,8 @@ package object csv {
4952
ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
5053
ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
5154
charset = charset,
52-
inferCsvSchema = inferSchema)(sqlContext)
55+
inferCsvSchema = inferSchema,
56+
columnsTypeMap = columnsTypeMap)(sqlContext)
5357
sqlContext.baseRelationToDataFrame(csvRelation)
5458
}
5559

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
year,make,model,comment,price,new,blank
2+
"2012","Tesla","S","No comment",900X00.00,false
3+
4+
1997,Ford,E350,"Go get one now they are going fast",23000,true,
5+
2015,Chevy,Volt,,40000X.6767,false
6+

src/test/resources/cars-typed.csv

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
year,make,model,comment,price,new,blank
2+
"2012","Tesla","S","No comment",90000.00,false
3+
4+
1997,Ford,E350,"Go get one now they are going fast",23000,true,
5+
2015,Chevy,Volt,,40000.6767,false
6+

src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala

+17
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,23 @@ class CsvFastSuite extends FunSuite {
360360

361361
}
362362

363+
test("DSL test custom schema") {
364+
365+
val results = TestSQLContext
366+
.csvFile(carsFile, columnsTypeMap = Map("year" -> IntegerType))
367+
368+
assert(results.schema == StructType(List(
369+
StructField("year", IntegerType, true),
370+
StructField("make", StringType, true),
371+
StructField("model", StringType, true),
372+
StructField("comment", StringType, true),
373+
StructField("blank", StringType, true))
374+
))
375+
376+
assert(results.collect().size === numCars)
377+
378+
}
379+
363380
test("DSL test inferred schema passed through") {
364381

365382
val dataFrame = TestSQLContext

src/test/scala/com/databricks/spark/csv/CsvSuite.scala

+33
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import TestSQLContext._
2929

3030
class CsvSuite extends FunSuite {
3131
val carsFile = "src/test/resources/cars.csv"
32+
val carsTypedColumnsFile = "src/test/resources/cars-typed.csv"
33+
val carsTypedColumnsFailFile = "src/test/resources/cars-typed-fail.csv"
3234
val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv"
3335
val carsTsvFile = "src/test/resources/cars.tsv"
3436
val carsAltFile = "src/test/resources/cars-alternative.csv"
@@ -159,6 +161,37 @@ class CsvSuite extends FunSuite {
159161
assert(results.size === numCars)
160162
}
161163

164+
test("DSL test typed columns using sparkContext.csvFile") {
165+
val typedColumnsMap = Map("price" -> DoubleType, "new" -> BooleanType)
166+
167+
val expectedColumnNamesAndTheirTypes =
168+
Array("year" -> StringType.toString,
169+
"make" -> StringType.toString,
170+
"model" -> StringType.toString,
171+
"comment" -> StringType.toString,
172+
"price" -> DoubleType.toString,
173+
"new" -> BooleanType.toString,
174+
"blank" -> StringType.toString)
175+
176+
val results = TestSQLContext.csvFile(carsTypedColumnsFile, columnsTypeMap = typedColumnsMap)
177+
assume(results.dtypes containsSlice expectedColumnNamesAndTheirTypes)
178+
}
179+
180+
test("DSL test typed values using sparkContext.csvFile") {
181+
val typedColumnsMap = Map("price" -> DoubleType, "new" -> BooleanType)
182+
183+
val results = TestSQLContext.csvFile(carsTypedColumnsFile, columnsTypeMap = typedColumnsMap)
184+
assert(results.collect().map(_.getDouble(4)) === Seq(90000.00d, 23000d, 40000.6767d))
185+
assert(results.collect().map(_.getBoolean(5)) === Seq(false, true, false))
186+
}
187+
188+
test("Expect parsing error with wrong type for FailFast mode using sparkContext.csvFile") {
189+
val typedColumnsMap = Map("price" -> DoubleType, "new" -> BooleanType)
190+
191+
intercept[SparkException] {
192+
TestSQLContext.csvFile(carsTypedColumnsFailFile, columnsTypeMap = typedColumnsMap, mode = "FAILFAST").collect()
193+
}
194+
}
162195

163196
test("Expect parsing error with wrong delimiter setting using sparkContext.csvFile") {
164197
intercept[ org.apache.spark.sql.AnalysisException] {

0 commit comments

Comments
 (0)