Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/main/java/com/databricks/spark/csv/JavaCsvParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class JavaCsvParser {
private Boolean useHeader = true;
private Character delimiter = ',';
private Character quote = '"';
private Character escape = '\\';
private StructType schema = null;

public JavaCsvParser withUseHeader(Boolean flag) {
Expand All @@ -44,6 +45,11 @@ public JavaCsvParser withQuoteChar(Character quote) {
return this;
}

public JavaCsvParser withEscapeChar(Character escape) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is removed in master. No need for these changes.

this.escape = escape;
return this;
}

public JavaCsvParser withSchema(StructType schema) {
this.schema = schema;
return this;
Expand All @@ -52,7 +58,7 @@ public JavaCsvParser withSchema(StructType schema) {
/** Returns a Schema RDD for the given CSV path. */
public DataFrame csvFile(SQLContext sqlContext, String path) {
CsvRelation relation = new
CsvRelation(path, useHeader, delimiter, quote, schema, sqlContext);
CsvRelation(path, useHeader, delimiter, quote, escape, schema, sqlContext);
return sqlContext.baseRelationToDataFrame(relation);
}
}
8 changes: 7 additions & 1 deletion src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class CsvParser {
private var useHeader: Boolean = true
private var delimiter: Character = ','
private var quote: Character = '"'
private var escape: Character = null
private var schema: StructType = null

def withUseHeader(flag: Boolean): CsvParser = {
Expand All @@ -43,14 +44,19 @@ class CsvParser {
this
}

def withEscapeChar(escape: Character): CsvParser = {
this.escape = escape
this
}

def withSchema(schema: StructType): CsvParser = {
this.schema = schema
this
}

/** Returns a Schema RDD for the given CSV path. */
def csvFile(sqlContext: SQLContext, path: String): DataFrame = {
val relation: CsvRelation = CsvRelation(path, useHeader, delimiter, quote, schema)(sqlContext)
val relation: CsvRelation = CsvRelation(path, useHeader, delimiter, quote, escape, schema)(sqlContext)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more than 100 characters.

sqlContext.baseRelationToDataFrame(relation)
}

Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ case class CsvRelation protected[spark] (
location: String,
useHeader: Boolean,
delimiter: Char,
quote: Char,
quote: Character,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Char

escape: Character,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Char

userSchema: StructType = null)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with InsertableRelation {

Expand All @@ -53,6 +54,7 @@ case class CsvRelation protected[spark] (
val csvFormat = CSVFormat.DEFAULT
.withDelimiter(delimiter)
.withQuote(quote)
.withEscape(escape)
.withSkipHeaderRecord(false)
.withHeader(fieldNames: _*)

Expand All @@ -78,6 +80,7 @@ case class CsvRelation protected[spark] (
val csvFormat = CSVFormat.DEFAULT
.withDelimiter(delimiter)
.withQuote(quote)
.withEscape(escape)
.withSkipHeaderRecord(false)
val firstRow = CSVParser.parse(firstLine, csvFormat).getRecords.head.toList
val header = if (useHeader) {
Expand Down
9 changes: 8 additions & 1 deletion src/main/scala/com/databricks/spark/csv/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ class DefaultSource
throw new Exception("Quotation cannot be more than one character.")
}

val escape = parameters.getOrElse("escape", "\\")
val escapeChar = if (escape.length == 1) {
escape.charAt(0)
} else {
throw new Exception("Escape cannot be more than one character.")
}

val useHeader = parameters.getOrElse("header", "true")
val headerFlag = if (useHeader == "true") {
true
Expand All @@ -71,7 +78,7 @@ class DefaultSource
throw new Exception("Header flag can be true or false")
}

CsvRelation(path, headerFlag, delimiterChar, quoteChar, schema)(sqlContext)
CsvRelation(path, headerFlag, delimiterChar, quoteChar, escapeChar, schema)(sqlContext)
}

override def createRelation(
Expand Down
50 changes: 30 additions & 20 deletions src/main/scala/com/databricks/spark/csv/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ package com.databricks.spark

import org.apache.spark.sql.{SQLContext, DataFrame}

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;

import java.io.StringWriter;

import scala.collection.convert.WrapAsJava

package object csv {

/**
Expand All @@ -28,7 +35,8 @@ package object csv {
location = filePath,
useHeader = true,
delimiter = ',',
quote = '"')(sqlContext)
quote = '"',
escape = '\\')(sqlContext)
sqlContext.baseRelationToDataFrame(csvRelation)
}

Expand All @@ -37,36 +45,38 @@ package object csv {
location = filePath,
useHeader = true,
delimiter = '\t',
quote = '"')(sqlContext)
quote = '"',
escape = '\\')(sqlContext)
sqlContext.baseRelationToDataFrame(csvRelation)
}
}

implicit class CsvSchemaRDD(dataFrame: DataFrame) {
def saveAsCsvFile(path: String, parameters: Map[String, String] = Map()): Unit = {
// TODO(hossein): For nested types, we may want to perform special work
val delimiter = parameters.getOrElse("delimiter", ",")
val delimiter = parameters.getOrElse("delimiter", ",").charAt(0)
val quote = parameters.getOrElse("quote", "\"").charAt(0)
val escape = parameters.getOrElse("escape", "\\").charAt(0)
val generateHeader = parameters.getOrElse("header", "false").toBoolean
val header = if (generateHeader) {
dataFrame.columns.map(c => s""""$c"""").mkString(delimiter)
} else {
"" // There is no need to generate header in this case
}
val strRDD = dataFrame.rdd.mapPartitions { iter =>
new Iterator[String] {
var firstRow: Boolean = generateHeader
val header = dataFrame.columns

override def hasNext = iter.hasNext
var firstRow: Boolean = generateHeader
val csvFileFormat = CSVFormat.DEFAULT
.withDelimiter(delimiter)
.withQuote(quote)
.withEscape(escape)

override def next: String = {
if (firstRow) {
firstRow = false
header + "\n" + iter.next.mkString(delimiter)
} else {
iter.next.mkString(delimiter)
}
}
val strRDD = dataFrame.rdd.mapPartitions { iter =>
var firstRow: Boolean = generateHeader
val newIter = iter.map(_.toSeq.toArray)
val stringWriter = new StringWriter()
val csvPrinter = new CSVPrinter(stringWriter, csvFileFormat)
if (firstRow) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this condition is always true for each partition. Note that previously we were making the check inside the Iterator.next().

firstRow = false
csvPrinter.printRecord(header:_*)
}
csvPrinter.printRecords(WrapAsJava.asJavaIterable(newIter.toIterable))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be inefficient because it seems to me we are traversing the iterator multiple times for each partition. I know CSVPrinter API is not very flexible here, but can we avoid multiple traversals?

Iterator(stringWriter.toString)
}
strRDD.saveAsTextFile(path)
}
Expand Down
4 changes: 4 additions & 0 deletions src/test/resources/family-cars.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
year,make,model,comment
2012,VW,Touran,"The ideal car for \"families\" and all their \"bags\", \"boxes\" and \"barbecues\""
2013,Seat,Alhambra,"It is a great \"family\" car, for big families"
2014,Peugeot,5008,"It is a fine \"family\" car"
33 changes: 33 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import TestSQLContext._
class CsvSuite extends FunSuite {
val carsFile = "src/test/resources/cars.csv"
val carsAltFile = "src/test/resources/cars-alternative.csv"
val familyCarsFile = "src/test/resources/family-cars.csv"
val emptyFile = "src/test/resources/empty.csv"
val tempEmptyDir = "target/test/empty/"
val tempFamilyCarsDir = "target/test/family-cars"

test("DSL test") {
val results = TestSQLContext
Expand Down Expand Up @@ -61,6 +63,37 @@ class CsvSuite extends FunSuite {
assert(results.size === 2)
}

test("DSL test read write with escape") {
//Parse a csv file with \ as escape character
val results = new CsvParser()
.withEscapeChar('\\')
.csvFile(TestSQLContext, familyCarsFile)
//Check that the file was as expected parse
val firstComment1 = results
.select("comment")
.collect()
.head
.getString(0)
assert(firstComment1 === "The ideal car for \"families\" and all their \"bags\", \"boxes\" and \"barbecues\"")

TestUtils.deleteRecursively(new File(tempFamilyCarsDir))
//Save the dataFrame without providing an escape character (default is ")
results.saveAsCsvFile(tempFamilyCarsDir, Map("header" -> "true"))
//Check that the generated file is well formed
val rawData = TestSQLContext.sparkContext.textFile(tempFamilyCarsDir).toArray
assert(rawData.contains("2012,VW,Touran,\"The ideal car for \"\"families\"\" and all their \"\"bags\"\", \"\"boxes\"\" and \"\"barbecues\"\"\""))

//Check that the generated file is well parsed
val results2 = new CsvParser()
.csvFile(TestSQLContext, tempFamilyCarsDir)
val firstComment2 = results2
.select("comment")
.collect()
.head
.getString(0)
assert(firstComment2 === "The ideal car for \"families\" and all their \"bags\", \"boxes\" and \"barbecues\"")
}

test("DDL test with alternative delimiter and quote") {
sql(
s"""
Expand Down