Skip to content

parsing options and serializing arrays #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pomExtra := (

spName := "databricks/spark-csv"

sparkVersion := "1.4.0"
sparkVersion := "1.4.1"

sparkComponents += "sql"

Expand Down
2 changes: 2 additions & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositori

resolvers += "Spark Package Main Repo" at "https://dl.bintray.com/spark-packages/maven"

resolvers += Resolver.url("scoverage-bintray", url("https://dl.bintray.com/sksamuel/sbt-plugins/"))(Resolver.ivyStylePatterns)

addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0")

addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0")
Expand Down
71 changes: 53 additions & 18 deletions src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,24 @@
package com.databricks.spark.csv


import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext}

import com.databricks.spark.csv.util.{ParserLibs, ParseModes, TextFile}

/**
* A collection of static functions for working with CSV files in Spark SQL
*/
class CsvParser {

private var useHeader: Boolean = false
private var delimiter: Character = ','
private var quote: Character = '"'
private var escape: Character = null
private var csvParsingOpts: CSVParsingOpts = CSVParsingOpts()
private var lineParsingOpts: LineParsingOpts = LineParsingOpts()
private var realNumberParsingOpts: RealNumberParsingOpts = RealNumberParsingOpts()
private var intNumberParsingOpts: IntNumberParsingOpts = IntNumberParsingOpts()
private var stringParsingOpts: StringParsingOpts = StringParsingOpts()
private var comment: Character = '#'
private var schema: StructType = null
private var parseMode: String = ParseModes.DEFAULT
private var ignoreLeadingWhiteSpace: Boolean = false
private var ignoreTrailingWhiteSpace: Boolean = false
private var parserLib: String = ParserLibs.DEFAULT
private var charset: String = TextFile.DEFAULT_CHARSET.name()
private var inferSchema: Boolean = false
Expand All @@ -44,12 +44,12 @@ class CsvParser {
}

def withDelimiter(delimiter: Character): CsvParser = {
this.delimiter = delimiter
this.csvParsingOpts.delimiter = delimiter
this
}

def withQuoteChar(quote: Character): CsvParser = {
this.quote = quote
this.csvParsingOpts.quoteChar = quote
this
}

Expand All @@ -64,7 +64,7 @@ class CsvParser {
}

def withEscape(escapeChar: Character): CsvParser = {
this.escape = escapeChar
this.csvParsingOpts.escapeChar = escapeChar
this
}

Expand All @@ -74,12 +74,12 @@ class CsvParser {
}

def withIgnoreLeadingWhiteSpace(ignore: Boolean): CsvParser = {
this.ignoreLeadingWhiteSpace = ignore
this.csvParsingOpts.ignoreLeadingWhitespace = ignore
this
}

def withIgnoreTrailingWhiteSpace(ignore: Boolean): CsvParser = {
this.ignoreTrailingWhiteSpace = ignore
this.csvParsingOpts.ignoreTrailingWhitespace = ignore
this
}

Expand All @@ -88,6 +88,41 @@ class CsvParser {
this
}

def withCsvParsingOpts(csvParsingOpts: CSVParsingOpts) = {
this.csvParsingOpts = csvParsingOpts
this
}

def withLineParsingOpts(lineParsingOpts: LineParsingOpts) = {
this.lineParsingOpts = lineParsingOpts
this
}

def withRealNumberParsingOpts(numberParsingOpts: RealNumberParsingOpts) = {
this.realNumberParsingOpts = numberParsingOpts
this
}

def withIntNumberParsingOpts(numberParsingOpts: IntNumberParsingOpts) = {
this.intNumberParsingOpts = numberParsingOpts
this
}


def withStringParsingOpts(stringParsingOpts: StringParsingOpts) = {
this.stringParsingOpts = stringParsingOpts
this
}

def withOpts(optMap: Map[String, String]) = {
this.stringParsingOpts = StringParsingOpts(optMap)
this.lineParsingOpts = LineParsingOpts(optMap)
this.realNumberParsingOpts = RealNumberParsingOpts(optMap)
this.intNumberParsingOpts = IntNumberParsingOpts(optMap)
this.csvParsingOpts = CSVParsingOpts(optMap)
this
}

def withCharset(charset: String): CsvParser = {
this.charset = charset
this
Expand All @@ -104,15 +139,15 @@ class CsvParser {
val relation: CsvRelation = CsvRelation(
path,
useHeader,
delimiter,
quote,
escape,
comment,
csvParsingOpts,
parseMode,
parserLib,
ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace,
schema,
comment,
lineParsingOpts,
realNumberParsingOpts,
intNumberParsingOpts,
stringParsingOpts,
charset,
inferSchema)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
Expand Down
148 changes: 97 additions & 51 deletions src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@ import org.apache.spark.sql.types._
import com.databricks.spark.csv.util._
import com.databricks.spark.sql.readers._

case class CsvRelation protected[spark] (
location: String,
useHeader: Boolean,
delimiter: Char,
quote: Char,
escape: Character,
comment: Character,
parseMode: String,
parserLib: String,
ignoreLeadingWhiteSpace: Boolean,
ignoreTrailingWhiteSpace: Boolean,
userSchema: StructType = null,
charset: String = TextFile.DEFAULT_CHARSET.name(),
inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext)
case class CsvRelation protected[spark](
location: String,
useHeader: Boolean,
csvParsingOpts: CSVParsingOpts,
parseMode: String,
parserLib: String,
userSchema: StructType = null,
comment: Character,
lineExceptionPolicy: LineParsingOpts = LineParsingOpts(),
realNumOpts: RealNumberParsingOpts = RealNumberParsingOpts(),
intNumOpts: IntNumberParsingOpts = IntNumberParsingOpts(),
stringParsingOpts: StringParsingOpts = StringParsingOpts(),
charset: String = TextFile.DEFAULT_CHARSET.name(),
inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with InsertableRelation {

/**
Expand All @@ -59,8 +59,10 @@ case class CsvRelation protected[spark] (
logger.warn(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.")
}

if((ignoreLeadingWhiteSpace || ignoreLeadingWhiteSpace) && ParserLibs.isCommonsLib(parserLib)) {
logger.warn(s"Ignore white space options may not work with Commons parserLib option")
if ((csvParsingOpts.ignoreLeadingWhitespace ||
csvParsingOpts.ignoreTrailingWhitespace) &&
ParserLibs.isCommonsLib(parserLib)) {
logger.warn(s"Ignore white space options only supported with univocity parser option")
}

private val failFast = ParseModes.isFailFastMode(parseMode)
Expand All @@ -70,16 +72,16 @@ case class CsvRelation protected[spark] (
val schema = inferSchema()

def tokenRdd(header: Array[String]): RDD[Array[String]] = {
val baseRDD = TextFile.withCharset(sqlContext.sparkContext, location, charset,
csvParsingOpts.numParts)

val baseRDD = TextFile.withCharset(sqlContext.sparkContext, location, charset)

if(ParserLibs.isUnivocityLib(parserLib)) {
if (ParserLibs.isUnivocityLib(parserLib)) {
univocityParseCSV(baseRDD, header)
} else {
val csvFormat = CSVFormat.DEFAULT
.withDelimiter(delimiter)
.withQuote(quote)
.withEscape(escape)
.withDelimiter(csvParsingOpts.delimiter)
.withQuote(csvParsingOpts.quoteChar)
.withEscape(csvParsingOpts.escapeChar)
.withSkipHeaderRecord(false)
.withHeader(header: _*)
.withCommentMarker(comment)
Expand All @@ -102,28 +104,67 @@ case class CsvRelation protected[spark] (
// By making this a lazy val we keep the RDD around, amortizing the cost of locating splits.
def buildScan = {
val schemaFields = schema.fields
tokenRdd(schemaFields.map(_.name)).flatMap{ tokens =>
tokenRdd(schemaFields.map(_.name)).flatMap { tokens =>
lazy val errorDetail = s"${tokens.mkString(csvParsingOpts.delimiter.toString)}"

if (dropMalformed && schemaFields.length != tokens.size) {
logger.warn(s"Dropping malformed line: $tokens")
if (schemaFields.length != tokens.size &&
(dropMalformed || lineExceptionPolicy.badLinePolicy == LineExceptionPolicy.Ignore)) {
logger.warn(s"Dropping malformed line: $errorDetail")
None
} else if (failFast && schemaFields.length != tokens.size) {
throw new RuntimeException(s"Malformed line in FAILFAST mode: $tokens")
} else if (schemaFields.length != tokens.size &&
(failFast || lineExceptionPolicy.badLinePolicy == LineExceptionPolicy.Abort)) {
throw new RuntimeException(s"Malformed line in FAILFAST or Abort mode: $errorDetail")
} else {
var index: Int = 0
val rowArray = new Array[Any](schemaFields.length)
try {
index = 0
while (index < schemaFields.length) {
val field = schemaFields(index)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable)
try {
rowArray(index) = TypeCast.castTo(tokens(index), schemaFields(index).dataType)
} catch {
case e: NumberFormatException if realNumOpts.enable &&
(schemaFields(index).dataType == DoubleType ||
schemaFields(index).dataType == FloatType) =>

rowArray(index) = if (realNumOpts.nullStrings.contains(tokens(index))) {
null
} else if (realNumOpts.nanStrings.contains(tokens(index))) {
TypeCast.castTo("NaN", schemaFields(index).dataType)
} else if (realNumOpts.infPosStrings.contains(tokens(index))) {
TypeCast.castTo("Infinity", schemaFields(index).dataType)
} else if (realNumOpts.infNegStrings.contains(tokens(index))) {
TypeCast.castTo("-Infinity", schemaFields(index).dataType)
} else {
throw new IllegalStateException(
s"Failed to parse as double/float number ${tokens(index)}")
}

case _: NumberFormatException if intNumOpts.enable &&
(schemaFields(index).dataType == IntegerType ||
schemaFields(index).dataType == LongType) =>

rowArray(index) = if (intNumOpts.nullStrings.contains(tokens(index))) {
null
} else {
throw new IllegalStateException(
s"Failed to parse as int/long number ${tokens(index)}")
}
}
index = index + 1
}
Some(Row.fromSeq(rowArray))
} catch {
case aiob: ArrayIndexOutOfBoundsException if permissive =>
(index until schemaFields.length).foreach(ind => rowArray(ind) = null)
case aiob: ArrayIndexOutOfBoundsException
if permissive || lineExceptionPolicy.badLinePolicy == LineExceptionPolicy.Fill =>
(index until schemaFields.length).foreach { ind =>
rowArray(ind) = TypeCast.castTo(lineExceptionPolicy.fillValue,
schemaFields(index).dataType)
}
Some(Row.fromSeq(rowArray))
case NonFatal(e) if !failFast =>
logger.error(s"Exception while parsing line: $errorDetail. ", e)
None
}
}
}
Expand All @@ -133,29 +174,33 @@ case class CsvRelation protected[spark] (
if (this.userSchema != null) {
userSchema
} else {
val firstRow = if(ParserLibs.isUnivocityLib(parserLib)) {
val escapeVal = if(escape == null) '\\' else escape.charValue()
val firstRow = if (ParserLibs.isUnivocityLib(parserLib)) {
val escapeVal = if (csvParsingOpts.escapeChar == null) '\\'
else csvParsingOpts.escapeChar.charValue()
val commentChar: Char = if (comment == null) '\0' else comment
new LineCsvReader(fieldSep = delimiter, quote = quote, escape = escapeVal,
commentMarker = commentChar).parseLine(firstLine)
new LineCsvReader(fieldSep = csvParsingOpts.delimiter,
quote = csvParsingOpts.quoteChar,
escape = escapeVal,
commentMarker = commentChar)
.parseLine(firstLine)
} else {
val csvFormat = CSVFormat.DEFAULT
.withDelimiter(delimiter)
.withQuote(quote)
.withEscape(escape)
.withDelimiter(csvParsingOpts.delimiter)
.withQuote(csvParsingOpts.quoteChar)
.withEscape(csvParsingOpts.escapeChar)
.withSkipHeaderRecord(false)
CSVParser.parse(firstLine, csvFormat).getRecords.head.toArray
}
val header = if (useHeader) {
firstRow
} else {
firstRow.zipWithIndex.map { case (value, index) => s"C$index"}
firstRow.zipWithIndex.map { case (value, index) => s"C$index" }
}
if (this.inferCsvSchema) {
InferSchema(tokenRdd(header), header)
} else{
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
val schemaFields = header.map { fieldName =>
StructField(fieldName.toString, StringType, nullable = true)
}
StructType(schemaFields)
Expand All @@ -178,29 +223,30 @@ case class CsvRelation protected[spark] (
}
}

private def univocityParseCSV(
file: RDD[String],
header: Seq[String]): RDD[Array[String]] = {
private def univocityParseCSV(file: RDD[String], header: Seq[String]) = {
// If header is set, make sure firstLine is materialized before sending to executors.
val filterLine = if (useHeader) firstLine else null
val dataLines = if(useHeader) file.filter(_ != filterLine) else file
val rows = dataLines.mapPartitionsWithIndex({
case (split, iter) => {
val escapeVal = if(escape == null) '\\' else escape.charValue()
val escapeVal = if (csvParsingOpts.escapeChar == null) '\\'
else csvParsingOpts.escapeChar.charValue()
val commentChar: Char = if (comment == null) '\0' else comment

new BulkCsvReader(iter, split,
headers = header, fieldSep = delimiter,
quote = quote, escape = escapeVal, commentMarker = commentChar)
headers = header, fieldSep = csvParsingOpts.delimiter,
quote = csvParsingOpts.quoteChar, escape = escapeVal,
ignoreLeadingSpace = csvParsingOpts.ignoreLeadingWhitespace,
ignoreTrailingSpace = csvParsingOpts.ignoreTrailingWhitespace,
commentMarker = commentChar)
}
}, true)

rows
}

private def parseCSV(
iter: Iterator[String],
csvFormat: CSVFormat): Iterator[Array[String]] = {
private def parseCSV(iter: Iterator[String],
csvFormat: CSVFormat): Iterator[Array[String]] = {
iter.flatMap { line =>
try {
val records = CSVParser.parse(line, csvFormat).getRecords
Expand Down Expand Up @@ -233,7 +279,7 @@ case class CsvRelation protected[spark] (
+ s" to INSERT OVERWRITE a CSV table:\n${e.toString}")
}
// Write the data. We assume that schema isn't changed, and we won't update it.
data.saveAsCsvFile(location, Map("delimiter" -> delimiter.toString))
data.saveAsCsvFile(location, Map("delimiter" -> csvParsingOpts.delimiter.toString))
} else {
sys.error("CSV tables only support INSERT OVERWRITE for now.")
}
Expand Down
Loading