diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index 370ab3c..8ce0f0b 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -15,7 +15,6 @@ */ package com.databricks.spark.csv - import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType @@ -117,12 +116,12 @@ class CsvParser extends Serializable { this } - /** Returns a Schema RDD for the given CSV path. */ - @throws[RuntimeException] - def csvFile(sqlContext: SQLContext, path: String): DataFrame = { - val relation: CsvRelation = CsvRelation( - () => TextFile.withCharset(sqlContext.sparkContext, path, charset), - Some(path), + /** Returns a csvRelation instance based on the state definition of csv parser. */ + private[csv] def csvRelation(sqlContext: SQLContext, csvRDD: RDD[String], + path: Option[String]): CsvRelation = { + CsvRelation( + () => csvRDD, + path, useHeader, delimiter, quote, @@ -137,27 +136,17 @@ class CsvParser extends Serializable { inferSchema, codec, nullValue)(sqlContext) + } + /** Returns a Schema RDD for the given CSV path. */ + @throws[RuntimeException] + def csvFile(sqlContext: SQLContext, path: String): DataFrame = { + val relation: CsvRelation = csvRelation(sqlContext, + TextFile.withCharset(sqlContext.sparkContext, path, charset), Some(path)) sqlContext.baseRelationToDataFrame(relation) } def csvRdd(sqlContext: SQLContext, csvRDD: RDD[String]): DataFrame = { - val relation: CsvRelation = CsvRelation( - () => csvRDD, - None, - useHeader, - delimiter, - quote, - escape, - comment, - parseMode, - parserLib, - ignoreLeadingWhiteSpace, - ignoreTrailingWhiteSpace, - treatEmptyValuesAsNulls, - schema, - inferSchema, - codec, - nullValue)(sqlContext) + val relation: CsvRelation = csvRelation(sqlContext, csvRDD, None) sqlContext.baseRelationToDataFrame(relation) } }