diff --git a/build.sbt b/build.sbt index b3f87fe..99f7526 100755 --- a/build.sbt +++ b/build.sbt @@ -59,7 +59,7 @@ pomExtra := ( spName := "databricks/spark-csv" -sparkVersion := "1.4.0" +sparkVersion := "1.4.1" sparkComponents += "sql" diff --git a/project/plugins.sbt b/project/plugins.sbt index b26f1ca..c87a2ee 100755 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -8,6 +8,8 @@ resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositori resolvers += "Spark Package Main Repo" at "https://dl.bintray.com/spark-packages/maven" +resolvers += Resolver.url("scoverage-bintray", url("https://dl.bintray.com/sksamuel/sbt-plugins/"))(Resolver.ivyStylePatterns) + addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index cf92908..611b1a8 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -16,24 +16,24 @@ package com.databricks.spark.csv -import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext} + import com.databricks.spark.csv.util.{ParserLibs, ParseModes, TextFile} /** * A collection of static functions for working with CSV files in Spark SQL */ class CsvParser { - private var useHeader: Boolean = false - private var delimiter: Character = ',' - private var quote: Character = '"' - private var escape: Character = null + private var csvParsingOpts: CSVParsingOpts = CSVParsingOpts() + private var lineParsingOpts: LineParsingOpts = LineParsingOpts() + private var realNumberParsingOpts: RealNumberParsingOpts = RealNumberParsingOpts() + private var intNumberParsingOpts: IntNumberParsingOpts = IntNumberParsingOpts() + private var stringParsingOpts: StringParsingOpts = StringParsingOpts() private var comment: Character = '#' private var schema: StructType = null private var parseMode: String = ParseModes.DEFAULT - private var ignoreLeadingWhiteSpace: Boolean = false - private var ignoreTrailingWhiteSpace: Boolean = false private var parserLib: String = ParserLibs.DEFAULT private var charset: String = TextFile.DEFAULT_CHARSET.name() private var inferSchema: Boolean = false @@ -44,12 +44,12 @@ class CsvParser { } def withDelimiter(delimiter: Character): CsvParser = { - this.delimiter = delimiter + this.csvParsingOpts.delimiter = delimiter this } def withQuoteChar(quote: Character): CsvParser = { - this.quote = quote + this.csvParsingOpts.quoteChar = quote this } @@ -64,7 +64,7 @@ class CsvParser { } def withEscape(escapeChar: Character): CsvParser = { - this.escape = escapeChar + this.csvParsingOpts.escapeChar = escapeChar this } @@ -74,12 +74,12 @@ class CsvParser { } def withIgnoreLeadingWhiteSpace(ignore: Boolean): CsvParser = { - this.ignoreLeadingWhiteSpace = ignore + this.csvParsingOpts.ignoreLeadingWhitespace = ignore this } def withIgnoreTrailingWhiteSpace(ignore: Boolean): CsvParser = { - this.ignoreTrailingWhiteSpace = ignore + this.csvParsingOpts.ignoreTrailingWhitespace = ignore this } @@ -88,6 +88,41 @@ class CsvParser { this } + def withCsvParsingOpts(csvParsingOpts: CSVParsingOpts) = { + this.csvParsingOpts = csvParsingOpts + this + } + + def withLineParsingOpts(lineParsingOpts: LineParsingOpts) = { + this.lineParsingOpts = lineParsingOpts + this + } + + def withRealNumberParsingOpts(numberParsingOpts: RealNumberParsingOpts) = { + this.realNumberParsingOpts = numberParsingOpts + this + } + + def withIntNumberParsingOpts(numberParsingOpts: IntNumberParsingOpts) = { + this.intNumberParsingOpts = numberParsingOpts + this + } + + + def withStringParsingOpts(stringParsingOpts: StringParsingOpts) = { + this.stringParsingOpts = stringParsingOpts + this + } + + def withOpts(optMap: Map[String, String]) = { + this.stringParsingOpts = StringParsingOpts(optMap) + this.lineParsingOpts = LineParsingOpts(optMap) + this.realNumberParsingOpts = RealNumberParsingOpts(optMap) + this.intNumberParsingOpts = IntNumberParsingOpts(optMap) + this.csvParsingOpts = CSVParsingOpts(optMap) + this + } + def withCharset(charset: String): CsvParser = { this.charset = charset this @@ -104,15 +139,15 @@ class CsvParser { val relation: CsvRelation = CsvRelation( path, useHeader, - delimiter, - quote, - escape, - comment, + csvParsingOpts, parseMode, parserLib, - ignoreLeadingWhiteSpace, - ignoreTrailingWhiteSpace, schema, + comment, + lineParsingOpts, + realNumberParsingOpts, + intNumberParsingOpts, + stringParsingOpts, charset, inferSchema)(sqlContext) sqlContext.baseRelationToDataFrame(relation) diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index eddd2b5..a02f04a 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -31,20 +31,20 @@ import org.apache.spark.sql.types._ import com.databricks.spark.csv.util._ import com.databricks.spark.sql.readers._ -case class CsvRelation protected[spark] ( - location: String, - useHeader: Boolean, - delimiter: Char, - quote: Char, - escape: Character, - comment: Character, - parseMode: String, - parserLib: String, - ignoreLeadingWhiteSpace: Boolean, - ignoreTrailingWhiteSpace: Boolean, - userSchema: StructType = null, - charset: String = TextFile.DEFAULT_CHARSET.name(), - inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext) +case class CsvRelation protected[spark]( + location: String, + useHeader: Boolean, + csvParsingOpts: CSVParsingOpts, + parseMode: String, + parserLib: String, + userSchema: StructType = null, + comment: Character, + lineExceptionPolicy: LineParsingOpts = LineParsingOpts(), + realNumOpts: RealNumberParsingOpts = RealNumberParsingOpts(), + intNumOpts: IntNumberParsingOpts = IntNumberParsingOpts(), + stringParsingOpts: StringParsingOpts = StringParsingOpts(), + charset: String = TextFile.DEFAULT_CHARSET.name(), + inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with InsertableRelation { /** @@ -59,8 +59,10 @@ case class CsvRelation protected[spark] ( logger.warn(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") } - if((ignoreLeadingWhiteSpace || ignoreLeadingWhiteSpace) && ParserLibs.isCommonsLib(parserLib)) { - logger.warn(s"Ignore white space options may not work with Commons parserLib option") + if ((csvParsingOpts.ignoreLeadingWhitespace || + csvParsingOpts.ignoreTrailingWhitespace) && + ParserLibs.isCommonsLib(parserLib)) { + logger.warn(s"Ignore white space options only supported with univocity parser option") } private val failFast = ParseModes.isFailFastMode(parseMode) @@ -70,16 +72,16 @@ case class CsvRelation protected[spark] ( val schema = inferSchema() def tokenRdd(header: Array[String]): RDD[Array[String]] = { + val baseRDD = TextFile.withCharset(sqlContext.sparkContext, location, charset, + csvParsingOpts.numParts) - val baseRDD = TextFile.withCharset(sqlContext.sparkContext, location, charset) - - if(ParserLibs.isUnivocityLib(parserLib)) { + if (ParserLibs.isUnivocityLib(parserLib)) { univocityParseCSV(baseRDD, header) } else { val csvFormat = CSVFormat.DEFAULT - .withDelimiter(delimiter) - .withQuote(quote) - .withEscape(escape) + .withDelimiter(csvParsingOpts.delimiter) + .withQuote(csvParsingOpts.quoteChar) + .withEscape(csvParsingOpts.escapeChar) .withSkipHeaderRecord(false) .withHeader(header: _*) .withCommentMarker(comment) @@ -102,28 +104,67 @@ case class CsvRelation protected[spark] ( // By making this a lazy val we keep the RDD around, amortizing the cost of locating splits. def buildScan = { val schemaFields = schema.fields - tokenRdd(schemaFields.map(_.name)).flatMap{ tokens => + tokenRdd(schemaFields.map(_.name)).flatMap { tokens => + lazy val errorDetail = s"${tokens.mkString(csvParsingOpts.delimiter.toString)}" - if (dropMalformed && schemaFields.length != tokens.size) { - logger.warn(s"Dropping malformed line: $tokens") + if (schemaFields.length != tokens.size && + (dropMalformed || lineExceptionPolicy.badLinePolicy == LineExceptionPolicy.Ignore)) { + logger.warn(s"Dropping malformed line: $errorDetail") None - } else if (failFast && schemaFields.length != tokens.size) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: $tokens") + } else if (schemaFields.length != tokens.size && + (failFast || lineExceptionPolicy.badLinePolicy == LineExceptionPolicy.Abort)) { + throw new RuntimeException(s"Malformed line in FAILFAST or Abort mode: $errorDetail") } else { var index: Int = 0 val rowArray = new Array[Any](schemaFields.length) try { index = 0 while (index < schemaFields.length) { - val field = schemaFields(index) - rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable) + try { + rowArray(index) = TypeCast.castTo(tokens(index), schemaFields(index).dataType) + } catch { + case e: NumberFormatException if realNumOpts.enable && + (schemaFields(index).dataType == DoubleType || + schemaFields(index).dataType == FloatType) => + + rowArray(index) = if (realNumOpts.nullStrings.contains(tokens(index))) { + null + } else if (realNumOpts.nanStrings.contains(tokens(index))) { + TypeCast.castTo("NaN", schemaFields(index).dataType) + } else if (realNumOpts.infPosStrings.contains(tokens(index))) { + TypeCast.castTo("Infinity", schemaFields(index).dataType) + } else if (realNumOpts.infNegStrings.contains(tokens(index))) { + TypeCast.castTo("-Infinity", schemaFields(index).dataType) + } else { + throw new IllegalStateException( + s"Failed to parse as double/float number ${tokens(index)}") + } + + case _: NumberFormatException if intNumOpts.enable && + (schemaFields(index).dataType == IntegerType || + schemaFields(index).dataType == LongType) => + + rowArray(index) = if (intNumOpts.nullStrings.contains(tokens(index))) { + null + } else { + throw new IllegalStateException( + s"Failed to parse as int/long number ${tokens(index)}") + } + } index = index + 1 } Some(Row.fromSeq(rowArray)) } catch { - case aiob: ArrayIndexOutOfBoundsException if permissive => - (index until schemaFields.length).foreach(ind => rowArray(ind) = null) + case aiob: ArrayIndexOutOfBoundsException + if permissive || lineExceptionPolicy.badLinePolicy == LineExceptionPolicy.Fill => + (index until schemaFields.length).foreach { ind => + rowArray(ind) = TypeCast.castTo(lineExceptionPolicy.fillValue, + schemaFields(index).dataType) + } Some(Row.fromSeq(rowArray)) + case NonFatal(e) if !failFast => + logger.error(s"Exception while parsing line: $errorDetail. ", e) + None } } } @@ -133,29 +174,33 @@ case class CsvRelation protected[spark] ( if (this.userSchema != null) { userSchema } else { - val firstRow = if(ParserLibs.isUnivocityLib(parserLib)) { - val escapeVal = if(escape == null) '\\' else escape.charValue() + val firstRow = if (ParserLibs.isUnivocityLib(parserLib)) { + val escapeVal = if (csvParsingOpts.escapeChar == null) '\\' + else csvParsingOpts.escapeChar.charValue() val commentChar: Char = if (comment == null) '\0' else comment - new LineCsvReader(fieldSep = delimiter, quote = quote, escape = escapeVal, - commentMarker = commentChar).parseLine(firstLine) + new LineCsvReader(fieldSep = csvParsingOpts.delimiter, + quote = csvParsingOpts.quoteChar, + escape = escapeVal, + commentMarker = commentChar) + .parseLine(firstLine) } else { val csvFormat = CSVFormat.DEFAULT - .withDelimiter(delimiter) - .withQuote(quote) - .withEscape(escape) + .withDelimiter(csvParsingOpts.delimiter) + .withQuote(csvParsingOpts.quoteChar) + .withEscape(csvParsingOpts.escapeChar) .withSkipHeaderRecord(false) CSVParser.parse(firstLine, csvFormat).getRecords.head.toArray } val header = if (useHeader) { firstRow } else { - firstRow.zipWithIndex.map { case (value, index) => s"C$index"} + firstRow.zipWithIndex.map { case (value, index) => s"C$index" } } if (this.inferCsvSchema) { InferSchema(tokenRdd(header), header) - } else{ + } else { // By default fields are assumed to be StringType - val schemaFields = header.map { fieldName => + val schemaFields = header.map { fieldName => StructField(fieldName.toString, StringType, nullable = true) } StructType(schemaFields) @@ -178,29 +223,30 @@ case class CsvRelation protected[spark] ( } } - private def univocityParseCSV( - file: RDD[String], - header: Seq[String]): RDD[Array[String]] = { + private def univocityParseCSV(file: RDD[String], header: Seq[String]) = { // If header is set, make sure firstLine is materialized before sending to executors. val filterLine = if (useHeader) firstLine else null val dataLines = if(useHeader) file.filter(_ != filterLine) else file val rows = dataLines.mapPartitionsWithIndex({ case (split, iter) => { - val escapeVal = if(escape == null) '\\' else escape.charValue() + val escapeVal = if (csvParsingOpts.escapeChar == null) '\\' + else csvParsingOpts.escapeChar.charValue() val commentChar: Char = if (comment == null) '\0' else comment new BulkCsvReader(iter, split, - headers = header, fieldSep = delimiter, - quote = quote, escape = escapeVal, commentMarker = commentChar) + headers = header, fieldSep = csvParsingOpts.delimiter, + quote = csvParsingOpts.quoteChar, escape = escapeVal, + ignoreLeadingSpace = csvParsingOpts.ignoreLeadingWhitespace, + ignoreTrailingSpace = csvParsingOpts.ignoreTrailingWhitespace, + commentMarker = commentChar) } }, true) rows } - private def parseCSV( - iter: Iterator[String], - csvFormat: CSVFormat): Iterator[Array[String]] = { + private def parseCSV(iter: Iterator[String], + csvFormat: CSVFormat): Iterator[Array[String]] = { iter.flatMap { line => try { val records = CSVParser.parse(line, csvFormat).getRecords @@ -233,7 +279,7 @@ case class CsvRelation protected[spark] ( + s" to INSERT OVERWRITE a CSV table:\n${e.toString}") } // Write the data. We assume that schema isn't changed, and we won't update it. - data.saveAsCsvFile(location, Map("delimiter" -> delimiter.toString)) + data.saveAsCsvFile(location, Map("delimiter" -> csvParsingOpts.delimiter.toString)) } else { sys.error("CSV tables only support INSERT OVERWRITE for now.") } diff --git a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala index 939b0d9..ff0cad7 100755 --- a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala @@ -16,9 +16,10 @@ package com.databricks.spark.csv import org.apache.hadoop.fs.Path -import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} + import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} import com.databricks.spark.csv.util.{ParserLibs, TextFile, TypeCast} /** @@ -45,9 +46,9 @@ class DefaultSource * Parameters have to include 'path' and optionally 'delimiter', 'quote', and 'header' */ override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String], - schema: StructType) = { + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType) = { val path = checkPath(parameters) val delimiter = TypeCast.toChar(parameters.getOrElse("delimiter", ",")) @@ -89,10 +90,10 @@ class DefaultSource val parserLib = parameters.getOrElse("parserLib", ParserLibs.DEFAULT) val ignoreLeadingWhiteSpace = parameters.getOrElse("ignoreLeadingWhiteSpace", "false") - val ignoreLeadingWhiteSpaceFlag = if(ignoreLeadingWhiteSpace == "false") { + val ignoreLeadingWhiteSpaceFlag = if (ignoreLeadingWhiteSpace == "false") { false - } else if(ignoreLeadingWhiteSpace == "true") { - if(!ParserLibs.isUnivocityLib(parserLib)) { + } else if (ignoreLeadingWhiteSpace == "true") { + if (!ParserLibs.isUnivocityLib(parserLib)) { throw new Exception("Ignore whitesspace supported for Univocity parser only") } true @@ -100,10 +101,10 @@ class DefaultSource throw new Exception("Ignore white space flag can be true or false") } val ignoreTrailingWhiteSpace = parameters.getOrElse("ignoreTrailingWhiteSpace", "false") - val ignoreTrailingWhiteSpaceFlag = if(ignoreTrailingWhiteSpace == "false") { + val ignoreTrailingWhiteSpaceFlag = if (ignoreTrailingWhiteSpace == "false") { false - } else if(ignoreTrailingWhiteSpace == "true") { - if(!ParserLibs.isUnivocityLib(parserLib)) { + } else if (ignoreTrailingWhiteSpace == "true") { + if (!ParserLibs.isUnivocityLib(parserLib)) { throw new Exception("Ignore whitespace supported for the Univocity parser only") } true @@ -113,36 +114,52 @@ class DefaultSource val charset = parameters.getOrElse("charset", TextFile.DEFAULT_CHARSET.name()) // TODO validate charset? - val inferSchema = parameters.getOrElse("inferSchema", "false") - val inferSchemaFlag = if(inferSchema == "false") { + val inferSchemaFlag = if (inferSchema == "false") { false - } else if(inferSchema == "true") { + } else if (inferSchema == "true") { true } else { throw new Exception("Infer schema flag can be true or false") } + val lineParsingOpts = LineParsingOpts(parameters) + val realNumParsingOpts = RealNumberParsingOpts(parameters) + val intNumParsingOpts = IntNumberParsingOpts(parameters) + val stringParsingOpts = StringParsingOpts(parameters) + + val csvParsingOpts = if (!parameters.exists { case (k, v) => + k.startsWith("csvParsingOpts.") + }) { + CSVParsingOpts(delimiter = delimiter, + quoteChar = quoteChar, + escapeChar = escapeChar, + ignoreLeadingWhitespace = ignoreLeadingWhiteSpaceFlag, + ignoreTrailingWhitespace = ignoreTrailingWhiteSpaceFlag) + } else { + CSVParsingOpts(parameters) + } + CsvRelation(path, - headerFlag, - delimiter, - quoteChar, - escapeChar, - commentChar, - parseMode, - parserLib, - ignoreLeadingWhiteSpaceFlag, - ignoreTrailingWhiteSpaceFlag, - schema, - charset, - inferSchemaFlag)(sqlContext) + useHeader = headerFlag, + csvParsingOpts = csvParsingOpts, + lineExceptionPolicy = lineParsingOpts, + realNumOpts = realNumParsingOpts, + intNumOpts = intNumParsingOpts, + stringParsingOpts = stringParsingOpts, + parseMode = parseMode, + parserLib = parserLib, + userSchema = schema, + comment = commentChar, + charset = charset, + inferCsvSchema = inferSchemaFlag)(sqlContext) } override def createRelation( - sqlContext: SQLContext, - mode: SaveMode, - parameters: Map[String, String], - data: DataFrame): BaseRelation = { + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { val path = checkPath(parameters) val filesystemPath = new Path(path) val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) diff --git a/src/main/scala/com/databricks/spark/csv/ParsingOptions.scala b/src/main/scala/com/databricks/spark/csv/ParsingOptions.scala new file mode 100644 index 0000000..b0849fc --- /dev/null +++ b/src/main/scala/com/databricks/spark/csv/ParsingOptions.scala @@ -0,0 +1,241 @@ +// scalastyle:off +/* + * Copyright 2015 Ayasdi Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:on + +package com.databricks.spark.csv + +import scala.collection.immutable.HashSet + +/** + * Action to take when malformed lines are found in a CSV File + */ +object LineExceptionPolicy { + + sealed trait EnumVal + + /** + * ignore the malformed line and continue + */ + case object Ignore extends EnumVal + + /** + * stop parsing and abort + */ + case object Abort extends EnumVal + + /** + * if fields are missing in a line, fill in the blanks + */ + case object Fill extends EnumVal + +} + +object ParsingOptions { + val defaultNullStrings = HashSet("", "NULL", "N/A", "null", "n/a") + val defaultNaNStrings = HashSet("NaN", "nan") + val defaultInfPosString = HashSet("+Inf", "Inf", "Infinity", "+Infinity", "inf", "+inf") + val defaultInfNegString = HashSet("-Inf", "-inf", "-Infinity") + + private[csv] def delimitedStringToSet(str: String) = { + str.split(",").toSet + } +} + +/** + * Options to control parsing of real numbers e.g. the types Float and Double + * @param nanStrings NaNs + * @param nullStrings nulls + * @param infNegStrings negative infinity + * @param infPosStrings positive infinity + * @param enable false to not apply these options, default true + */ +case class RealNumberParsingOpts( + var nanStrings: Set[String] = ParsingOptions.defaultNaNStrings, + var infPosStrings: Set[String] = ParsingOptions.defaultInfPosString, + var infNegStrings: Set[String] = ParsingOptions.defaultInfNegString, + var nullStrings: Set[String] = ParsingOptions.defaultNullStrings, + var enable: Boolean = true) + +/** + * Options to control parsing of integral numbers e.g. the types Int and Long + * @param nullStrings nulls + * @param enable false to not apply these options, default true + */ +case class IntNumberParsingOpts(var nullStrings: Set[String] = ParsingOptions.defaultNullStrings, + var enable: Boolean = true) + +/** + * Options to control parsing of strings + * @param nullStrings nulls + * @param emptyStringReplace replace empty string with this string + */ +case class StringParsingOpts(var emptyStringReplace: String = "", + var nullStrings: Set[String] = ParsingOptions.defaultNullStrings) + +/** + * options to handle exceptions while parsing a line + * @param badLinePolicy abort, ignore line or fill with fillValue when not enough fields are parsed + * @param fillValue if line exception policy is to fill in the blanks, use this value to fill + */ +case class LineParsingOpts( + var badLinePolicy: LineExceptionPolicy.EnumVal = LineExceptionPolicy.Fill, + var fillValue: String = null) + +/** + * CSV parsing options + * @param quoteChar fields containing delimiters, other special chars are quoted using this + * character e.g. "this is a comma ," + * @param escapeChar if a quote character appears in a field, it is escaped using this + * e.g. "this is a quote \"" + * @param ignoreLeadingWhitespace ignore white space before a field + * @param ignoreTrailingWhitespace ignore white space after a field + * @param numParts number of partitions to use in sc.textFile() + */ +case class CSVParsingOpts(var delimiter: Character = ',', + var quoteChar: Character = '"', + var escapeChar: Character = '\\', + var ignoreLeadingWhitespace: Boolean = true, + var ignoreTrailingWhitespace: Boolean = true, + var numParts: Int = 0) + +/** + * builds a [[RealNumberParsingOpts]] instance from "text" + * realNumParsingOpts.{nans, infs, -infs, nulls, enable} are supported + */ +object RealNumberParsingOpts { + val prefix = "realNumParsingOpts." + + def apply(opts: Map[String, String]): RealNumberParsingOpts = { + val build = RealNumberParsingOpts() + for (opt <- opts if opt._1.startsWith(prefix)) { + (opt._1.stripPrefix(prefix), opt._2) match { + case ("nans", value: String) => + build.nanStrings = ParsingOptions.delimitedStringToSet(value) + case ("infs", value: String) => + build.infPosStrings = ParsingOptions.delimitedStringToSet(value) + case ("-infs", value: String) => + build.infNegStrings = ParsingOptions.delimitedStringToSet(value) + case ("nulls", value: String) => + build.nullStrings = ParsingOptions.delimitedStringToSet(value) + case ("enable", value: String) => build.enable = value.toBoolean + case _ => throw new IllegalArgumentException(s"Unknown option $opt") + } + } + + build + } +} + +/** + * builds a [[IntNumberParsingOpts]] instance from "text" + * intNumParsingOpts.{nulls, enable} are supported + */ +object IntNumberParsingOpts { + val prefix = "intNumParsingOpts." + + def apply(opts: Map[String, String]): IntNumberParsingOpts = { + val build = IntNumberParsingOpts() + for (opt <- opts if opt._1.startsWith(prefix)) { + (opt._1.stripPrefix(prefix), opt._2) match { + case ("nulls", value: String) => + build.nullStrings = ParsingOptions.delimitedStringToSet(value) + case ("enable", value: String) => build.enable = value.toBoolean + case _ => throw new IllegalArgumentException(s"Unknown option $opt") + } + } + + build + } +} + +/** + * builds a [[StringParsingOpts]] instance from "text" + * stringParsingOpts.{nulls, emptyStringReplace} are supported + */ +object StringParsingOpts { + val prefix = "stringParsingOpts." + + def apply(opts: Map[String, String]): StringParsingOpts = { + val build = StringParsingOpts() + for (opt <- opts if opt._1.startsWith(prefix)) { + (opt._1.stripPrefix(prefix), opt._2) match { + case ("nulls", value: String) => + build.nullStrings = ParsingOptions.delimitedStringToSet(value) + case ("emptyStringReplace", value: String) => build.emptyStringReplace = value + case _ => throw new IllegalArgumentException(s"Unknown option $opt") + } + } + + build + } +} + +/** + * builds a [[LineParsingOpts]] instance from "text" + * lineParsingOpts.{badLinePolicy, fillValue} are supported + * lineParsingOpts.badLinePolicy can be one of fill, ignore or abort + */ +object LineParsingOpts { + val prefix = "lineParsingOpts." + + def apply(opts: Map[String, String]): LineParsingOpts = { + val build = LineParsingOpts() + for (opt <- opts if opt._1.startsWith(prefix)) { + (opt._1.stripPrefix(prefix), opt._2) match { + case ("badLinePolicy", value: String) => + build.badLinePolicy = value.toLowerCase match { + case "fill" => LineExceptionPolicy.Fill + case "ignore" => LineExceptionPolicy.Ignore + case "abort" => LineExceptionPolicy.Abort + case _ => throw new IllegalArgumentException(s"Unknown option $opt") + } + case ("fillValue", value: String) => build.fillValue = value + case _ => throw new IllegalArgumentException(s"Unknown option $opt") + } + } + + build + } +} + +/** + * builds a [[CSVParsingOpts]] instance from "text" + * csvParsingOpts.{delimiter, quote, escape, ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace, + * numParts} are supported + */ +object CSVParsingOpts { + val prefix = "csvParsingOpts." + + def apply(opts: Map[String, String]): CSVParsingOpts = { + val build = CSVParsingOpts() + for (opt <- opts if opt._1.startsWith(prefix)) { + (opt._1.stripPrefix(prefix), opt._2) match { + case ("delimiter", value: String) => build.delimiter = value.charAt(0); + case ("quote", value: String) => build.quoteChar = value.charAt(0) + case ("escape", value: String) => build.escapeChar = value.charAt(0) + case ("ignoreLeadingSpace", value: String) => + build.ignoreLeadingWhitespace = value.toBoolean + case ("ignoreTrailingSpace", value: String) => + build.ignoreTrailingWhitespace = value.toBoolean + case ("numParts", value: String) => build.numParts = value.toInt + case _ => throw new IllegalArgumentException(s"Unknown option $opt") + } + } + + build + } +} diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index 2812df5..414f8c0 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -15,6 +15,9 @@ */ package com.databricks.spark +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.csv.CSVFormat import org.apache.hadoop.io.compress.CompressionCodec @@ -39,17 +42,18 @@ package object csv { ignoreTrailingWhiteSpace: Boolean = false, charset: String = TextFile.DEFAULT_CHARSET.name(), inferSchema: Boolean = false) = { + val csvParsingOpts = CSVParsingOpts(delimiter = delimiter, + quoteChar = quote, + escapeChar = escape, + ignoreLeadingWhitespace = ignoreLeadingWhiteSpace, + ignoreTrailingWhitespace = ignoreTrailingWhiteSpace) val csvRelation = CsvRelation( location = filePath, useHeader = useHeader, - delimiter = delimiter, - quote = quote, - escape = escape, + csvParsingOpts = csvParsingOpts, comment = comment, parseMode = mode, parserLib = parserLib, - ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace, - ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace, charset = charset, inferCsvSchema = inferSchema)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) @@ -62,17 +66,19 @@ package object csv { ignoreTrailingWhiteSpace: Boolean = false, charset: String = TextFile.DEFAULT_CHARSET.name(), inferSchema: Boolean = false) = { + val csvParsingOpts = CSVParsingOpts(delimiter = '\t', + quoteChar = '"', + escapeChar = '\\', + ignoreLeadingWhitespace = ignoreLeadingWhiteSpace, + ignoreTrailingWhitespace = ignoreTrailingWhiteSpace) + val csvRelation = CsvRelation( location = filePath, useHeader = useHeader, - delimiter = '\t', - quote = '"', - escape = '\\', + csvParsingOpts = csvParsingOpts, comment = '#', parseMode = "PERMISSIVE", parserLib = parserLib, - ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace, - ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace, charset = charset, inferCsvSchema = inferSchema)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) @@ -85,7 +91,9 @@ package object csv { * Saves DataFrame as csv files. By default uses ',' as delimiter, and includes header line. */ def saveAsCsvFile(path: String, parameters: Map[String, String] = Map(), - compressionCodec: Class[_ <: CompressionCodec] = null): Unit = { + compressionCodec: Class[_ <: CompressionCodec] = null, + sparseColInfo: mutable.Map[String, mutable.Map[String, Int]] = null): Unit = { + // TODO(hossein): For nested types, we may want to perform special work val delimiter = parameters.getOrElse("delimiter", ",") val delimiterChar = if (delimiter.length == 1) { @@ -126,8 +134,33 @@ package object csv { } val generateHeader = parameters.getOrElse("header", "false").toBoolean + + val isSparse: Array[Boolean] = dataFrame.columns.flatMap { colName: String => + if (sparseColInfo != null && sparseColInfo.contains(colName)) { + Array.fill(sparseColInfo(colName).size)(true) + } else { + Array(false) + } + } + + def makeHeader : String = { + val hs = dataFrame.columns.flatMap { colName: String => + if (sparseColInfo.contains(colName)) { + require(sparseColInfo.contains(colName)) + sparseColInfo(colName).toSeq.sortBy(_._2).map(_._1) + } else { + Seq(colName) + } + } + csvFormat.format(hs : _*) + } + val header = if (generateHeader) { - csvFormat.format(dataFrame.columns.map(_.asInstanceOf[AnyRef]):_*) + if (sparseColInfo == null) { + csvFormat.format(dataFrame.columns.map(_.asInstanceOf[AnyRef]): _*) + } else { + makeHeader + } } else { "" // There is no need to generate header in this case } @@ -151,7 +184,17 @@ package object csv { override def next: String = { if(!iter.isEmpty) { - val row = csvFormat.format(iter.next.toSeq.map(_.asInstanceOf[AnyRef]):_*) + def makeCsvRow(inFields: Seq[Any]) : String = { + val fields = inFields.flatMap { f => + if(isSparse(inFields.indexOf(f))) { + f.asInstanceOf[ArrayBuffer[Any]] + } else { + ArrayBuffer(f) + } + } + csvFormat.format(fields.map(_.asInstanceOf[AnyRef]): _*) + } + val row = makeCsvRow(iter.next.toSeq) if (firstRow) { firstRow = false header + csvFormat.getRecordSeparator() + row diff --git a/src/main/scala/com/databricks/spark/csv/util/TextFile.scala b/src/main/scala/com/databricks/spark/csv/util/TextFile.scala index 3b8d6c6..ab4f2e3 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TextFile.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TextFile.scala @@ -17,17 +17,23 @@ package com.databricks.spark.csv.util import java.nio.charset.Charset -import org.apache.hadoop.io.{Text, LongWritable} +import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat + import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD private[csv] object TextFile { val DEFAULT_CHARSET = Charset.forName("UTF-8") - def withCharset(context: SparkContext, location: String, charset: String): RDD[String] = { + def withCharset(context: SparkContext, location: String, + charset: String, numParts: Int = 0): RDD[String] = { if (Charset.forName(charset) == DEFAULT_CHARSET) { - context.textFile(location) + if (numParts == 0) { + context.textFile(location) + } else { + context.textFile(location, numParts) + } } else { // can't pass a Charset object here cause its not serializable // TODO: maybe use mapPartitions instead? diff --git a/src/test/resources/cars-alternative.csv b/src/test/resources/cars-alternative.csv index 2c1285a..baff5bc 100644 --- a/src/test/resources/cars-alternative.csv +++ b/src/test/resources/cars-alternative.csv @@ -1,5 +1,5 @@ year|make|model|comment -'2012'|'Tesla'|'S'| 'No comment' + '2012' |'Tesla'|'S'| 'No comment' -1997|Ford|E350|'Go get one now they are going fast' -2015|Chevy|Volt + 1997|Ford|E350|'Go get one now they are going fast' +2015 |Chevy|Volt diff --git a/src/test/resources/numbers.csv b/src/test/resources/numbers.csv new file mode 100644 index 0000000..072247d --- /dev/null +++ b/src/test/resources/numbers.csv @@ -0,0 +1,11 @@ +double, float, int, long +1.0, 1.0, 1, 1 + , , , +NaN, NaN, 3, 3 +NULL, null, N/A, n/a +Inf, Inf, 5, 5 +-Inff, -Inff, 6, 6 +Infinity, Infinity, 7, 7 + + + diff --git a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala index 0fa4e9e..6a41d2e 100644 --- a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala @@ -19,14 +19,16 @@ import java.io.File import java.nio.charset.UnsupportedCharsetException import org.apache.hadoop.io.compress.GzipCodec -import org.apache.spark.sql.Row -import org.apache.spark.sql.test._ +import org.scalatest.FunSuite + import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ -import org.scalatest.FunSuite /* Implicits */ -import TestSQLContext._ + +import org.apache.spark.sql.test.TestSQLContext._ class CsvFastSuite extends FunSuite { val carsFile = "src/test/resources/cars.csv" @@ -36,6 +38,7 @@ class CsvFastSuite extends FunSuite { val nullNumbersFile = "src/test/resources/null-numbers.csv" val emptyFile = "src/test/resources/empty.csv" val escapeFile = "src/test/resources/escape.csv" + val numbersFile = "src/test/resources/numbers.csv" val tempEmptyDir = "target/test/empty2/" val commentsFile = "src/test/resources/comments.csv" val disableCommentsFile = "src/test/resources/disable_comments.csv" @@ -110,7 +113,7 @@ class CsvFastSuite extends FunSuite { s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, priceTag decimal, - | comments string, grp string) + |comments string, grp string) |USING com.databricks.spark.csv |OPTIONS (path "$carsTsvFile", header "true", delimiter "\t", parserLib "univocity") """.stripMargin.replaceAll("\n", " ")) @@ -119,6 +122,162 @@ class CsvFastSuite extends FunSuite { assert(sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) } + test("DSL test for Line exception policy") { + val results = new CsvParser() + .withUseHeader(true) + .withLineParsingOpts(LineParsingOpts(badLinePolicy = LineExceptionPolicy.Fill, fillValue = "fill")) + .withParserLib("univocity") + .csvFile(TestSQLContext, carsFile) + .collect() + val volt = results.filter(row => row.get(2).asInstanceOf[String] == "Volt").head + assert(volt.get(3).asInstanceOf[String] === "fill") + assert(volt.get(4).asInstanceOf[String] === "fill") + } + + test("DSL test for CSV Parsing Opts: delimiter") { + val results = new CsvParser() + .withUseHeader(true) + .withCsvParsingOpts(CSVParsingOpts(delimiter = '\t')) + .withParserLib("univocity") + .csvFile(TestSQLContext, carsTsvFile) + .select("year") + .collect() + + assert(results.size === numCars) + } + + test("DSL test for CSV Parsing Opts: numParts") { + val rdd = new CsvParser() + .withUseHeader(true) + .withCsvParsingOpts(CSVParsingOpts(numParts = 1, delimiter = '\t')) + .withParserLib("univocity") + .csvFile(TestSQLContext, carsTsvFile) + .select("year") + .rdd + + assert(rdd.partitions.size === 1) + } + + test("DSL test for CSV Parsing Opts: quote") { + val results = new CsvParser() + .withUseHeader(true) + .withCsvParsingOpts(CSVParsingOpts(quoteChar = '\'', delimiter = '|', numParts = 1)) + .withParserLib("univocity") + .csvFile(TestSQLContext, carsAltFile) + .select("year") + .collect() + + assert(results.size === numCars) + + val years = results.map(_.get(0).asInstanceOf[String]) + assert(years === Array("2012", "1997", "2015")) + } + + test("DSL test for CSV Parsing Opts: whitespace") { + var results = new CsvParser() + .withUseHeader(true) + .withCsvParsingOpts(CSVParsingOpts(delimiter = '|', numParts = 1, ignoreLeadingWhitespace = false)) + .withParserLib("univocity") + .csvFile(TestSQLContext, carsAltFile) + .select("year") + .collect() + + assert(results.size === numCars) + + var years = results.map(_.get(0).asInstanceOf[String]) + assert(years === Array(" \'2012\'", " 1997", "2015")) + + results = new CsvParser() + .withUseHeader(true) + .withCsvParsingOpts(CSVParsingOpts(delimiter = '|', numParts = 1, ignoreLeadingWhitespace = false)) + .withParserLib("univocity") + .csvFile(TestSQLContext, carsAltFile) + .select("year") + .collect() + + assert(results.size === numCars) + + years = results.map(_.get(0).asInstanceOf[String]) + assert(years === Array(" \'2012\'", " 1997", "2015")) + + results = new CsvParser() + .withUseHeader(true) + .withCsvParsingOpts(CSVParsingOpts(delimiter = '|', numParts = 1, ignoreTrailingWhitespace = false)) + .withParserLib("univocity") + .csvFile(TestSQLContext, carsAltFile) + .select("year") + .collect() + + assert(results.size === numCars) + + years = results.map(_.get(0).asInstanceOf[String]) + assert(years === Array("\'2012\' ", "1997", "2015 ")) + } + + test("DSL test for CSV Parsing Opts: special") { + var results = new CsvParser() + .withUseHeader(true) + .withSchema(StructType(Seq(StructField("double", DoubleType), + StructField("float", FloatType), + StructField("int", IntegerType), + StructField("long", LongType)))) + .withCsvParsingOpts(CSVParsingOpts(numParts = 1)) + .withParserLib("univocity") + .csvFile(TestSQLContext, numbersFile) + .collect() + + assert(results.size === 6) + + var doubles = results.map(_.get(0)) + assert(doubles.count(_.asInstanceOf[Double].isNaN) === 1) + assert(doubles.count(_.asInstanceOf[Double].isInfinite) === 2) + assert(doubles.count(_.asInstanceOf[Double] == 0.0) === 2) + + var floats = results.map(_.get(1)) + assert(floats.count(_.asInstanceOf[Float].isNaN) === 1) + assert(floats.count(_.asInstanceOf[Float].isInfinite) === 2) + assert(floats.count(_.asInstanceOf[Float] == 0.0) === 2) + + var ints = results.map(_.get(2)) + assert(ints.count(_.asInstanceOf[Int] == 0) === 2) + + var longs = results.map(_.get(3)) + assert(longs.count(_.asInstanceOf[Long] == 0) === 2) + + results = new CsvParser() + .withUseHeader(true) + .withSchema(StructType(Seq(StructField("double", DoubleType), + StructField("float", FloatType), + StructField("int", IntegerType), + StructField("long", LongType)))) + .withCsvParsingOpts(CSVParsingOpts(numParts = 1)) + .withRealNumberParsingOpts(RealNumberParsingOpts( + infNegStrings = ParsingOptions.defaultInfNegString + "-Inff")) + .withParserLib("univocity") + .csvFile(TestSQLContext, numbersFile) + .collect() + + assert(results.size === 7) + + doubles = results.map(_.get(0)) + assert(doubles.count(_.asInstanceOf[Double].isNaN) === 1) + assert(doubles.count(_.asInstanceOf[Double].isInfinite) === 3) + assert(doubles.count(_.asInstanceOf[Double] == 0.0) === 2) + + floats = results.map(_.get(1)) + assert(floats.count(_.asInstanceOf[Float].isNaN) === 1) + assert(floats.count(_.asInstanceOf[Float].isInfinite) === 3) + assert(floats.count(_.asInstanceOf[Float] == 0.0) === 2) + + ints = results.map(_.get(2)) + assert(ints.count(_.asInstanceOf[Int] == 0) === 2) + + longs = results.map(_.get(3)) + assert(longs.count(_.asInstanceOf[Long] == 0) === 2) + + } + + test("DSL test for DROPMALFORMED parsing mode") { val results = new CsvParser() .withParseMode("DROPMALFORMED") @@ -137,16 +296,15 @@ class CsvFastSuite extends FunSuite { .withUseHeader(true) .withParserLib("univocity") - val exception = intercept[SparkException]{ + val exception = intercept[SparkException] { parser.csvFile(TestSQLContext, carsFile) .select("year") .collect() } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exception.getMessage.contains("Malformed line in FAILFAST or Abort mode")) } - test("DSL test with alternative delimiter and quote") { val results = new CsvParser() .withDelimiter('|') @@ -199,7 +357,7 @@ class CsvFastSuite extends FunSuite { .collect() assert(results.slice(0, numCars).toSeq.map(_(0).asInstanceOf[String]) == - Seq("'2012'", "1997", "2015")) + Seq(" '2012' ", " 1997", "2015 ")) } test("DDL test with alternative delimiter and quote") { @@ -227,22 +385,22 @@ class CsvFastSuite extends FunSuite { } test("DDL test with empty file") { - sql(s""" - |CREATE TEMPORARY TABLE carsTable - |(yearMade double, makeName string, modelName string, comments string, grp string) - |USING com.databricks.spark.csv - |OPTIONS (path "$emptyFile", header "false", parserLib "univocity") + sql( s""" + |CREATE TEMPORARY TABLE carsTable + |(yearMade double, makeName string, modelName string, comments string, grp string) + |USING com.databricks.spark.csv + |OPTIONS (path "$emptyFile", header "false", parserLib "univocity") """.stripMargin.replaceAll("\n", " ")) assert(sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) } test("DDL test with schema") { - sql(s""" - |CREATE TEMPORARY TABLE carsTable - |(yearMade double, makeName string, modelName string, comments string, grp string) - |USING com.databricks.spark.csv - |OPTIONS (path "$carsFile", header "true", parserLib "univocity") + sql( s""" + |CREATE TEMPORARY TABLE carsTable + |(yearMade double, makeName string, modelName string, comments string, grp string) + |USING com.databricks.spark.csv + |OPTIONS (path "$carsFile", header "true", parserLib "univocity") """.stripMargin.replaceAll("\n", " ")) assert(sql("SELECT makeName FROM carsTable").collect().size === numCars) @@ -269,11 +427,11 @@ class CsvFastSuite extends FunSuite { |USING com.databricks.spark.csv |OPTIONS (path "$carsFile", header "true", parserLib "univocity") """.stripMargin.replaceAll("\n", " ")) - sql(s""" - |CREATE TEMPORARY TABLE carsTableEmpty - |(yearMade double, makeName string, modelName string, comments string, grp string) - |USING com.databricks.spark.csv - |OPTIONS (path "$tempEmptyDir", header "false", parserLib "univocity") + sql( s""" + |CREATE TEMPORARY TABLE carsTableEmpty + |(yearMade double, makeName string, modelName string, comments string, grp string) + |USING com.databricks.spark.csv + |OPTIONS (path "$tempEmptyDir", header "false", parserLib "univocity") """.stripMargin.replaceAll("\n", " ")) assert(sql("SELECT * FROM carsTableIO").collect().size === numCars) @@ -294,9 +452,9 @@ class CsvFastSuite extends FunSuite { val copyFilePath = tempEmptyDir + "cars-copy.csv" val cars = TestSQLContext.csvFile(carsFile, parserLib = "univocity") - cars.saveAsCsvFile(copyFilePath, Map("header" -> "true")) + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", "headerPerPart" -> "false")) - val carsCopy = TestSQLContext.csvFile(copyFilePath + "/") + val carsCopy = TestSQLContext.csvFile(copyFilePath + "/", parserLib = "univocity") assert(carsCopy.count == cars.count) assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) @@ -309,7 +467,7 @@ class CsvFastSuite extends FunSuite { val copyFilePath = tempEmptyDir + "cars-copy.csv" val cars = TestSQLContext.csvFile(carsFile, parserLib = "univocity") - cars.saveAsCsvFile(copyFilePath, Map("header" -> "true"), classOf[GzipCodec]) + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", "headerPerPart" -> "false"), classOf[GzipCodec]) val carsCopy = TestSQLContext.csvFile(copyFilePath + "/") @@ -324,7 +482,7 @@ class CsvFastSuite extends FunSuite { val copyFilePath = tempEmptyDir + "cars-copy.csv" val cars = TestSQLContext.csvFile(carsFile, parserLib = "univocity") - cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", "quote" -> "\"")) + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", "headerPerPart" -> "false", "quote" -> "\"")) val carsCopy = TestSQLContext.csvFile(copyFilePath + "/", parserLib = "univocity") @@ -339,7 +497,8 @@ class CsvFastSuite extends FunSuite { val copyFilePath = tempEmptyDir + "cars-copy.csv" val cars = TestSQLContext.csvFile(carsFile) - cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", "quote" -> "!")) + cars.saveAsCsvFile(copyFilePath, + Map("header" -> "true", "headerPerPart" -> "false", "quote" -> "!")) val carsCopy = TestSQLContext.csvFile(copyFilePath + "/", quote = '!', parserLib = "univocity") @@ -353,8 +512,9 @@ class CsvFastSuite extends FunSuite { new File(tempEmptyDir).mkdirs() val copyFilePath = tempEmptyDir + "escape-copy.csv" - val escape = TestSQLContext.csvFile(escapeFile, escape='|', quote='"') - escape.saveAsCsvFile(copyFilePath, Map("header" -> "true", "quote" -> "\"")) + val escape = TestSQLContext.csvFile(escapeFile, escape = '|', quote = '"') + escape.saveAsCsvFile(copyFilePath, + Map("header" -> "true", "headerPerPart" -> "false", "quote" -> "\"")) val escapeCopy = TestSQLContext.csvFile(copyFilePath + "/", parserLib = "univocity") @@ -457,4 +617,5 @@ class CsvFastSuite extends FunSuite { } -} \ No newline at end of file +} + diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 97fa63d..0214fd6 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -100,6 +100,18 @@ class CsvSuite extends FunSuite { assert(sql("SELECT year FROM carsTable").collect().size === numCars) } +//FIXME: the dot doesn't seem to work. +// test("DDL test with tab separated file, using newer options") { +// sql( +// s""" +// |CREATE TEMPORARY TABLE carsTable +// |USING com.databricks.spark.csv +// |OPTIONS (path "$carsTsvFile", header "true", "csvParsingOpts.delimiter" "\t") +// """.stripMargin.replaceAll("\n", " ")) +// +// assert(sql("SELECT year FROM carsTable").collect().size === numCars) +// } + test("DDL test parsing decimal type") { sql( s""" @@ -137,7 +149,7 @@ class CsvSuite extends FunSuite { .collect() } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exception.getMessage.contains("Malformed line in FAILFAST or Abort mode")) } @@ -153,6 +165,20 @@ class CsvSuite extends FunSuite { assert(results.size === numCars) } + test("DSL test with alternative delimiter and quote using simple options API") { + val optMap = Map("csvParsingOpts.quote" -> "'", + "csvParsingOpts.delimiter" -> "|" + ) + + val results = new CsvParser().withOpts(optMap) + .withUseHeader(true) + .csvFile(TestSQLContext, carsAltFile) + .select("year") + .collect() + + assert(results.size === numCars) + } + test("DSL test with alternative delimiter and quote using sparkContext.csvFile") { val results = TestSQLContext.csvFile(carsAltFile, useHeader = true, delimiter = '|', quote = '\'') @@ -178,7 +204,7 @@ class CsvSuite extends FunSuite { .collect() assert(results.slice(0, numCars).toSeq.map(_(0).asInstanceOf[String]) == - Seq("'2012'", "1997", "2015")) + Seq(" '2012' ", " 1997", "2015 ")) } test("DDL test with alternative delimiter and quote") { @@ -230,7 +256,6 @@ class CsvSuite extends FunSuite { |USING com.databricks.spark.csv |OPTIONS (path "$carsFile", header "true") """.stripMargin.replaceAll("\n", " ")) - assert(sql("SELECT makeName FROM carsTable").collect().size === numCars) assert(sql("SELECT avg(yearMade) FROM carsTable where grp = '' group by grp") .collect().head(0) === 2004.5) @@ -441,4 +466,4 @@ class CsvSuite extends FunSuite { assert(results.toSeq.map(_.toSeq) == expected) } -} \ No newline at end of file +} diff --git a/src/test/scala/com/databricks/spark/csv/OptionsSuite.scala b/src/test/scala/com/databricks/spark/csv/OptionsSuite.scala new file mode 100644 index 0000000..fc0bee2 --- /dev/null +++ b/src/test/scala/com/databricks/spark/csv/OptionsSuite.scala @@ -0,0 +1,84 @@ +// scalastyle:off +/* + * Copyright 2015 Ayasdi Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:on +package com.databricks.spark.csv + +import scala.collection.immutable.HashSet + +import org.scalatest.FunSuite + +class OptionsSuite extends FunSuite { + test("csv opts") { + val optMap = Map("csvParsingOpts.delimiter" -> "|", + "csvParsingOpts.quote" -> "[", + "csvParsingOpts.ignoreLeadingSpace" -> "false", + "csvParsingOpts.ignoreTrailingSpace" -> "true", + "csvParsingOpts.escape" -> ":", + "csvParsingOpts.numParts" -> "5") + val opts = CSVParsingOpts(optMap) + + assert(opts.delimiter === '|') + assert(opts.escapeChar === ':') + assert(opts.ignoreLeadingWhitespace === false) + assert(opts.ignoreTrailingWhitespace === true) + assert(opts.numParts === 5) + assert(opts.quoteChar === '[') + } + + test("line opts") { + val optMap = Map("lineParsingOpts.badLinePolicy" -> "abort", + "lineParsingOpts.fillValue" -> "duh") + val opts = LineParsingOpts(optMap) + + assert(opts.fillValue === "duh") + assert(opts.badLinePolicy === LineExceptionPolicy.Abort) + } + + test("string opts") { + val optMap = Map("stringParsingOpts.nulls" -> "abcd,efg", + "stringParsingOpts.emptyStringReplace" -> "") + val opts = StringParsingOpts(optMap) + + assert(opts.nullStrings === HashSet("abcd", "efg")) + assert(opts.emptyStringReplace === "") + } + + test("int opts") { + val optMap = Map("intNumParsingOpts.nulls" -> "abcd,efg", + "intNumParsingOpts.enable" -> "false") + val opts = IntNumberParsingOpts(optMap) + + assert(opts.nullStrings === HashSet("abcd", "efg")) + assert(opts.enable === false) + } + + test("real opts") { + val optMap = Map("realNumParsingOpts.nulls" -> "abcd,efg", + "realNumParsingOpts.enable" -> "false", + "realNumParsingOpts.nans" -> "NaN,nan", + "realNumParsingOpts.infs" -> "iinnff,IINNFF", + "realNumParsingOpts.-infs" -> "minusInf") + val opts = RealNumberParsingOpts(optMap) + + assert(opts.nullStrings === HashSet("abcd", "efg")) + assert(opts.nanStrings === HashSet("NaN", "nan")) + assert(opts.infPosStrings === HashSet("iinnff", "IINNFF")) + assert(opts.infNegStrings === HashSet("minusInf")) + assert(opts.enable === false) + } + +} diff --git a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala index e17bed5..2fb5de4 100644 --- a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala @@ -34,6 +34,19 @@ class TypeCastSuite extends FunSuite { } } + test("Can parse special") { + val strValues = Seq("NaN", "Infinity", "-Infinity") + val doubleChecks: Seq[Double => Boolean] = Seq(x => x.isNaN, x => x.isPosInfinity, x => x.isNegInfinity) + val floatChecks: Seq[Float => Boolean] = Seq(x => x.isNaN, x => x.isPosInfinity, x => x.isNegInfinity) + + strValues.zip(doubleChecks).foreach { case (strVal, checker) => + assert(checker(TypeCast.castTo(strVal, DoubleType).asInstanceOf[Double])) + } + strValues.zip(floatChecks).foreach { case (strVal, checker) => + assert(checker(TypeCast.castTo(strVal, FloatType).asInstanceOf[Float])) + } + } + test("Can parse escaped characters") { assert(TypeCast.toChar("""\t""") === '\t') assert(TypeCast.toChar("""\r""") === '\r')