diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE index 989e95ccd0135..0e41cf1826453 100644 --- a/.github/PULL_REQUEST_TEMPLATE +++ b/.github/PULL_REQUEST_TEMPLATE @@ -2,11 +2,9 @@ (Please fill in changes proposed in this fix) - ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) - - (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) +Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 75861d5de7092..801d2ed4e7500 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -55,6 +55,19 @@ setMethod("initialize", "SparkDataFrame", function(.Object, sdf, isCached) { .Object }) +#' Set options/mode and then return the write object +#' @noRd +setWriteOptions <- function(write, path = NULL, mode = "error", ...) { + options <- varargsToStrEnv(...) + if (!is.null(path)) { + options[["path"]] <- path + } + jmode <- convertToJSaveMode(mode) + write <- callJMethod(write, "mode", jmode) + write <- callJMethod(write, "options", options) + write +} + #' @export #' @param sdf A Java object reference to the backing Scala DataFrame #' @param isCached TRUE if the SparkDataFrame is cached @@ -727,6 +740,8 @@ setMethod("toJSON", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions #' @rdname write.json @@ -743,8 +758,9 @@ setMethod("toJSON", #' @note write.json since 1.6.0 setMethod("write.json", signature(x = "SparkDataFrame", path = "character"), - function(x, path) { + function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") + write <- setWriteOptions(write, mode = mode, ...) invisible(callJMethod(write, "json", path)) }) @@ -755,6 +771,8 @@ setMethod("write.json", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions #' @aliases write.orc,SparkDataFrame,character-method @@ -771,8 +789,9 @@ setMethod("write.json", #' @note write.orc since 2.0.0 setMethod("write.orc", signature(x = "SparkDataFrame", path = "character"), - function(x, path) { + function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") + write <- setWriteOptions(write, mode = mode, ...) invisible(callJMethod(write, "orc", path)) }) @@ -783,6 +802,8 @@ setMethod("write.orc", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions #' @rdname write.parquet @@ -800,8 +821,9 @@ setMethod("write.orc", #' @note write.parquet since 1.6.0 setMethod("write.parquet", signature(x = "SparkDataFrame", path = "character"), - function(x, path) { + function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") + write <- setWriteOptions(write, mode = mode, ...) invisible(callJMethod(write, "parquet", path)) }) @@ -825,6 +847,8 @@ setMethod("saveAsParquetFile", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions #' @aliases write.text,SparkDataFrame,character-method @@ -841,8 +865,9 @@ setMethod("saveAsParquetFile", #' @note write.text since 2.0.0 setMethod("write.text", signature(x = "SparkDataFrame", path = "character"), - function(x, path) { + function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") + write <- setWriteOptions(write, mode = mode, ...) invisible(callJMethod(write, "text", path)) }) @@ -2637,15 +2662,9 @@ setMethod("write.df", if (is.null(source)) { source <- getDefaultSqlSource() } - jmode <- convertToJSaveMode(mode) - options <- varargsToEnv(...) - if (!is.null(path)) { - options[["path"]] <- path - } write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) - write <- callJMethod(write, "mode", jmode) - write <- callJMethod(write, "options", options) + write <- setWriteOptions(write, path = path, mode = mode, ...) write <- handledCallJMethod(write, "save") }) @@ -2701,7 +2720,7 @@ setMethod("saveAsTable", source <- getDefaultSqlSource() } jmode <- convertToJSaveMode(mode) - options <- varargsToEnv(...) + options <- varargsToStrEnv(...) write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index baa87824beb91..0d6a229e63455 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -328,6 +328,7 @@ setMethod("toDF", signature(x = "RDD"), #' It goes through the entire dataset once to determine the schema. #' #' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.json #' @export @@ -341,11 +342,13 @@ setMethod("toDF", signature(x = "RDD"), #' @name read.json #' @method read.json default #' @note read.json since 1.6.0 -read.json.default <- function(path) { +read.json.default <- function(path, ...) { sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "options", options) sdf <- callJMethod(read, "json", paths) dataFrame(sdf) } @@ -405,16 +408,19 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' Loads an ORC file, returning the result as a SparkDataFrame. #' #' @param path Path of file to read. +#' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.orc #' @export #' @name read.orc #' @note read.orc since 2.0.0 -read.orc <- function(path) { +read.orc <- function(path, ...) { sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) # Allow the user to have a more flexible definiton of the ORC file path path <- suppressWarnings(normalizePath(path)) read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "options", options) sdf <- callJMethod(read, "orc", path) dataFrame(sdf) } @@ -430,11 +436,13 @@ read.orc <- function(path) { #' @name read.parquet #' @method read.parquet default #' @note read.parquet since 1.6.0 -read.parquet.default <- function(path) { +read.parquet.default <- function(path, ...) { sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) # Allow the user to have a more flexible definiton of the Parquet file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "options", options) sdf <- callJMethod(read, "parquet", paths) dataFrame(sdf) } @@ -467,6 +475,7 @@ parquetFile <- function(x, ...) { #' Each line in the text file is a new row in the resulting SparkDataFrame. #' #' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.text #' @export @@ -479,11 +488,13 @@ parquetFile <- function(x, ...) { #' @name read.text #' @method read.text default #' @note read.text since 1.6.1 -read.text.default <- function(path) { +read.text.default <- function(path, ...) { sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "options", options) sdf <- callJMethod(read, "text", paths) dataFrame(sdf) } @@ -779,7 +790,7 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string "in 'spark.sql.sources.default' configuration by default.") } sparkSession <- getSparkSession() - options <- varargsToEnv(...) + options <- varargsToStrEnv(...) if (!is.null(path)) { options[["path"]] <- path } @@ -842,7 +853,7 @@ loadDF <- function(x = NULL, ...) { #' @note createExternalTable since 1.4.0 createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) { sparkSession <- getSparkSession() - options <- varargsToEnv(...) + options <- varargsToStrEnv(...) if (!is.null(path)) { options[["path"]] <- path } diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index fe2f3e3d10a9b..438d77a388f0e 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -87,6 +87,10 @@ objectFile <- function(sc, path, minPartitions = NULL) { #' in the list are split into \code{numSlices} slices and distributed to nodes #' in the cluster. #' +#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function +#' will write it to disk and send the file name to JVM. Also to make sure each slice is not +#' larger than that limit, number of slices may be increased. +#' #' @param sc SparkContext to use #' @param coll collection to parallelize #' @param numSlices number of partitions to create in the RDD @@ -120,6 +124,11 @@ parallelize <- function(sc, coll, numSlices = 1) { coll <- as.list(coll) } + sizeLimit <- getMaxAllocationLimit(sc) + objectSize <- object.size(coll) + + # For large objects we make sure the size of each slice is also smaller than sizeLimit + numSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) if (numSlices > length(coll)) numSlices <- length(coll) @@ -130,12 +139,44 @@ parallelize <- function(sc, coll, numSlices = 1) { # 2-tuples of raws serializedSlices <- lapply(slices, serialize, connection = NULL) - jrdd <- callJStatic("org.apache.spark.api.r.RRDD", - "createRDDFromArray", sc, serializedSlices) + # The PRC backend cannot handle arguments larger than 2GB (INT_MAX) + # If serialized data is safely less than that threshold we send it over the PRC channel. + # Otherwise, we write it to a file and send the file name + if (objectSize < sizeLimit) { + jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices) + } else { + fileName <- writeToTempFile(serializedSlices) + jrdd <- tryCatch(callJStatic( + "org.apache.spark.api.r.RRDD", "createRDDFromFile", sc, fileName, as.integer(numSlices)), + finally = { + file.remove(fileName) + }) + } RDD(jrdd, "byte") } +getMaxAllocationLimit <- function(sc) { + conf <- callJMethod(sc, "getConf") + as.numeric( + callJMethod(conf, + "get", + "spark.r.maxAllocationLimit", + toString(.Machine$integer.max / 10) # Default to a safe value: 200MB + )) +} + +writeToTempFile <- function(serializedSlices) { + fileName <- tempfile() + conn <- file(fileName, "wb") + for (slice in serializedSlices) { + writeBin(as.integer(length(slice)), conn, endian = "big") + writeBin(slice, conn, endian = "big") + } + close(conn) + fileName +} + #' Include this specified package on all workers #' #' This function can be used to include a package on all workers before the diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 90a02e2778310..810aea9017743 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -651,15 +651,17 @@ setGeneric("write.jdbc", function(x, url, tableName, mode = "error", ...) { #' @rdname write.json #' @export -setGeneric("write.json", function(x, path) { standardGeneric("write.json") }) +setGeneric("write.json", function(x, path, ...) { standardGeneric("write.json") }) #' @rdname write.orc #' @export -setGeneric("write.orc", function(x, path) { standardGeneric("write.orc") }) +setGeneric("write.orc", function(x, path, ...) { standardGeneric("write.orc") }) #' @rdname write.parquet #' @export -setGeneric("write.parquet", function(x, path) { standardGeneric("write.parquet") }) +setGeneric("write.parquet", function(x, path, ...) { + standardGeneric("write.parquet") +}) #' @rdname write.parquet #' @export @@ -667,7 +669,7 @@ setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParqu #' @rdname write.text #' @export -setGeneric("write.text", function(x, path) { standardGeneric("write.text") }) +setGeneric("write.text", function(x, path, ...) { standardGeneric("write.text") }) #' @rdname schema #' @export diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index e69666453480c..fa8bb0f79ce80 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -334,6 +334,28 @@ varargsToEnv <- function(...) { env } +# Utility function to capture the varargs into environment object but all values are converted +# into string. +varargsToStrEnv <- function(...) { + pairs <- list(...) + env <- new.env() + for (name in names(pairs)) { + value <- pairs[[name]] + if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { + stop(paste0("Unsupported type for ", name, " : ", class(value), + ". Supported types are logical, numeric, character and NULL.")) + } + if (is.logical(value)) { + env[[name]] <- tolower(as.character(value)) + } else if (is.null(value)) { + env[[name]] <- value + } else { + env[[name]] <- as.character(value) + } + } + env +} + getStorageLevel <- function(newLevel = c("DISK_ONLY", "DISK_ONLY_2", "MEMORY_AND_DISK", diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index a1eaaf20916a2..c99315726a22c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -481,6 +481,16 @@ test_that("spark.naiveBayes", { expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA) expect_equal(as.character(predict(m, t1[1, ])), "Yes") } + + # Test numeric response variable + t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1) + t2 <- t1[-4] + df <- suppressWarnings(createDataFrame(t2)) + m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0) + s <- summary(m) + expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) }) test_that("spark.survreg", { diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index f5ab601f274fe..af81d0586e0a6 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -208,6 +208,17 @@ test_that("create DataFrame from RDD", { unsetHiveContext() }) +test_that("createDataFrame uses files for large objects", { + # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value + conf <- callJMethod(sparkSession, "conf") + callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") + df <- createDataFrame(iris) + + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10)) + expect_equal(dim(df), dim(iris)) +}) + test_that("read/write csv as DataFrame", { csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", @@ -256,6 +267,23 @@ test_that("read/write csv as DataFrame", { unlink(csvPath2) }) +test_that("Support other types for options", { + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "NA,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + csvDf <- read.df(csvPath, "csv", header = "true", inferSchema = "true") + expected <- read.df(csvPath, "csv", header = TRUE, inferSchema = TRUE) + expect_equal(collect(csvDf), collect(expected)) + + expect_error(read.df(csvPath, "csv", header = TRUE, maxColumns = 3)) + unlink(csvPath) +}) + test_that("convert NAs to null type in DataFrames", { rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) df <- createDataFrame(rdd, list("a", "b")) @@ -497,6 +525,19 @@ test_that("read/write json files", { unlink(jsonPath3) }) +test_that("read/write json files - compression option", { + df <- read.df(jsonPath, "json") + + jsonPath <- tempfile(pattern = "jsonPath", fileext = ".json") + write.json(df, jsonPath, compression = "gzip") + jsonDF <- read.json(jsonPath) + expect_is(jsonDF, "SparkDataFrame") + expect_equal(count(jsonDF), count(df)) + expect_true(length(list.files(jsonPath, pattern = ".gz")) > 0) + + unlink(jsonPath) +}) + test_that("jsonRDD() on a RDD with json string", { sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) @@ -1786,6 +1827,21 @@ test_that("read/write ORC files", { unsetHiveContext() }) +test_that("read/write ORC files - compression option", { + setHiveContext(sc) + df <- read.df(jsonPath, "json") + + orcPath2 <- tempfile(pattern = "orcPath2", fileext = ".orc") + write.orc(df, orcPath2, compression = "ZLIB") + orcDF <- read.orc(orcPath2) + expect_is(orcDF, "SparkDataFrame") + expect_equal(count(orcDF), count(df)) + expect_true(length(list.files(orcPath2, pattern = ".zlib.orc")) > 0) + + unlink(orcPath2) + unsetHiveContext() +}) + test_that("read/write Parquet files", { df <- read.df(jsonPath, "json") # Test write.df and read.df @@ -1817,6 +1873,23 @@ test_that("read/write Parquet files", { unlink(parquetPath4) }) +test_that("read/write Parquet files - compression option/mode", { + df <- read.df(jsonPath, "json") + tempPath <- tempfile(pattern = "tempPath", fileext = ".parquet") + + # Test write.df and read.df + write.parquet(df, tempPath, compression = "GZIP") + df2 <- read.parquet(tempPath) + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + expect_true(length(list.files(tempPath, pattern = ".gz.parquet")) > 0) + + write.parquet(df, tempPath, mode = "overwrite") + df3 <- read.parquet(tempPath) + expect_is(df3, "SparkDataFrame") + expect_equal(count(df3), 3) +}) + test_that("read/write text files", { # Test write.df and read.df df <- read.df(jsonPath, "text") @@ -1838,6 +1911,19 @@ test_that("read/write text files", { unlink(textPath2) }) +test_that("read/write text files - compression option", { + df <- read.df(jsonPath, "text") + + textPath <- tempfile(pattern = "textPath", fileext = ".txt") + write.text(df, textPath, compression = "GZIP") + textDF <- read.text(textPath) + expect_is(textDF, "SparkDataFrame") + expect_equal(count(textDF), count(df)) + expect_true(length(list.files(textPath, pattern = ".gz")) > 0) + + unlink(textPath) +}) + test_that("describe() and summarize() on a DataFrame", { df <- read.json(jsonPath) stats <- describe(df, "age") @@ -2534,7 +2620,7 @@ test_that("enableHiveSupport on SparkSession", { unsetHiveContext() # if we are still here, it must be built with hive conf <- callJMethod(sparkSession, "conf") - value <- callJMethod(conf, "get", "spark.sql.catalogImplementation", "") + value <- callJMethod(conf, "get", "spark.sql.catalogImplementation") expect_equal(value, "hive") }) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 69ed5549168b1..a20254e9b3fa9 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -217,4 +217,13 @@ test_that("rbindRaws", { }) +test_that("varargsToStrEnv", { + strenv <- varargsToStrEnv(a = 1, b = 1.1, c = TRUE, d = "abcd") + env <- varargsToEnv(a = "1", b = "1.1", c = "true", d = "abcd") + expect_equal(strenv, env) + expect_error(varargsToStrEnv(a = list(1, "a")), + paste0("Unsupported type for a : list. Supported types are logical, ", + "numeric, character and NULL.")) +}) + sparkR.session.stop() diff --git a/README.md b/README.md index c77c429e577cd..dd7d0e22495b3 100644 --- a/README.md +++ b/README.md @@ -97,3 +97,8 @@ building for particular Hive and Hive Thriftserver distributions. Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. + +## Contributing + +Please review the [Contribution to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) +wiki for information on how to get started contributing to the project. diff --git a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java new file mode 100644 index 0000000000000..323098f69c6e1 --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.Documented; + +/** + * Annotation to inform users of how much to rely on a particular package, + * class or method not changing over time. + */ +public class InterfaceStability { + + /** + * Stable APIs that retain source and binary compatibility within a major release. + * These interfaces can change from one major release to another major release + * (e.g. from 1.0 to 2.0). + */ + @Documented + public @interface Stable {}; + + /** + * APIs that are meant to evolve towards becoming stable APIs, but are not stable APIs yet. + * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2). + */ + @Documented + public @interface Evolving {}; + + /** + * Unstable APIs, with no guarantee on stability. + * Classes that are unannotated are considered Unstable. + */ + @Documented + public @interface Unstable {}; +} diff --git a/core/pom.xml b/core/pom.xml index 9a4f234953a23..205bbc588be09 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -320,7 +320,7 @@ net.razorvine pyrolite - 4.9 + 4.13 net.razorvine diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 428ff72e71a43..7835017910232 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -145,7 +145,9 @@ private UnsafeExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.fileBufferSizeBytes = 32 * 1024; - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); + // The spill metrics are stored in a new ShuffleWriteMetrics, and then discarded (this fixes SPARK-16827). + // TODO: Instead, separate spill metrics should be stored and reported (tracked in SPARK-3577). + this.writeMetrics = new ShuffleWriteMetrics(); if (existingInMemorySorter == null) { this.inMemSorter = new UnsafeInMemorySorter( diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index a2b3826dd324b..1fd6ef4a71253 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -59,7 +59,11 @@ Last Updated - + + + Event Log + + {{#applications}} @@ -73,6 +77,7 @@ {{duration}} {{sparkUser}} {{lastUpdated}} + Download {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index c8094005c65dd..2a32e18672a22 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -133,6 +133,7 @@ $(document).ready(function() { {name: 'sixth', type: "title-numeric"}, {name: 'seventh'}, {name: 'eighth'}, + {name: 'ninth'}, ], "autoWidth": false, "order": [[ 4, "desc" ]] diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 51a699f41d15d..c9c342df82c97 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -636,7 +636,9 @@ private[spark] object SparkConf extends Logging { "Please use spark.kryoserializer.buffer instead. The default value for " + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'."), - DeprecatedConfig("spark.rpc", "2.0", "Not used any more.") + DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), + DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", + "Please use the new blacklisting options, spark.blacklist.*") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 42690844f9610..7ca3c103dbf5b 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -92,6 +92,16 @@ case class FetchFailed( s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " + s"message=\n$message\n)" } + + /** + * Fetch failures lead to a different failure handling path: (1) we don't abort the stage after + * 4 task failures, instead we immediately go back to the stage which generated the map output, + * and regenerate the missing data. (2) we don't count fetch failures for blacklisting, since + * presumably its not the fault of the executor where the task ran, but the executor which + * stored the data. This is especially important because we we might rack up a bunch of + * fetch-failures in rapid succession, on all nodes of the cluster, due to one bad node. + */ + override def countTowardsTaskFailures: Boolean = false } /** @@ -204,6 +214,7 @@ case object TaskResultLost extends TaskFailedReason { @DeveloperApi case object TaskKilled extends TaskFailedReason { override def toErrorString: String = "TaskKilled (killed intentionally)" + override def countTowardsTaskFailures: Boolean = false } /** diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 7d5348266bf6e..1422ef888fd4a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -168,7 +168,7 @@ private[r] class RBackendHandler(server: RBackend) } } catch { case e: Exception => - logError(s"$methodName on $objId failed") + logError(s"$methodName on $objId failed", e) writeInt(dos, -1) // Writing the error message of the cause for the exception. This will be returned // to user in the R process. diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 59c8429c80172..a1a5eb8cf55e8 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.python.PythonRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -140,4 +141,16 @@ private[r] object RRDD { def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) } + + /** + * Create an RRDD given a temporary file name. This is used to create RRDD when parallelize is + * called on large R objects. + * + * @param fileName name of temporary file on driver machine + * @param parallelism number of slices defaults to 4 + */ + def createRDDFromFile(jsc: JavaSparkContext, fileName: String, parallelism: Int): + JavaRDD[Array[Byte]] = { + PythonRDD.readRDDFromFile(jsc, fileName, parallelism) + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 80611658a1640..5c052286099f5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -24,6 +24,7 @@ import java.security.PrivilegedExceptionAction import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import scala.util.Properties import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path @@ -47,7 +48,6 @@ import org.apache.spark.deploy.rest._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} - /** * Whether to submit, kill, or request the status of an application. * The latter two operations are currently supported only for standalone and Mesos cluster modes. @@ -104,6 +104,8 @@ object SparkSubmit { /___/ .__/\_,_/_/ /_/\_\ version %s /_/ """.format(SPARK_VERSION)) + printStream.println("Using Scala %s, %s, %s".format( + Properties.versionString, Properties.javaVmName, Properties.javaVersion)) printStream.println("Branch %s".format(SPARK_BRANCH)) printStream.println("Compiled by user %s on %s".format(SPARK_BUILD_USER, SPARK_BUILD_DATE)) printStream.println("Revision %s".format(SPARK_REVISION)) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index fa55d470842b3..b30c980e95a9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -22,9 +22,9 @@ import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source import com.fasterxml.jackson.core.JsonProcessingException -import org.eclipse.jetty.server.{Server, ServerConnector} +import org.eclipse.jetty.server.{HttpConnectionFactory, Server, ServerConnector} import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} -import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s._ import org.json4s.jackson.JsonMethods._ @@ -83,7 +83,15 @@ private[spark] abstract class RestSubmissionServer( threadPool.setDaemon(true) val server = new Server(threadPool) - val connector = new ServerConnector(server) + val connector = new ServerConnector( + server, + null, + // Call this full constructor to set this, which forces daemon threads: + new ScheduledExecutorScheduler("RestSubmissionServer-JettyScheduler", true), + null, + -1, + -1, + new HttpConnectionFactory()) connector.setHost(host) connector.setPort(startPort) server.addConnector(connector) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 2956768c16417..dfd2f818acdac 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,8 +17,6 @@ package org.apache.spark.executor -import java.util.{ArrayList, Collections} - import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} @@ -27,7 +25,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.{BlockId, BlockStatus} -import org.apache.spark.util.{AccumulatorContext, AccumulatorMetadata, AccumulatorV2, LongAccumulator} +import org.apache.spark.util._ /** @@ -56,7 +54,7 @@ class TaskMetrics private[spark] () extends Serializable { private val _memoryBytesSpilled = new LongAccumulator private val _diskBytesSpilled = new LongAccumulator private val _peakExecutionMemory = new LongAccumulator - private val _updatedBlockStatuses = new BlockStatusesAccumulator + private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] /** * Time taken on the executor to deserialize this task. @@ -323,39 +321,3 @@ private[spark] object TaskMetrics extends Logging { tm } } - - -private[spark] class BlockStatusesAccumulator - extends AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]] { - private val _seq = Collections.synchronizedList(new ArrayList[(BlockId, BlockStatus)]()) - - override def isZero(): Boolean = _seq.isEmpty - - override def copyAndReset(): BlockStatusesAccumulator = new BlockStatusesAccumulator - - override def copy(): BlockStatusesAccumulator = { - val newAcc = new BlockStatusesAccumulator - newAcc._seq.addAll(_seq) - newAcc - } - - override def reset(): Unit = _seq.clear() - - override def add(v: (BlockId, BlockStatus)): Unit = _seq.add(v) - - override def merge( - other: AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]]): Unit = { - other match { - case o: BlockStatusesAccumulator => _seq.addAll(o.value) - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") - } - } - - override def value: java.util.List[(BlockId, BlockStatus)] = _seq - - def setValue(newValue: java.util.List[(BlockId, BlockStatus)]): Unit = { - _seq.clear() - _seq.addAll(newValue) - } -} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d536cc5097b2d..497ca92c7bc60 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.internal +import java.util.concurrent.TimeUnit + import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.ByteUnit import org.apache.spark.util.Utils @@ -91,12 +93,48 @@ package object config { .toSequence .createWithDefault(Nil) - // Note: This is a SQL config but needs to be in core because the REPL depends on it - private[spark] val CATALOG_IMPLEMENTATION = ConfigBuilder("spark.sql.catalogImplementation") - .internal() - .stringConf - .checkValues(Set("hive", "in-memory")) - .createWithDefault("in-memory") + private[spark] val MAX_TASK_FAILURES = + ConfigBuilder("spark.task.maxFailures") + .intConf + .createWithDefault(4) + + // Blacklist confs + private[spark] val BLACKLIST_ENABLED = + ConfigBuilder("spark.blacklist.enabled") + .booleanConf + .createOptional + + private[spark] val MAX_TASK_ATTEMPTS_PER_EXECUTOR = + ConfigBuilder("spark.blacklist.task.maxTaskAttemptsPerExecutor") + .intConf + .createWithDefault(1) + + private[spark] val MAX_TASK_ATTEMPTS_PER_NODE = + ConfigBuilder("spark.blacklist.task.maxTaskAttemptsPerNode") + .intConf + .createWithDefault(2) + + private[spark] val MAX_FAILURES_PER_EXEC_STAGE = + ConfigBuilder("spark.blacklist.stage.maxFailedTasksPerExecutor") + .intConf + .createWithDefault(2) + + private[spark] val MAX_FAILED_EXEC_PER_NODE_STAGE = + ConfigBuilder("spark.blacklist.stage.maxFailedExecutorsPerNode") + .intConf + .createWithDefault(2) + + private[spark] val BLACKLIST_TIMEOUT_CONF = + ConfigBuilder("spark.blacklist.timeout") + .timeConf(TimeUnit.MILLISECONDS) + .createOptional + + private[spark] val BLACKLIST_LEGACY_TIMEOUT_CONF = + ConfigBuilder("spark.scheduler.executorTaskBlacklistTime") + .internal() + .timeConf(TimeUnit.MILLISECONDS) + .createOptional + // End blacklist confs private[spark] val LISTENER_BUS_EVENT_QUEUE_SIZE = ConfigBuilder("spark.scheduler.listenerbus.eventqueue.size") @@ -163,4 +201,9 @@ package object config { .doc("Port to use for the block managed on the driver.") .fallbackConf(BLOCK_MANAGER_PORT) + private[spark] val IGNORE_CORRUPT_FILES = ConfigBuilder("spark.files.ignoreCorruptFiles") + .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + + "encountering corrupt files and contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala index 637492a97551b..5a5bd7fbbe2f8 100644 --- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala @@ -17,21 +17,18 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.NormalDistribution +import org.apache.commons.math3.distribution.{PascalDistribution, PoissonDistribution} /** * An ApproximateEvaluator for counts. - * - * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might - * be best to make this a special case of GroupedCountEvaluator with one group. */ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[Long, BoundedDouble] { - var outputsMerged = 0 - var sum: Long = 0 + private var outputsMerged = 0 + private var sum: Long = 0 - override def merge(outputId: Int, taskResult: Long) { + override def merge(outputId: Int, taskResult: Long): Unit = { outputsMerged += 1 sum += taskResult } @@ -39,18 +36,40 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) override def currentResult(): BoundedDouble = { if (outputsMerged == totalOutputs) { new BoundedDouble(sum, 1.0, sum, sum) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else if (outputsMerged == 0 || sum == 0) { + new BoundedDouble(0, 0.0, 0.0, Double.PositiveInfinity) } else { val p = outputsMerged.toDouble / totalOutputs - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val confFactor = new NormalDistribution(). - inverseCumulativeProbability(1 - (1 - confidence) / 2) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - new BoundedDouble(mean, confidence, low, high) + CountEvaluator.bound(confidence, sum, p) } } } + +private[partial] object CountEvaluator { + + def bound(confidence: Double, sum: Long, p: Double): BoundedDouble = { + // Let the total count be N. A fraction p has been counted already, with sum 'sum', + // as if each element from the total data set had been seen with probability p. + val dist = + if (sum <= 10000) { + // The remaining count, k=N-sum, may be modeled as negative binomial (aka Pascal), + // where there have been 'sum' successes of probability p already. (There are several + // conventions, but this is the one followed by Commons Math3.) + new PascalDistribution(sum.toInt, p) + } else { + // For large 'sum' (certainly, > Int.MaxValue!), use a Poisson approximation, which has + // a different interpretation. "sum" elements have been observed having scanned a fraction + // p of the data. This suggests data is counted at a rate of sum / p across the whole data + // set. The total expected count from the rest is distributed as + // (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p) + new PoissonDistribution(sum * (1 - p) / p) + } + // Not quite symmetric; calculate interval straight from discrete distribution + val low = dist.inverseCumulativeProbability((1 - confidence) / 2) + val high = dist.inverseCumulativeProbability((1 + confidence) / 2) + // Add 'sum' to each because distribution is just of remaining count, not observed + new BoundedDouble(sum + dist.getNumericalMean, confidence, sum + low, sum + high) + } + + +} diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 5afce75680f94..d2b4187df5d50 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -17,15 +17,10 @@ package org.apache.spark.partial -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap import scala.reflect.ClassTag -import org.apache.commons.math3.distribution.NormalDistribution - import org.apache.spark.util.collection.OpenHashMap /** @@ -34,10 +29,10 @@ import org.apache.spark.util.collection.OpenHashMap private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[OpenHashMap[T, Long], Map[T, BoundedDouble]] { - var outputsMerged = 0 - var sums = new OpenHashMap[T, Long]() // Sum of counts for each key + private var outputsMerged = 0 + private val sums = new OpenHashMap[T, Long]() // Sum of counts for each key - override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]) { + override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]): Unit = { outputsMerged += 1 taskResult.foreach { case (key, value) => sums.changeValue(key, value, _ + value) @@ -46,27 +41,12 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf override def currentResult(): Map[T, BoundedDouble] = { if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - sums.foreach { case (key, sum) => - result.put(key, new BoundedDouble(sum, 1.0, sum, sum)) - } - result.asScala + sums.map { case (key, sum) => (key, new BoundedDouble(sum, 1.0, sum, sum)) }.toMap } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { val p = outputsMerged.toDouble / totalOutputs - val confFactor = new NormalDistribution(). - inverseCumulativeProbability(1 - (1 - confidence) / 2) - val result = new JHashMap[T, BoundedDouble](sums.size) - sums.foreach { case (key, sum) => - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result.put(key, new BoundedDouble(mean, confidence, low, high)) - } - result.asScala + sums.map { case (key, sum) => (key, CountEvaluator.bound(confidence, sum, p)) }.toMap } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala deleted file mode 100644 index a164040684803..0000000000000 --- a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.partial - -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConverters._ -import scala.collection.Map -import scala.collection.mutable.HashMap - -import org.apache.spark.util.StatCounter - -/** - * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val mean = entry.getValue.mean - result.put(entry.getKey, new BoundedDouble(mean, 1.0, mean, mean)) - } - result.asScala - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val mean = counter.mean - val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = studentTCacher.get(counter.count) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result.put(entry.getKey, new BoundedDouble(mean, confidence, low, high)) - } - result.asScala - } - } -} diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala deleted file mode 100644 index 54a1beab3514b..0000000000000 --- a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.partial - -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConverters._ -import scala.collection.Map -import scala.collection.mutable.HashMap - -import org.apache.spark.util.StatCounter - -/** - * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getValue.sum - result.put(entry.getKey, new BoundedDouble(sum, 1.0, sum, sum)) - } - result.asScala - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count - val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) - val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = studentTCacher.get(counter.count) - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - result.put(entry.getKey, new BoundedDouble(sumEstimate, confidence, low, high)) - } - result.asScala - } - } -} diff --git a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala index 787a21a61fdcf..3fb2d30a800b6 100644 --- a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala @@ -27,10 +27,10 @@ import org.apache.spark.util.StatCounter private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[StatCounter, BoundedDouble] { - var outputsMerged = 0 - var counter = new StatCounter + private var outputsMerged = 0 + private val counter = new StatCounter() - override def merge(outputId: Int, taskResult: StatCounter) { + override def merge(outputId: Int, taskResult: StatCounter): Unit = { outputsMerged += 1 counter.merge(taskResult) } @@ -38,19 +38,24 @@ private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) override def currentResult(): BoundedDouble = { if (outputsMerged == totalOutputs) { new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean) - } else if (outputsMerged == 0) { + } else if (outputsMerged == 0 || counter.count == 0) { new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else if (counter.count == 1) { + new BoundedDouble(counter.mean, confidence, Double.NegativeInfinity, Double.PositiveInfinity) } else { val mean = counter.mean val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = { - if (counter.count > 100) { - new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) + val confFactor = if (counter.count > 100) { + // For large n, the normal distribution is a good approximation to t-distribution + new NormalDistribution().inverseCumulativeProbability((1 + confidence) / 2) } else { + // t-distribution describes distribution of actual population mean + // note that if this goes to 0, TDistribution will throw an exception. + // Hence special casing 1 above. val degreesOfFreedom = (counter.count - 1).toInt - new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) + new TDistribution(degreesOfFreedom).inverseCumulativeProbability((1 + confidence) / 2) } - } + // Symmetric, so confidence interval is symmetric about mean of distribution val low = mean - confFactor * stdev val high = mean + confFactor * stdev new BoundedDouble(mean, confidence, low, high) diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala deleted file mode 100644 index 55acb9ca64d3f..0000000000000 --- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.partial - -import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution} - -/** - * A utility class for caching Student's T distribution values for a given confidence level - * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate - * confidence intervals for many keys. - */ -private[spark] class StudentTCacher(confidence: Double) { - - val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation - - val normalApprox = new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) - val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) - - def get(sampleSize: Long): Double = { - if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) { - normalApprox - } else { - val size = sampleSize.toInt - if (cache(size) < 0) { - val tDist = new TDistribution(size - 1) - cache(size) = tDist.inverseCumulativeProbability(1 - (1 - confidence) / 2) - } - cache(size) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala index 5fe33583166c3..1988052b733e6 100644 --- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -30,10 +30,10 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[StatCounter, BoundedDouble] { // modified in merge - var outputsMerged = 0 - val counter = new StatCounter + private var outputsMerged = 0 + private val counter = new StatCounter() - override def merge(outputId: Int, taskResult: StatCounter) { + override def merge(outputId: Int, taskResult: StatCounter): Unit = { outputsMerged += 1 counter.merge(taskResult) } @@ -45,34 +45,45 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) } else { val p = outputsMerged.toDouble / totalOutputs + // Expected value of unobserved is presumed equal to that of the observed data val meanEstimate = counter.mean - val countEstimate = (counter.count + 1 - p) / p + // Expected size of rest of the data is proportional + val countEstimate = counter.count * (1 - p) / p + // Expected sum is simply their product val sumEstimate = meanEstimate * countEstimate + // Variance of unobserved data is presumed equal to that of the observed data val meanVar = counter.sampleVariance / counter.count - // branch at this point because counter.count == 1 implies counter.sampleVariance == Nan + // branch at this point because count == 1 implies counter.sampleVariance == Nan // and we don't want to ever return a bound of NaN if (meanVar.isNaN || counter.count == 1) { - new BoundedDouble(sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity) + // add sum because estimate is of unobserved data sum + new BoundedDouble( + counter.sum + sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity) } else { - val countVar = (counter.count + 1) * (1 - p) / (p * p) + // See CountEvaluator. Variance of population count here follows from negative binomial + val countVar = counter.count * (1 - p) / (p * p) + // Var(Sum) = Var(Mean*Count) = + // [E(Mean)]^2 * Var(Count) + [E(Count)]^2 * Var(Mean) + Var(Mean) * Var(Count) val sumVar = (meanEstimate * meanEstimate * countVar) + (countEstimate * countEstimate * meanVar) + (meanVar * countVar) val sumStdev = math.sqrt(sumVar) val confFactor = if (counter.count > 100) { - new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) + new NormalDistribution().inverseCumulativeProbability((1 + confidence) / 2) } else { // note that if this goes to 0, TDistribution will throw an exception. // Hence special casing 1 above. val degreesOfFreedom = (counter.count - 1).toInt - new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) + new TDistribution(degreesOfFreedom).inverseCumulativeProbability((1 + confidence) / 2) } - + // Symmetric, so confidence interval is symmetric about mean of distribution val low = sumEstimate - confFactor * sumStdev val high = sumEstimate + confFactor * sumStdev - new BoundedDouble(sumEstimate, confidence, low, high) + // add sum because estimate is of unobserved data sum + new BoundedDouble( + counter.sum + sumEstimate, confidence, counter.sum + low, counter.sum + high) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 4640b5dc2f654..e1cf3938de098 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.EOFException +import java.io.IOException import java.text.SimpleDateFormat import java.util.Date @@ -43,6 +43,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel @@ -139,6 +140,8 @@ class HadoopRDD[K, V]( private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) + private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { val conf: Configuration = broadcastedConf.value.value @@ -253,8 +256,7 @@ class HadoopRDD[K, V]( try { finished = !reader.next(key, value) } catch { - case eof: EOFException => - finished = true + case e: IOException if ignoreCorruptFiles => finished = true } if (!finished) { inputMetrics.incRecordsRead(1) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 1c7aec919bdc4..baf31fb658870 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.rdd +import java.io.IOException import java.text.SimpleDateFormat import java.util.Date @@ -33,6 +34,7 @@ import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} @@ -85,6 +87,8 @@ class NewHadoopRDD[K, V]( private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) + private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + def getConf: Configuration = { val conf: Configuration = confBroadcast.value.value if (shouldCloneJobConf) { @@ -179,7 +183,11 @@ class NewHadoopRDD[K, V]( override def hasNext: Boolean = { if (!finished && !havePair) { - finished = !reader.nextKeyValue + try { + finished = !reader.nextKeyValue + } catch { + case e: IOException if ignoreCorruptFiles => finished = true + } if (finished) { // Close and release the reader here; close() will also be called when the task // completes, but for tasks that read from many files, it helps to release the diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index ab6554fd8a7e7..eac901d10067c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -69,10 +69,10 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag]( val inputFiles = fs.listStatus(cpath) .map(_.getPath) .filter(_.getName.startsWith("part-")) - .sortBy(_.toString) + .sortBy(_.getName.stripPrefix("part-").toInt) // Fail fast if input files are invalid inputFiles.zipWithIndex.foreach { case (path, i) => - if (!path.toString.endsWith(ReliableCheckpointRDD.checkpointFileName(i))) { + if (path.getName != ReliableCheckpointRDD.checkpointFileName(i)) { throw new SparkException(s"Invalid checkpoint file: $path") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala new file mode 100644 index 0000000000000..fca4c6d37e446 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config +import org.apache.spark.util.Utils + +private[scheduler] object BlacklistTracker extends Logging { + + private val DEFAULT_TIMEOUT = "1h" + + /** + * Returns true if the blacklist is enabled, based on checking the configuration in the following + * order: + * 1. Is it specifically enabled or disabled? + * 2. Is it enabled via the legacy timeout conf? + * 3. Default is off + */ + def isBlacklistEnabled(conf: SparkConf): Boolean = { + conf.get(config.BLACKLIST_ENABLED) match { + case Some(enabled) => + enabled + case None => + // if they've got a non-zero setting for the legacy conf, always enable the blacklist, + // otherwise, use the default. + val legacyKey = config.BLACKLIST_LEGACY_TIMEOUT_CONF.key + conf.get(config.BLACKLIST_LEGACY_TIMEOUT_CONF).exists { legacyTimeout => + if (legacyTimeout == 0) { + logWarning(s"Turning off blacklisting due to legacy configuration: $legacyKey == 0") + false + } else { + logWarning(s"Turning on blacklisting due to legacy configuration: $legacyKey > 0") + true + } + } + } + } + + def getBlacklistTimeout(conf: SparkConf): Long = { + conf.get(config.BLACKLIST_TIMEOUT_CONF).getOrElse { + conf.get(config.BLACKLIST_LEGACY_TIMEOUT_CONF).getOrElse { + Utils.timeStringAsMs(DEFAULT_TIMEOUT) + } + } + } + + /** + * Verify that blacklist configurations are consistent; if not, throw an exception. Should only + * be called if blacklisting is enabled. + * + * The configuration for the blacklist is expected to adhere to a few invariants. Default + * values follow these rules of course, but users may unwittingly change one configuration + * without making the corresponding adjustment elsewhere. This ensures we fail-fast when + * there are such misconfigurations. + */ + def validateBlacklistConfs(conf: SparkConf): Unit = { + + def mustBePos(k: String, v: String): Unit = { + throw new IllegalArgumentException(s"$k was $v, but must be > 0.") + } + + Seq( + config.MAX_TASK_ATTEMPTS_PER_EXECUTOR, + config.MAX_TASK_ATTEMPTS_PER_NODE, + config.MAX_FAILURES_PER_EXEC_STAGE, + config.MAX_FAILED_EXEC_PER_NODE_STAGE + ).foreach { config => + val v = conf.get(config) + if (v <= 0) { + mustBePos(config.key, v.toString) + } + } + + val timeout = getBlacklistTimeout(conf) + if (timeout <= 0) { + // first, figure out where the timeout came from, to include the right conf in the message. + conf.get(config.BLACKLIST_TIMEOUT_CONF) match { + case Some(t) => + mustBePos(config.BLACKLIST_TIMEOUT_CONF.key, timeout.toString) + case None => + mustBePos(config.BLACKLIST_LEGACY_TIMEOUT_CONF.key, timeout.toString) + } + } + + val maxTaskFailures = conf.get(config.MAX_TASK_FAILURES) + val maxNodeAttempts = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE) + + if (maxNodeAttempts >= maxTaskFailures) { + throw new IllegalArgumentException(s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key} " + + s"( = ${maxNodeAttempts}) was >= ${config.MAX_TASK_FAILURES.key} " + + s"( = ${maxTaskFailures} ). Though blacklisting is enabled, with this configuration, " + + s"Spark will not be robust to one bad node. Decrease " + + s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key}, increase ${config.MAX_TASK_FAILURES.key}, " + + s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}") + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala new file mode 100644 index 0000000000000..20ab27d127aba --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler + +import scala.collection.mutable.HashMap + +/** + * Small helper for tracking failed tasks for blacklisting purposes. Info on all failures on one + * executor, within one task set. + */ +private[scheduler] class ExecutorFailuresInTaskSet(val node: String) { + /** + * Mapping from index of the tasks in the taskset, to the number of times it has failed on this + * executor. + */ + val taskToFailureCount = HashMap[Int, Int]() + + def updateWithFailure(taskIndex: Int): Unit = { + val prevFailureCount = taskToFailureCount.getOrElse(taskIndex, 0) + taskToFailureCount(taskIndex) = prevFailureCount + 1 + } + + def numUniqueTasksWithFailures: Int = taskToFailureCount.size + + /** + * Return the number of times this executor has failed on the given task index. + */ + def getNumTaskFailures(index: Int): Int = { + taskToFailureCount.getOrElse(index, 0) + } + + override def toString(): String = { + s"numUniqueTasksWithFailures = $numUniqueTasksWithFailures; " + + s"tasksToFailureCount = $taskToFailureCount" + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0ad4730fe20a6..3e3f1ad031e66 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -22,14 +22,14 @@ import java.util.{Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet +import scala.collection.Set +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging +import org.apache.spark.internal.config import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.scheduler.local.LocalSchedulerBackend @@ -57,7 +57,7 @@ private[spark] class TaskSchedulerImpl( isLocal: Boolean = false) extends TaskScheduler with Logging { - def this(sc: SparkContext) = this(sc, sc.conf.getInt("spark.task.maxFailures", 4)) + def this(sc: SparkContext) = this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) val conf = sc.conf @@ -100,7 +100,7 @@ private[spark] class TaskSchedulerImpl( // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host - protected val executorsByHost = new HashMap[String, HashSet[String]] + protected val hostToExecutors = new HashMap[String, HashSet[String]] protected val hostsByRack = new HashMap[String, HashSet[String]] @@ -243,8 +243,8 @@ private[spark] class TaskSchedulerImpl( } } manager.parent.removeSchedulable(manager) - logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" - .format(manager.taskSet.id, manager.parent.name)) + logInfo(s"Removed TaskSet ${manager.taskSet.id}, whose tasks have all completed, from pool" + + s" ${manager.parent.name}") } private def resourceOfferSingleTaskSet( @@ -291,11 +291,11 @@ private[spark] class TaskSchedulerImpl( // Also track if new executor is added var newExecAvail = false for (o <- offers) { - if (!executorsByHost.contains(o.host)) { - executorsByHost(o.host) = new HashSet[String]() + if (!hostToExecutors.contains(o.host)) { + hostToExecutors(o.host) = new HashSet[String]() } if (!executorIdToTaskCount.contains(o.executorId)) { - executorsByHost(o.host) += o.executorId + hostToExecutors(o.host) += o.executorId executorAdded(o.executorId, o.host) executorIdToHost(o.executorId) = o.host executorIdToTaskCount(o.executorId) = 0 @@ -334,7 +334,7 @@ private[spark] class TaskSchedulerImpl( } while (launchedTaskAtCurrentMaxLocality) } if (!launchedAnyTask) { - taskSet.abortIfCompletelyBlacklisted(executorIdToHost.keys) + taskSet.abortIfCompletelyBlacklisted(hostToExecutors) } } @@ -542,10 +542,10 @@ private[spark] class TaskSchedulerImpl( executorIdToTaskCount -= executorId val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) + val execs = hostToExecutors.getOrElse(host, new HashSet) execs -= executorId if (execs.isEmpty) { - executorsByHost -= host + hostToExecutors -= host for (rack <- getRackForHost(host); hosts <- hostsByRack.get(rack)) { hosts -= host if (hosts.isEmpty) { @@ -565,11 +565,11 @@ private[spark] class TaskSchedulerImpl( } def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { - executorsByHost.get(host).map(_.toSet) + hostToExecutors.get(host).map(_.toSet) } def hasExecutorsAliveOnHost(host: String): Boolean = synchronized { - executorsByHost.contains(host) + hostToExecutors.contains(host) } def hasHostAliveOnRack(rack: String): Boolean = synchronized { @@ -662,5 +662,4 @@ private[spark] object TaskSchedulerImpl { retval.toList } - } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala new file mode 100644 index 0000000000000..f4b0f55b7686a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler + +import scala.collection.mutable.{HashMap, HashSet} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config +import org.apache.spark.util.Clock + +/** + * Handles blacklisting executors and nodes within a taskset. This includes blacklisting specific + * (task, executor) / (task, nodes) pairs, and also completely blacklisting executors and nodes + * for the entire taskset. + * + * THREADING: This class is a helper to [[TaskSetManager]]; as with the methods in + * [[TaskSetManager]] this class is designed only to be called from code with a lock on the + * TaskScheduler (e.g. its event handlers). It should not be called from other threads. + */ +private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, val clock: Clock) + extends Logging { + + private val MAX_TASK_ATTEMPTS_PER_EXECUTOR = conf.get(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR) + private val MAX_TASK_ATTEMPTS_PER_NODE = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE) + private val MAX_FAILURES_PER_EXEC_STAGE = conf.get(config.MAX_FAILURES_PER_EXEC_STAGE) + private val MAX_FAILED_EXEC_PER_NODE_STAGE = conf.get(config.MAX_FAILED_EXEC_PER_NODE_STAGE) + + /** + * A map from each executor to the task failures on that executor. + */ + val execToFailures = new HashMap[String, ExecutorFailuresInTaskSet]() + + /** + * Map from node to all executors on it with failures. Needed because we want to know about + * executors on a node even after they have died. (We don't want to bother tracking the + * node -> execs mapping in the usual case when there aren't any failures). + */ + private val nodeToExecsWithFailures = new HashMap[String, HashSet[String]]() + private val nodeToBlacklistedTaskIndexes = new HashMap[String, HashSet[Int]]() + private val blacklistedExecs = new HashSet[String]() + private val blacklistedNodes = new HashSet[String]() + + /** + * Return true if this executor is blacklisted for the given task. This does *not* + * need to return true if the executor is blacklisted for the entire stage. + * That is to keep this method as fast as possible in the inner-loop of the + * scheduler, where those filters will have already been applied. + */ + def isExecutorBlacklistedForTask(executorId: String, index: Int): Boolean = { + execToFailures.get(executorId).exists { execFailures => + execFailures.getNumTaskFailures(index) >= MAX_TASK_ATTEMPTS_PER_EXECUTOR + } + } + + def isNodeBlacklistedForTask(node: String, index: Int): Boolean = { + nodeToBlacklistedTaskIndexes.get(node).exists(_.contains(index)) + } + + /** + * Return true if this executor is blacklisted for the given stage. Completely ignores + * anything to do with the node the executor is on. That + * is to keep this method as fast as possible in the inner-loop of the scheduler, where those + * filters will already have been applied. + */ + def isExecutorBlacklistedForTaskSet(executorId: String): Boolean = { + blacklistedExecs.contains(executorId) + } + + def isNodeBlacklistedForTaskSet(node: String): Boolean = { + blacklistedNodes.contains(node) + } + + private[scheduler] def updateBlacklistForFailedTask( + host: String, + exec: String, + index: Int): Unit = { + val execFailures = execToFailures.getOrElseUpdate(exec, new ExecutorFailuresInTaskSet(host)) + execFailures.updateWithFailure(index) + + // check if this task has also failed on other executors on the same host -- if its gone + // over the limit, blacklist this task from the entire host. + val execsWithFailuresOnNode = nodeToExecsWithFailures.getOrElseUpdate(host, new HashSet()) + execsWithFailuresOnNode += exec + val failuresOnHost = execsWithFailuresOnNode.toIterator.flatMap { exec => + execToFailures.get(exec).map { failures => + // We count task attempts here, not the number of unique executors with failures. This is + // because jobs are aborted based on the number task attempts; if we counted unique + // executors, it would be hard to config to ensure that you try another + // node before hitting the max number of task failures. + failures.getNumTaskFailures(index) + } + }.sum + if (failuresOnHost >= MAX_TASK_ATTEMPTS_PER_NODE) { + nodeToBlacklistedTaskIndexes.getOrElseUpdate(host, new HashSet()) += index + } + + // Check if enough tasks have failed on the executor to blacklist it for the entire stage. + if (execFailures.numUniqueTasksWithFailures >= MAX_FAILURES_PER_EXEC_STAGE) { + if (blacklistedExecs.add(exec)) { + logInfo(s"Blacklisting executor ${exec} for stage $stageId") + // This executor has been pushed into the blacklist for this stage. Let's check if it + // pushes the whole node into the blacklist. + val blacklistedExecutorsOnNode = + execsWithFailuresOnNode.filter(blacklistedExecs.contains(_)) + if (blacklistedExecutorsOnNode.size >= MAX_FAILED_EXEC_PER_NODE_STAGE) { + if (blacklistedNodes.add(host)) { + logInfo(s"Blacklisting ${host} for stage $stageId") + } + } + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 226bed284a40a..9491bc7a0497e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -22,9 +22,7 @@ import java.nio.ByteBuffer import java.util.Arrays import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.math.{max, min} import scala.util.control.NonFatal @@ -53,19 +51,9 @@ private[spark] class TaskSetManager( sched: TaskSchedulerImpl, val taskSet: TaskSet, val maxTaskFailures: Int, - clock: Clock = new SystemClock()) - extends Schedulable with Logging { + clock: Clock = new SystemClock()) extends Schedulable with Logging { - val conf = sched.sc.conf - - /* - * Sometimes if an executor is dead or in an otherwise invalid state, the driver - * does not realize right away leading to repeated task failures. If enabled, - * this temporarily prevents a task from re-launching on an executor where - * it just failed. - */ - private val EXECUTOR_TASK_BLACKLIST_TIMEOUT = - conf.getLong("spark.scheduler.executorTaskBlacklistTime", 0L) + private val conf = sched.sc.conf // Quantile of tasks at which to start speculation val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) @@ -83,8 +71,6 @@ private[spark] class TaskSetManager( val copiesRunning = new Array[Int](numTasks) val successful = new Array[Boolean](numTasks) private val numFailures = new Array[Int](numTasks) - // key is taskId (aka TaskInfo.index), value is a Map of executor id to when it failed - private val failedExecutors = new HashMap[Int, HashMap[String, Long]]() val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) var tasksSuccessful = 0 @@ -98,6 +84,14 @@ private[spark] class TaskSetManager( var totalResultSize = 0L var calculatedTasks = 0 + private val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { + if (BlacklistTracker.isBlacklistEnabled(conf)) { + Some(new TaskSetBlacklist(conf, stageId, clock)) + } else { + None + } + } + val runningTasksSet = new HashSet[Long] override def runningTasks: Int = runningTasksSet.size @@ -245,12 +239,15 @@ private[spark] class TaskSetManager( * This method also cleans up any tasks in the list that have already * been launched, since we want that to happen lazily. */ - private def dequeueTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = { + private def dequeueTaskFromList( + execId: String, + host: String, + list: ArrayBuffer[Int]): Option[Int] = { var indexOffset = list.size while (indexOffset > 0) { indexOffset -= 1 val index = list(indexOffset) - if (!executorIsBlacklisted(execId, index)) { + if (!isTaskBlacklistedOnExecOrNode(index, execId, host)) { // This should almost always be list.trimEnd(1) to remove tail list.remove(indexOffset) if (copiesRunning(index) == 0 && !successful(index)) { @@ -266,19 +263,11 @@ private[spark] class TaskSetManager( taskAttempts(taskIndex).exists(_.host == host) } - /** - * Is this re-execution of a failed task on an executor it already failed in before - * EXECUTOR_TASK_BLACKLIST_TIMEOUT has elapsed ? - */ - private[scheduler] def executorIsBlacklisted(execId: String, taskId: Int): Boolean = { - if (failedExecutors.contains(taskId)) { - val failed = failedExecutors.get(taskId).get - - return failed.contains(execId) && - clock.getTimeMillis() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT + private def isTaskBlacklistedOnExecOrNode(index: Int, execId: String, host: String): Boolean = { + taskSetBlacklistHelperOpt.exists { blacklist => + blacklist.isNodeBlacklistedForTask(host, index) || + blacklist.isExecutorBlacklistedForTask(execId, index) } - - false } /** @@ -292,8 +281,10 @@ private[spark] class TaskSetManager( { speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set - def canRunOnHost(index: Int): Boolean = - !hasAttemptOnHost(index, host) && !executorIsBlacklisted(execId, index) + def canRunOnHost(index: Int): Boolean = { + !hasAttemptOnHost(index, host) && + !isTaskBlacklistedOnExecOrNode(index, execId, host) + } if (!speculatableTasks.isEmpty) { // Check for process-local tasks; note that tasks can be process-local @@ -366,19 +357,19 @@ private[spark] class TaskSetManager( private def dequeueTask(execId: String, host: String, maxLocality: TaskLocality.Value) : Option[(Int, TaskLocality.Value, Boolean)] = { - for (index <- dequeueTaskFromList(execId, getPendingTasksForExecutor(execId))) { + for (index <- dequeueTaskFromList(execId, host, getPendingTasksForExecutor(execId))) { return Some((index, TaskLocality.PROCESS_LOCAL, false)) } if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) { - for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host))) { + for (index <- dequeueTaskFromList(execId, host, getPendingTasksForHost(host))) { return Some((index, TaskLocality.NODE_LOCAL, false)) } } if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) { // Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic - for (index <- dequeueTaskFromList(execId, pendingTasksWithNoPrefs)) { + for (index <- dequeueTaskFromList(execId, host, pendingTasksWithNoPrefs)) { return Some((index, TaskLocality.PROCESS_LOCAL, false)) } } @@ -386,14 +377,14 @@ private[spark] class TaskSetManager( if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) { for { rack <- sched.getRackForHost(host) - index <- dequeueTaskFromList(execId, getPendingTasksForRack(rack)) + index <- dequeueTaskFromList(execId, host, getPendingTasksForRack(rack)) } { return Some((index, TaskLocality.RACK_LOCAL, false)) } } if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) { - for (index <- dequeueTaskFromList(execId, allPendingTasks)) { + for (index <- dequeueTaskFromList(execId, host, allPendingTasks)) { return Some((index, TaskLocality.ANY, false)) } } @@ -421,7 +412,11 @@ private[spark] class TaskSetManager( maxLocality: TaskLocality.TaskLocality) : Option[TaskDescription] = { - if (!isZombie) { + val offerBlacklisted = taskSetBlacklistHelperOpt.exists { blacklist => + blacklist.isNodeBlacklistedForTaskSet(host) || + blacklist.isExecutorBlacklistedForTaskSet(execId) + } + if (!isZombie && !offerBlacklisted) { val curTime = clock.getTimeMillis() var allowedLocality = maxLocality @@ -434,60 +429,59 @@ private[spark] class TaskSetManager( } } - dequeueTask(execId, host, allowedLocality) match { - case Some((index, taskLocality, speculative)) => - // Found a task; do some bookkeeping and return a task description - val task = tasks(index) - val taskId = sched.newTaskId() - // Do various bookkeeping - copiesRunning(index) += 1 - val attemptNum = taskAttempts(index).size - val info = new TaskInfo(taskId, index, attemptNum, curTime, - execId, host, taskLocality, speculative) - taskInfos(taskId) = info - taskAttempts(index) = info :: taskAttempts(index) - // Update our locality level for delay scheduling - // NO_PREF will not affect the variables related to delay scheduling - if (maxLocality != TaskLocality.NO_PREF) { - currentLocalityIndex = getLocalityIndex(taskLocality) - lastLaunchTime = curTime - } - // Serialize and return the task - val startTime = clock.getTimeMillis() - val serializedTask: ByteBuffer = try { - Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) - } catch { - // If the task cannot be serialized, then there's no point to re-attempt the task, - // as it will always fail. So just abort the whole task-set. - case NonFatal(e) => - val msg = s"Failed to serialize task $taskId, not attempting to retry it." - logError(msg, e) - abort(s"$msg Exception during serialization: $e") - throw new TaskNotSerializableException(e) - } - if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && - !emittedTaskSizeWarning) { - emittedTaskSizeWarning = true - logWarning(s"Stage ${task.stageId} contains a task of very large size " + - s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + - s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") - } - addRunningTask(taskId) - - // We used to log the time it takes to serialize the task, but task size is already - // a good proxy to task serialization time. - // val timeTaken = clock.getTime() - startTime - val taskName = s"task ${info.id} in stage ${taskSet.id}" - logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," + - s" $taskLocality, ${serializedTask.limit} bytes)") - - sched.dagScheduler.taskStarted(task, info) - return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, - taskName, index, serializedTask)) - case _ => + dequeueTask(execId, host, allowedLocality).map { case ((index, taskLocality, speculative)) => + // Found a task; do some bookkeeping and return a task description + val task = tasks(index) + val taskId = sched.newTaskId() + // Do various bookkeeping + copiesRunning(index) += 1 + val attemptNum = taskAttempts(index).size + val info = new TaskInfo(taskId, index, attemptNum, curTime, + execId, host, taskLocality, speculative) + taskInfos(taskId) = info + taskAttempts(index) = info :: taskAttempts(index) + // Update our locality level for delay scheduling + // NO_PREF will not affect the variables related to delay scheduling + if (maxLocality != TaskLocality.NO_PREF) { + currentLocalityIndex = getLocalityIndex(taskLocality) + lastLaunchTime = curTime + } + // Serialize and return the task + val startTime = clock.getTimeMillis() + val serializedTask: ByteBuffer = try { + Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + } catch { + // If the task cannot be serialized, then there's no point to re-attempt the task, + // as it will always fail. So just abort the whole task-set. + case NonFatal(e) => + val msg = s"Failed to serialize task $taskId, not attempting to retry it." + logError(msg, e) + abort(s"$msg Exception during serialization: $e") + throw new TaskNotSerializableException(e) + } + if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && + !emittedTaskSizeWarning) { + emittedTaskSizeWarning = true + logWarning(s"Stage ${task.stageId} contains a task of very large size " + + s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + + s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") + } + addRunningTask(taskId) + + // We used to log the time it takes to serialize the task, but task size is already + // a good proxy to task serialization time. + // val timeTaken = clock.getTime() - startTime + val taskName = s"task ${info.id} in stage ${taskSet.id}" + logInfo(s"Starting $taskName (TID $taskId, $host, executor ${info.executorId}, " + + s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit} bytes)") + + sched.dagScheduler.taskStarted(task, info) + new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, + taskName, index, serializedTask) } + } else { + None } - None } private def maybeFinishTaskSet() { @@ -589,37 +583,56 @@ private[spark] class TaskSetManager( * the hang as quickly as we could have, but we'll always detect the hang eventually, and the * method is faster in the typical case. In the worst case, this method can take * O(maxTaskFailures + numTasks) time, but it will be faster when there haven't been any task - * failures (this is because the method picks on unscheduled task, and then iterates through each - * executor until it finds one that the task hasn't failed on already). + * failures (this is because the method picks one unscheduled task, and then iterates through each + * executor until it finds one that the task isn't blacklisted on). */ - private[scheduler] def abortIfCompletelyBlacklisted(executors: Iterable[String]): Unit = { - - val pendingTask: Option[Int] = { - // usually this will just take the last pending task, but because of the lazy removal - // from each list, we may need to go deeper in the list. We poll from the end because - // failed tasks are put back at the end of allPendingTasks, so we're more likely to find - // an unschedulable task this way. - val indexOffset = allPendingTasks.lastIndexWhere { indexInTaskSet => - copiesRunning(indexInTaskSet) == 0 && !successful(indexInTaskSet) - } - if (indexOffset == -1) { - None - } else { - Some(allPendingTasks(indexOffset)) - } - } + private[scheduler] def abortIfCompletelyBlacklisted( + hostToExecutors: HashMap[String, HashSet[String]]): Unit = { + taskSetBlacklistHelperOpt.foreach { taskSetBlacklist => + // Only look for unschedulable tasks when at least one executor has registered. Otherwise, + // task sets will be (unnecessarily) aborted in cases when no executors have registered yet. + if (hostToExecutors.nonEmpty) { + // find any task that needs to be scheduled + val pendingTask: Option[Int] = { + // usually this will just take the last pending task, but because of the lazy removal + // from each list, we may need to go deeper in the list. We poll from the end because + // failed tasks are put back at the end of allPendingTasks, so we're more likely to find + // an unschedulable task this way. + val indexOffset = allPendingTasks.lastIndexWhere { indexInTaskSet => + copiesRunning(indexInTaskSet) == 0 && !successful(indexInTaskSet) + } + if (indexOffset == -1) { + None + } else { + Some(allPendingTasks(indexOffset)) + } + } - // If no executors have registered yet, don't abort the stage, just wait. We probably - // got here because a task set was added before the executors registered. - if (executors.nonEmpty) { - // take any task that needs to be scheduled, and see if we can find some executor it *could* - // run on - pendingTask.foreach { taskId => - if (executors.forall(executorIsBlacklisted(_, taskId))) { - val execs = executors.toIndexedSeq.sorted.mkString("(", ",", ")") - val partition = tasks(taskId).partitionId - abort(s"Aborting ${taskSet} because task $taskId (partition $partition)" + - s" has already failed on executors $execs, and no other executors are available.") + pendingTask.foreach { indexInTaskSet => + // try to find some executor this task can run on. Its possible that some *other* + // task isn't schedulable anywhere, but we will discover that in some later call, + // when that unschedulable task is the last task remaining. + val blacklistedEverywhere = hostToExecutors.forall { case (host, execsOnHost) => + // Check if the task can run on the node + val nodeBlacklisted = + taskSetBlacklist.isNodeBlacklistedForTaskSet(host) || + taskSetBlacklist.isNodeBlacklistedForTask(host, indexInTaskSet) + if (nodeBlacklisted) { + true + } else { + // Check if the task can run on any of the executors + execsOnHost.forall { exec => + taskSetBlacklist.isExecutorBlacklistedForTaskSet(exec) || + taskSetBlacklist.isExecutorBlacklistedForTask(exec, indexInTaskSet) + } + } + } + if (blacklistedEverywhere) { + val partition = tasks(indexInTaskSet).partitionId + abort(s"Aborting $taskSet because task $indexInTaskSet (partition $partition) " + + s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior " + + s"can be configured via spark.blacklist.*.") + } } } } @@ -677,8 +690,9 @@ private[spark] class TaskSetManager( } if (!successful(index)) { tasksSuccessful += 1 - logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format( - info.id, taskSet.id, info.taskId, info.duration, info.host, tasksSuccessful, numTasks)) + logInfo(s"Finished task ${info.id} in stage ${taskSet.id} (TID ${info.taskId}) in" + + s" ${info.duration} ms on ${info.host} (executor ${info.executorId})" + + s" ($tasksSuccessful/$numTasks)") // Mark successful and stop if all the tasks have succeeded. successful(index) = true if (tasksSuccessful == numTasks) { @@ -688,7 +702,6 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } - failedExecutors.remove(index) maybeFinishTaskSet() } @@ -706,8 +719,8 @@ private[spark] class TaskSetManager( val index = info.index copiesRunning(index) -= 1 var accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty - val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + - reason.toErrorString + val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}," + + s" executor ${info.executorId}): ${reason.toErrorString}" val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => logWarning(failureReason) @@ -715,7 +728,6 @@ private[spark] class TaskSetManager( successful(index) = true tasksSuccessful += 1 } - // Not adding to failed executors for FetchFailed. isZombie = true None @@ -751,8 +763,8 @@ private[spark] class TaskSetManager( logWarning(failureReason) } else { logInfo( - s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " + - s"${ef.className} (${ef.description}) [duplicate $dupCount]") + s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on ${info.host}, executor" + + s" ${info.executorId}: ${ef.className} (${ef.description}) [duplicate $dupCount]") } ef.exception @@ -766,9 +778,7 @@ private[spark] class TaskSetManager( logWarning(failureReason) None } - // always add to failed executors - failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). - put(info.executorId, clock.getTimeMillis()) + sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) if (successful(index)) { @@ -780,7 +790,9 @@ private[spark] class TaskSetManager( addPendingTask(index) } - if (!isZombie && state != TaskState.KILLED && reason.countTowardsTaskFailures) { + if (!isZombie && reason.countTowardsTaskFailures) { + taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( + info.host, info.executorId, index)) assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 24f3f757157f3..35c3c8d00f99b 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -27,12 +27,12 @@ import scala.xml.Node import org.eclipse.jetty.client.api.Response import org.eclipse.jetty.proxy.ProxyServlet -import org.eclipse.jetty.server.{Request, Server, ServerConnector} +import org.eclipse.jetty.server.{HttpConnectionFactory, Request, Server, ServerConnector} import org.eclipse.jetty.server.handler._ import org.eclipse.jetty.servlet._ import org.eclipse.jetty.servlets.gzip.GzipHandler import org.eclipse.jetty.util.component.LifeCycle -import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s.JValue import org.json4s.jackson.JsonMethods.{pretty, render} @@ -294,7 +294,15 @@ private[spark] object JettyUtils extends Logging { val server = new Server(pool) val connectors = new ArrayBuffer[ServerConnector] // Create a connector on port currentPort to listen for HTTP requests - val httpConnector = new ServerConnector(server) + val httpConnector = new ServerConnector( + server, + null, + // Call this full constructor to set this, which forces daemon threads: + new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true), + null, + -1, + -1, + new HttpConnectionFactory()) httpConnector.setPort(currentPort) connectors += httpConnector diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index c04964ec66479..f6713097b9349 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -216,6 +216,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { private def jobsTable( request: HttpServletRequest, + tableHeaderId: String, jobTag: String, jobs: Seq[JobUIData]): Seq[Node] = { val allParameters = request.getParameterMap.asScala.toMap @@ -256,6 +257,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { try { new JobPagedTable( jobs, + tableHeaderId, jobTag, UIUtils.prependBaseUri(parent.basePath), "jobs", // subPath @@ -288,9 +290,9 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val completedJobs = listener.completedJobs.reverse.toSeq val failedJobs = listener.failedJobs.reverse.toSeq - val activeJobsTable = jobsTable(request, "activeJob", activeJobs) - val completedJobsTable = jobsTable(request, "completedJob", completedJobs) - val failedJobsTable = jobsTable(request, "failedJob", failedJobs) + val activeJobsTable = jobsTable(request, "active", "activeJob", activeJobs) + val completedJobsTable = jobsTable(request, "completed", "completedJob", completedJobs) + val failedJobsTable = jobsTable(request, "failed", "failedJob", failedJobs) val shouldShowActiveJobs = activeJobs.nonEmpty val shouldShowCompletedJobs = completedJobs.nonEmpty @@ -455,23 +457,11 @@ private[ui] class JobDataSource( * Return Ordering according to sortColumn and desc */ private def ordering(sortColumn: String, desc: Boolean): Ordering[JobTableRowData] = { - val ordering = sortColumn match { - case "Job Id" | "Job Id (Job Group)" => new Ordering[JobTableRowData] { - override def compare(x: JobTableRowData, y: JobTableRowData): Int = - Ordering.Int.compare(x.jobData.jobId, y.jobData.jobId) - } - case "Description" => new Ordering[JobTableRowData] { - override def compare(x: JobTableRowData, y: JobTableRowData): Int = - Ordering.String.compare(x.lastStageDescription, y.lastStageDescription) - } - case "Submitted" => new Ordering[JobTableRowData] { - override def compare(x: JobTableRowData, y: JobTableRowData): Int = - Ordering.Long.compare(x.submissionTime, y.submissionTime) - } - case "Duration" => new Ordering[JobTableRowData] { - override def compare(x: JobTableRowData, y: JobTableRowData): Int = - Ordering.Long.compare(x.duration, y.duration) - } + val ordering: Ordering[JobTableRowData] = sortColumn match { + case "Job Id" | "Job Id (Job Group)" => Ordering.by(_.jobData.jobId) + case "Description" => Ordering.by(x => (x.lastStageDescription, x.lastStageName)) + case "Submitted" => Ordering.by(_.submissionTime) + case "Duration" => Ordering.by(_.duration) case "Stages: Succeeded/Total" | "Tasks (for all stages): Succeeded/Total" => throw new IllegalArgumentException(s"Unsortable column: $sortColumn") case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") @@ -486,6 +476,7 @@ private[ui] class JobDataSource( } private[ui] class JobPagedTable( data: Seq[JobUIData], + tableHeaderId: String, jobTag: String, basePath: String, subPath: String, @@ -498,8 +489,7 @@ private[ui] class JobPagedTable( sortColumn: String, desc: Boolean ) extends PagedTable[JobTableRowData] { - val parameterPath = UIUtils.prependBaseUri(basePath) + s"/$subPath/?" + - parameterOtherTable.mkString("&") + val parameterPath = basePath + s"/$subPath/?" + parameterOtherTable.mkString("&") override def tableId: String = jobTag + "-table" @@ -528,12 +518,13 @@ private[ui] class JobPagedTable( s"&$pageNumberFormField=$page" + s"&$jobTag.sort=$encodedSortColumn" + s"&$jobTag.desc=$desc" + - s"&$pageSizeFormField=$pageSize" + s"&$pageSizeFormField=$pageSize" + + s"#$tableHeaderId" } override def goButtonFormPath: String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - s"$parameterPath&$jobTag.sort=$encodedSortColumn&$jobTag.desc=$desc" + s"$parameterPath&$jobTag.sort=$encodedSortColumn&$jobTag.desc=$desc#$tableHeaderId" } override def headers: Seq[Node] = { @@ -557,7 +548,8 @@ private[ui] class JobPagedTable( parameterPath + s"&$jobTag.sort=${URLEncoder.encode(header, "UTF-8")}" + s"&$jobTag.desc=${!desc}" + - s"&$jobTag.pageSize=$pageSize") + s"&$jobTag.pageSize=$pageSize" + + s"#$tableHeaderId") val arrow = if (desc) "▾" else "▴" // UP or DOWN @@ -572,7 +564,8 @@ private[ui] class JobPagedTable( val headerLink = Unparsed( parameterPath + s"&$jobTag.sort=${URLEncoder.encode(header, "UTF-8")}" + - s"&$jobTag.pageSize=$pageSize") + s"&$jobTag.pageSize=$pageSize" + + s"#$tableHeaderId") diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index cba8f82dd77a6..fe6ca1099e6b0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -41,19 +41,19 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val subPath = "stages" val activeStagesTable = - new StageTableBase(request, activeStages, "activeStage", parent.basePath, subPath, + new StageTableBase(request, activeStages, "active", "activeStage", parent.basePath, subPath, parent.progressListener, parent.isFairScheduler, killEnabled = parent.killEnabled, isFailedStage = false) val pendingStagesTable = - new StageTableBase(request, pendingStages, "pendingStage", parent.basePath, subPath, - parent.progressListener, parent.isFairScheduler, + new StageTableBase(request, pendingStages, "pending", "pendingStage", parent.basePath, + subPath, parent.progressListener, parent.isFairScheduler, killEnabled = false, isFailedStage = false) val completedStagesTable = - new StageTableBase(request, completedStages, "completedStage", parent.basePath, subPath, - parent.progressListener, parent.isFairScheduler, + new StageTableBase(request, completedStages, "completed", "completedStage", parent.basePath, + subPath, parent.progressListener, parent.isFairScheduler, killEnabled = false, isFailedStage = false) val failedStagesTable = - new StageTableBase(request, failedStages, "failedStage", parent.basePath, subPath, + new StageTableBase(request, failedStages, "failed", "failedStage", parent.basePath, subPath, parent.progressListener, parent.isFairScheduler, killEnabled = false, isFailedStage = true) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 2f7f8976a8899..0ff9e5e9411ca 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -230,20 +230,27 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { val basePath = "jobs/job" + val pendingOrSkippedTableId = + if (isComplete) { + "pending" + } else { + "skipped" + } + val activeStagesTable = - new StageTableBase(request, activeStages, "activeStage", parent.basePath, + new StageTableBase(request, activeStages, "active", "activeStage", parent.basePath, basePath, parent.jobProgresslistener, parent.isFairScheduler, killEnabled = parent.killEnabled, isFailedStage = false) val pendingOrSkippedStagesTable = - new StageTableBase(request, pendingOrSkippedStages, "pendingStage", parent.basePath, - basePath, parent.jobProgresslistener, parent.isFairScheduler, + new StageTableBase(request, pendingOrSkippedStages, pendingOrSkippedTableId, "pendingStage", + parent.basePath, basePath, parent.jobProgresslistener, parent.isFairScheduler, killEnabled = false, isFailedStage = false) val completedStagesTable = - new StageTableBase(request, completedStages, "completedStage", parent.basePath, + new StageTableBase(request, completedStages, "completed", "completedStage", parent.basePath, basePath, parent.jobProgresslistener, parent.isFairScheduler, killEnabled = false, isFailedStage = false) val failedStagesTable = - new StageTableBase(request, failedStages, "failedStage", parent.basePath, + new StageTableBase(request, failedStages, "failed", "failedStage", parent.basePath, basePath, parent.jobProgresslistener, parent.isFairScheduler, killEnabled = false, isFailedStage = true) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index f9cb717918592..8ee70d27cc09f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -44,7 +44,7 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { } val shouldShowActiveStages = activeStages.nonEmpty val activeStagesTable = - new StageTableBase(request, activeStages, "activeStage", parent.basePath, "stages/pool", + new StageTableBase(request, activeStages, "", "activeStage", parent.basePath, "stages/pool", parent.progressListener, parent.isFairScheduler, parent.killEnabled, isFailedStage = false) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index c322ae0972ad7..8c7cefe200739 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -1050,89 +1050,38 @@ private[ui] class TaskDataSource( * Return Ordering according to sortColumn and desc */ private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = { - val ordering = sortColumn match { - case "Index" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Int.compare(x.index, y.index) - } - case "ID" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.taskId, y.taskId) - } - case "Attempt" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Int.compare(x.attempt, y.attempt) - } - case "Status" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.String.compare(x.status, y.status) - } - case "Locality Level" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.String.compare(x.taskLocality, y.taskLocality) - } - case "Executor ID / Host" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.String.compare(x.executorIdAndHost, y.executorIdAndHost) - } - case "Launch Time" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.launchTime, y.launchTime) - } - case "Duration" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.duration, y.duration) - } - case "Scheduler Delay" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.schedulerDelay, y.schedulerDelay) - } - case "Task Deserialization Time" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.taskDeserializationTime, y.taskDeserializationTime) - } - case "GC Time" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.gcTime, y.gcTime) - } - case "Result Serialization Time" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.serializationTime, y.serializationTime) - } - case "Getting Result Time" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime) - } - case "Peak Execution Memory" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed) - } + val ordering: Ordering[TaskTableRowData] = sortColumn match { + case "Index" => Ordering.by(_.index) + case "ID" => Ordering.by(_.taskId) + case "Attempt" => Ordering.by(_.attempt) + case "Status" => Ordering.by(_.status) + case "Locality Level" => Ordering.by(_.taskLocality) + case "Executor ID / Host" => Ordering.by(_.executorIdAndHost) + case "Launch Time" => Ordering.by(_.launchTime) + case "Duration" => Ordering.by(_.duration) + case "Scheduler Delay" => Ordering.by(_.schedulerDelay) + case "Task Deserialization Time" => Ordering.by(_.taskDeserializationTime) + case "GC Time" => Ordering.by(_.gcTime) + case "Result Serialization Time" => Ordering.by(_.serializationTime) + case "Getting Result Time" => Ordering.by(_.gettingResultTime) + case "Peak Execution Memory" => Ordering.by(_.peakExecutionMemoryUsed) case "Accumulators" => if (hasAccumulators) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.String.compare(x.accumulators.get, y.accumulators.get) - } + Ordering.by(_.accumulators.get) } else { throw new IllegalArgumentException( "Cannot sort by Accumulators because of no accumulators") } case "Input Size / Records" => if (hasInput) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.input.get.inputSortable, y.input.get.inputSortable) - } + Ordering.by(_.input.get.inputSortable) } else { throw new IllegalArgumentException( "Cannot sort by Input Size / Records because of no inputs") } case "Output Size / Records" => if (hasOutput) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.output.get.outputSortable, y.output.get.outputSortable) - } + Ordering.by(_.output.get.outputSortable) } else { throw new IllegalArgumentException( "Cannot sort by Output Size / Records because of no outputs") @@ -1140,33 +1089,21 @@ private[ui] class TaskDataSource( // ShuffleRead case "Shuffle Read Blocked Time" => if (hasShuffleRead) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.shuffleRead.get.shuffleReadBlockedTimeSortable, - y.shuffleRead.get.shuffleReadBlockedTimeSortable) - } + Ordering.by(_.shuffleRead.get.shuffleReadBlockedTimeSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads") } case "Shuffle Read Size / Records" => if (hasShuffleRead) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.shuffleRead.get.shuffleReadSortable, - y.shuffleRead.get.shuffleReadSortable) - } + Ordering.by(_.shuffleRead.get.shuffleReadSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Read Size / Records because of no shuffle reads") } case "Shuffle Remote Reads" => if (hasShuffleRead) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.shuffleRead.get.shuffleReadRemoteSortable, - y.shuffleRead.get.shuffleReadRemoteSortable) - } + Ordering.by(_.shuffleRead.get.shuffleReadRemoteSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Remote Reads because of no shuffle reads") @@ -1174,22 +1111,14 @@ private[ui] class TaskDataSource( // ShuffleWrite case "Write Time" => if (hasShuffleWrite) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.shuffleWrite.get.writeTimeSortable, - y.shuffleWrite.get.writeTimeSortable) - } + Ordering.by(_.shuffleWrite.get.writeTimeSortable) } else { throw new IllegalArgumentException( "Cannot sort by Write Time because of no shuffle writes") } case "Shuffle Write Size / Records" => if (hasShuffleWrite) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.shuffleWrite.get.shuffleWriteSortable, - y.shuffleWrite.get.shuffleWriteSortable) - } + Ordering.by(_.shuffleWrite.get.shuffleWriteSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Write Size / Records because of no shuffle writes") @@ -1197,30 +1126,19 @@ private[ui] class TaskDataSource( // BytesSpilled case "Shuffle Spill (Memory)" => if (hasBytesSpilled) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.bytesSpilled.get.memoryBytesSpilledSortable, - y.bytesSpilled.get.memoryBytesSpilledSortable) - } + Ordering.by(_.bytesSpilled.get.memoryBytesSpilledSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Spill (Memory) because of no spills") } case "Shuffle Spill (Disk)" => if (hasBytesSpilled) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.bytesSpilled.get.diskBytesSpilledSortable, - y.bytesSpilled.get.diskBytesSpilledSortable) - } + Ordering.by(_.bytesSpilled.get.diskBytesSpilledSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Spill (Disk) because of no spills") } - case "Errors" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.String.compare(x.error, y.error) - } + case "Errors" => Ordering.by(_.error) case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") } if (desc) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 2a04e8fc7d007..9b9b4681ba5db 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -34,6 +34,7 @@ import org.apache.spark.util.Utils private[ui] class StageTableBase( request: HttpServletRequest, stages: Seq[StageInfo], + tableHeaderID: String, stageTag: String, basePath: String, subPath: String, @@ -77,6 +78,7 @@ private[ui] class StageTableBase( val toNodeSeq = try { new StagePagedTable( stages, + tableHeaderID, stageTag, basePath, subPath, @@ -107,7 +109,6 @@ private[ui] class StageTableRowData( val stageId: Int, val attemptId: Int, val schedulingPool: String, - val description: String, val descriptionOption: Option[String], val submissionTime: Long, val formattedSubmissionTime: String, @@ -126,11 +127,12 @@ private[ui] class MissingStageTableRowData( stageInfo: StageInfo, stageId: Int, attemptId: Int) extends StageTableRowData( - stageInfo, None, stageId, attemptId, "", "", None, 0, "", -1, "", 0, "", 0, "", 0, "", 0, "") + stageInfo, None, stageId, attemptId, "", None, 0, "", -1, "", 0, "", 0, "", 0, "", 0, "") /** Page showing list of all ongoing and recently finished stages */ private[ui] class StagePagedTable( stages: Seq[StageInfo], + tableHeaderId: String, stageTag: String, basePath: String, subPath: String, @@ -173,12 +175,13 @@ private[ui] class StagePagedTable( s"&$pageNumberFormField=$page" + s"&$stageTag.sort=$encodedSortColumn" + s"&$stageTag.desc=$desc" + - s"&$pageSizeFormField=$pageSize" + s"&$pageSizeFormField=$pageSize" + + s"#$tableHeaderId" } override def goButtonFormPath: String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - s"$parameterPath&$stageTag.sort=$encodedSortColumn&$stageTag.desc=$desc" + s"$parameterPath&$stageTag.sort=$encodedSortColumn&$stageTag.desc=$desc#$tableHeaderId" } override def headers: Seq[Node] = { @@ -226,7 +229,8 @@ private[ui] class StagePagedTable( parameterPath + s"&$stageTag.sort=${URLEncoder.encode(header, "UTF-8")}" + s"&$stageTag.desc=${!desc}" + - s"&$stageTag.pageSize=$pageSize") + s"&$stageTag.pageSize=$pageSize") + + s"#$tableHeaderId" val arrow = if (desc) "▾" else "▴" // UP or DOWN @@ -241,7 +245,8 @@ private[ui] class StagePagedTable( val headerLink = Unparsed( parameterPath + s"&$stageTag.sort=${URLEncoder.encode(header, "UTF-8")}" + - s"&$stageTag.pageSize=$pageSize") + s"&$stageTag.pageSize=$pageSize") + + s"#$tableHeaderId" @@ -464,7 +469,6 @@ private[ui] class StageDataSource( s.stageId, s.attemptId, stageData.schedulingPool, - description.getOrElse(""), description, s.submissionTime.getOrElse(0), formattedSubmissionTime, @@ -485,43 +489,16 @@ private[ui] class StageDataSource( * Return Ordering according to sortColumn and desc */ private def ordering(sortColumn: String, desc: Boolean): Ordering[StageTableRowData] = { - val ordering = sortColumn match { - case "Stage Id" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Int.compare(x.stageId, y.stageId) - } - case "Pool Name" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.String.compare(x.schedulingPool, y.schedulingPool) - } - case "Description" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.String.compare(x.description, y.description) - } - case "Submitted" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.submissionTime, y.submissionTime) - } - case "Duration" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.duration, y.duration) - } - case "Input" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.inputRead, y.inputRead) - } - case "Output" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.outputWrite, y.outputWrite) - } - case "Shuffle Read" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.shuffleRead, y.shuffleRead) - } - case "Shuffle Write" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.shuffleWrite, y.shuffleWrite) - } + val ordering: Ordering[StageTableRowData] = sortColumn match { + case "Stage Id" => Ordering.by(_.stageId) + case "Pool Name" => Ordering.by(_.schedulingPool) + case "Description" => Ordering.by(x => (x.descriptionOption, x.stageInfo.name)) + case "Submitted" => Ordering.by(_.submissionTime) + case "Duration" => Ordering.by(_.duration) + case "Input" => Ordering.by(_.inputRead) + case "Output" => Ordering.by(_.outputWrite) + case "Shuffle Read" => Ordering.by(_.shuffleRead) + case "Shuffle Write" => Ordering.by(_.shuffleWrite) case "Tasks: Succeeded/Total" => throw new IllegalArgumentException(s"Unsortable column: $sortColumn") case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 606d15d599e81..227e940c9c50c 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -197,27 +197,12 @@ private[ui] class BlockDataSource( * Return Ordering according to sortColumn and desc */ private def ordering(sortColumn: String, desc: Boolean): Ordering[BlockTableRowData] = { - val ordering = sortColumn match { - case "Block Name" => new Ordering[BlockTableRowData] { - override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = - Ordering.String.compare(x.blockName, y.blockName) - } - case "Storage Level" => new Ordering[BlockTableRowData] { - override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = - Ordering.String.compare(x.storageLevel, y.storageLevel) - } - case "Size in Memory" => new Ordering[BlockTableRowData] { - override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = - Ordering.Long.compare(x.memoryUsed, y.memoryUsed) - } - case "Size on Disk" => new Ordering[BlockTableRowData] { - override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = - Ordering.Long.compare(x.diskUsed, y.diskUsed) - } - case "Executors" => new Ordering[BlockTableRowData] { - override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = - Ordering.String.compare(x.executors, y.executors) - } + val ordering: Ordering[BlockTableRowData] = sortColumn match { + case "Block Name" => Ordering.by(_.blockName) + case "Storage Level" => Ordering.by(_.storageLevel) + case "Size in Memory" => Ordering.by(_.memoryUsed) + case "Size on Disk" => Ordering.by(_.diskUsed) + case "Executors" => Ordering.by(_.executors) case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") } if (desc) { diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 470d912ecff13..d3ddd39131326 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -444,7 +444,9 @@ class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { override def copy(): CollectionAccumulator[T] = { val newAcc = new CollectionAccumulator[T] - newAcc._list.addAll(_list) + _list.synchronized { + newAcc._list.addAll(_list) + } newAcc } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index f4fa7b4061640..c11eb3ffa4601 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -281,7 +281,7 @@ private[spark] object JsonProtocol { ("Finish Time" -> taskInfo.finishTime) ~ ("Failed" -> taskInfo.failed) ~ ("Killed" -> taskInfo.killed) ~ - ("Accumulables" -> JArray(taskInfo.accumulables.map(accumulableInfoToJson).toList)) + ("Accumulables" -> JArray(taskInfo.accumulables.toList.map(accumulableInfoToJson))) } def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 993834f8d7d42..cc52bb1d23cd5 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark -import java.io.{File, FileWriter} +import java.io._ +import java.util.zip.GZIPOutputStream import scala.io.Source @@ -29,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInp import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.apache.spark.input.PortableDataStream +import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -541,4 +543,62 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { }.collect() assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001")) } + + test("spark.files.ignoreCorruptFiles should work both HadoopRDD and NewHadoopRDD") { + val inputFile = File.createTempFile("input-", ".gz") + try { + // Create a corrupt gzip file + val byteOutput = new ByteArrayOutputStream() + val gzip = new GZIPOutputStream(byteOutput) + try { + gzip.write(Array[Byte](1, 2, 3, 4)) + } finally { + gzip.close() + } + val bytes = byteOutput.toByteArray + val o = new FileOutputStream(inputFile) + try { + // It's corrupt since we only write half of bytes into the file. + o.write(bytes.take(bytes.length / 2)) + } finally { + o.close() + } + + // Reading a corrupt gzip file should throw EOFException + sc = new SparkContext("local", "test") + // Test HadoopRDD + var e = intercept[SparkException] { + sc.textFile(inputFile.toURI.toString).collect() + } + assert(e.getCause.isInstanceOf[EOFException]) + assert(e.getCause.getMessage === "Unexpected end of input stream") + // Test NewHadoopRDD + e = intercept[SparkException] { + sc.newAPIHadoopFile( + inputFile.toURI.toString, + classOf[NewTextInputFormat], + classOf[LongWritable], + classOf[Text]).collect() + } + assert(e.getCause.isInstanceOf[EOFException]) + assert(e.getCause.getMessage === "Unexpected end of input stream") + sc.stop() + + val conf = new SparkConf().set(IGNORE_CORRUPT_FILES, true) + sc = new SparkContext("local", "test", conf) + // Test HadoopRDD + assert(sc.textFile(inputFile.toURI.toString).collect().isEmpty) + // Test NewHadoopRDD + assert { + sc.newAPIHadoopFile( + inputFile.toURI.toString, + classOf[NewTextInputFormat], + classOf[LongWritable], + classOf[Text]).collect().isEmpty + } + } finally { + inputFile.delete() + } + } + } diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index cd876807f890e..18077c08c9dcc 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark // scalastyle:off +import java.io.File + import org.scalatest.{BeforeAndAfterAll, FunSuite, Outcome} import org.apache.spark.internal.Logging @@ -41,6 +43,15 @@ abstract class SparkFunSuite } } + // helper function + protected final def getTestResourceFile(file: String): File = { + new File(getClass.getClassLoader.getResource(file).getFile) + } + + protected final def getTestResourcePath(file: String): String = { + getTestResourceFile(file).getCanonicalPath + } + /** * Log the suite name and the test name before and after each test. * diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 5b316b2f6b4b7..a595bc174a310 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -59,8 +59,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers with JsonTestUtils with Eventually with WebBrowser with LocalSparkContext with ResetSystemProperties { - private val logDir = new File("src/test/resources/spark-events") - private val expRoot = new File("src/test/resources/HistoryServerExpectations/") + private val logDir = getTestResourcePath("spark-events") + private val expRoot = getTestResourceFile("HistoryServerExpectations") private var provider: FsHistoryProvider = null private var server: HistoryServer = null @@ -68,7 +68,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers def init(): Unit = { val conf = new SparkConf() - .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) + .set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") provider = new FsHistoryProvider(conf) diff --git a/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala new file mode 100644 index 0000000000000..da3256bd882e8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import org.apache.spark.SparkFunSuite + +class CountEvaluatorSuite extends SparkFunSuite { + + test("test count 0") { + val evaluator = new CountEvaluator(10, 0.95) + assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + evaluator.merge(1, 0) + assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + } + + test("test count >= 1") { + val evaluator = new CountEvaluator(10, 0.95) + evaluator.merge(1, 1) + assert(new BoundedDouble(10.0, 0.95, 1.0, 36.0) == evaluator.currentResult()) + evaluator.merge(1, 3) + assert(new BoundedDouble(20.0, 0.95, 7.0, 41.0) == evaluator.currentResult()) + evaluator.merge(1, 8) + assert(new BoundedDouble(40.0, 0.95, 24.0, 61.0) == evaluator.currentResult()) + (4 to 10).foreach(_ => evaluator.merge(1, 10)) + assert(new BoundedDouble(82.0, 1.0, 82.0, 82.0) == evaluator.currentResult()) + } + +} diff --git a/core/src/test/scala/org/apache/spark/partial/MeanEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/MeanEvaluatorSuite.scala new file mode 100644 index 0000000000000..eaa1262b4199f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/partial/MeanEvaluatorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.StatCounter + +class MeanEvaluatorSuite extends SparkFunSuite { + + test("test count 0") { + val evaluator = new MeanEvaluator(10, 0.95) + assert(new BoundedDouble(0.0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter()) + assert(new BoundedDouble(0.0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter(Seq(0.0))) + assert(new BoundedDouble(0.0, 0.95, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + } + + test("test count 1") { + val evaluator = new MeanEvaluator(10, 0.95) + evaluator.merge(1, new StatCounter(Seq(1.0))) + assert(new BoundedDouble(1.0, 0.95, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + } + + test("test count > 1") { + val evaluator = new MeanEvaluator(10, 0.95) + evaluator.merge(1, new StatCounter(Seq(1.0))) + evaluator.merge(1, new StatCounter(Seq(3.0))) + assert(new BoundedDouble(2.0, 0.95, -10.706204736174746, 14.706204736174746) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter(Seq(8.0))) + assert(new BoundedDouble(4.0, 0.95, -4.9566858949231225, 12.956685894923123) == + evaluator.currentResult()) + (4 to 10).foreach(_ => evaluator.merge(1, new StatCounter(Seq(9.0)))) + assert(new BoundedDouble(7.5, 1.0, 7.5, 7.5) == evaluator.currentResult()) + } + +} diff --git a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala index a79f5b4d74467..e212db73627e7 100644 --- a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala @@ -17,61 +17,34 @@ package org.apache.spark.partial -import org.apache.spark._ +import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter -class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext { +class SumEvaluatorSuite extends SparkFunSuite { test("correct handling of count 1") { + // sanity check: + assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2)) - // setup - val counter = new StatCounter(List(2.0)) // count of 10 because it's larger than 1, // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute - val res = evaluator.currentResult() - // 38.0 - 7.1E-15 because that's how the maths shakes out - val targetMean = 38.0 - 7.1E-15 - - // Sanity check that equality works on BoundedDouble - assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2)) - - // actual test - assert(res == - new BoundedDouble(targetMean, 0.950, Double.NegativeInfinity, Double.PositiveInfinity)) + evaluator.merge(1, new StatCounter(Seq(2.0))) + assert(new BoundedDouble(20.0, 0.95, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) } test("correct handling of count 0") { - - // setup - val counter = new StatCounter(List()) - // count of 10 because it's larger than 0, - // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) - // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute - val res = evaluator.currentResult() - // assert - assert(res == new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)) + evaluator.merge(1, new StatCounter()) + assert(new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) } test("correct handling of NaN") { - - // setup - val counter = new StatCounter(List(1, Double.NaN, 2)) - // count of 10 because it's larger than 0, - // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) - // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute + evaluator.merge(1, new StatCounter(Seq(1, Double.NaN, 2))) val res = evaluator.currentResult() // assert - note semantics of == in face of NaN assert(res.mean.isNaN) @@ -81,27 +54,24 @@ class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext { } test("correct handling of > 1 values") { - - // setup - val counter = new StatCounter(List(1, 3, 2)) - // count of 10 because it's larger than 0, - // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) - // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute + evaluator.merge(1, new StatCounter(Seq(1.0, 3.0, 2.0))) val res = evaluator.currentResult() + assert(new BoundedDouble(60.0, 0.95, -101.7362525347778, 221.7362525347778) == + evaluator.currentResult()) + } - // These vals because that's how the maths shakes out - val targetMean = 78.0 - val targetLow = -117.617 + 2.732357258139473E-5 - val targetHigh = 273.617 - 2.7323572624027292E-5 - val target = new BoundedDouble(targetMean, 0.95, targetLow, targetHigh) - - - // check that values are within expected tolerance of expectation - assert(res == target) + test("test count > 1") { + val evaluator = new SumEvaluator(10, 0.95) + evaluator.merge(1, new StatCounter().merge(1.0)) + evaluator.merge(1, new StatCounter().merge(3.0)) + assert(new BoundedDouble(20.0, 0.95, -186.4513905077019, 226.4513905077019) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter().merge(8.0)) + assert(new BoundedDouble(40.0, 0.95, -72.75723361226733, 152.75723361226733) == + evaluator.currentResult()) + (4 to 10).foreach(_ => evaluator.merge(1, new StatCounter().merge(9.0))) + assert(new BoundedDouble(75.0, 1.0, 75.0, 75.0) == evaluator.currentResult()) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala index 14c8b664d4d8b..f6015cd51c2bd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala @@ -16,10 +16,10 @@ */ package org.apache.spark.scheduler -import scala.concurrent.Await import scala.concurrent.duration._ import org.apache.spark._ +import org.apache.spark.internal.config class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorMockBackend]{ @@ -42,7 +42,10 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM // Test demonstrating the issue -- without a config change, the scheduler keeps scheduling // according to locality preferences, and so the job fails - testScheduler("If preferred node is bad, without blacklist job will fail") { + testScheduler("If preferred node is bad, without blacklist job will fail", + extraConfs = Seq( + config.BLACKLIST_ENABLED.key -> "false" + )) { val rdd = new MockRDDWithLocalityPrefs(sc, 10, Nil, badHost) withBackend(badHostBackend _) { val jobFuture = submit(rdd, (0 until 10).toArray) @@ -51,37 +54,38 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM assertDataStructuresEmpty(noFailure = false) } - // even with the blacklist turned on, if maxTaskFailures is not more than the number - // of executors on the bad node, then locality preferences will lead to us cycling through - // the executors on the bad node, and still failing the job testScheduler( - "With blacklist on, job will still fail if there are too many bad executors on bad host", + "With default settings, job can succeed despite multiple bad executors on node", extraConfs = Seq( - // set this to something much longer than the test duration so that executors don't get - // removed from the blacklist during the test - ("spark.scheduler.executorTaskBlacklistTime", "10000000") + config.BLACKLIST_ENABLED.key -> "true", + config.MAX_TASK_FAILURES.key -> "4", + "spark.testing.nHosts" -> "2", + "spark.testing.nExecutorsPerHost" -> "5", + "spark.testing.nCoresPerExecutor" -> "10" ) ) { - val rdd = new MockRDDWithLocalityPrefs(sc, 10, Nil, badHost) + // To reliably reproduce the failure that would occur without blacklisting, we have to use 1 + // task. That way, we ensure this 1 task gets rotated through enough bad executors on the host + // to fail the taskSet, before we have a bunch of different tasks fail in the executors so we + // blacklist them. + // But the point here is -- without blacklisting, we would never schedule anything on the good + // host-1 before we hit too many failures trying our preferred host-0. + val rdd = new MockRDDWithLocalityPrefs(sc, 1, Nil, badHost) withBackend(badHostBackend _) { - val jobFuture = submit(rdd, (0 until 10).toArray) + val jobFuture = submit(rdd, (0 until 1).toArray) awaitJobTermination(jobFuture, duration) } - assertDataStructuresEmpty(noFailure = false) + assertDataStructuresEmpty(noFailure = true) } - // Here we run with the blacklist on, and maxTaskFailures high enough that we'll eventually - // schedule on a good node and succeed the job + // Here we run with the blacklist on, and the default config takes care of having this + // robust to one bad node. testScheduler( "Bad node with multiple executors, job will still succeed with the right confs", extraConfs = Seq( - // set this to something much longer than the test duration so that executors don't get - // removed from the blacklist during the test - ("spark.scheduler.executorTaskBlacklistTime", "10000000"), - // this has to be higher than the number of executors on the bad host - ("spark.task.maxFailures", "5"), + config.BLACKLIST_ENABLED.key -> "true", // just to avoid this test taking too long - ("spark.locality.wait", "10ms") + "spark.locality.wait" -> "10ms" ) ) { val rdd = new MockRDDWithLocalityPrefs(sc, 10, Nil, badHost) @@ -98,9 +102,7 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM testScheduler( "SPARK-15865 Progress with fewer executors than maxTaskFailures", extraConfs = Seq( - // set this to something much longer than the test duration so that executors don't get - // removed from the blacklist during the test - "spark.scheduler.executorTaskBlacklistTime" -> "10000000", + config.BLACKLIST_ENABLED.key -> "true", "spark.testing.nHosts" -> "2", "spark.testing.nExecutorsPerHost" -> "1", "spark.testing.nCoresPerExecutor" -> "1" @@ -112,9 +114,9 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM } withBackend(runBackend _) { val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) - Await.ready(jobFuture, duration) + awaitJobTermination(jobFuture, duration) val pattern = ("Aborting TaskSet 0.0 because task .* " + - "already failed on executors \\(.*\\), and no other executors are available").r + "cannot run anywhere due to node and executor blacklist").r assert(pattern.findFirstIn(failure.getMessage).isDefined, s"Couldn't find $pattern in ${failure.getMessage()}") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala new file mode 100644 index 0000000000000..b2e7ec5df015c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config + +class BlacklistTrackerSuite extends SparkFunSuite { + + test("blacklist still respects legacy configs") { + val conf = new SparkConf().setMaster("local") + assert(!BlacklistTracker.isBlacklistEnabled(conf)) + conf.set(config.BLACKLIST_LEGACY_TIMEOUT_CONF, 5000L) + assert(BlacklistTracker.isBlacklistEnabled(conf)) + assert(5000 === BlacklistTracker.getBlacklistTimeout(conf)) + // the new conf takes precedence, though + conf.set(config.BLACKLIST_TIMEOUT_CONF, 1000L) + assert(1000 === BlacklistTracker.getBlacklistTimeout(conf)) + + // if you explicitly set the legacy conf to 0, that also would disable blacklisting + conf.set(config.BLACKLIST_LEGACY_TIMEOUT_CONF, 0L) + assert(!BlacklistTracker.isBlacklistEnabled(conf)) + // but again, the new conf takes precendence + conf.set(config.BLACKLIST_ENABLED, true) + assert(BlacklistTracker.isBlacklistEnabled(conf)) + assert(1000 === BlacklistTracker.getBlacklistTimeout(conf)) + } + + test("check blacklist configuration invariants") { + val conf = new SparkConf().setMaster("yarn-cluster") + Seq( + (2, 2), + (2, 3) + ).foreach { case (maxTaskFailures, maxNodeAttempts) => + conf.set(config.MAX_TASK_FAILURES, maxTaskFailures) + conf.set(config.MAX_TASK_ATTEMPTS_PER_NODE.key, maxNodeAttempts.toString) + val excMsg = intercept[IllegalArgumentException] { + BlacklistTracker.validateBlacklistConfs(conf) + }.getMessage() + assert(excMsg === s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key} " + + s"( = ${maxNodeAttempts}) was >= ${config.MAX_TASK_FAILURES.key} " + + s"( = ${maxTaskFailures} ). Though blacklisting is enabled, with this configuration, " + + s"Spark will not be robust to one bad node. Decrease " + + s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key}, increase ${config.MAX_TASK_FAILURES.key}, " + + s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}") + } + + conf.remove(config.MAX_TASK_FAILURES) + conf.remove(config.MAX_TASK_ATTEMPTS_PER_NODE) + + Seq( + config.MAX_TASK_ATTEMPTS_PER_EXECUTOR, + config.MAX_TASK_ATTEMPTS_PER_NODE, + config.MAX_FAILURES_PER_EXEC_STAGE, + config.MAX_FAILED_EXEC_PER_NODE_STAGE, + config.BLACKLIST_TIMEOUT_CONF + ).foreach { config => + conf.set(config.key, "0") + val excMsg = intercept[IllegalArgumentException] { + BlacklistTracker.validateBlacklistConfs(conf) + }.getMessage() + assert(excMsg.contains(s"${config.key} was 0, but must be > 0.")) + conf.remove(config) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 5cd548bbc72d9..c28aa06623a60 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -620,9 +620,9 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) } + assertDataStructuresEmpty() assert(results === (0 until 10).map { idx => idx -> (42 + idx) }.toMap) assert(stageToAttempts === Map(0 -> Set(0, 1), 1 -> Set(0, 1))) - assertDataStructuresEmpty() } testScheduler("job failure after 4 attempts") { @@ -634,7 +634,7 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) - failure.getMessage.contains("test task failure") + assert(failure.getMessage.contains("test task failure")) } assertDataStructuresEmpty(noFailure = false) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 61787b54f824f..f5f1947661d9a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import org.scalatest.BeforeAndAfterEach import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.internal.Logging class FakeSchedulerBackend extends SchedulerBackend { @@ -32,7 +33,6 @@ class FakeSchedulerBackend extends SchedulerBackend { class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterEach with Logging { - var failedTaskSetException: Option[Throwable] = None var failedTaskSetReason: String = null var failedTaskSet = false @@ -60,10 +60,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B } def setupScheduler(confs: (String, String)*): TaskSchedulerImpl = { - sc = new SparkContext("local", "TaskSchedulerImplSuite") + val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite") confs.foreach { case (k, v) => - sc.conf.set(k, v) + conf.set(k, v) } + sc = new SparkContext(conf) taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. @@ -287,9 +288,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // schedulable on another executor. However, that executor may fail later on, leaving the // first task with no place to run. val taskScheduler = setupScheduler( - // set this to something much longer than the test duration so that executors don't get - // removed from the blacklist during the test - "spark.scheduler.executorTaskBlacklistTime" -> "10000000" + config.BLACKLIST_ENABLED.key -> "true" ) val taskSet = FakeTask.createTaskSet(2) @@ -328,8 +327,9 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(tsm.isZombie) assert(failedTaskSet) val idx = failedTask.index - assert(failedTaskSetReason == s"Aborting TaskSet 0.0 because task $idx (partition $idx) has " + - s"already failed on executors (executor0), and no other executors are available.") + assert(failedTaskSetReason === s"Aborting TaskSet 0.0 because task $idx (partition $idx) " + + s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior can be " + + s"configured via spark.blacklist.*.") } test("don't abort if there is an executor available, though it hasn't had scheduled tasks yet") { @@ -339,9 +339,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // available and not bail on the job val taskScheduler = setupScheduler( - // set this to something much longer than the test duration so that executors don't get - // removed from the blacklist during the test - "spark.scheduler.executorTaskBlacklistTime" -> "10000000" + config.BLACKLIST_ENABLED.key -> "true" ) val taskSet = FakeTask.createTaskSet(2, (0 until 2).map { _ => Seq(TaskLocation("host0")) }: _*) @@ -377,7 +375,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val taskScheduler = setupScheduler() taskScheduler.submitTasks(FakeTask.createTaskSet(2, 0, - (0 until 2).map { _ => Seq(TaskLocation("host0", "executor2"))}: _* + (0 until 2).map { _ => Seq(TaskLocation("host0", "executor2")) }: _* )) val taskDescs = taskScheduler.resourceOffers(IndexedSeq( diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala new file mode 100644 index 0000000000000..8c902af5685ff --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config +import org.apache.spark.util.{ManualClock, SystemClock} + +class TaskSetBlacklistSuite extends SparkFunSuite { + + test("Blacklisting tasks, executors, and nodes") { + val conf = new SparkConf().setAppName("test").setMaster("local") + .set(config.BLACKLIST_ENABLED.key, "true") + val clock = new ManualClock + + val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, clock = clock) + clock.setTime(0) + // We will mark task 0 & 1 failed on both executor 1 & 2. + // We should blacklist all executors on that host, for all tasks for the stage. Note the API + // will return false for isExecutorBacklistedForTaskSet even when the node is blacklisted, so + // the executor is implicitly blacklisted (this makes sense with how the scheduler uses the + // blacklist) + + // First, mark task 0 as failed on exec1. + // task 0 should be blacklisted on exec1, and nowhere else + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 0) + for { + executor <- (1 to 4).map(_.toString) + index <- 0 until 10 + } { + val shouldBeBlacklisted = (executor == "exec1" && index == 0) + assert(taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === shouldBeBlacklisted) + } + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + // Mark task 1 failed on exec1 -- this pushes the executor into the blacklist + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 1) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + // Mark one task as failed on exec2 -- not enough for any further blacklisting yet. + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 0) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + // Mark another task as failed on exec2 -- now we blacklist exec2, which also leads to + // blacklisting the entire node. + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 1) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + // Make sure the blacklist has the correct per-task && per-executor responses, over a wider + // range of inputs. + for { + executor <- (1 to 4).map(e => s"exec$e") + index <- 0 until 10 + } { + withClue(s"exec = $executor; index = $index") { + val badExec = (executor == "exec1" || executor == "exec2") + val badIndex = (index == 0 || index == 1) + assert( + // this ignores whether the executor is blacklisted entirely for the taskset -- that is + // intentional, it keeps it fast and is sufficient for usage in the scheduler. + taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === (badExec && badIndex)) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet(executor) === badExec) + } + } + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + val execToFailures = taskSetBlacklist.execToFailures + assert(execToFailures.keySet === Set("exec1", "exec2")) + + Seq("exec1", "exec2").foreach { exec => + assert( + execToFailures(exec).taskToFailureCount === Map( + 0 -> 1, + 1 -> 1 + ) + ) + } + } + + test("multiple attempts for the same task count once") { + // Make sure that for blacklisting tasks, the node counts task attempts, not executors. But for + // stage-level blacklisting, we count unique tasks. The reason for this difference is, with + // task-attempt blacklisting, we want to make it easy to configure so that you ensure a node + // is blacklisted before the taskset is completely aborted because of spark.task.maxFailures. + // But with stage-blacklisting, we want to make sure we're not just counting one bad task + // that has failed many times. + + val conf = new SparkConf().setMaster("local").setAppName("test") + .set(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR, 2) + .set(config.MAX_TASK_ATTEMPTS_PER_NODE, 3) + .set(config.MAX_FAILURES_PER_EXEC_STAGE, 2) + .set(config.MAX_FAILED_EXEC_PER_NODE_STAGE, 3) + val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + // Fail a task twice on hostA, exec:1 + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + assert(taskSetBlacklist.isExecutorBlacklistedForTask("1", 0)) + assert(!taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + // Fail the same task once more on hostA, exec:2 + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 0) + assert(taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + // Fail another task on hostA, exec:1. Now that executor has failures on two different tasks, + // so its blacklisted + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + // Fail a third task on hostA, exec:2, so that exec is blacklisted for the whole task set + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 2) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + // Fail a fourth & fifth task on hostA, exec:3. Now we've got three executors that are + // blacklisted for the taskset, so blacklist the whole node. + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 3) + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 4) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("3")) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + } + + test("only blacklist nodes for the task set when all the blacklisted executors are all on " + + "same host") { + // we blacklist executors on two different hosts within one taskSet -- make sure that doesn't + // lead to any node blacklisting + val conf = new SparkConf().setAppName("test").setMaster("local") + .set(config.BLACKLIST_ENABLED.key, "true") + val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostB")) + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 7d6ad08036cb4..69edcf3347243 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.Mockito.{mock, verify} import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.internal.Logging import org.apache.spark.util.{AccumulatorV2, ManualClock} @@ -103,7 +104,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex val host = executorIdToHost.get(execId) assert(host != None) val hostId = host.get - val executorsOnHost = executorsByHost(hostId) + val executorsOnHost = hostToExecutors(hostId) executorsOnHost -= execId for (rack <- getRackForHost(hostId); hosts <- hostsByRack.get(rack)) { hosts -= hostId @@ -125,7 +126,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex def addExecutor(execId: String, host: String) { executors.put(execId, host) - val executorsOnHost = executorsByHost.getOrElseUpdate(host, new mutable.HashSet[String]) + val executorsOnHost = hostToExecutors.getOrElseUpdate(host, new mutable.HashSet[String]) executorsOnHost += execId executorIdToHost += execId -> host for (rack <- getRackForHost(host)) { @@ -411,7 +412,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg test("executors should be blacklisted after task failure, in spite of locality preferences") { val rescheduleDelay = 300L val conf = new SparkConf(). - set("spark.scheduler.executorTaskBlacklistTime", rescheduleDelay.toString). + set(config.BLACKLIST_ENABLED, true). + set(config.BLACKLIST_TIMEOUT_CONF, rescheduleDelay). // don't wait to jump locality levels in this test set("spark.locality.wait", "0") @@ -475,19 +477,24 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.resourceOffer("exec2", "host2", ANY).isEmpty) } - // After reschedule delay, scheduling on exec1 should be possible. + // Despite advancing beyond the time for expiring executors from within the blacklist, + // we *never* expire from *within* the stage blacklist clock.advance(rescheduleDelay) { val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) - assert(offerResult.isDefined, "Expect resource offer to return a task") + assert(offerResult.isEmpty) + } + { + val offerResult = manager.resourceOffer("exec3", "host3", ANY) + assert(offerResult.isDefined) assert(offerResult.get.index === 0) - assert(offerResult.get.executorId === "exec1") + assert(offerResult.get.executorId === "exec3") - assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).isEmpty) + assert(manager.resourceOffer("exec3", "host3", ANY).isEmpty) - // Cause exec1 to fail : failure 4 + // Cause exec3 to fail : failure 4 manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) } @@ -859,6 +866,114 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(sched.endedTasks(3) === Success) } + test("Killing speculative tasks does not count towards aborting the taskset") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(5) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.6") + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 5 tasks to start + val tasks = new ArrayBuffer[TaskDescription]() + for ((k, v) <- List( + "exec1" -> "host1", + "exec1" -> "host1", + "exec1" -> "host1", + "exec2" -> "host2", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(k, v, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === k) + tasks += task + } + assert(sched.startedTasks.toSet === (0 until 5).toSet) + // Complete 3 tasks and leave 2 tasks in running + for (id <- Set(0, 1, 2)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + def runningTaskForIndex(index: Int): TaskDescription = { + tasks.find { task => + task.index == index && !sched.endedTasks.contains(task.taskId) + }.getOrElse { + throw new RuntimeException(s"couldn't find index $index in " + + s"tasks: ${tasks.map{t => t.index -> t.taskId}} with endedTasks:" + + s" ${sched.endedTasks.keys}") + } + } + + // have each of the running tasks fail 3 times (not enough to abort the stage) + (0 until 3).foreach { attempt => + Seq(3, 4).foreach { index => + val task = runningTaskForIndex(index) + logInfo(s"failing task $task") + val endReason = ExceptionFailure("a", "b", Array(), "c", None) + manager.handleFailedTask(task.taskId, TaskState.FAILED, endReason) + sched.endedTasks(task.taskId) = endReason + assert(!manager.isZombie) + val nextTask = manager.resourceOffer(s"exec2", s"host2", NO_PREF) + assert(nextTask.isDefined, s"no offer for attempt $attempt of $index") + tasks += nextTask.get + } + } + + // we can't be sure which one of our running tasks will get another speculative copy + val originalTasks = Seq(3, 4).map { index => index -> runningTaskForIndex(index) }.toMap + + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + // Offer resource to start the speculative attempt for the running task + val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption5.isDefined) + val speculativeTask = taskOption5.get + assert(speculativeTask.index === 3 || speculativeTask.index === 4) + assert(speculativeTask.taskId === 11) + assert(speculativeTask.executorId === "exec1") + assert(speculativeTask.attemptNumber === 4) + sched.backend = mock(classOf[SchedulerBackend]) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(speculativeTask.taskId, createTaskResult(3, accumUpdatesByTask(3))) + // Verify that it kills other running attempt + val origTask = originalTasks(speculativeTask.index) + verify(sched.backend).killTask(origTask.taskId, "exec2", true) + // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be + // killed, so the FakeTaskScheduler is only told about the successful completion + // of the speculated task. + assert(sched.endedTasks(3) === Success) + // also because the scheduler is a mock, our manager isn't notified about the task killed event, + // so we do that manually + manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled) + // this task has "failed" 4 times, but one of them doesn't count, so keep running the stage + assert(manager.tasksSuccessful === 4) + assert(!manager.isZombie) + + // now run another speculative task + val taskOpt6 = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOpt6.isDefined) + val speculativeTask2 = taskOpt6.get + assert(speculativeTask2.index === 3 || speculativeTask2.index === 4) + assert(speculativeTask2.index !== speculativeTask.index) + assert(speculativeTask2.attemptNumber === 4) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(speculativeTask2.taskId, + createTaskResult(3, accumUpdatesByTask(3))) + // Verify that it kills other running attempt + val origTask2 = originalTasks(speculativeTask2.index) + verify(sched.backend).killTask(origTask2.taskId, "exec2", true) + assert(manager.tasksSuccessful === 5) + assert(manager.isZombie) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index c1484b0afa85f..46aa9c37986cc 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.serializer import com.esotericsoftware.kryo.Kryo import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.serializer.KryoDistributedTest._ import org.apache.spark.util.Utils @@ -29,7 +30,8 @@ class KryoSerializerDistributedSuite extends SparkFunSuite with LocalSparkContex val conf = new SparkConf(false) .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) - .set("spark.task.maxFailures", "1") + .set(config.MAX_TASK_FAILURES, 1) + .set(config.BLACKLIST_ENABLED, false) val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index dbb8dca4c8dab..4abcfb7e51914 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -53,9 +53,10 @@ class UISuite extends SparkFunSuite { } private def sslEnabledConf(): (SparkConf, SSLOptions) = { + val keyStoreFilePath = getTestResourcePath("spark.keystore") val conf = new SparkConf() .set("spark.ssl.ui.enabled", "true") - .set("spark.ssl.ui.keyStore", "./src/test/resources/spark.keystore") + .set("spark.ssl.ui.keyStore", keyStoreFilePath) .set("spark.ssl.ui.keyStorePassword", "123456") .set("spark.ssl.ui.keyPassword", "123456") (conf, new SecurityManager(conf).getSSLOptions("ui")) diff --git a/dev/create-release/generate-changelist.py b/dev/create-release/generate-changelist.py deleted file mode 100755 index 2e1a35a629342..0000000000000 --- a/dev/create-release/generate-changelist.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/python - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Creates CHANGES.txt from git history. -# -# Usage: -# First set the new release version and old CHANGES.txt version in this file. -# Make sure you have SPARK_HOME set. -# $ python generate-changelist.py - - -import os -import sys -import subprocess -import time -import traceback - -SPARK_HOME = os.environ["SPARK_HOME"] -NEW_RELEASE_VERSION = "1.0.0" -PREV_RELEASE_GIT_TAG = "v0.9.1" - -CHANGELIST = "CHANGES.txt" -OLD_CHANGELIST = "%s.old" % (CHANGELIST) -NEW_CHANGELIST = "%s.new" % (CHANGELIST) -TMP_CHANGELIST = "%s.tmp" % (CHANGELIST) - -# date before first PR in TLP Spark repo -SPARK_REPO_CHANGE_DATE1 = time.strptime("2014-02-26", "%Y-%m-%d") -# date after last PR in incubator Spark repo -SPARK_REPO_CHANGE_DATE2 = time.strptime("2014-03-01", "%Y-%m-%d") -# Threshold PR number that differentiates PRs to TLP -# and incubator repos -SPARK_REPO_PR_NUM_THRESH = 200 - -LOG_FILE_NAME = "changes_%s" % time.strftime("%h_%m_%Y_%I_%M_%S") -LOG_FILE = open(LOG_FILE_NAME, 'w') - - -def run_cmd(cmd): - try: - print >> LOG_FILE, "Running command: %s" % cmd - output = subprocess.check_output(cmd, shell=True, stderr=LOG_FILE) - print >> LOG_FILE, "Output: %s" % output - return output - except: - traceback.print_exc() - cleanup() - sys.exit(1) - - -def append_to_changelist(string): - with open(TMP_CHANGELIST, "a") as f: - print >> f, string - - -def cleanup(ask=True): - if ask is True: - print "OK to delete temporary and log files? (y/N): " - response = raw_input() - if ask is False or (ask is True and response == "y"): - if os.path.isfile(TMP_CHANGELIST): - os.remove(TMP_CHANGELIST) - if os.path.isfile(OLD_CHANGELIST): - os.remove(OLD_CHANGELIST) - LOG_FILE.close() - os.remove(LOG_FILE_NAME) - - -print "Generating new %s for Spark release %s" % (CHANGELIST, NEW_RELEASE_VERSION) -os.chdir(SPARK_HOME) -if os.path.isfile(TMP_CHANGELIST): - os.remove(TMP_CHANGELIST) -if os.path.isfile(OLD_CHANGELIST): - os.remove(OLD_CHANGELIST) - -append_to_changelist("Spark Change Log") -append_to_changelist("----------------") -append_to_changelist("") -append_to_changelist("Release %s" % NEW_RELEASE_VERSION) -append_to_changelist("") - -print "Getting commits between tag %s and HEAD" % PREV_RELEASE_GIT_TAG -hashes = run_cmd("git log %s..HEAD --pretty='%%h'" % PREV_RELEASE_GIT_TAG).split() - -print "Getting details of %s commits" % len(hashes) -for h in hashes: - date = run_cmd("git log %s -1 --pretty='%%ad' --date=iso | head -1" % h).strip() - subject = run_cmd("git log %s -1 --pretty='%%s' | head -1" % h).strip() - body = run_cmd("git log %s -1 --pretty='%%b'" % h) - committer = run_cmd("git log %s -1 --pretty='%%cn <%%ce>' | head -1" % h).strip() - body_lines = body.split("\n") - - if "Merge pull" in subject: - # Parse old format commit message - append_to_changelist(" %s %s" % (h, date)) - append_to_changelist(" %s" % subject) - append_to_changelist(" [%s]" % body_lines[0]) - append_to_changelist("") - - elif "maven-release" not in subject: - # Parse new format commit message - # Get authors from commit message, committer otherwise - authors = [committer] - if "Author:" in body: - authors = [line.split(":")[1].strip() for line in body_lines if "Author:" in line] - - # Generate GitHub PR URL for easy access if possible - github_url = "" - if "Closes #" in body: - pr_num = [line.split()[1].lstrip("#") for line in body_lines if "Closes #" in line][0] - github_url = "github.com/apache/spark/pull/%s" % pr_num - day = time.strptime(date.split()[0], "%Y-%m-%d") - if (day < SPARK_REPO_CHANGE_DATE1 or - (day < SPARK_REPO_CHANGE_DATE2 and pr_num < SPARK_REPO_PR_NUM_THRESH)): - github_url = "github.com/apache/incubator-spark/pull/%s" % pr_num - - append_to_changelist(" %s" % subject) - append_to_changelist(" %s" % ', '.join(authors)) - # for author in authors: - # append_to_changelist(" %s" % author) - append_to_changelist(" %s" % date) - if len(github_url) > 0: - append_to_changelist(" Commit: %s, %s" % (h, github_url)) - else: - append_to_changelist(" Commit: %s" % h) - append_to_changelist("") - -# Append old change list -print "Appending changelist from tag %s" % PREV_RELEASE_GIT_TAG -run_cmd("git show %s:%s | tail -n +3 >> %s" % (PREV_RELEASE_GIT_TAG, CHANGELIST, TMP_CHANGELIST)) -run_cmd("cp %s %s" % (TMP_CHANGELIST, NEW_CHANGELIST)) -print "New change list generated as %s" % NEW_CHANGELIST -cleanup(False) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index f4f92c6d20c23..b30f8c347c0af 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -141,7 +141,7 @@ pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar py4j-0.10.3.jar -pyrolite-4.9.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 3db013f1a7585..5b3a7651dd299 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -148,7 +148,7 @@ pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar py4j-0.10.3.jar -pyrolite-4.9.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 71710109a16ac..e323efe30f64b 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -148,7 +148,7 @@ pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar py4j-0.10.3.jar -pyrolite-4.9.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index cb30fda253c0a..77d97e5365b9f 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -156,7 +156,7 @@ pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar py4j-0.10.3.jar -pyrolite-4.9.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 9008aa80bc877..572edfa0cc29e 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -157,7 +157,7 @@ pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar py4j-0.10.3.jar -pyrolite-4.9.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 5f14683d9a52f..b34ab51f3b996 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -241,6 +241,17 @@ def __hash__(self): ] ) +streaming_kafka_0_10 = Module( + name="streaming-kafka-0-10", + dependencies=[streaming], + source_file_regexes=[ + "external/kafka-0-10", + "external/kafka-0-10-assembly", + ], + sbt_test_goals=[ + "streaming-kafka-0-10/test", + ] +) streaming_flume_sink = Module( name="streaming-flume-sink", diff --git a/docs/building-spark.md b/docs/building-spark.md index da7eeb8348378..f5acee6b90059 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -91,13 +91,13 @@ Examples: ./build/mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package # Apache Hadoop 2.4.X or 2.5.X - ./build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=VERSION -DskipTests clean package + ./build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package # Apache Hadoop 2.6.X ./build/mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.6.0 -DskipTests clean package # Apache Hadoop 2.7.X and later - ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=VERSION -DskipTests clean package + ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.0 -DskipTests clean package # Different versions of HDFS and YARN. ./build/mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=2.2.0 -DskipTests clean package diff --git a/docs/configuration.md b/docs/configuration.md index 82ce232b336d9..373e22d71a872 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1245,6 +1245,49 @@ Apart from these, the following properties are also available, and may be useful The interval length for the scheduler to revive the worker resource offers to run tasks. + + spark.blacklist.enabled + + false + + + If set to "true", prevent Spark from scheduling tasks on executors that have been blacklisted + due to too many task failures. The blacklisting algorithm can be further controlled by the + other "spark.blacklist" configuration options. + + + + spark.blacklist.task.maxTaskAttemptsPerExecutor + 1 + + (Experimental) For a given task, how many times it can be retried on one executor before the + executor is blacklisted for that task. + + + + spark.blacklist.task.maxTaskAttemptsPerNode + 2 + + (Experimental) For a given task, how many times it can be retried on one node, before the entire + node is blacklisted for that task. + + + + spark.blacklist.stage.maxFailedTasksPerExecutor + 2 + + (Experimental) How many different tasks must fail on one executor, within one stage, before the + executor is blacklisted for that stage. + + + + spark.blacklist.stage.maxFailedExecutorsPerNode + 2 + + (Experimental) How many different executors are marked as blacklisted for a given stage, before + the entire node is marked as failed for the stage. + + spark.speculation false diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 74d5ee1ca6b3f..20b4bee0f58e1 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1373,7 +1373,7 @@ res2: Long = 10 {% endhighlight %} While this code used the built-in support for accumulators of type Long, programmers can also -create their own types by subclassing [AccumulatorV2](api/scala/index.html#org.apache.spark.AccumulatorV2). +create their own types by subclassing [AccumulatorV2](api/scala/index.html#org.apache.spark.util.AccumulatorV2). The AccumulatorV2 abstract class has several methods which need to override: `reset` for resetting the accumulator to zero, and `add` for add anothor value into the accumulator, `merge` for merging another same-type accumulator into this one. Other methods need to override can refer to scala API document. For example, supposing we had a `MyVector` class representing mathematical vectors, we could write: diff --git a/docs/quick-start.md b/docs/quick-start.md index 2eab8d19aa4c6..cb9a378199562 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -240,7 +240,8 @@ object SimpleApp { val logData = sc.textFile(logFile, 2).cache() val numAs = logData.filter(line => line.contains("a")).count() val numBs = logData.filter(line => line.contains("b")).count() - println("Lines with a: %s, Lines with b: %s".format(numAs, numBs)) + println(s"Lines with a: $numAs, Lines with b: $numBs") + sc.stop() } } {% endhighlight %} @@ -328,6 +329,8 @@ public class SimpleApp { }).count(); System.out.println("Lines with a: " + numAs + ", lines with b: " + numBs); + + sc.stop() } } {% endhighlight %} @@ -407,6 +410,8 @@ numAs = logData.filter(lambda s: 'a' in s).count() numBs = logData.filter(lambda s: 'b' in s).count() print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) + +sc.stop() {% endhighlight %} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 173961deaadcb..77b06fcf33740 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -498,6 +498,15 @@ See the [configuration page](configuration.html) for information on Spark config in the history server. + + spark.mesos.gpus.max + 0 + + Set the maximum number GPU resources to acquire for this job. Note that executors will still launch when no GPU resources are found + since this configuration is just a upper limit and not a guaranteed amount. + + + diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 71bdd19c16dbb..d0f43ab0a9cc9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -220,6 +220,41 @@ The `sql` function enables applications to run SQL queries programmatically and +## Global Temporary View + +Temporay views in Spark SQL are session-scoped and will disappear if the session that creates it +terminates. If you want to have a temporary view that is shared among all sessions and keep alive +until the Spark application terminiates, you can create a global temporary view. Global temporary +view is tied to a system preserved database `global_temp`, and we must use the qualified name to +refer it, e.g. `SELECT * FROM global_temp.view1`. + +
+
+{% include_example global_temp_view scala/org/apache/spark/examples/sql/SparkSQLExample.scala %} +
+ +
+{% include_example global_temp_view java/org/apache/spark/examples/sql/JavaSparkSQLExample.java %} +
+ +
+{% include_example global_temp_view python/sql/basic.py %} +
+ +
+ +{% highlight sql %} + +CREATE GLOBAL TEMPORARY VIEW temp_view AS SELECT a + 1, b * 2 FROM tbl + +SELECT * FROM global_temp.temp_view + +{% endhighlight %} + +
+
+ + ## Creating Datasets Datasets are similar to RDDs, however, instead of using Java serialization or Kryo they use @@ -1014,16 +1049,20 @@ bin/spark-shell --driver-class-path postgresql-9.4.1207.jar --jars postgresql-9. {% endhighlight %} Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using -the Data Sources API. The following options are supported: +the Data Sources API. Users can specify the JDBC connection properties in the data source options. +user and password are normally provided as connection properties for +logging into the data sources. In addition to the connection properties, Spark also supports +the following case-sensitive options: + - + + + + + + + + + + + - +
Property NameMeaning
url - The JDBC URL to connect to. + The JDBC URL to connect to. The source-specific connection properties may be specified in the URL. e.g., jdbc:postgresql://localhost/test?user=fred&password=secret
dbtable @@ -1048,28 +1087,42 @@ the Data Sources API. The following options are supported: partitionColumn must be a numeric column from the table in question. Notice that lowerBound and upperBound are just used to decide the partition stride, not for filtering the rows in table. So all rows in the table will be - partitioned and returned. + partitioned and returned. This option applies only to reading.
fetchsize - The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). + The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). This option applies only to reading.
batchsize + The JDBC batch size, which determines how many rows to insert per round trip. This can help performance on JDBC drivers. This option applies only to writing. It defaults to 1000. +
isolationLevel + The transaction isolation level, which applies to current connection. It can be one of NONE, READ_COMMITTED, READ_UNCOMMITTED, REPEATABLE_READ, or SERIALIZABLE, corresponding to standard transaction isolation levels defined by JDBC's Connection object, with default of READ_UNCOMMITTED. This option applies only to writing. Please refer the documentation in java.sql.Connection. +
truncate - This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g. indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. + This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g., indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. This option applies only to writing.
createTableOptions - This is a JDBC writer related option. If specified, this option allows setting of database-specific table and partition options when creating a table. For example: CREATE TABLE t (name string) ENGINE=InnoDB. + This is a JDBC writer related option. If specified, this option allows setting of database-specific table and partition options when creating a table (e.g., CREATE TABLE t (name string) ENGINE=InnoDB.). This option applies only to writing.
@@ -1101,11 +1154,11 @@ USING org.apache.spark.sql.jdbc OPTIONS ( url "jdbc:postgresql:dbserver", dbtable "schema.tablename", - user 'username', + user 'username', password 'password' ) -INSERT INTO TABLE jdbcTable +INSERT INTO TABLE jdbcTable SELECT * FROM resultTable {% endhighlight %} @@ -1293,7 +1346,7 @@ options. - Dataset API and DataFrame API are unified. In Scala, `DataFrame` becomes a type alias for `Dataset[Row]`, while Java API users must replace `DataFrame` with `Dataset`. Both the typed - transformations (e.g. `map`, `filter`, and `groupByKey`) and untyped transformations (e.g. + transformations (e.g., `map`, `filter`, and `groupByKey`) and untyped transformations (e.g., `select` and `groupBy`) are available on the Dataset class. Since compile-time type-safety in Python and R is not a language feature, the concept of Dataset does not apply to these languages’ APIs. Instead, `DataFrame` remains the primary programing abstraction, which is analogous to the @@ -1342,7 +1395,7 @@ options. - Timestamps are now stored at a precision of 1us, rather than 1ns - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains unchanged. - - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). + - The canonical name of SQL/DataFrame functions are now lower case (e.g., sum vs SUM). - JSON data source will not automatically load new files that are created by other applications (i.e. files that are not inserted to the dataset through Spark SQL). For a JSON persistent table (i.e. the metadata of the table is stored in Hive Metastore), @@ -1357,7 +1410,7 @@ options. Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) and writing data out (`DataFrame.write`), -and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). +and deprecated the old APIs (e.g., `SQLContext.parquetFile`, `SQLContext.jsonFile`). See the API docs for `SQLContext.read` (
Scala, diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index 44c39e39446de..456b8453383db 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -27,7 +27,7 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea "bootstrap.servers" -> "localhost:9092,anotherhost:9092", "key.deserializer" -> classOf[StringDeserializer], "value.deserializer" -> classOf[StringDeserializer], - "group.id" -> "example", + "group.id" -> "use_a_separate_group_id_for_each_stream", "auto.offset.reset" -> "latest", "enable.auto.commit" -> (false: java.lang.Boolean) ) @@ -48,7 +48,7 @@ Each item in the stream is a [ConsumerRecord](http://kafka.apache.org/0100/javad For possible kafkaParams, see [Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). -Note that enable.auto.commit is disabled, for discussion see [Storing Offsets](streaming-kafka-0-10-integration.html#storing-offsets) below. +Note that the example sets enable.auto.commit to false, for discussion see [Storing Offsets](streaming-kafka-0-10-integration.html#storing-offsets) below. ### LocationStrategies The new Kafka consumer API will pre-fetch messages into buffers. Therefore it is important for performance reasons that the Spark integration keep cached consumers on executors (rather than recreating them for each batch), and prefer to schedule partitions on the host locations that have the appropriate consumers. @@ -57,6 +57,9 @@ In most cases, you should use `LocationStrategies.PreferConsistent` as shown abo The cache for consumers has a default maximum size of 64. If you expect to be handling more than (64 * number of executors) Kafka partitions, you can change this setting via `spark.streaming.kafka.consumer.cache.maxCapacity` +The cache is keyed by topicpartition and group.id, so use a **separate** `group.id` for each call to `createDirectStream`. + + ### ConsumerStrategies The new Kafka consumer API has a number of different ways to specify topics, some of which require considerable post-object-instantiation setup. `ConsumerStrategies` provides an abstraction that allows Spark to obtain properly configured consumers even after restart from checkpoint. diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java index cff9032f52b5a..c5770d147a6b5 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java @@ -54,6 +54,7 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; // $example off:programmatic_schema$ +import org.apache.spark.sql.AnalysisException; // $example on:untyped_ops$ // col("...") is preferable to df.col("...") @@ -84,7 +85,7 @@ public void setAge(int age) { } // $example off:create_ds$ - public static void main(String[] args) { + public static void main(String[] args) throws AnalysisException { // $example on:init_session$ SparkSession spark = SparkSession .builder() @@ -101,7 +102,7 @@ public static void main(String[] args) { spark.stop(); } - private static void runBasicDataFrameExample(SparkSession spark) { + private static void runBasicDataFrameExample(SparkSession spark) throws AnalysisException { // $example on:create_df$ Dataset df = spark.read().json("examples/src/main/resources/people.json"); @@ -176,6 +177,31 @@ private static void runBasicDataFrameExample(SparkSession spark) { // | 19| Justin| // +----+-------+ // $example off:run_sql$ + + // $example on:global_temp_view$ + // Register the DataFrame as a global temporary view + df.createGlobalTempView("people"); + + // Global temporary view is tied to a system preserved database `global_temp` + spark.sql("SELECT * FROM global_temp.people").show(); + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + + // Global temporary view is cross-session + spark.newSession().sql("SELECT * FROM global_temp.people").show(); + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + // $example off:global_temp_view$ } private static void runDatasetCreationExample(SparkSession spark) { diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py index fdc017aed97c1..ebcf66995b477 100644 --- a/examples/src/main/python/sql/basic.py +++ b/examples/src/main/python/sql/basic.py @@ -114,6 +114,31 @@ def basic_df_example(spark): # +----+-------+ # $example off:run_sql$ + # $example on:global_temp_view$ + # Register the DataFrame as a global temporary view + df.createGlobalTempView("people") + + # Global temporary view is tied to a system preserved database `global_temp` + spark.sql("SELECT * FROM global_temp.people").show() + # +----+-------+ + # | age| name| + # +----+-------+ + # |null|Michael| + # | 30| Andy| + # | 19| Justin| + # +----+-------+ + + # Global temporary view is cross-session + spark.newSession().sql("SELECT * FROM global_temp.people").show() + # +----+-------+ + # | age| name| + # +----+-------+ + # |null|Michael| + # | 30| Andy| + # | 19| Justin| + # +----+-------+ + # $example off:global_temp_view$ + def schema_inference_example(spark): # $example on:schema_inferring$ diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala index 129b81d5fbbf3..f27c403c5b388 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala @@ -135,6 +135,31 @@ object SparkSQLExample { // | 19| Justin| // +----+-------+ // $example off:run_sql$ + + // $example on:global_temp_view$ + // Register the DataFrame as a global temporary view + df.createGlobalTempView("people") + + // Global temporary view is tied to a system preserved database `global_temp` + spark.sql("SELECT * FROM global_temp.people").show() + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + + // Global temporary view is cross-session + spark.newSession().sql("SELECT * FROM global_temp.people").show() + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + // $example off:global_temp_view$ } private def runDatasetCreationExample(spark: SparkSession): Unit = { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 6c03070398fca..c640b93b0a2ee 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -22,7 +22,6 @@ import java.util.concurrent.atomic.AtomicInteger import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata -import org.scalatest.BeforeAndAfter import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.execution.streaming._ @@ -344,7 +343,7 @@ class KafkaSourceSuite extends KafkaSourceTest { } -class KafkaSourceStressSuite extends KafkaSourceTest with BeforeAndAfter { +class KafkaSourceStressSuite extends KafkaSourceTest { import testImplicits._ @@ -358,12 +357,6 @@ class KafkaSourceStressSuite extends KafkaSourceTest with BeforeAndAfter { start + Random.nextInt(start + end - 1) } - after { - for (topic <- testUtils.getAllTopicsAndPartitionSize().toMap.keys) { - testUtils.deleteTopic(topic) - } - } - test("stress test with multiple topics and partitions") { topics.foreach { topic => testUtils.createTopic(topic, partitions = nextInt(1, 6)) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala index 60255fc655e5f..778c06ea16a2b 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -104,6 +104,8 @@ private case class Subscribe[K, V]( toSeek.asScala.foreach { case (topicPartition, offset) => consumer.seek(topicPartition, offset) } + // we've called poll, we must pause or next poll may consume messages and set position + consumer.pause(consumer.assignment()) } consumer @@ -154,6 +156,8 @@ private case class SubscribePattern[K, V]( toSeek.asScala.foreach { case (topicPartition, offset) => consumer.seek(topicPartition, offset) } + // we've called poll, we must pause or next poll may consume messages and set position + consumer.pause(consumer.assignment()) } consumer diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 13827f68f2cb5..432537ebf05b2 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -161,12 +161,31 @@ private[spark] class DirectKafkaInputDStream[K, V]( } } + /** + * The concern here is that poll might consume messages despite being paused, + * which would throw off consumer position. Fix position if this happens. + */ + private def paranoidPoll(c: Consumer[K, V]): Unit = { + val msgs = c.poll(0) + if (!msgs.isEmpty) { + // position should be minimum offset per topicpartition + msgs.asScala.foldLeft(Map[TopicPartition, Long]()) { (acc, m) => + val tp = new TopicPartition(m.topic, m.partition) + val off = acc.get(tp).map(o => Math.min(o, m.offset)).getOrElse(m.offset) + acc + (tp -> off) + }.foreach { case (tp, off) => + logInfo(s"poll(0) returned messages, seeking $tp to $off to compensate") + c.seek(tp, off) + } + } + } + /** * Returns the latest (highest) available offsets, taking new partitions into account. */ protected def latestOffsets(): Map[TopicPartition, Long] = { val c = consumer - c.poll(0) + paranoidPoll(c) val parts = c.assignment().asScala // make sure new partitions are reflected in currentOffsets @@ -223,7 +242,7 @@ private[spark] class DirectKafkaInputDStream[K, V]( override def start(): Unit = { val c = consumer - c.poll(0) + paranoidPoll(c) if (currentOffsets.isEmpty) { currentOffsets = c.assignment().asScala.map { tp => tp -> c.position(tp) diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index e04f35eceb1b4..02aec43c3b34f 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -159,17 +159,19 @@ class DirectKafkaStreamSuite } test("pattern based subscription") { - val topics = List("pat1", "pat2", "advanced3") - // Should match 2 out of 3 topics + val topics = List("pat1", "pat2", "pat3", "advanced3") + // Should match 3 out of 4 topics val pat = """pat\d""".r.pattern val data = Map("a" -> 7, "b" -> 9) topics.foreach { t => kafkaTestUtils.createTopic(t) kafkaTestUtils.sendMessages(t, data) } - val offsets = Map(new TopicPartition("pat2", 0) -> 3L) - // 2 matching topics, one of which starts 3 messages later - val expectedTotal = (data.values.sum * 2) - 3 + val offsets = Map( + new TopicPartition("pat2", 0) -> 3L, + new TopicPartition("pat3", 0) -> 4L) + // 3 matching topics, two of which start a total of 7 messages later + val expectedTotal = (data.values.sum * 3) - 7 val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") ssc = new StreamingContext(sparkConf, Milliseconds(1000)) diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index a64b5768c57b2..e67bf3e328f94 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -59,6 +59,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) + private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") @@ -72,7 +74,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Cores we have acquired with each Mesos task ID val coresByTaskId = new mutable.HashMap[String, Int] + val gpusByTaskId = new mutable.HashMap[String, Int] var totalCoresAcquired = 0 + var totalGpusAcquired = 0 // SlaveID -> Slave // This map accumulates entries for the duration of the job. Slaves are never deleted, because @@ -396,6 +400,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( launchTasks = true val taskId = newMesosTaskId() val offerCPUs = getResource(resources, "cpus").toInt + val taskGPUs = Math.min( + Math.max(0, maxGpus - totalGpusAcquired), getResource(resources, "gpus").toInt) val taskCPUs = executorCores(offerCPUs) val taskMemory = executorMemory(sc) @@ -403,7 +409,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( slaves.getOrElseUpdate(slaveId, new Slave(offer.getHostname)).taskIDs.add(taskId) val (resourcesLeft, resourcesToUse) = - partitionTaskResources(resources, taskCPUs, taskMemory) + partitionTaskResources(resources, taskCPUs, taskMemory, taskGPUs) val taskBuilder = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) @@ -425,6 +431,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( remainingResources(offerId) = resourcesLeft.asJava totalCoresAcquired += taskCPUs coresByTaskId(taskId) = taskCPUs + if (taskGPUs > 0) { + totalGpusAcquired += taskGPUs + gpusByTaskId(taskId) = taskGPUs + } } } } @@ -432,21 +442,28 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } /** Extracts task needed resources from a list of available resources. */ - private def partitionTaskResources(resources: JList[Resource], taskCPUs: Int, taskMemory: Int) + private def partitionTaskResources( + resources: JList[Resource], + taskCPUs: Int, + taskMemory: Int, + taskGPUs: Int) : (List[Resource], List[Resource]) = { // partition cpus & mem val (afterCPUResources, cpuResourcesToUse) = partitionResources(resources, "cpus", taskCPUs) val (afterMemResources, memResourcesToUse) = partitionResources(afterCPUResources.asJava, "mem", taskMemory) + val (afterGPUResources, gpuResourcesToUse) = + partitionResources(afterMemResources.asJava, "gpus", taskGPUs) // If user specifies port numbers in SparkConfig then consecutive tasks will not be launched // on the same host. This essentially means one executor per host. // TODO: handle network isolator case val (nonPortResources, portResourcesToUse) = - partitionPortResources(nonZeroPortValuesFromConfig(sc.conf), afterMemResources) + partitionPortResources(nonZeroPortValuesFromConfig(sc.conf), afterGPUResources) - (nonPortResources, cpuResourcesToUse ++ memResourcesToUse ++ portResourcesToUse) + (nonPortResources, + cpuResourcesToUse ++ memResourcesToUse ++ portResourcesToUse ++ gpuResourcesToUse) } private def canLaunchTask(slaveId: String, resources: JList[Resource]): Boolean = { @@ -513,6 +530,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( totalCoresAcquired -= cores coresByTaskId -= taskId } + // Also remove the gpus we have remembered for this task, if it's in the hashmap + for (gpus <- gpusByTaskId.get(taskId)) { + totalGpusAcquired -= gpus + gpusByTaskId -= taskId + } // If it was a failure, mark the slave as failed for blacklisting purposes if (TaskState.isFailed(state)) { slave.taskFailures += 1 diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 2963d161d6700..73cc241239c4c 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import com.google.common.base.Splitter import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} +import org.apache.mesos.Protos.FrameworkInfo.Capability import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} import org.apache.spark.{SparkConf, SparkContext, SparkException} @@ -93,6 +94,10 @@ trait MesosSchedulerUtils extends Logging { conf.getOption("spark.mesos.role").foreach { role => fwInfoBuilder.setRole(role) } + val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) + if (maxGpus > 0) { + fwInfoBuilder.addCapabilities(Capability.newBuilder().setType(Capability.Type.GPU_RESOURCES)) + } if (credBuilder.hasPrincipal) { new MesosSchedulerDriver( scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build()) diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index c3ab488e2aa69..75ba02e470e27 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -67,7 +67,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite val minMem = backend.executorMemory(sc) val minCpu = 4 - val offers = List((minMem, minCpu)) + val offers = List(Resources(minMem, minCpu)) // launches a task on a valid offer offerResources(offers) @@ -95,8 +95,8 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite // launches a task on a valid offer val minMem = backend.executorMemory(sc) + 1024 val minCpu = 4 - val offer1 = (minMem, minCpu) - val offer2 = (minMem, 1) + val offer1 = Resources(minMem, minCpu) + val offer2 = Resources(minMem, 1) offerResources(List(offer1, offer2)) verifyTaskLaunched(driver, "o1") @@ -115,7 +115,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite setBackend(Map("spark.executor.cores" -> executorCores.toString)) val executorMemory = backend.executorMemory(sc) - val offers = List((executorMemory * 2, executorCores + 1)) + val offers = List(Resources(executorMemory * 2, executorCores + 1)) offerResources(offers) val taskInfos = verifyTaskLaunched(driver, "o1") @@ -130,7 +130,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite val executorMemory = backend.executorMemory(sc) val offerCores = 10 - offerResources(List((executorMemory * 2, offerCores))) + offerResources(List(Resources(executorMemory * 2, offerCores))) val taskInfos = verifyTaskLaunched(driver, "o1") assert(taskInfos.length == 1) @@ -144,7 +144,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite setBackend(Map("spark.cores.max" -> maxCores.toString)) val executorMemory = backend.executorMemory(sc) - offerResources(List((executorMemory, maxCores + 1))) + offerResources(List(Resources(executorMemory, maxCores + 1))) val taskInfos = verifyTaskLaunched(driver, "o1") assert(taskInfos.length == 1) @@ -153,9 +153,38 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(cpus == maxCores) } + test("mesos does not acquire gpus if not specified") { + setBackend() + + val executorMemory = backend.executorMemory(sc) + offerResources(List(Resources(executorMemory, 1, 1))) + + val taskInfos = verifyTaskLaunched(driver, "o1") + assert(taskInfos.length == 1) + + val gpus = backend.getResource(taskInfos.head.getResourcesList, "gpus") + assert(gpus == 0.0) + } + + + test("mesos does not acquire more than spark.mesos.gpus.max") { + val maxGpus = 5 + setBackend(Map("spark.mesos.gpus.max" -> maxGpus.toString)) + + val executorMemory = backend.executorMemory(sc) + offerResources(List(Resources(executorMemory, 1, maxGpus + 1))) + + val taskInfos = verifyTaskLaunched(driver, "o1") + assert(taskInfos.length == 1) + + val gpus = backend.getResource(taskInfos.head.getResourcesList, "gpus") + assert(gpus == maxGpus) + } + + test("mesos declines offers that violate attribute constraints") { setBackend(Map("spark.mesos.constraints" -> "x:true")) - offerResources(List((backend.executorMemory(sc), 4))) + offerResources(List(Resources(backend.executorMemory(sc), 4))) verifyDeclinedOffer(driver, createOfferId("o1"), true) } @@ -165,8 +194,8 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite val executorMemory = backend.executorMemory(sc) offerResources(List( - (executorMemory, maxCores + 1), - (executorMemory, maxCores + 1))) + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1))) verifyTaskLaunched(driver, "o1") verifyDeclinedOffer(driver, createOfferId("o2"), true) @@ -180,8 +209,8 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite val executorMemory = backend.executorMemory(sc) offerResources(List( - (executorMemory * 2, executorCores * 2), - (executorMemory * 2, executorCores * 2))) + Resources(executorMemory * 2, executorCores * 2), + Resources(executorMemory * 2, executorCores * 2))) verifyTaskLaunched(driver, "o1") verifyTaskLaunched(driver, "o2") @@ -193,7 +222,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite // offer with room for two executors val executorMemory = backend.executorMemory(sc) - offerResources(List((executorMemory * 2, executorCores * 2))) + offerResources(List(Resources(executorMemory * 2, executorCores * 2))) // verify two executors were started on a single offer val taskInfos = verifyTaskLaunched(driver, "o1") @@ -397,7 +426,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite setBackend() // launches a task on a valid offer - val offers = List((backend.executorMemory(sc), 1)) + val offers = List(Resources(backend.executorMemory(sc), 1)) offerResources(offers) verifyTaskLaunched(driver, "o1") @@ -434,6 +463,8 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(launchedTasks.head.getCommand.getUrisList.asScala(0).getValue == url) } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) + private def verifyDeclinedOffer(driver: SchedulerDriver, offerId: OfferID, filter: Boolean = false): Unit = { @@ -444,9 +475,9 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite } } - private def offerResources(offers: List[(Int, Int)], startId: Int = 1): Unit = { + private def offerResources(offers: List[Resources], startId: Int = 1): Unit = { val mesosOffers = offers.zipWithIndex.map {case (offer, i) => - createOffer(s"o${i + startId}", s"s${i + startId}", offer._1, offer._2)} + createOffer(s"o${i + startId}", s"s${i + startId}", offer.mem, offer.cpus, None, offer.gpus)} backend.resourceOffers(driver, mesosOffers.asJava) } diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index fa9406f5f0553..7ebb294aa9080 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -32,8 +32,9 @@ object Utils { offerId: String, slaveId: String, mem: Int, - cpu: Int, - ports: Option[(Long, Long)] = None): Offer = { + cpus: Int, + ports: Option[(Long, Long)] = None, + gpus: Int = 0): Offer = { val builder = Offer.newBuilder() builder.addResourcesBuilder() .setName("mem") @@ -42,7 +43,7 @@ object Utils { builder.addResourcesBuilder() .setName("cpus") .setType(Value.Type.SCALAR) - .setScalar(Scalar.newBuilder().setValue(cpu)) + .setScalar(Scalar.newBuilder().setValue(cpus)) ports.foreach { resourcePorts => builder.addResourcesBuilder() .setName("ports") @@ -50,6 +51,12 @@ object Utils { .setRanges(Ranges.newBuilder().addRange(MesosRange.newBuilder() .setBegin(resourcePorts._1).setEnd(resourcePorts._2).build())) } + if (gpus > 0) { + builder.addResourcesBuilder() + .setName("gpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(gpus)) + } builder.setId(createOfferId(offerId)) .setFrameworkId(FrameworkID.newBuilder() .setValue("f1")) @@ -82,4 +89,3 @@ object Utils { TaskID.newBuilder().setValue(taskId).build() } } - diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 329961a25d984..862a468745fbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -78,7 +78,6 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas /** * Param for the name of family which is a description of the label distribution * to be used in the model. - * Supported options: "auto", "multinomial", "binomial". * Supported options: * - "auto": Automatically select the family based on the number of classes: * If numClasses == 1 || numClasses == 2, set to "binomial". diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index e565a6fd3ece2..994ed993c99df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -110,16 +110,28 @@ class NaiveBayes @Since("1.5.0") ( @Since("2.1.0") def setWeightCol(value: String): this.type = set(weightCol, value) - override protected def train(dataset: Dataset[_]): NaiveBayesModel = { - val numClasses = getNumClasses(dataset) + /** + * ml assumes input labels in range [0, numClasses). But this implementation + * is also called by mllib NaiveBayes which allows other kinds of input labels + * such as {-1, +1}. Here we use this parameter to switch between different processing logic. + * It should be removed when we remove mllib NaiveBayes. + */ + private[spark] var isML: Boolean = true - if (isDefined(thresholds)) { - require($(thresholds).length == numClasses, this.getClass.getSimpleName + - ".train() called with non-matching numClasses and thresholds.length." + - s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") - } + private[spark] def setIsML(isML: Boolean): this.type = { + this.isML = isML + this + } - val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { + if (isML) { + val numClasses = getNumClasses(dataset) + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + } val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { @@ -153,6 +165,7 @@ class NaiveBayes @Since("1.5.0") ( } } + val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) // Aggregates term frequencies per label. @@ -176,6 +189,7 @@ class NaiveBayes @Since("1.5.0") ( val numLabels = aggregated.length val numDocuments = aggregated.map(_._2._1).sum + val labelArray = new Array[Double](numLabels) val piArray = new Array[Double](numLabels) val thetaArray = new Array[Double](numLabels * numFeatures) @@ -183,6 +197,7 @@ class NaiveBayes @Since("1.5.0") ( val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 aggregated.foreach { case (label, (n, sumTermFreqs)) => + labelArray(i) = label piArray(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = $(modelType) match { case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) @@ -201,7 +216,7 @@ class NaiveBayes @Since("1.5.0") ( val pi = Vectors.dense(piArray) val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true) - new NaiveBayesModel(uid, pi, theta) + new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray) } @Since("1.5.0") @@ -239,6 +254,19 @@ class NaiveBayesModel private[ml] ( import NaiveBayes.{Bernoulli, Multinomial} + /** + * mllib NaiveBayes is a wrapper of ml implementation currently. + * Input labels of mllib could be {-1, +1} and mllib NaiveBayesModel exposes labels, + * both of which are different from ml, so we should store the labels sequentially + * to be called by mllib. This should be removed when we remove mllib NaiveBayes. + */ + private[spark] var oldLabels: Array[Double] = null + + private[spark] def setOldLabels(labels: Array[Double]): this.type = { + this.oldLabels = labels + this + } + /** * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 2ee899bcca564..389898666eb8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.VectorUDT -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -104,6 +104,27 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** + * Force to index label whether it is numeric or string type. + * Usually we index label only when it is string type. + * If the formula was used by classification algorithms, + * we can force to index label even it is numeric type by setting this param with true. + * Default: false. + * @group param + */ + @Since("2.1.0") + val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel", + "Force to index label whether it is numeric or string") + setDefault(forceIndexLabel -> false) + + /** @group getParam */ + @Since("2.1.0") + def getForceIndexLabel: Boolean = $(forceIndexLabel) + + /** @group setParam */ + @Since("2.1.0") + def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value) + /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { require(isDefined(formula), "Formula must be defined first.") @@ -167,8 +188,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap) encoderStages += new ColumnPruner(tempColumns.toSet) - if (dataset.schema.fieldNames.contains(resolvedFormula.label) && - dataset.schema(resolvedFormula.label).dataType == StringType) { + if ((dataset.schema.fieldNames.contains(resolvedFormula.label) && + dataset.schema(resolvedFormula.label).dataType == StringType) || $(forceIndexLabel)) { encoderStages += new StringIndexer() .setInputCol(resolvedFormula.label) .setOutputCol($(labelCol)) @@ -181,6 +202,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") // optimistic schema; does not contain any ML attributes override def transformSchema(schema: StructType): StructType = { + require(!hasLabelCol(schema) || !$(forceIndexLabel), + "If label column already exists, forceIndexLabel can not be set with true.") if (hasLabelCol(schema)) { StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala index a1e53662f02a8..f4a8556c71f6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.linalg import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -46,7 +46,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { } override def serialize(obj: Matrix): InternalRow = { - val row = new GenericMutableRow(7) + val row = new GenericInternalRow(7) obj match { case sm: SparseMatrix => row.setByte(0, 0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala index 0b9b2ff5c5e26..917861309c573 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.linalg import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -42,14 +42,14 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def serialize(obj: Vector): InternalRow = { obj match { case SparseVector(size, indices, values) => - val row = new GenericMutableRow(4) + val row = new GenericInternalRow(4) row.setByte(0, 0) row.setInt(1, size) row.update(2, UnsafeArrayData.fromPrimitiveArray(indices)) row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row case DenseVector(values) => - val row = new GenericMutableRow(4) + val row = new GenericInternalRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index d1a39fea76ef8..4fdab2dd94655 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -59,6 +59,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = { val rFormula = new RFormula() .setFormula(formula) + .setForceIndexLabel(true) RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 536c58f998080..025ed20c75a04 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -188,17 +188,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select( + col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + "solver is used.'") // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) - val instances: RDD[Instance] = dataset.select( - col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), $(standardization), true) @@ -221,12 +222,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String return lrModel.setSummary(trainingSummary) } - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 0b7ad92b3cf30..b504f411d256d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -705,14 +705,17 @@ private[spark] object RandomForest extends Logging { node.stats } + val validFeatureSplits = + Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx => + featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + }.withFilter { case (_, featureIndex) => + binAggregates.metadata.numSplits(featureIndex) != 0 + } + // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = - Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) - } else { - featureIndexIdx - } + validFeatureSplits.map { case (featureIndexIdx, featureIndex) => val numSplits = binAggregates.metadata.numSplits(featureIndex) if (binAggregates.metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. @@ -966,7 +969,7 @@ private[spark] object RandomForest extends Logging { * NOTE: `metadata.numbins` will be changed accordingly * if there are not enough splits to be found * @param featureIndex feature index to find splits - * @return array of splits + * @return array of split thresholds */ private[tree] def findSplitsForContinuousFeature( featureSamples: Iterable[Double], @@ -975,7 +978,9 @@ private[spark] object RandomForest extends Logging { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") - val splits = { + val splits = if (featureSamples.isEmpty) { + Array.empty[Double] + } else { val numSplits = metadata.numSplits(featureIndex) // get count for each distinct value @@ -987,9 +992,9 @@ private[spark] object RandomForest extends Logging { val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray // if possible splits is not enough or just enough, just return all possible splits - val possibleSplits = valueCounts.length + val possibleSplits = valueCounts.length - 1 if (possibleSplits <= numSplits) { - valueCounts.map(_._1) + valueCounts.map(_._1).init } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) @@ -1023,12 +1028,6 @@ private[spark] object RandomForest extends Logging { splitsBuilder.result() } } - - // TODO: Do not fail; just ignore the useless feature. - assert(splits.length > 0, - s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + - " Please remove this feature and then try again.") - splits } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 32d6968a4e85f..33561be4b5bc1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -364,15 +364,10 @@ class NaiveBayes private ( val nb = new NewNaiveBayes() .setModelType(modelType) .setSmoothing(lambda) + .setIsML(false) - val labels = data.map(_.label).distinct().collect().sorted - - // Input labels for [[org.apache.spark.ml.classification.NaiveBayes]] must be - // in range [0, numClasses). - val dataset = data.map { - case LabeledPoint(label, features) => - (labels.indexOf(label).toDouble, features.asML) - }.toDF("label", "features") + val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) } + .toDF("label", "features") val newModel = nb.fit(dataset) @@ -383,7 +378,9 @@ class NaiveBayes private ( theta(i)(j) = v } - new NaiveBayesModel(labels, pi, theta, modelType) + require(newModel.oldLabels != null, + "The underlying ML NaiveBayes training does not produce labels.") + new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 23141aaf42b49..68a7b3b6763af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -43,18 +43,17 @@ import org.apache.spark.util.random.XORShiftRandom class KMeans private ( private var k: Int, private var maxIterations: Int, - private var runs: Int, private var initializationMode: String, private var initializationSteps: Int, private var epsilon: Double, private var seed: Long) extends Serializable with Logging { /** - * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, + * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random}. */ @Since("0.8.0") - def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) + def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) /** * Number of clusters to create (k). @@ -112,15 +111,17 @@ class KMeans private ( * This function has no effect since Spark 2.0.0. */ @Since("1.4.0") + @deprecated("This has no effect and always returns 1", "2.1.0") def getRuns: Int = { logWarning("Getting number of runs has no effect since Spark 2.0.0.") - runs + 1 } /** * This function has no effect since Spark 2.0.0. */ @Since("0.8.0") + @deprecated("This has no effect", "2.1.0") def setRuns(runs: Int): this.type = { logWarning("Setting number of runs has no effect since Spark 2.0.0.") this @@ -239,17 +240,9 @@ class KMeans private ( val initStartTime = System.nanoTime() - // Only one run is allowed when initialModel is given - val numRuns = if (initialModel.nonEmpty) { - if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.") - 1 - } else { - runs - } - val centers = initialModel match { case Some(kMeansCenters) => - Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) + kMeansCenters.clusterCenters.map(new VectorWithNorm(_)) case None => if (initializationMode == KMeans.RANDOM) { initRandom(data) @@ -258,89 +251,62 @@ class KMeans private ( } } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 - logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + - " seconds.") + logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.") - val active = Array.fill(numRuns)(true) - val costs = Array.fill(numRuns)(0.0) - - var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns) + var converged = false + var cost = 0.0 var iteration = 0 val iterationStartTime = System.nanoTime() - instr.foreach(_.logNumFeatures(centers(0)(0).vector.size)) + instr.foreach(_.logNumFeatures(centers.head.vector.size)) - // Execute iterations of Lloyd's algorithm until all runs have converged - while (iteration < maxIterations && !activeRuns.isEmpty) { - type WeightedPoint = (Vector, Long) - def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint = { - axpy(1.0, x._1, y._1) - (y._1, x._2 + y._2) - } - - val activeCenters = activeRuns.map(r => centers(r)).toArray - val costAccums = activeRuns.map(_ => sc.doubleAccumulator) - - val bcActiveCenters = sc.broadcast(activeCenters) + // Execute iterations of Lloyd's algorithm until converged + while (iteration < maxIterations && !converged) { + val costAccum = sc.doubleAccumulator + val bcCenters = sc.broadcast(centers) // Find the sum and count of points mapping to each center val totalContribs = data.mapPartitions { points => - val thisActiveCenters = bcActiveCenters.value - val runs = thisActiveCenters.length - val k = thisActiveCenters(0).length - val dims = thisActiveCenters(0)(0).vector.size + val thisCenters = bcCenters.value + val dims = thisCenters.head.vector.size - val sums = Array.fill(runs, k)(Vectors.zeros(dims)) - val counts = Array.fill(runs, k)(0L) + val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims)) + val counts = Array.fill(thisCenters.length)(0L) points.foreach { point => - (0 until runs).foreach { i => - val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point) - costAccums(i).add(cost) - val sum = sums(i)(bestCenter) - axpy(1.0, point.vector, sum) - counts(i)(bestCenter) += 1 - } + val (bestCenter, cost) = KMeans.findClosest(thisCenters, point) + costAccum.add(cost) + val sum = sums(bestCenter) + axpy(1.0, point.vector, sum) + counts(bestCenter) += 1 } - val contribs = for (i <- 0 until runs; j <- 0 until k) yield { - ((i, j), (sums(i)(j), counts(i)(j))) - } - contribs.iterator - }.reduceByKey(mergeContribs).collectAsMap() - - bcActiveCenters.destroy(blocking = false) - - // Update the cluster centers and costs for each active run - for ((run, i) <- activeRuns.zipWithIndex) { - var changed = false - var j = 0 - while (j < k) { - val (sum, count) = totalContribs((i, j)) - if (count != 0) { - scal(1.0 / count, sum) - val newCenter = new VectorWithNorm(sum) - if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) { - changed = true - } - centers(run)(j) = newCenter - } - j += 1 - } - if (!changed) { - active(run) = false - logInfo("Run " + run + " finished in " + (iteration + 1) + " iterations") + counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator + }.reduceByKey { case ((sum1, count1), (sum2, count2)) => + axpy(1.0, sum2, sum1) + (sum1, count1 + count2) + }.collectAsMap() + + bcCenters.destroy(blocking = false) + + // Update the cluster centers and costs + converged = true + totalContribs.foreach { case (j, (sum, count)) => + scal(1.0 / count, sum) + val newCenter = new VectorWithNorm(sum) + if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) { + converged = false } - costs(run) = costAccums(i).value + centers(j) = newCenter } - activeRuns = activeRuns.filter(active(_)) + cost = costAccum.value iteration += 1 } val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9 - logInfo(s"Iterations took " + "%.3f".format(iterationTimeInSeconds) + " seconds.") + logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.") if (iteration == maxIterations) { logInfo(s"KMeans reached the max number of iterations: $maxIterations.") @@ -348,59 +314,43 @@ class KMeans private ( logInfo(s"KMeans converged in $iteration iterations.") } - val (minCost, bestRun) = costs.zipWithIndex.min + logInfo(s"The cost is $cost.") - logInfo(s"The cost for the best run is $minCost.") - - new KMeansModel(centers(bestRun).map(_.vector)) + new KMeansModel(centers.map(_.vector)) } /** - * Initialize `runs` sets of cluster centers at random. + * Initialize a set of cluster centers at random. */ - private def initRandom(data: RDD[VectorWithNorm]) - : Array[Array[VectorWithNorm]] = { - // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq - Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v => - new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm) - }.toArray) + private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { + data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()).map(_.toDense) } /** - * Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al. + * Initialize a set of cluster centers using the k-means|| algorithm by Bahmani et al. * (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries - * to find with dissimilar cluster centers by starting with a random center and then doing + * to find dissimilar cluster centers by starting with a random center and then doing * passes where more centers are chosen with probability proportional to their squared distance * to the current cluster set. It results in a provable approximation to an optimal clustering. * * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. */ - private def initKMeansParallel(data: RDD[VectorWithNorm]) - : Array[Array[VectorWithNorm]] = { + private def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { // Initialize empty centers and point costs. - val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) - var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity)) + var costs = data.map(_ => Double.PositiveInfinity) - // Initialize each run's first center to a random point. + // Initialize the first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() - val sample = data.takeSample(true, runs, seed).toSeq + val sample = data.takeSample(false, 1, seed) // Could be empty if data is empty; fail with a better message early: - require(sample.size >= runs, s"Required $runs samples but got ${sample.size} from $data") - val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) - - /** Merges new centers to centers. */ - def mergeNewCenters(): Unit = { - var r = 0 - while (r < runs) { - centers(r) ++= newCenters(r) - newCenters(r).clear() - r += 1 - } - } + require(sample.nonEmpty, s"No samples available from $data") + + val centers = ArrayBuffer[VectorWithNorm]() + var newCenters = Seq(sample.head.toDense) + centers ++= newCenters - // On each step, sample 2 * k points on average for each run with probability proportional - // to their squared distance from that run's centers. Note that only distances between points + // On each step, sample 2 * k points on average with probability proportional + // to their squared distance from the centers. Note that only distances between points // and new centers are computed in each iteration. var step = 0 var bcNewCentersList = ArrayBuffer[Broadcast[_]]() @@ -409,74 +359,39 @@ class KMeans private ( bcNewCentersList += bcNewCenters val preCosts = costs costs = data.zip(preCosts).map { case (point, cost) => - Array.tabulate(runs) { r => - math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) - } - }.persist(StorageLevel.MEMORY_AND_DISK) - val sumCosts = costs - .aggregate(new Array[Double](runs))( - seqOp = (s, v) => { - // s += v - var r = 0 - while (r < runs) { - s(r) += v(r) - r += 1 - } - s - }, - combOp = (s0, s1) => { - // s0 += s1 - var r = 0 - while (r < runs) { - s0(r) += s1(r) - r += 1 - } - s0 - } - ) + math.min(KMeans.pointCost(bcNewCenters.value, point), cost) + }.persist(StorageLevel.MEMORY_AND_DISK) + val sumCosts = costs.sum() bcNewCenters.unpersist(blocking = false) preCosts.unpersist(blocking = false) - val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) - pointsWithCosts.flatMap { case (p, c) => - val rs = (0 until runs).filter { r => - rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) - } - if (rs.nonEmpty) Some((p, rs)) else None - } + pointCosts.filter { case (_, c) => rand.nextDouble() < 2.0 * c * k / sumCosts }.map(_._1) }.collect() - mergeNewCenters() - chosen.foreach { case (p, rs) => - rs.foreach(newCenters(_) += p.toDense) - } + newCenters = chosen.map(_.toDense) + centers ++= newCenters step += 1 } - mergeNewCenters() costs.unpersist(blocking = false) bcNewCentersList.foreach(_.destroy(false)) - // Finally, we might have a set of more than k candidate centers for each run; weigh each - // candidate by the number of points in the dataset mapping to it and run a local k-means++ - // on the weighted centers to pick just k of them - val bcCenters = data.context.broadcast(centers) - val weightMap = data.flatMap { p => - Iterator.tabulate(runs) { r => - ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) - } - }.reduceByKey(_ + _).collectAsMap() + if (centers.size == k) { + centers.toArray + } else { + // Finally, we might have a set of more or less than k candidate centers; weight each + // candidate by the number of points in the dataset mapping to it and run a local k-means++ + // on the weighted centers to pick k of them + val bcCenters = data.context.broadcast(centers) + val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue() - bcCenters.destroy(blocking = false) + bcCenters.destroy(blocking = false) - val finalCenters = (0 until runs).par.map { r => - val myCenters = centers(r).toArray - val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray - LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) + val myWeights = centers.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray + LocalKMeans.kMeansPlusPlus(0, centers.toArray, myWeights, k, 30) } - - finalCenters.toArray } } @@ -493,6 +408,52 @@ object KMeans { @Since("0.8.0") val K_MEANS_PARALLEL = "k-means||" + /** + * Trains a k-means model using the given set of parameters. + * + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") + * @param seed Random seed for cluster initialization. Default is to generate seed based + * on system time. + */ + @Since("2.1.0") + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int, + initializationMode: String, + seed: Long): KMeansModel = { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setInitializationMode(initializationMode) + .setSeed(seed) + .run(data) + } + + /** + * Trains a k-means model using the given set of parameters. + * + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") + */ + @Since("2.1.0") + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int, + initializationMode: String): KMeansModel = { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setInitializationMode(initializationMode) + .run(data) + } + /** * Trains a k-means model using the given set of parameters. * @@ -506,6 +467,7 @@ object KMeans { * on system time. */ @Since("1.3.0") + @deprecated("Use train method without 'runs'", "2.1.0") def train( data: RDD[Vector], k: Int, @@ -531,6 +493,7 @@ object KMeans { * "k-means||". (default: "k-means||") */ @Since("0.8.0") + @deprecated("Use train method without 'runs'", "2.1.0") def train( data: RDD[Vector], k: Int, @@ -551,19 +514,24 @@ object KMeans { data: RDD[Vector], k: Int, maxIterations: Int): KMeansModel = { - train(data, k, maxIterations, 1, K_MEANS_PARALLEL) + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .run(data) } /** * Trains a k-means model using specified parameters and the default values for unspecified. */ @Since("0.8.0") + @deprecated("Use train method without 'runs'", "2.1.0") def train( data: RDD[Vector], k: Int, maxIterations: Int, runs: Int): KMeansModel = { - train(data, k, maxIterations, runs, K_MEANS_PARALLEL) + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .run(data) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 6642999a2121f..542a69b3ef8cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -28,7 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.annotation.Since import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -189,7 +189,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { } override def serialize(obj: Matrix): InternalRow = { - val row = new GenericMutableRow(7) + val row = new GenericInternalRow(7) obj match { case sm: SparseMatrix => row.setByte(0, 0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 91f065831c804..fbd217af74ecb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -34,7 +34,7 @@ import org.apache.spark.annotation.{AlphaComponent, Since} import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -214,14 +214,14 @@ class VectorUDT extends UserDefinedType[Vector] { override def serialize(obj: Vector): InternalRow = { obj match { case SparseVector(size, indices, values) => - val row = new GenericMutableRow(4) + val row = new GenericInternalRow(4) row.setByte(0, 0) row.setInt(1, size) row.update(2, UnsafeArrayData.fromPrimitiveArray(indices)) row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row case DenseVector(values) => - val row = new GenericMutableRow(4) + val row = new GenericInternalRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 97c268f3d5c97..c664460d7d8bb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -57,7 +57,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } - test("label column already exists") { + test("label column already exists and forceIndexLabel was set with false") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y") val model = formula.fit(original) @@ -66,6 +66,14 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(resultSchema.toString == model.transform(original).schema.toString) } + test("label column already exists but forceIndexLabel was set with true") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y").setForceIndexLabel(true) + val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + intercept[IllegalArgumentException] { + formula.fit(original) + } + } + test("label column already exists but is not numeric type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = Seq((0, true), (2, false)).toDF("x", "y") @@ -137,6 +145,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result.collect() === expected.collect()) } + test("force to index label even it is numeric type") { + val formula = new RFormula().setFormula("id ~ a + b").setForceIndexLabel(true) + val original = spark.createDataFrame( + Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val expected = spark.createDataFrame( + Seq( + (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), + (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), + (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), + (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.collect() === expected.collect()) + } + test("attribute generation") { val formula = new RFormula().setFormula("id ~ a + b") val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 5ae371b489aa5..1c94ec67d79d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -1015,12 +1015,14 @@ class LinearRegressionSuite } test("should support all NumericType labels and not support other types") { - val lr = new LinearRegression().setMaxIter(1) - MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( - lr, spark, isClassification = false) { (expected, actual) => + for (solver <- Seq("auto", "l-bfgs", "normal")) { + val lr = new LinearRegression().setMaxIter(1).setSolver(solver) + MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( + lr, spark, isClassification = false) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 79b19ea5ad206..499d386e66413 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -115,7 +115,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 3) + assert(splits === Array(1.0, 2.0)) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -129,23 +129,53 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 2) - assert(splits(0) === 2.0) - assert(splits(1) === 3.0) + assert(splits === Array(2.0, 3.0)) } // find splits when most samples close to the maximum { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, + Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 1) - assert(splits(0) === 1.0) + assert(splits === Array(1.0)) } + + // find splits for constant feature + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 0, 0).map(_.toDouble) + val featureSamplesEmpty = Array.empty[Double] + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits === Array[Double]()) + val splitsEmpty = + RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0) + assert(splitsEmpty === Array[Double]()) + } + } + + test("train with constant features") { + val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)) + val data = Array.fill(5)(lp) + val rdd = sc.parallelize(data) + val strategy = new OldStrategy( + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 2, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) + val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) + assert(tree.rootNode.impurity === -1.0) + assert(tree.depth === 0) + assert(tree.rootNode.prediction === lp.label) } test("Multiclass classification with unordered categorical features: split calculations") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 163e3f2fdea40..ae72d37a0b61c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -55,7 +55,11 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getFunction"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.databaseExists"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.tableExists"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists"), + // [SPARK-17338][SQL] add global temp view + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropGlobalTempView"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView") ) } diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 924da3eecf214..64b6f238e9c32 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -52,6 +52,14 @@ >>> sorted(conf.getAll(), key=lambda p: p[0]) [(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), \ (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] +>>> conf._jconf.setExecutorEnv("VAR5", "value5") +JavaObject id... +>>> print(conf.toDebugString()) +spark.executorEnv.VAR1=value1 +spark.executorEnv.VAR3=value3 +spark.executorEnv.VAR4=value4 +spark.executorEnv.VAR5=value5 +spark.home=/path """ __all__ = ['SparkConf'] @@ -101,13 +109,24 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): self._jconf = _jconf else: from pyspark.context import SparkContext - SparkContext._ensure_initialized() _jvm = _jvm or SparkContext._jvm - self._jconf = _jvm.SparkConf(loadDefaults) + + if _jvm is not None: + # JVM is created, so create self._jconf directly through JVM + self._jconf = _jvm.SparkConf(loadDefaults) + self._conf = None + else: + # JVM is not created, so store data in self._conf first + self._jconf = None + self._conf = {} def set(self, key, value): """Set a configuration property.""" - self._jconf.set(key, unicode(value)) + # Try to set self._jconf first if JVM is created, set self._conf if JVM is not created yet. + if self._jconf is not None: + self._jconf.set(key, unicode(value)) + else: + self._conf[key] = unicode(value) return self def setIfMissing(self, key, value): @@ -118,17 +137,17 @@ def setIfMissing(self, key, value): def setMaster(self, value): """Set master URL to connect to.""" - self._jconf.setMaster(value) + self.set("spark.master", value) return self def setAppName(self, value): """Set application name.""" - self._jconf.setAppName(value) + self.set("spark.app.name", value) return self def setSparkHome(self, value): """Set path where Spark is installed on worker nodes.""" - self._jconf.setSparkHome(value) + self.set("spark.home", value) return self def setExecutorEnv(self, key=None, value=None, pairs=None): @@ -136,10 +155,10 @@ def setExecutorEnv(self, key=None, value=None, pairs=None): if (key is not None and pairs is not None) or (key is None and pairs is None): raise Exception("Either pass one key-value pair or a list of pairs") elif key is not None: - self._jconf.setExecutorEnv(key, value) + self.set("spark.executorEnv." + key, value) elif pairs is not None: for (k, v) in pairs: - self._jconf.setExecutorEnv(k, v) + self.set("spark.executorEnv." + k, v) return self def setAll(self, pairs): @@ -149,35 +168,49 @@ def setAll(self, pairs): :param pairs: list of key-value pairs to set """ for (k, v) in pairs: - self._jconf.set(k, v) + self.set(k, v) return self def get(self, key, defaultValue=None): """Get the configured value for some key, or return a default otherwise.""" if defaultValue is None: # Py4J doesn't call the right get() if we pass None - if not self._jconf.contains(key): - return None - return self._jconf.get(key) + if self._jconf is not None: + if not self._jconf.contains(key): + return None + return self._jconf.get(key) + else: + if key not in self._conf: + return None + return self._conf[key] else: - return self._jconf.get(key, defaultValue) + if self._jconf is not None: + return self._jconf.get(key, defaultValue) + else: + return self._conf.get(key, defaultValue) def getAll(self): """Get all values as a list of key-value pairs.""" - pairs = [] - for elem in self._jconf.getAll(): - pairs.append((elem._1(), elem._2())) - return pairs + if self._jconf is not None: + return [(elem._1(), elem._2()) for elem in self._jconf.getAll()] + else: + return self._conf.items() def contains(self, key): """Does this configuration contain a given key?""" - return self._jconf.contains(key) + if self._jconf is not None: + return self._jconf.contains(key) + else: + return key in self._conf def toDebugString(self): """ Returns a printable version of the configuration, as a list of key=value pairs, one per line. """ - return self._jconf.toDebugString() + if self._jconf is not None: + return self._jconf.toDebugString() + else: + return '\n'.join('%s=%s' % (k, v) for k, v in self._conf.items()) def _test(): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a3dd1950a522f..1b2e199c395be 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -109,7 +109,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ValueError:... """ self._callsite = first_spark_call() or CallSite(None, None, None) - SparkContext._ensure_initialized(self, gateway=gateway) + SparkContext._ensure_initialized(self, gateway=gateway, conf=conf) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, conf, jsc, profiler_cls) @@ -121,7 +121,15 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, conf, jsc, profiler_cls): self.environment = environment or {} - self._conf = conf or SparkConf(_jvm=self._jvm) + # java gateway must have been launched at this point. + if conf is not None and conf._jconf is not None: + # conf has been initialized in JVM properly, so use conf directly. This represent the + # scenario that JVM has been launched before SparkConf is created (e.g. SparkContext is + # created and then stopped, and we create a new SparkConf and new SparkContext again) + self._conf = conf + else: + self._conf = SparkConf(_jvm=SparkContext._jvm) + self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer if batchSize == 0: @@ -232,14 +240,14 @@ def _initialize_context(self, jconf): return self._jvm.JavaSparkContext(jconf) @classmethod - def _ensure_initialized(cls, instance=None, gateway=None): + def _ensure_initialized(cls, instance=None, gateway=None, conf=None): """ Checks whether a SparkContext is initialized or not. Throws error if a SparkContext is already running. """ with SparkContext._lock: if not SparkContext._gateway: - SparkContext._gateway = gateway or launch_gateway() + SparkContext._gateway = gateway or launch_gateway(conf) SparkContext._jvm = SparkContext._gateway.jvm if instance: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index f76cadcf62438..c1cf843d84388 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -32,7 +32,12 @@ from pyspark.serializers import read_int -def launch_gateway(): +def launch_gateway(conf=None): + """ + launch jvm gateway + :param conf: spark configuration passed to spark-submit + :return: + """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: @@ -41,13 +46,17 @@ def launch_gateway(): # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" + command = [os.path.join(SPARK_HOME, script)] + if conf: + for k, v in conf.getAll(): + command += ['--conf', '%s=%s' % (k, v)] submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): submit_args = ' '.join([ "--conf spark.ui.enabled=false", submit_args ]) - command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) + command = command + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ea60fab029582..3f763a10d4066 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -981,7 +981,7 @@ def trees(self): @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol, HasThresholds, JavaMLWritable, JavaMLReadable): + HasRawPredictionCol, HasThresholds, HasWeightCol, JavaMLWritable, JavaMLReadable): """ Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. `Multinomial NB @@ -995,23 +995,23 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([ - ... Row(label=0.0, features=Vectors.dense([0.0, 0.0])), - ... Row(label=0.0, features=Vectors.dense([0.0, 1.0])), - ... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))]) - >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial") + ... Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])), + ... Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])), + ... Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))]) + >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight") >>> model = nb.fit(df) >>> model.pi - DenseVector([-0.51..., -0.91...]) + DenseVector([-0.81..., -0.58...]) >>> model.theta - DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1) + DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1) >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF() >>> result = model.transform(test0).head() >>> result.prediction 1.0 >>> result.probability - DenseVector([0.42..., 0.57...]) + DenseVector([0.32..., 0.67...]) >>> result.rawPrediction - DenseVector([-1.60..., -1.32...]) + DenseVector([-1.72..., -0.99...]) >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 @@ -1045,11 +1045,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial", thresholds=None): + modelType="multinomial", thresholds=None, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial", thresholds=None) + modelType="multinomial", thresholds=None, weightCol=None) """ super(NaiveBayes, self).__init__() self._java_obj = self._new_java_obj( @@ -1062,11 +1062,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.5.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial", thresholds=None): + modelType="multinomial", thresholds=None, weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial", thresholds=None) + modelType="multinomial", thresholds=None, weightCol=None) Sets params for Naive Bayes. """ kwargs = self.setParams._input_kwargs diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ed81eb16df3cd..0e2ae19ca39aa 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2017,8 +2017,7 @@ def repartition(self, numPartitions): >>> len(rdd.repartition(10).glom().collect()) 10 """ - jrdd = self._jrdd.repartition(numPartitions) - return RDD(jrdd, self.ctx, self._jrdd_deserializer) + return self.coalesce(numPartitions, shuffle=True) def coalesce(self, numPartitions, shuffle=False): """ @@ -2029,7 +2028,15 @@ def coalesce(self, numPartitions, shuffle=False): >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() [[1, 2, 3, 4, 5]] """ - jrdd = self._jrdd.coalesce(numPartitions, shuffle) + if shuffle: + # In Scala's repartition code, we will distribute elements evenly across output + # partitions. However, the RDD from Python is serialized as a single binary data, + # so the distribution fails and produces highly skewed partitions. We need to + # convert it to a RDD of java object before repartitioning. + data_java_rdd = self._to_java_object_rdd().coalesce(numPartitions, shuffle) + jrdd = self.ctx._jvm.SerDeUtil.javaToPython(data_java_rdd) + else: + jrdd = self._jrdd.coalesce(numPartitions, shuffle) return RDD(jrdd, self.ctx, self._jrdd_deserializer) def zip(self, other): diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 3c5030722f307..a36d02e0db134 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -167,8 +167,12 @@ def createExternalTable(self, tableName, path=None, source=None, schema=None, ** @since(2.0) def dropTempView(self, viewName): - """Drops the temporary view with the given view name in the catalog. + """Drops the local temporary view with the given view name in the catalog. If the view has been cached before, then it will also be uncached. + Returns true if this view is dropped successfully, false otherwise. + + Note that, the return type of this method was None in Spark 2.0, but changed to Boolean + in Spark 2.1. >>> spark.createDataFrame([(1, 1)]).createTempView("my_table") >>> spark.table("my_table").collect() @@ -181,6 +185,23 @@ def dropTempView(self, viewName): """ self._jcatalog.dropTempView(viewName) + @since(2.1) + def dropGlobalTempView(self, viewName): + """Drops the global temporary view with the given view name in the catalog. + If the view has been cached before, then it will also be uncached. + Returns true if this view is dropped successfully, false otherwise. + + >>> spark.createDataFrame([(1, 1)]).createGlobalTempView("my_table") + >>> spark.table("global_temp.my_table").collect() + [Row(_1=1, _2=1)] + >>> spark.catalog.dropGlobalTempView("my_table") + >>> spark.table("global_temp.my_table") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: ... + """ + self._jcatalog.dropGlobalTempView(viewName) + @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=StringType()): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 7482be8bda5c4..8264dcf8a97d2 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -386,7 +386,7 @@ def tables(self, dbName=None): >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() - Row(tableName=u'table1', isTemporary=True) + Row(database=u'', tableName=u'table1', isTemporary=True) """ if dbName is None: return DataFrame(self._ssql_ctx.tables(), self) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0ac481a8a8b56..ce277eb204d13 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -131,7 +131,7 @@ def registerTempTable(self, name): @since(2.0) def createTempView(self, name): - """Creates a temporary view with this DataFrame. + """Creates a local temporary view with this DataFrame. The lifetime of this temporary table is tied to the :class:`SparkSession` that was used to create this :class:`DataFrame`. @@ -153,7 +153,7 @@ def createTempView(self, name): @since(2.0) def createOrReplaceTempView(self, name): - """Creates or replaces a temporary view with this DataFrame. + """Creates or replaces a local temporary view with this DataFrame. The lifetime of this temporary table is tied to the :class:`SparkSession` that was used to create this :class:`DataFrame`. @@ -169,6 +169,27 @@ def createOrReplaceTempView(self, name): """ self._jdf.createOrReplaceTempView(name) + @since(2.1) + def createGlobalTempView(self, name): + """Creates a global temporary view with this DataFrame. + + The lifetime of this temporary view is tied to this Spark application. + throws :class:`TempTableAlreadyExistsException`, if the view name already exists in the + catalog. + + >>> df.createGlobalTempView("people") + >>> df2 = spark.sql("select * from global_temp.people") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + >>> df.createGlobalTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: u"Temporary table 'people' already exists;" + >>> spark.catalog.dropGlobalTempView("people") + + """ + self._jdf.createGlobalTempView(name) + @property @since(1.4) def write(self): @@ -640,25 +661,24 @@ def join(self, other, on=None, how=None): if on is not None and not isinstance(on, list): on = [on] - if on is None or len(on) == 0: - jdf = self._jdf.crossJoin(other._jdf) - elif isinstance(on[0], basestring): - if how is None: - jdf = self._jdf.join(other._jdf, self._jseq(on), "inner") + if on is not None: + if isinstance(on[0], basestring): + on = self._jseq(on) else: - assert isinstance(how, basestring), "how should be basestring" - jdf = self._jdf.join(other._jdf, self._jseq(on), how) + assert isinstance(on[0], Column), "on should be Column or list of Column" + if len(on) > 1: + on = reduce(lambda x, y: x.__and__(y), on) + else: + on = on[0] + on = on._jc + + if on is None and how is None: + jdf = self._jdf.crossJoin(other._jdf) else: - assert isinstance(on[0], Column), "on should be Column or list of Column" - if len(on) > 1: - on = reduce(lambda x, y: x.__and__(y), on) - else: - on = on[0] if how is None: - jdf = self._jdf.join(other._jdf, on._jc, "inner") - else: - assert isinstance(how, basestring), "how should be basestring" - jdf = self._jdf.join(other._jdf, on._jc, how) + how = "inner" + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, on, how) return DataFrame(jdf, self.sql_ctx) @since(1.6) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 45d6bf944b702..7fa3fd2de7ddf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -112,11 +112,8 @@ def _(): 'sinh': 'Computes the hyperbolic sine of the given value.', 'tan': 'Computes the tangent of the given value.', 'tanh': 'Computes the hyperbolic tangent of the given value.', - 'toDegrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + - 'measured in degrees.', - 'toRadians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + - 'measured in radians.', - + 'toDegrees': '.. note:: Deprecated in 2.1, use degrees instead.', + 'toRadians': '.. note:: Deprecated in 2.1, use radians instead.', 'bitwiseNOT': 'Computes bitwise not.', } @@ -135,7 +132,15 @@ def _(): 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.', 'collect_list': 'Aggregate function: returns a list of objects with duplicates.', 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' + - ' eliminated.' + ' eliminated.', +} + +_functions_2_1 = { + # unary math functions + 'degrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + + 'measured in degrees.', + 'radians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + + 'measured in radians.', } # math functions that take two arguments as input @@ -182,21 +187,31 @@ def _(): globals()[_name] = since(1.6)(_create_window_function(_name, _doc)) for _name, _doc in _functions_1_6.items(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) +for _name, _doc in _functions_2_1.items(): + globals()[_name] = since(2.1)(_create_function(_name, _doc)) del _name, _doc @since(1.3) def approxCountDistinct(col, rsd=None): + """ + .. note:: Deprecated in 2.1, use approx_count_distinct instead. + """ + return approx_count_distinct(col, rsd) + + +@since(2.1) +def approx_count_distinct(col, rsd=None): """Returns a new :class:`Column` for approximate distinct count of ``col``. - >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() + >>> df.agg(approx_count_distinct(df.age).alias('c')).collect() [Row(c=2)] """ sc = SparkContext._active_spark_context if rsd is None: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) + jc = sc._jvm.functions.approx_count_distinct(_to_java_column(col)) else: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) + jc = sc._jvm.functions.approx_count_distinct(_to_java_column(col), rsd) return Column(jc) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3ad6f80de9fdf..91c2b17049fa1 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -289,8 +289,8 @@ def text(self, paths): [Row(value=u'hello'), Row(value=u'this')] """ if isinstance(paths, basestring): - path = [paths] - return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path))) + paths = [paths] + return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @since(2.0) def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 8418abf99c8d5..1e40b9c39fc4f 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -176,7 +176,7 @@ def getOrCreate(self): sc._conf.set(key, value) session = SparkSession(sc) for key, value in self._options.items(): - session.conf.set(key, value) + session._jsparkSession.sessionState().conf().setConfString(key, value) for key, value in self._options.items(): session.sparkContext._conf.set(key, value) return session diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c2171c277cac3..51d5e7ab0568e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1508,6 +1508,12 @@ def test_toDF_with_schema_string(self): self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + # Regression test for invalid join methods when on is None, Spark-14761 + def test_invalid_join_method(self): + df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) + df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"]) + self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type")) + def test_conf(self): spark = self.spark spark.conf.set("bogo", "sipeo") @@ -1702,6 +1708,20 @@ def test_cache(self): "does_not_exist", lambda: spark.catalog.uncacheTable("does_not_exist")) + def test_read_text_file_list(self): + df = self.spark.read.text(['python/test_support/sql/text-test.txt', + 'python/test_support/sql/text-test.txt']) + count = df.count() + self.assertEquals(count, 4) + + def test_BinaryType_serialization(self): + # Pyrolite version <= 4.9 could not serialize BinaryType with Python3 SPARK-17808 + schema = StructType([StructField('mybytes', BinaryType())]) + data = [[bytearray(b'here is my data')], + [bytearray(b'and here is some more')]] + df = self.spark.createDataFrame(data, schema=schema) + df.collect() + class HiveSparkSubmitTests(SparkSubmitTests): @@ -1853,6 +1873,38 @@ def test_window_functions_without_partitionBy(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_window_functions_cumulative_sum(self): + df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"]) + from pyspark.sql import functions as F + + # Test cumulative sum + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0))) + rs = sorted(sel.collect()) + expected = [("one", 1), ("two", 3)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + # Test boundary values less than JVM's Long.MinValue and make sure we don't overflow + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0))) + rs = sorted(sel.collect()) + expected = [("one", 1), ("two", 3)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + # Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow + frame_end = Window.unboundedFollowing + 1 + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end))) + rs = sorted(sel.collect()) + expected = [("one", 3), ("two", 2)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + def test_collect_functions(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql import functions diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 46663f69a0881..c345e623f1cb1 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -36,8 +36,8 @@ class Window(object): For example: - >>> # PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - >>> window = Window.partitionBy("country").orderBy("date").rowsBetween(-sys.maxsize, 0) + >>> # ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + >>> window = Window.orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow) >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) @@ -46,6 +46,16 @@ class Window(object): .. versionadded:: 1.4 """ + + _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 + _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 + + unboundedPreceding = _JAVA_MIN_LONG + + unboundedFollowing = _JAVA_MAX_LONG + + currentRow = 0 + @staticmethod @since(1.4) def partitionBy(*cols): @@ -66,6 +76,66 @@ def orderBy(*cols): jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols)) return WindowSpec(jspec) + @staticmethod + @since(2.1) + def rowsBetween(start, end): + """ + Creates a :class:`WindowSpec` with the frame boundaries defined, + from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative positions from the current row. + For example, "0" means "current row", while "-1" means the row before + the current row, and "5" means the fifth row after the current row. + + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to -9223372036854775808. + :param end: boundary end, inclusive. + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to 9223372036854775807. + """ + if start <= Window._JAVA_MIN_LONG: + start = Window.unboundedPreceding + if end >= Window._JAVA_MAX_LONG: + end = Window.unboundedFollowing + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rowsBetween(start, end) + return WindowSpec(jspec) + + @staticmethod + @since(2.1) + def rangeBetween(start, end): + """ + Creates a :class:`WindowSpec` with the frame boundaries defined, + from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative from the current row. For example, + "0" means "current row", while "-1" means one off before the current row, + and "5" means the five off after the current row. + + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to -9223372036854775808. + :param end: boundary end, inclusive. + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to 9223372036854775807. + """ + if start <= Window._JAVA_MIN_LONG: + start = Window.unboundedPreceding + if end >= Window._JAVA_MAX_LONG: + end = Window.unboundedFollowing + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end) + return WindowSpec(jspec) + class WindowSpec(object): """ @@ -79,9 +149,6 @@ class WindowSpec(object): .. versionadded:: 1.4 """ - _JAVA_MAX_LONG = (1 << 63) - 1 - _JAVA_MIN_LONG = - (1 << 63) - def __init__(self, jspec): self._jspec = jspec @@ -112,15 +179,21 @@ def rowsBetween(self, start, end): For example, "0" means "current row", while "-1" means the row before the current row, and "5" means the fifth row after the current row. + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + :param start: boundary start, inclusive. - The frame is unbounded if this is ``-sys.maxsize`` (or lower). + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to -9223372036854775808. :param end: boundary end, inclusive. - The frame is unbounded if this is ``sys.maxsize`` (or higher). + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to 9223372036854775807. """ - if start <= -sys.maxsize: - start = self._JAVA_MIN_LONG - if end >= sys.maxsize: - end = self._JAVA_MAX_LONG + if start <= Window._JAVA_MIN_LONG: + start = Window.unboundedPreceding + if end >= Window._JAVA_MAX_LONG: + end = Window.unboundedFollowing return WindowSpec(self._jspec.rowsBetween(start, end)) @since(1.4) @@ -132,15 +205,21 @@ def rangeBetween(self, start, end): "0" means "current row", while "-1" means one off before the current row, and "5" means the five off after the current row. + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + :param start: boundary start, inclusive. - The frame is unbounded if this is ``-sys.maxsize`` (or lower). + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to -9223372036854775808. :param end: boundary end, inclusive. - The frame is unbounded if this is ``sys.maxsize`` (or higher). + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to 9223372036854775807. """ - if start <= -sys.maxsize: - start = self._JAVA_MIN_LONG - if end >= sys.maxsize: - end = self._JAVA_MAX_LONG + if start <= Window._JAVA_MIN_LONG: + start = Window.unboundedPreceding + if end >= Window._JAVA_MAX_LONG: + end = Window.unboundedFollowing return WindowSpec(self._jspec.rangeBetween(start, end)) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b0756911bfc10..3e0bd16d85ca4 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -914,6 +914,16 @@ def test_repartitionAndSortWithinPartitions(self): self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) + def test_repartition_no_skewed(self): + num_partitions = 20 + a = self.sc.parallelize(range(int(1000)), 2) + l = a.repartition(num_partitions).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + l = a.coalesce(num_partitions, True).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + def test_distinct(self): rdd = self.sc.parallelize((1, 2, 3)*10, 10) self.assertEqual(rdd.getNumPartitions(), 10) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 5dfe18ad49822..fec4d49379591 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -22,9 +22,9 @@ import java.io.File import scala.tools.nsc.GenericRunnerSettings import org.apache.spark._ -import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.Utils object Main extends Logging { diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index f7d7a4f041315..9262e938c2a60 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -24,8 +24,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.commons.lang3.StringEscapeUtils import org.apache.log4j.{Level, LogManager} import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.internal.config._ import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.Utils class ReplSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6a94def65f360..b599a884957a8 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -111,11 +111,12 @@ statement | ALTER TABLE tableIdentifier RECOVER PARTITIONS #recoverPartitions | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? #dropTable | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable - | CREATE (OR REPLACE)? TEMPORARY? VIEW (IF NOT EXISTS)? tableIdentifier + | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? + VIEW (IF NOT EXISTS)? tableIdentifier identifierCommentList? (COMMENT STRING)? (PARTITIONED ON identifierList)? (TBLPROPERTIES tablePropertyList)? AS query #createView - | CREATE (OR REPLACE)? TEMPORARY VIEW + | CREATE (OR REPLACE)? GLOBAL? TEMPORARY VIEW tableIdentifier ('(' colTypeList ')')? tableProvider (OPTIONS tablePropertyList)? #createTempViewUsing | ALTER VIEW tableIdentifier AS? query #alterViewQuery @@ -584,7 +585,7 @@ intervalValue dataType : complex=ARRAY '<' dataType '>' #complexDataType | complex=MAP '<' dataType ',' dataType '>' #complexDataType - | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType + | complex=STRUCT ('<' complexColTypeList? '>' | NEQ) #complexDataType | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType ; @@ -593,7 +594,15 @@ colTypeList ; colType - : identifier ':'? dataType (COMMENT STRING)? + : identifier dataType (COMMENT STRING)? + ; + +complexColTypeList + : complexColType (',' complexColType)* + ; + +complexColType + : identifier ':' dataType (COMMENT STRING)? ; whenClause @@ -668,7 +677,7 @@ nonReserved | MAP | ARRAY | STRUCT | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED - | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS + | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF @@ -856,6 +865,7 @@ CACHE: 'CACHE'; UNCACHE: 'UNCACHE'; LAZY: 'LAZY'; FORMATTED: 'FORMATTED'; +GLOBAL: 'GLOBAL'; TEMPORARY: 'TEMPORARY' | 'TEMP'; OPTIONS: 'OPTIONS'; UNSET: 'UNSET'; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java index 5ed60fe78d116..2ce1fdcbf56ae 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java @@ -17,16 +17,22 @@ package org.apache.spark.sql; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.expressions.GenericRow; /** * A factory class used to construct {@link Row} objects. + * + * @since 1.3.0 */ +@InterfaceStability.Stable public class RowFactory { /** * Create a {@link Row} from the given arguments. Position i in the argument list becomes * position i in the created {@link Row} object. + * + * @since 1.3.0 */ public static Row create(Object ... values) { return new GenericRow(values); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 9027652d57f14..c3f0abac244cf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -59,7 +59,7 @@ * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow extends MutableRow implements Externalizable, KryoSerializable { +public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable { ////////////////////////////////////////////////////////////////////////////// // Static methods diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 41e2582921198..49a18df2c72c0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming; import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.InternalOutputModes; /** @@ -29,6 +30,7 @@ * @since 2.0.0 */ @Experimental +@InterfaceStability.Evolving public class OutputMode { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 747ab1809fc0a..0f8570fe470bd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -19,10 +19,15 @@ import java.util.*; +import org.apache.spark.annotation.InterfaceStability; + /** * To get/create specific data type, users should use singleton objects and factory methods * provided by this class. + * + * @since 1.3.0 */ +@InterfaceStability.Stable public class DataTypes { /** * Gets the StringType object. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index 110ed460cc8fa..1290614a3207d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -20,6 +20,7 @@ import java.lang.annotation.*; import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.annotation.InterfaceStability; /** * ::DeveloperApi:: @@ -30,6 +31,7 @@ @DeveloperApi @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) +@InterfaceStability.Evolving public @interface SQLUserDefinedType { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 6911843999392..f3003306acc6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -17,17 +17,16 @@ package org.apache.spark.sql -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -// TODO: don't swallow original stack trace if it exists - /** - * :: DeveloperApi :: * Thrown when a query fails to analyze, usually because the query itself is invalid. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 501c1304dbedb..b9f8c46443021 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.annotation.implicitNotFound import scala.reflect.ClassTag -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.types._ @@ -67,6 +67,7 @@ import org.apache.spark.sql.types._ * @since 1.6.0 */ @Experimental +@InterfaceStability.Evolving @implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + "(Int, String, etc) and Product types (case classes) are supported by importing " + "spark.implicits._ Support for serializing other types will be added in future " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index e72f67c48a296..dc90659a676e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -22,7 +22,7 @@ import java.lang.reflect.Modifier import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast} @@ -36,6 +36,7 @@ import org.apache.spark.sql.types._ * @since 1.6.0 */ @Experimental +@InterfaceStability.Evolving object Encoders { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index e16850efbea5f..65f91429648c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,9 +20,14 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.util.hashing.MurmurHash3 +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable object Row { /** * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: @@ -117,8 +122,9 @@ object Row { * } * }}} * - * @group row + * @since 1.3.0 */ +@InterfaceStability.Stable trait Row extends Serializable { /** Number of elements in the Row. */ def size: Int = length @@ -351,7 +357,7 @@ trait Row extends Serializable { }.toMap } - override def toString(): String = s"[${this.mkString(",")}]" + override def toString: String = s"[${this.mkString(",")}]" /** * Make a copy of the current [[Row]] object. @@ -456,7 +462,7 @@ trait Row extends Serializable { def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) /** - * Returns the value of a given fieldName. + * Returns the value at position i. * * @throws UnsupportedOperationException when schema is not defined. * @throws ClassCastException when data type does not match. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index eba95c5c8b908..f498e071b50a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, Decimal, StructType} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -31,6 +31,27 @@ abstract class InternalRow extends SpecializedGetters with Serializable { // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString + def setNullAt(i: Int): Unit + + def update(i: Int, value: Any): Unit + + // default implementation (slow) + def setBoolean(i: Int, value: Boolean): Unit = update(i, value) + def setByte(i: Int, value: Byte): Unit = update(i, value) + def setShort(i: Int, value: Short): Unit = update(i, value) + def setInt(i: Int, value: Int): Unit = update(i, value) + def setLong(i: Int, value: Long): Unit = update(i, value) + def setFloat(i: Int, value: Float): Unit = update(i, value) + def setDouble(i: Int, value: Double): Unit = update(i, value) + + /** + * Update the decimal column at `i`. + * + * Note: In order to support update decimal with precision > 18 in UnsafeRow, + * CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision). + */ + def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } + /** * Make a copy of the current [[InternalRow]] object. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ae8869ff25f2d..536d38777f89d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -458,12 +458,12 @@ class Analyzer( i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) case u: UnresolvedRelation => val table = u.tableIdentifier - if (table.database.isDefined && conf.runSQLonFile && + if (table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) && (!catalog.databaseExists(table.database.get) || !catalog.tableExists(table))) { - // If the table does not exist, and the database part is specified, and we support - // running SQL directly on files, then let's just return the original UnresolvedRelation. - // It is possible we are matching a query like "select * from parquet.`/path/to/query`". - // The plan will get resolved later. + // If the database part is specified, and we support running SQL directly on files, and + // it's not a temporary view, and the table does not exist, then let's just return the + // original UnresolvedRelation. It is possible we are matching a query like "select * + // from parquet.`/path/to/query`". The plan will get resolved later. // Note that we are testing (!db_exists || !table_exists) because the catalog throws // an exception from tableExists if the database does not exist. u diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/GlobalTempViewManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/GlobalTempViewManager.scala new file mode 100644 index 0000000000000..6095ac0bc9c50 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/GlobalTempViewManager.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.StringUtils + + +/** + * A thread-safe manager for global temporary views, providing atomic operations to manage them, + * e.g. create, update, remove, etc. + * + * Note that, the view name is always case-sensitive here, callers are responsible to format the + * view name w.r.t. case-sensitive config. + * + * @param database The system preserved virtual database that keeps all the global temporary views. + */ +class GlobalTempViewManager(val database: String) { + + /** List of view definitions, mapping from view name to logical plan. */ + @GuardedBy("this") + private val viewDefinitions = new mutable.HashMap[String, LogicalPlan] + + /** + * Returns the global view definition which matches the given name, or None if not found. + */ + def get(name: String): Option[LogicalPlan] = synchronized { + viewDefinitions.get(name) + } + + /** + * Creates a global temp view, or issue an exception if the view already exists and + * `overrideIfExists` is false. + */ + def create( + name: String, + viewDefinition: LogicalPlan, + overrideIfExists: Boolean): Unit = synchronized { + if (!overrideIfExists && viewDefinitions.contains(name)) { + throw new TempTableAlreadyExistsException(name) + } + viewDefinitions.put(name, viewDefinition) + } + + /** + * Updates the global temp view if it exists, returns true if updated, false otherwise. + */ + def update( + name: String, + viewDefinition: LogicalPlan): Boolean = synchronized { + if (viewDefinitions.contains(name)) { + viewDefinitions.put(name, viewDefinition) + true + } else { + false + } + } + + /** + * Removes the global temp view if it exists, returns true if removed, false otherwise. + */ + def remove(name: String): Boolean = synchronized { + viewDefinitions.remove(name).isDefined + } + + /** + * Renames the global temp view if the source view exists and the destination view not exists, or + * issue an exception if the source view exists but the destination view already exists. Returns + * true if renamed, false otherwise. + */ + def rename(oldName: String, newName: String): Boolean = synchronized { + if (viewDefinitions.contains(oldName)) { + if (viewDefinitions.contains(newName)) { + throw new AnalysisException( + s"rename temporary view from '$oldName' to '$newName': destination view already exists") + } + + val viewDefinition = viewDefinitions(oldName) + viewDefinitions.remove(oldName) + viewDefinitions.put(newName, viewDefinition) + true + } else { + false + } + } + + /** + * Lists the names of all global temporary views. + */ + def listViewNames(pattern: String): Seq[String] = synchronized { + StringUtils.filterPattern(viewDefinitions.keys.toSeq, pattern) + } + + /** + * Clears all the global temporary views. + */ + def clear(): Unit = synchronized { + viewDefinitions.clear() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 8c01c7a3f2bd5..fe41c41a6eb20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -47,6 +47,7 @@ object SessionCatalog { */ class SessionCatalog( externalCatalog: ExternalCatalog, + globalTempViewManager: GlobalTempViewManager, functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, conf: CatalystConf, @@ -61,6 +62,7 @@ class SessionCatalog( conf: CatalystConf) { this( externalCatalog, + new GlobalTempViewManager("global_temp"), DummyFunctionResourceLoader, functionRegistry, conf, @@ -142,8 +144,13 @@ class SessionCatalog( // ---------------------------------------------------------------------------- def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { - val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString val dbName = formatDatabaseName(dbDefinition.name) + if (dbName == globalTempViewManager.database) { + throw new AnalysisException( + s"${globalTempViewManager.database} is a system preserved database, " + + "you cannot create a database with this name.") + } + val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString externalCatalog.createDatabase( dbDefinition.copy(name = dbName, locationUri = qualifiedPath), ignoreIfExists) @@ -154,7 +161,7 @@ class SessionCatalog( if (dbName == DEFAULT_DATABASE) { throw new AnalysisException(s"Can not drop default database") } else if (dbName == getCurrentDatabase) { - throw new AnalysisException(s"Can not drop current database `${dbName}`") + throw new AnalysisException(s"Can not drop current database `$dbName`") } externalCatalog.dropDatabase(dbName, ignoreIfNotExists, cascade) } @@ -188,6 +195,13 @@ class SessionCatalog( def setCurrentDatabase(db: String): Unit = { val dbName = formatDatabaseName(db) + if (dbName == globalTempViewManager.database) { + throw new AnalysisException( + s"${globalTempViewManager.database} is a system preserved database, " + + "you cannot use it as current database. To access global temporary views, you should " + + "use qualified name with the GLOBAL_TEMP_DATABASE, e.g. SELECT * FROM " + + s"${globalTempViewManager.database}.viewName.") + } requireDbExists(dbName) synchronized { currentDb = dbName } } @@ -329,7 +343,7 @@ class SessionCatalog( // ---------------------------------------------- /** - * Create a temporary table. + * Create a local temporary view. */ def createTempView( name: String, @@ -343,17 +357,67 @@ class SessionCatalog( } /** - * Return a temporary view exactly as it was stored. + * Create a global temporary view. + */ + def createGlobalTempView( + name: String, + viewDefinition: LogicalPlan, + overrideIfExists: Boolean): Unit = { + globalTempViewManager.create(formatTableName(name), viewDefinition, overrideIfExists) + } + + /** + * Alter the definition of a local/global temp view matching the given name, returns true if a + * temp view is matched and altered, false otherwise. + */ + def alterTempViewDefinition( + name: TableIdentifier, + viewDefinition: LogicalPlan): Boolean = synchronized { + val viewName = formatTableName(name.table) + if (name.database.isEmpty) { + if (tempTables.contains(viewName)) { + createTempView(viewName, viewDefinition, overrideIfExists = true) + true + } else { + false + } + } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { + globalTempViewManager.update(viewName, viewDefinition) + } else { + false + } + } + + /** + * Return a local temporary view exactly as it was stored. */ def getTempView(name: String): Option[LogicalPlan] = synchronized { tempTables.get(formatTableName(name)) } /** - * Drop a temporary view. + * Return a global temporary view exactly as it was stored. + */ + def getGlobalTempView(name: String): Option[LogicalPlan] = { + globalTempViewManager.get(formatTableName(name)) + } + + /** + * Drop a local temporary view. + * + * Returns true if this view is dropped successfully, false otherwise. + */ + def dropTempView(name: String): Boolean = synchronized { + tempTables.remove(formatTableName(name)).isDefined + } + + /** + * Drop a global temporary view. + * + * Returns true if this view is dropped successfully, false otherwise. */ - def dropTempView(name: String): Unit = synchronized { - tempTables.remove(formatTableName(name)) + def dropGlobalTempView(name: String): Boolean = { + globalTempViewManager.remove(formatTableName(name)) } // ------------------------------------------------------------- @@ -371,9 +435,7 @@ class SessionCatalog( */ def getTempViewOrPermanentTableMetadata(name: TableIdentifier): CatalogTable = synchronized { val table = formatTableName(name.table) - if (name.database.isDefined) { - getTableMetadata(name) - } else { + if (name.database.isEmpty) { getTempView(table).map { plan => CatalogTable( identifier = TableIdentifier(table), @@ -381,6 +443,16 @@ class SessionCatalog( storage = CatalogStorageFormat.empty, schema = plan.output.toStructType) }.getOrElse(getTableMetadata(name)) + } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { + globalTempViewManager.get(table).map { plan => + CatalogTable( + identifier = TableIdentifier(table, Some(globalTempViewManager.database)), + tableType = CatalogTableType.VIEW, + storage = CatalogStorageFormat.empty, + schema = plan.output.toStructType) + }.getOrElse(throw new NoSuchTableException(globalTempViewManager.database, table)) + } else { + getTableMetadata(name) } } @@ -393,21 +465,25 @@ class SessionCatalog( */ def renameTable(oldName: TableIdentifier, newName: String): Unit = synchronized { val db = formatDatabaseName(oldName.database.getOrElse(currentDb)) - requireDbExists(db) val oldTableName = formatTableName(oldName.table) val newTableName = formatTableName(newName) - if (oldName.database.isDefined || !tempTables.contains(oldTableName)) { - requireTableExists(TableIdentifier(oldTableName, Some(db))) - requireTableNotExists(TableIdentifier(newTableName, Some(db))) - externalCatalog.renameTable(db, oldTableName, newTableName) + if (db == globalTempViewManager.database) { + globalTempViewManager.rename(oldTableName, newTableName) } else { - if (tempTables.contains(newTableName)) { - throw new AnalysisException( - s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': destination table already exists") + requireDbExists(db) + if (oldName.database.isDefined || !tempTables.contains(oldTableName)) { + requireTableExists(TableIdentifier(oldTableName, Some(db))) + requireTableNotExists(TableIdentifier(newTableName, Some(db))) + externalCatalog.renameTable(db, oldTableName, newTableName) + } else { + if (tempTables.contains(newTableName)) { + throw new AnalysisException(s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': " + + "destination table already exists") + } + val table = tempTables(oldTableName) + tempTables.remove(oldTableName) + tempTables.put(newTableName, table) } - val table = tempTables(oldTableName) - tempTables.remove(oldTableName) - tempTables.put(newTableName, table) } } @@ -424,17 +500,24 @@ class SessionCatalog( purge: Boolean): Unit = synchronized { val db = formatDatabaseName(name.database.getOrElse(currentDb)) val table = formatTableName(name.table) - if (name.database.isDefined || !tempTables.contains(table)) { - requireDbExists(db) - // When ignoreIfNotExists is false, no exception is issued when the table does not exist. - // Instead, log it as an error message. - if (tableExists(TableIdentifier(table, Option(db)))) { - externalCatalog.dropTable(db, table, ignoreIfNotExists = true, purge = purge) - } else if (!ignoreIfNotExists) { - throw new NoSuchTableException(db = db, table = table) + if (db == globalTempViewManager.database) { + val viewExists = globalTempViewManager.remove(table) + if (!viewExists && !ignoreIfNotExists) { + throw new NoSuchTableException(globalTempViewManager.database, table) } } else { - tempTables.remove(table) + if (name.database.isDefined || !tempTables.contains(table)) { + requireDbExists(db) + // When ignoreIfNotExists is false, no exception is issued when the table does not exist. + // Instead, log it as an error message. + if (tableExists(TableIdentifier(table, Option(db)))) { + externalCatalog.dropTable(db, table, ignoreIfNotExists = true, purge = purge) + } else if (!ignoreIfNotExists) { + throw new NoSuchTableException(db = db, table = table) + } + } else { + tempTables.remove(table) + } } } @@ -445,6 +528,9 @@ class SessionCatalog( * If no database is specified, this will first attempt to return a temporary table/view with * the same name, then, if that does not exist, return the table/view from the current database. * + * Note that, the global temp view database is also valid here, this will return the global temp + * view matching the given name. + * * If the relation is a view, the relation will be wrapped in a [[SubqueryAlias]] which will * track the name of the view. */ @@ -453,7 +539,11 @@ class SessionCatalog( val db = formatDatabaseName(name.database.getOrElse(currentDb)) val table = formatTableName(name.table) val relationAlias = alias.getOrElse(table) - if (name.database.isDefined || !tempTables.contains(table)) { + if (db == globalTempViewManager.database) { + globalTempViewManager.get(table).map { viewDef => + SubqueryAlias(relationAlias, viewDef, Some(name)) + }.getOrElse(throw new NoSuchTableException(db, table)) + } else if (name.database.isDefined || !tempTables.contains(table)) { val metadata = externalCatalog.getTable(db, table) val view = Option(metadata.tableType).collect { case CatalogTableType.VIEW => name @@ -472,27 +562,48 @@ class SessionCatalog( * explicitly specified. */ def isTemporaryTable(name: TableIdentifier): Boolean = synchronized { - name.database.isEmpty && tempTables.contains(formatTableName(name.table)) + val table = formatTableName(name.table) + if (name.database.isEmpty) { + tempTables.contains(table) + } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { + globalTempViewManager.get(table).isDefined + } else { + false + } } /** - * List all tables in the specified database, including temporary tables. + * List all tables in the specified database, including local temporary tables. + * + * Note that, if the specified database is global temporary view database, we will list global + * temporary views. */ def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*") /** - * List all matching tables in the specified database, including temporary tables. + * List all matching tables in the specified database, including local temporary tables. + * + * Note that, if the specified database is global temporary view database, we will list global + * temporary views. */ def listTables(db: String, pattern: String): Seq[TableIdentifier] = { val dbName = formatDatabaseName(db) - requireDbExists(dbName) - val dbTables = - externalCatalog.listTables(dbName, pattern).map { t => TableIdentifier(t, Some(dbName)) } - synchronized { - val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern) - .map { t => TableIdentifier(t) } - dbTables ++ _tempTables + val dbTables = if (dbName == globalTempViewManager.database) { + globalTempViewManager.listViewNames(pattern).map { name => + TableIdentifier(name, Some(globalTempViewManager.database)) + } + } else { + requireDbExists(dbName) + externalCatalog.listTables(dbName, pattern).map { name => + TableIdentifier(name, Some(dbName)) + } + } + val localTempViews = synchronized { + StringUtils.filterPattern(tempTables.keys.toSeq, pattern).map { name => + TableIdentifier(name) + } } + dbTables ++ localTempViews } /** @@ -504,6 +615,8 @@ class SessionCatalog( // If the database is not defined, there is a good chance this is a temp table. if (name.database.isEmpty) { tempTables.get(formatTableName(name.table)).foreach(_.refresh()) + } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { + globalTempViewManager.get(formatTableName(name.table)).foreach(_.refresh()) } } @@ -919,6 +1032,7 @@ class SessionCatalog( } } tempTables.clear() + globalTempViewManager.clear() functionRegistry.clear() // restore built-in functions FunctionRegistry.builtin.listFunction().foreach { f => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b96b744b4fa98..82e1a8a7cad96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -256,7 +256,7 @@ case class ExpressionEncoder[T]( private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) @transient - private lazy val inputRow = new GenericMutableRow(1) + private lazy val inputRow = new GenericInternalRow(1) @transient private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 07ba7d5e4a849..e876450c73fde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -62,6 +62,13 @@ object Canonicalize extends { case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) + case o: Or => + orderCommutative(o, { case Or(l, r) if l.deterministic && r.deterministic => Seq(l, r) }) + .reduce(Or) + case a: And => + orderCommutative(a, { case And(l, r) if l.deterministic && r.deterministic => Seq(l, r)}) + .reduce(And) + case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 70fff51956255..58fd65f62ffe7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -403,7 +403,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? - val newRow = new GenericMutableRow(from.fields.length) + val newRow = new GenericInternalRow(from.fields.length) buildCast[InternalRow](_, row => { var i = 0 while (i < row.numFields) { @@ -657,7 +657,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"$evPrim = CalendarInterval.fromString($c.toString());" + s"""$evPrim = CalendarInterval.fromString($c.toString()); + if(${evPrim} == null) { + ${evNull} = true; + } + """.stripMargin + } private[this] def decimalToTimestampCode(d: String): String = @@ -892,7 +897,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val fieldsCasts = from.fields.zip(to.fields).map { case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } - val rowClass = classOf[GenericMutableRow].getName + val rowClass = classOf[GenericInternalRow].getName val result = ctx.freshName("result") val tmpRow = ctx.freshName("tmpRow") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index ed894f6d6e10e..7770684a5b399 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -123,6 +123,22 @@ class JoinedRow extends InternalRow { override def anyNull: Boolean = row1.anyNull || row2.anyNull + override def setNullAt(i: Int): Unit = { + if (i < row1.numFields) { + row1.setNullAt(i) + } else { + row2.setNullAt(i - row1.numFields) + } + } + + override def update(i: Int, value: Any): Unit = { + if (i < row1.numFields) { + row1.update(i, value) + } else { + row2.update(i - row1.numFields, value) + } + } + override def copy(): InternalRow = { val copy1 = row1.copy() val copy2 = row2.copy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index c8d18667f7c4a..a81fa1ce3adcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -69,10 +69,10 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu }) private[this] val exprArray = expressions.toArray - private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length) + private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length) def currentValue: InternalRow = mutableRow - override def target(row: MutableRow): MutableProjection = { + override def target(row: InternalRow): MutableProjection = { mutableRow = row this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala index 61ca7272dfa61..74e0b4691d4cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.types._ /** * A parent class for mutable container objects that are reused when the values are changed, - * resulting in less garbage. These values are held by a [[SpecificMutableRow]]. + * resulting in less garbage. These values are held by a [[SpecificInternalRow]]. * * The following code was roughly used to generate these objects: * {{{ @@ -191,8 +191,7 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) - extends MutableRow with BaseGenericInternalRow { +final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGenericInternalRow { def this(dataTypes: Seq[DataType]) = this( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 1d218da6db806..83c8d400c5d6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -155,7 +155,7 @@ case class HyperLogLogPlusPlus( aggBufferAttributes.map(_.newInstance()) /** Fill all words with zeros. */ - override def initialize(buffer: MutableRow): Unit = { + override def initialize(buffer: InternalRow): Unit = { var word = 0 while (word < numWords) { buffer.setLong(mutableAggBufferOffset + word, 0) @@ -168,7 +168,7 @@ case class HyperLogLogPlusPlus( * * Variable names in the HLL++ paper match variable names in the code. */ - override def update(buffer: MutableRow, input: InternalRow): Unit = { + override def update(buffer: InternalRow, input: InternalRow): Unit = { val v = child.eval(input) if (v != null) { // Create the hashed value 'x'. @@ -200,7 +200,7 @@ case class HyperLogLogPlusPlus( * Merge the HLL buffers by iterating through the registers in both buffers and select the * maximum number of leading zeros for each register. */ - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { var idx = 0 var wordOffset = 0 while (wordOffset < numWords) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 16c03c500ad08..087606077295f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -30,7 +30,7 @@ object PivotFirst { // Currently UnsafeRow does not support the generic update method (throws // UnsupportedOperationException), so we need to explicitly support each DataType. - private val updateFunction: PartialFunction[DataType, (MutableRow, Int, Any) => Unit] = { + private val updateFunction: PartialFunction[DataType, (InternalRow, Int, Any) => Unit] = { case DoubleType => (row, offset, value) => row.setDouble(offset, value.asInstanceOf[Double]) case IntegerType => @@ -89,9 +89,9 @@ case class PivotFirst( val indexSize = pivotIndex.size - private val updateRow: (MutableRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType) + private val updateRow: (InternalRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType) - override def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit = { + override def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit = { val pivotColValue = pivotColumn.eval(inputRow) if (pivotColValue != null) { // We ignore rows whose pivot column value is not in the list of pivot column values. @@ -105,7 +105,7 @@ case class PivotFirst( } } - override def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit = { + override def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit = { for (i <- 0 until indexSize) { if (!inputAggBuffer.isNullAt(inputAggBufferOffset + i)) { val value = inputAggBuffer.get(inputAggBufferOffset + i, valueDataType) @@ -114,7 +114,7 @@ case class PivotFirst( } } - override def initialize(mutableAggBuffer: MutableRow): Unit = valueDataType match { + override def initialize(mutableAggBuffer: InternalRow): Unit = valueDataType match { case d: DecimalType => // Per doc of setDecimal we need to do this instead of setNullAt for DecimalType. for (i <- 0 until indexSize) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 78a388d20630b..89eb864e94702 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -60,11 +60,11 @@ abstract class Collect extends ImperativeAggregate { protected[this] val buffer: Growable[Any] with Iterable[Any] - override def initialize(b: MutableRow): Unit = { + override def initialize(b: InternalRow): Unit = { buffer.clear() } - override def update(b: MutableRow, input: InternalRow): Unit = { + override def update(b: InternalRow, input: InternalRow): Unit = { // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator val value = child.eval(input) @@ -73,7 +73,7 @@ abstract class Collect extends ImperativeAggregate { } } - override def merge(buffer: MutableRow, input: InternalRow): Unit = { + override def merge(buffer: InternalRow, input: InternalRow): Unit = { sys.error("Collect cannot be used in partial aggregations.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index b5c0844fbf310..f3fd58bc98ef6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -307,14 +307,14 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. */ - def initialize(mutableAggBuffer: MutableRow): Unit + def initialize(mutableAggBuffer: InternalRow): Unit /** * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`. * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. */ - def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit + def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit /** * Combines new intermediate results from the `inputAggBuffer` with the existing intermediate @@ -323,7 +323,7 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. */ - def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit + def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit } /** @@ -504,16 +504,16 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { /** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */ def deserialize(storageFormat: Array[Byte]): T - final override def initialize(buffer: MutableRow): Unit = { + final override def initialize(buffer: InternalRow): Unit = { val bufferObject = createAggregationBuffer() buffer.update(mutableAggBufferOffset, bufferObject) } - final override def update(buffer: MutableRow, input: InternalRow): Unit = { + final override def update(buffer: InternalRow, input: InternalRow): Unit = { update(getBufferObject(buffer), input) } - final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { + final override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = { val bufferObject = getBufferObject(buffer) // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset)) @@ -547,7 +547,7 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { * This is only called when doing Partial or PartialMerge mode aggregation, before the framework * shuffle out aggregate buffers. */ - final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = { + final def serializeAggregateBufferInPlace(buffer: InternalRow): Unit = { buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 574943d3d21f0..6cab50ae1bf8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -819,7 +819,7 @@ class CodeAndComment(val body: String, val comment: collection.Map[String, Strin */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - protected val genericMutableRowType: String = classOf[GenericMutableRow].getName + protected val genericMutableRowType: String = classOf[GenericInternalRow].getName /** * Generates a class for a given input expression. Called when there is not cached code @@ -889,7 +889,6 @@ object CodeGenerator extends Logging { classOf[UnsafeArrayData].getName, classOf[MapData].getName, classOf[UnsafeMapData].getName, - classOf[MutableRow].getName, classOf[Expression].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 13d61af1c9b40..5c4b56b0b224c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -24,10 +24,10 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp abstract class BaseMutableProjection extends MutableProjection /** - * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new + * Generates byte code that produces a [[InternalRow]] object that can update itself based on a new * input [[InternalRow]] for a fixed set of [[Expression Expressions]]. * It exposes a `target` method, which is used to set the row that will be updated. - * The internal [[MutableRow]] object created internally is used only when `target` is not used. + * The internal [[InternalRow]] object created internally is used only when `target` is not used. */ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableProjection] { @@ -102,7 +102,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} { private Object[] references; - private MutableRow mutableRow; + private InternalRow mutableRow; ${ctx.declareMutableStates()} public SpecificMutableProjection(Object[] references) { @@ -113,7 +113,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ${ctx.declareAddedFunctions()} - public ${classOf[BaseMutableProjection].getName} target(MutableRow row) { + public ${classOf[BaseMutableProjection].getName} target(InternalRow row) { mutableRow = row; return this; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 1c98c9ed10705..2773e1a666212 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._ abstract class BaseProjection extends Projection {} /** - * Generates byte code that produces a [[MutableRow]] object (not an [[UnsafeRow]]) that can update + * Generates byte code that produces a [[InternalRow]] object (not an [[UnsafeRow]]) that can update * itself based on a new input [[InternalRow]] for a fixed set of [[Expression Expressions]]. */ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] { @@ -164,12 +164,12 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] class SpecificSafeProjection extends ${classOf[BaseProjection].getName} { private Object[] references; - private MutableRow mutableRow; + private InternalRow mutableRow; ${ctx.declareMutableStates()} public SpecificSafeProjection(Object[] references) { this.references = references; - mutableRow = (MutableRow) references[references.length - 1]; + mutableRow = (InternalRow) references[references.length - 1]; ${ctx.initMutableStates()} } @@ -188,7 +188,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = CodeGenerator.compile(code) - val resultRow = new SpecificMutableRow(expressions.map(_.dataType)) + val resultRow = new SpecificInternalRow(expressions.map(_.dataType)) c.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index a6125c61e508a..1510a4796683c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -81,7 +81,7 @@ package object expressions { def currentValue: InternalRow /** Uses the given row to store the output of the projection. */ - def target(row: MutableRow): MutableProjection + def target(row: InternalRow): MutableProjection } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 73dceb35ac50e..751b821e1b009 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -157,33 +157,6 @@ trait BaseGenericInternalRow extends InternalRow { } } -/** - * An extended interface to [[InternalRow]] that allows the values for each column to be updated. - * Setting a value through a primitive function implicitly marks that column as not null. - */ -abstract class MutableRow extends InternalRow { - def setNullAt(i: Int): Unit - - def update(i: Int, value: Any): Unit - - // default implementation (slow) - def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } - def setByte(i: Int, value: Byte): Unit = { update(i, value) } - def setShort(i: Int, value: Short): Unit = { update(i, value) } - def setInt(i: Int, value: Int): Unit = { update(i, value) } - def setLong(i: Int, value: Long): Unit = { update(i, value) } - def setFloat(i: Int, value: Float): Unit = { update(i, value) } - def setDouble(i: Int, value: Double): Unit = { update(i, value) } - - /** - * Update the decimal column at `i`. - * - * Note: In order to support update decimal with precision > 18 in UnsafeRow, - * CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision). - */ - def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } -} - /** * A row implementation that uses an array of objects as the underlying storage. Note that, while * the array is not copied, and thus could technically be mutated after creation, this is not @@ -230,24 +203,9 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow override def numFields: Int = values.length - override def copy(): GenericInternalRow = this -} - -class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { - /** No-arg constructor for serialization. */ - protected def this() = this(null) - - def this(size: Int) = this(new Array[Any](size)) - - override protected def genericGet(ordinal: Int) = values(ordinal) - - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values - - override def numFields: Int = values.length - override def setNullAt(i: Int): Unit = { values(i) = null} override def update(i: Int, value: Any): Unit = { values(i) = value } - override def copy(): InternalRow = new GenericInternalRow(values.clone()) + override def copy(): GenericInternalRow = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index d7b48ceca591a..834897b85023d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst - /** * An identifier that optionally specifies a database. * @@ -29,8 +28,16 @@ sealed trait IdentifierWithDatabase { def database: Option[String] + /* + * Escapes back-ticks within the identifier name with double-back-ticks. + */ + private def quoteIdentifier(name: String): String = name.replace("`", "``") + def quotedString: String = { - if (database.isDefined) s"`${database.get}`.`$identifier`" else s"`$identifier`" + val replacedId = quoteIdentifier(identifier) + val replacedDb = database.map(quoteIdentifier(_)) + + if (replacedDb.isDefined) s"`${replacedDb.get}`.`$replacedId`" else s"`$replacedId`" } def unquotedString: String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index f80e6373d2f89..e476cb11a3517 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -105,7 +105,7 @@ class JacksonParser( } emptyRow } else { - val row = new GenericMutableRow(schema.length) + val row = new GenericInternalRow(schema.length) for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecord)) { require(schema(corruptIndex).dataType == StringType) row.update(corruptIndex, UTF8String.fromString(record)) @@ -363,7 +363,7 @@ class JacksonParser( parser: JsonParser, schema: StructType, fieldConverters: Seq[ValueConverter]): InternalRow = { - val row = new GenericMutableRow(schema.length) + val row = new GenericInternalRow(schema.length) while (nextUntil(parser, JsonToken.END_OBJECT)) { schema.getFieldIndex(parser.getCurrentName) match { case Some(index) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bf3f30279a6fe..929c1c4f2d9e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -316,7 +316,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Create the attributes. val (attributes, schemaLess) = if (colTypeList != null) { // Typed return columns. - (createStructType(colTypeList).toAttributes, false) + (createSchema(colTypeList).toAttributes, false) } else if (identifierSeq != null) { // Untyped return columns. val attrs = visitIdentifierSeq(identifierSeq).map { name => @@ -1450,14 +1450,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case SqlBaseParser.MAP => MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) case SqlBaseParser.STRUCT => - createStructType(ctx.colTypeList()) + createStructType(ctx.complexColTypeList()) } } /** - * Create a [[StructType]] from a sequence of [[StructField]]s. + * Create top level table schema. */ - protected def createStructType(ctx: ColTypeListContext): StructType = { + protected def createSchema(ctx: ColTypeListContext): StructType = { StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) } @@ -1476,4 +1476,28 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val structField = StructField(identifier.getText, typedVisit(dataType), nullable = true) if (STRING == null) structField else structField.withComment(string(STRING)) } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ComplexColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitComplexColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitComplexColTypeList( + ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.complexColType().asScala.map(visitComplexColType) + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + val structField = StructField(identifier.getText, typedVisit(dataType), nullable = true) + if (STRING == null) structField else structField.withComment(string(STRING)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 1981fd8f0a1b5..76dbb7cf0aec1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression /** @@ -131,10 +131,11 @@ protected[sql] abstract class AtomicType extends DataType { /** - * :: DeveloperApi :: * Numeric data types. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 82a03b0afc002..5d70ef01373f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -21,9 +21,15 @@ import scala.math.Ordering import org.json4s.JsonDSL._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.util.ArrayData +/** + * Companion object for ArrayType. + * + * @since 1.3.0 + */ +@InterfaceStability.Stable object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) @@ -37,9 +43,7 @@ object ArrayType extends AbstractDataType { override private[sql] def simpleString: String = "array" } - /** - * :: DeveloperApi :: * The data type for collections of multiple values. * Internally these are represented as columns that contain a ``scala.collection.Seq``. * @@ -51,8 +55,10 @@ object ArrayType extends AbstractDataType { * * @param elementType The data type of values. * @param containsNull Indicates if values have `null` values + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { /** No-arg constructor for kryo. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index c40e140e8c5c6..a4a358a242c70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -20,17 +20,16 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.util.TypeUtils /** - * :: DeveloperApi :: * The data type representing `Array[Byte]` values. * Please use the singleton [[DataTypes.BinaryType]]. */ -@DeveloperApi +@InterfaceStability.Stable class BinaryType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. @@ -54,5 +53,8 @@ class BinaryType private() extends AtomicType { private[spark] override def asNullable: BinaryType = this } - +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object BinaryType extends BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala index 2d8ee3d9bc286..059f89f9cda32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -20,15 +20,16 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class BooleanType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. @@ -45,5 +46,8 @@ class BooleanType private() extends AtomicType { private[spark] override def asNullable: BooleanType = this } - +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object BooleanType extends BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index d37130e27ba5a..bc6251f024e58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class ByteType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. @@ -48,4 +49,9 @@ class ByteType private() extends IntegralType { private[spark] override def asNullable: ByteType = this } + +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object ByteType extends ByteType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index 3565f52c21f69..e121044288e5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -17,19 +17,19 @@ package org.apache.spark.sql.types -import org.apache.spark.annotation.DeveloperApi - +import org.apache.spark.annotation.InterfaceStability /** - * :: DeveloperApi :: * The data type representing calendar time intervals. The calendar time interval is stored * internally in two components: number of months the number of microseconds. * * Note that calendar intervals are not comparable. * * Please use the singleton [[DataTypes.CalendarIntervalType]]. + * + * @since 1.5.0 */ -@DeveloperApi +@InterfaceStability.Stable class CalendarIntervalType private() extends DataType { override def defaultSize: Int = 16 @@ -37,4 +37,8 @@ class CalendarIntervalType private() extends DataType { private[spark] override def asNullable: CalendarIntervalType = this } +/** + * @since 1.5.0 + */ +@InterfaceStability.Stable case object CalendarIntervalType extends CalendarIntervalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4fc65cbce15bd..312585df1516b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -22,15 +22,16 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * The base type of all Spark SQL data types. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable abstract class DataType extends AbstractDataType { /** * Enables matching against DataType for expressions: @@ -94,6 +95,10 @@ abstract class DataType extends AbstractDataType { } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala index 2c966230e447e..8d0ecc051f4ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -20,19 +20,20 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * A date type, supporting "0001-01-01" through "9999-12-31". * * Please use the singleton [[DataTypes.DateType]]. * * Internally, this is represented as the number of days from 1970-01-01. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class DateType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DateType$" in byte code. @@ -51,5 +52,8 @@ class DateType private() extends AtomicType { private[spark] override def asNullable: DateType = this } - +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object DateType extends DateType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 70859052872dd..465fb83669a76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import java.lang.{Long => JLong} import java.math.{BigInteger, MathContext, RoundingMode} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability /** * A mutable implementation of BigDecimal that can hold a Long if values are small enough. @@ -30,6 +30,7 @@ import org.apache.spark.annotation.DeveloperApi * - If decimalVal is set, it represents the whole decimal value * - Otherwise, the decimal value is longVal / (10 ** _scale) */ +@InterfaceStability.Unstable final class Decimal extends Ordered[Decimal] with Serializable { import org.apache.spark.sql.types.Decimal._ @@ -185,7 +186,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { override def toString: String = toBigDecimal.toString() - @DeveloperApi def toDebugString: String = { if (decimalVal.ne(null)) { s"Decimal(expanded,$decimalVal,$precision,$scale})" @@ -380,6 +380,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } +@InterfaceStability.Unstable object Decimal { val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 6500875f95e54..d7ca0cbeedcd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression /** - * :: DeveloperApi :: * The data type representing `java.math.BigDecimal` values. * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number * of digits on right side of dot). @@ -36,8 +35,10 @@ import org.apache.spark.sql.catalyst.expressions.Expression * The default precision and scale is (10, 0). * * Please use [[DataTypes.createDecimalType()]] to create a specific instance. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable case class DecimalType(precision: Int, scale: Int) extends FractionalType { if (scale > precision) { @@ -101,7 +102,12 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } -/** Extra factory methods and pattern matchers for Decimals */ +/** + * Extra factory methods and pattern matchers for Decimals. + * + * @since 1.3.0 + */ +@InterfaceStability.Stable object DecimalType extends AbstractDataType { import scala.math.min diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index e553f65f3c99d..c21ac0e43eee0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -21,15 +21,16 @@ import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.DoubleAsIfIntegral import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class DoubleType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. @@ -51,4 +52,8 @@ class DoubleType private() extends FractionalType { private[spark] override def asNullable: DoubleType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object DoubleType extends DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index ae9aa9eefaf2a..c5bf8883bad93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -21,15 +21,16 @@ import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.FloatAsIfIntegral import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class FloatType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. @@ -51,4 +52,9 @@ class FloatType private() extends FractionalType { private[spark] override def asNullable: FloatType = this } + +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object FloatType extends FloatType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index 38a7b8ee52651..724e59c0bcbf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -20,15 +20,16 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class IntegerType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. @@ -49,4 +50,8 @@ class IntegerType private() extends IntegralType { private[spark] override def asNullable: IntegerType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object IntegerType extends IntegerType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 88aff0c87755c..42285a9d0aa29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class LongType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "LongType$" in byte code. @@ -48,5 +49,8 @@ class LongType private() extends IntegralType { private[spark] override def asNullable: LongType = this } - +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object LongType extends LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 178960929bd83..3a32aa43d1c3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability /** - * :: DeveloperApi :: * The data type for Maps. Keys in a map are not allowed to have `null` values. * * Please use [[DataTypes.createMapType()]] to create a specific instance. @@ -32,7 +31,7 @@ import org.apache.spark.annotation.DeveloperApi * @param valueType The data type of map values. * @param valueContainsNull Indicates if map values have `null` values. */ -@DeveloperApi +@InterfaceStability.Stable case class MapType( keyType: DataType, valueType: DataType, @@ -76,7 +75,10 @@ case class MapType( } } - +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable object MapType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 657bd86ce17d9..3aa4bf619f274 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -22,22 +22,22 @@ import scala.collection.mutable import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability /** - * :: DeveloperApi :: - * * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and * Array[Metadata]. JSON is used for serialization. * * The default constructor is private. User should use either [[MetadataBuilder]] or - * [[Metadata.fromJson()]] to create Metadata instances. + * `Metadata.fromJson()` to create Metadata instances. * * @param map an immutable map that stores the data + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable sealed class Metadata private[types] (private[types] val map: Map[String, Any]) extends Serializable { @@ -114,6 +114,10 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any]) private[sql] def jsonValue: JValue = Metadata.toJsonValue(this) } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable object Metadata { private[this] val _empty = new Metadata(Map.empty) @@ -218,11 +222,11 @@ object Metadata { } /** - * :: DeveloperApi :: - * * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class MetadataBuilder { private val map: mutable.Map[String, Any] = mutable.Map.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala index aa84115c2e42c..bdf9a819d007b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.types -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability /** - * :: DeveloperApi :: * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class NullType private() extends DataType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "NullType$" in byte code. @@ -34,4 +35,8 @@ class NullType private() extends DataType { private[spark] override def asNullable: NullType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object NullType extends NullType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index 486cf585284df..3fee299d578cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class ShortType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. @@ -48,4 +49,8 @@ class ShortType private() extends IntegralType { private[spark] override def asNullable: ShortType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object ShortType extends ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index 44a25361f31c4..5d5a6f52a305b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -20,15 +20,16 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.unsafe.types.UTF8String /** - * :: DeveloperApi :: * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class StringType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. @@ -45,5 +46,9 @@ class StringType private() extends AtomicType { private[spark] override def asNullable: StringType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object StringType extends StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala index cb8bf616968e5..2c18fdcc497fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ +import org.apache.spark.annotation.InterfaceStability + /** * A field inside a StructType. * @param name The name of this field. @@ -27,7 +29,10 @@ import org.json4s.JsonDSL._ * @param nullable Indicates if values of this field can be `null` values. * @param metadata The metadata of this field. The metadata should be preserved during * transformation if the content of the column is not modified, e.g, in selection. + * + * @since 1.3.0 */ +@InterfaceStability.Stable case class StructField( name: String, dataType: DataType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index dd4c88c4c43bc..0205c13aa986d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -23,14 +23,13 @@ import scala.util.Try import org.json4s.JsonDSL._ import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * A [[StructType]] object can be constructed by * {{{ * StructType(fields: Seq[StructField]) @@ -90,8 +89,10 @@ import org.apache.spark.util.Utils * val row = Row(Row(1, 2, true)) * // row: Row = [[1,2,true]] * }}} + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { /** No-arg constructor for kryo. */ @@ -138,7 +139,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * .add("c", StringType) */ def add(name: String, dataType: DataType): StructType = { - StructType(fields :+ new StructField(name, dataType, nullable = true, Metadata.empty)) + StructType(fields :+ StructField(name, dataType, nullable = true, Metadata.empty)) } /** @@ -150,7 +151,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * .add("c", StringType, true) */ def add(name: String, dataType: DataType, nullable: Boolean): StructType = { - StructType(fields :+ new StructField(name, dataType, nullable, Metadata.empty)) + StructType(fields :+ StructField(name, dataType, nullable, Metadata.empty)) } /** @@ -167,7 +168,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru dataType: DataType, nullable: Boolean, metadata: Metadata): StructType = { - StructType(fields :+ new StructField(name, dataType, nullable, metadata)) + StructType(fields :+ StructField(name, dataType, nullable, metadata)) } /** @@ -347,7 +348,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private[sql] override def simpleString(maxNumberFields: Int): String = { val builder = new StringBuilder val fieldTypes = fields.take(maxNumberFields).map { - case f => s"${f.name}: ${f.dataType.simpleString(maxNumberFields)}" + f => s"${f.name}: ${f.dataType.simpleString(maxNumberFields)}" } builder.append("struct<") builder.append(fieldTypes.mkString(", ")) @@ -393,6 +394,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable object StructType extends AbstractDataType { /** @@ -469,7 +474,7 @@ object StructType extends AbstractDataType { nullable = leftNullable || rightNullable) } .orElse { - optionalMeta.putBoolean(metadataKeyForOptionalField, true) + optionalMeta.putBoolean(metadataKeyForOptionalField, value = true) Some(leftField.copy(metadata = optionalMeta.build())) } .foreach(newFields += _) @@ -479,7 +484,7 @@ object StructType extends AbstractDataType { rightFields .filterNot(f => leftMapped.get(f.name).nonEmpty) .foreach { f => - optionalMeta.putBoolean(metadataKeyForOptionalField, true) + optionalMeta.putBoolean(metadataKeyForOptionalField, value = true) newFields += f.copy(metadata = optionalMeta.build()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index 2be9b2d76c9fe..4540d8358acad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -20,16 +20,17 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `java.sql.Timestamp` values. * Please use the singleton [[DataTypes.TimestampType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class TimestampType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. @@ -48,4 +49,8 @@ class TimestampType private() extends AtomicType { private[spark] override def asNullable: TimestampType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object TimestampType extends TimestampType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 894631382f8ce..c33219c95b50a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -22,8 +22,6 @@ import java.util.Objects import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.apache.spark.annotation.DeveloperApi - /** * The data type for User Defined Types (UDTs). * @@ -96,12 +94,10 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa } /** - * :: DeveloperApi :: * The user defined type in Python. * * Note: This can only be accessed via Python UDF, or accessed as serialized object. */ -@DeveloperApi private[sql] class PythonUserDefinedType( val sqlType: DataType, override val pyUDT: String, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 85563ddedc165..43b6afd9ad896 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.typeOf import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.objects.NewInstance import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -94,7 +94,7 @@ object TestingUDT { .add("c", DoubleType, nullable = false) override def serialize(n: NestedStruct): Any = { - val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + val row = new SpecificInternalRow(sqlType.asInstanceOf[StructType].map(_.dataType)) row.setInt(0, n.a) row.setLong(1, n.b) row.setDouble(2, n.c) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 5c35baacef2fa..b748595fc4f2d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -767,6 +767,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast between string and interval") { import org.apache.spark.unsafe.types.CalendarInterval + checkEvaluation(Cast(Literal(""), CalendarIntervalType), null) checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), new CalendarInterval(-3, 7 * CalendarInterval.MICROS_PER_HOUR)) checkEvaluation(Cast(Literal.create( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 5588b4429164c..0cb201e4dae3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -68,7 +68,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) if (!checkResult(actual, expected)) { @@ -91,7 +91,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expression = CaseWhen((1 to cases).map(generateCase(_))) val plan = GenerateMutableProjection.generate(Seq(expression)) - val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) + val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) val actual = plan(input).toSeq(Seq(expression.dataType)) assert(actual(0) == cases) @@ -101,7 +101,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1))))) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(new GenericArrayData(Seq.fill(length)(true))) if (!checkResult(actual, expected)) { @@ -116,7 +116,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { case (expr, i) => Seq(Literal(i), expr) })) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)).map { + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)).map { case m: ArrayBasedMapData => ArrayBasedMapData.toScalaMap(m) } val expected = (0 until length).map((_, true)).toMap :: Nil @@ -130,7 +130,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = Seq(CreateStruct(List.fill(length)(EqualTo(Literal(1), Literal(1))))) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) if (!checkResult(actual, expected)) { @@ -145,7 +145,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { expr => Seq(Literal(expr.toString), expr) })) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) if (!checkResult(actual, expected)) { @@ -158,7 +158,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val schema = StructType(Seq.fill(length)(StructField("int", IntegerType))) val expressions = Seq(CreateExternalRow(Seq.fill(length)(Literal(1)), schema)) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(Row.fromSeq(Seq.fill(length)(1))) if (!checkResult(actual, expected)) { @@ -174,7 +174,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.create("PST", StringType)) } val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)( DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-07-24 07:00:00"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index 60939ee0eda5d..c587d4f632531 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -80,6 +80,88 @@ class ExpressionSetSuite extends SparkFunSuite { setTest(1, Not(aUpper >= 1), aUpper < 1, Not(Literal(1) <= aUpper), Literal(1) > aUpper) setTest(1, Not(aUpper <= 1), aUpper > 1, Not(Literal(1) >= aUpper), Literal(1) < aUpper) + // Reordering AND/OR expressions + setTest(1, aUpper > bUpper && aUpper <= 10, aUpper <= 10 && aUpper > bUpper) + setTest(1, + aUpper > bUpper && bUpper > 100 && aUpper <= 10, + bUpper > 100 && aUpper <= 10 && aUpper > bUpper) + + setTest(1, aUpper > bUpper || aUpper <= 10, aUpper <= 10 || aUpper > bUpper) + setTest(1, + aUpper > bUpper || bUpper > 100 || aUpper <= 10, + bUpper > 100 || aUpper <= 10 || aUpper > bUpper) + + setTest(1, + (aUpper <= 10 && aUpper > bUpper) || bUpper > 100, + bUpper > 100 || (aUpper <= 10 && aUpper > bUpper)) + + setTest(1, + aUpper >= bUpper || (aUpper > 10 && bUpper < 10), + (bUpper < 10 && aUpper > 10) || aUpper >= bUpper) + + // More complicated cases mixing AND/OR + // Three predicates in the following: + // (bUpper > 100) + // (aUpper < 100 && bUpper <= aUpper) + // (aUpper >= 10 && bUpper >= 50) + // They can be reordered and the sub-predicates contained in each of them can be reordered too. + setTest(1, + (bUpper > 100) || (aUpper < 100 && bUpper <= aUpper) || (aUpper >= 10 && bUpper >= 50), + (aUpper >= 10 && bUpper >= 50) || (bUpper > 100) || (aUpper < 100 && bUpper <= aUpper), + (bUpper >= 50 && aUpper >= 10) || (bUpper <= aUpper && aUpper < 100) || (bUpper > 100)) + + // Two predicates in the following: + // (bUpper > 100 && aUpper < 100 && bUpper <= aUpper) + // (aUpper >= 10 && bUpper >= 50) + setTest(1, + (bUpper > 100 && aUpper < 100 && bUpper <= aUpper) || (aUpper >= 10 && bUpper >= 50), + (aUpper >= 10 && bUpper >= 50) || (aUpper < 100 && bUpper > 100 && bUpper <= aUpper), + (bUpper >= 50 && aUpper >= 10) || (bUpper <= aUpper && aUpper < 100 && bUpper > 100)) + + // Three predicates in the following: + // (aUpper >= 10) + // (bUpper <= 10 && aUpper === bUpper && aUpper < 100) + // (bUpper >= 100) + setTest(1, + (aUpper >= 10) || (bUpper <= 10 && aUpper === bUpper && aUpper < 100) || (bUpper >= 100), + (aUpper === bUpper && aUpper < 100 && bUpper <= 10) || (bUpper >= 100) || (aUpper >= 10), + (aUpper < 100 && bUpper <= 10 && aUpper === bUpper) || (aUpper >= 10) || (bUpper >= 100), + ((bUpper <= 10 && aUpper === bUpper) && aUpper < 100) || ((aUpper >= 10) || (bUpper >= 100))) + + // Don't reorder non-deterministic expression in AND/OR. + setTest(2, Rand(1L) > aUpper && aUpper <= 10, aUpper <= 10 && Rand(1L) > aUpper) + setTest(2, + aUpper > bUpper && bUpper > 100 && Rand(1L) > aUpper, + bUpper > 100 && Rand(1L) > aUpper && aUpper > bUpper) + + setTest(2, Rand(1L) > aUpper || aUpper <= 10, aUpper <= 10 || Rand(1L) > aUpper) + setTest(2, + aUpper > bUpper || aUpper <= Rand(1L) || aUpper <= 10, + aUpper <= Rand(1L) || aUpper <= 10 || aUpper > bUpper) + + // Partial reorder case: we don't reorder non-deterministic expressions, + // but we can reorder sub-expressions in deterministic AND/OR expressions. + // There are two predicates: + // (aUpper > bUpper || bUpper > 100) => we can reorder sub-expressions in it. + // (aUpper === Rand(1L)) + setTest(1, + (aUpper > bUpper || bUpper > 100) && aUpper === Rand(1L), + (bUpper > 100 || aUpper > bUpper) && aUpper === Rand(1L)) + + // There are three predicates: + // (Rand(1L) > aUpper) + // (aUpper <= Rand(1L) && aUpper > bUpper) + // (aUpper > 10 && bUpper > 10) => we can reorder sub-expressions in it. + setTest(1, + Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (aUpper > 10 && bUpper > 10), + Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (bUpper > 10 && aUpper > 10)) + + // Same predicates as above, but a negative case when we reorder non-deterministic + // expression in (aUpper <= Rand(1L) && aUpper > bUpper). + setTest(2, + Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (aUpper > 10 && bUpper > 10), + Rand(1L) > aUpper || (aUpper > bUpper && aUpper <= Rand(1L)) || (aUpper > 10 && bUpper > 10)) + test("add to / remove from set") { val initialSet = ExpressionSet(aUpper + 1 :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala index 0f1264c7c3269..25a675a90276d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala @@ -45,7 +45,7 @@ class MapDataSuite extends SparkFunSuite { // UnsafeMapData val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) - val row = new GenericMutableRow(1) + val row = new GenericInternalRow(1) def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { row.update(0, map) val unsafeRow = unsafeConverter.apply(row) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 90790dda753f8..cf3cbe270753e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -37,7 +37,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificMutableRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) row.setLong(1, 1) row.setInt(2, 2) @@ -75,7 +75,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificMutableRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) row.update(2, "World".getBytes(StandardCharsets.UTF_8)) @@ -94,7 +94,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificMutableRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))) @@ -138,7 +138,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val converter = UnsafeProjection.create(fieldTypes) val rowWithAllNullColumns: InternalRow = { - val r = new SpecificMutableRow(fieldTypes) + val r = new SpecificInternalRow(fieldTypes) for (i <- fieldTypes.indices) { r.setNullAt(i) } @@ -167,7 +167,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // columns, then the serialized row representation should be identical to what we would get by // creating an entirely null row via the converter val rowWithNoNullColumns: InternalRow = { - val r = new SpecificMutableRow(fieldTypes) + val r = new SpecificInternalRow(fieldTypes) r.setNullAt(0) r.setBoolean(1, false) r.setByte(2, 20) @@ -243,11 +243,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("NaN canonicalization") { val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) - val row1 = new SpecificMutableRow(fieldTypes) + val row1 = new SpecificInternalRow(fieldTypes) row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001)) row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L)) - val row2 = new SpecificMutableRow(fieldTypes) + val row2 = new SpecificInternalRow(fieldTypes) row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) @@ -263,7 +263,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val converter = UnsafeProjection.create(fieldTypes) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(1)) row.update(1, InternalRow(InternalRow(2L))) @@ -324,7 +324,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ) val converter = UnsafeProjection.create(fieldTypes) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(1, 2)) row.update(1, createArray(createArray(3, 4))) @@ -359,7 +359,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerMap = createMap(5, 6)(7, 8) val map2 = createMap(9)(innerMap) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, map1) row.update(1, map2) @@ -400,7 +400,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ) val converter = UnsafeProjection.create(fieldTypes) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createArray(1))) row.update(1, createArray(InternalRow(2L))) @@ -439,7 +439,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ) val converter = UnsafeProjection.create(fieldTypes) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createMap(1)(2))) row.update(1, createMap(3)(InternalRow(4L))) @@ -485,7 +485,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ) val converter = UnsafeProjection.create(fieldTypes) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(createMap(1)(2))) row.update(1, createMap(3)(createArray(4))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala index 61298a1b72d77..8456e244609bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribu import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericMutableRow, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest, PercentileDigestSerializer} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.ArrayData @@ -144,7 +144,8 @@ class ApproximatePercentileSuite extends SparkFunSuite { .withNewInputAggBufferOffset(inputAggregationBufferOffset) .withNewMutableAggBufferOffset(mutableAggregationBufferOffset) - val mutableAggBuffer = new GenericMutableRow(new Array[Any](mutableAggregationBufferOffset + 1)) + val mutableAggBuffer = new GenericInternalRow( + new Array[Any](mutableAggregationBufferOffset + 1)) agg.initialize(mutableAggBuffer) val dataCount = 10 (1 to dataCount).foreach { data => @@ -154,7 +155,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { // Serialize the aggregation buffer val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset) - val inputAggBuffer = new GenericMutableRow(Array[Any](null, serialized)) + val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized)) // Phase 2: final mode aggregation // Re-initialize the aggregation buffer @@ -311,7 +312,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { test("class ApproximatePercentile, null handling") { val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) val agg = new ApproximatePercentile(childExpression, Literal(0.5D)) - val buffer = new GenericMutableRow(new Array[Any](1)) + val buffer = new GenericInternalRow(new Array[Any](1)) agg.initialize(buffer) // Empty aggregation buffer assert(agg.eval(buffer) == null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala index f5374229ca5cd..17f6b71bb270b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala @@ -22,28 +22,29 @@ import java.util.Random import scala.collection.mutable import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{BoundReference, MutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificInternalRow} import org.apache.spark.sql.types.{DataType, IntegerType} class HyperLogLogPlusPlusSuite extends SparkFunSuite { /** Create a HLL++ instance and an input and output buffer. */ def createEstimator(rsd: Double, dt: DataType = IntegerType): - (HyperLogLogPlusPlus, MutableRow, MutableRow) = { - val input = new SpecificMutableRow(Seq(dt)) + (HyperLogLogPlusPlus, InternalRow, InternalRow) = { + val input = new SpecificInternalRow(Seq(dt)) val hll = new HyperLogLogPlusPlus(new BoundReference(0, dt, true), rsd) val buffer = createBuffer(hll) (hll, input, buffer) } - def createBuffer(hll: HyperLogLogPlusPlus): MutableRow = { - val buffer = new SpecificMutableRow(hll.aggBufferAttributes.map(_.dataType)) + def createBuffer(hll: HyperLogLogPlusPlus): InternalRow = { + val buffer = new SpecificInternalRow(hll.aggBufferAttributes.map(_.dataType)) hll.initialize(buffer) buffer } /** Evaluate the estimate. It should be within 3*SD's of the given true rsd. */ - def evaluateEstimate(hll: HyperLogLogPlusPlus, buffer: MutableRow, cardinality: Int): Unit = { + def evaluateEstimate(hll: HyperLogLogPlusPlus, buffer: InternalRow, cardinality: Int): Unit = { val estimate = hll.eval(buffer).asInstanceOf[Long].toDouble val error = math.abs((estimate / cardinality.toDouble) - 1.0d) assert(error < hll.trueRsd * 3.0d, "Error should be within 3 std. errors.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 020fb16f6f3d5..3964fa3924b24 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -116,6 +116,7 @@ class DataTypeParserSuite extends SparkFunSuite { unsupported("it is not a data type") unsupported("struct") unsupported("struct") // DataType parser accepts certain reserved keywords. checkDataType( @@ -125,16 +126,11 @@ class DataTypeParserSuite extends SparkFunSuite { StructField("DATE", BooleanType, true) :: Nil) ) - // Define struct columns without ':' - checkDataType( - "struct", - (new StructType).add("x", IntegerType).add("y", StringType)) - - checkDataType( - "struct<`x``y` int>", - (new StructType).add("x`y", IntegerType)) - // Use SQL keywords. checkDataType("struct", (new StructType).add("end", LongType).add("select", IntegerType).add("from", StringType)) + + // DataType parser accepts comments. + checkDataType("Struct", + (new StructType).add("x", IntegerType).add("y", StringType, true, "test")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 0fb1138478a9b..17cfc8158803b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -535,4 +535,13 @@ class ExpressionParserSuite extends PlanTest { // ".123BD" should not be treated as token of type BIGDECIMAL_LITERAL assertEqual("a.123BD_column", UnresolvedAttribute("a.123BD_column")) } + + test("SPARK-17832 function identifier contains backtick") { + val complexName = FunctionIdentifier("`ba`r", Some("`fo`o")) + assertEqual(complexName.quotedString, UnresolvedAttribute("`fo`o.`ba`r")) + intercept(complexName.unquotedString, "mismatched input") + // Function identifier contains countious backticks should be treated correctly. + val complexName2 = FunctionIdentifier("ba``r", Some("fo``o")) + assertEqual(complexName2.quotedString, UnresolvedAttribute("fo``o.ba``r")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 793be8953d07a..7d46011b410e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -104,4 +104,14 @@ class TableIdentifierParserSuite extends SparkFunSuite { // ".123BD" should not be treated as token of type BIGDECIMAL_LITERAL assert(parseTableIdentifier("a.123BD_LIST") == TableIdentifier("123BD_LIST", Some("a"))) } + + test("SPARK-17832 table identifier - contains backtick") { + val complexName = TableIdentifier("`weird`table`name", Some("`d`b`1")) + assert(complexName === parseTableIdentifier("```d``b``1`.```weird``table``name`")) + assert(complexName === parseTableIdentifier(complexName.quotedString)) + intercept[ParseException](parseTableIdentifier(complexName.unquotedString)) + // Table identifier contains countious backticks should be treated correctly. + val complexName2 = TableIdentifier("x``y", Some("d``b")) + assert(complexName2 === parseTableIdentifier(complexName2.quotedString)) + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 62abc2a821a3a..a6ce4c2edc232 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -21,8 +21,7 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; -import org.apache.spark.sql.catalyst.expressions.MutableRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -91,7 +90,7 @@ public void close() { * Adapter class to interop with existing components that expect internal row. A lot of * performance is lost with this translation. */ - public static final class Row extends MutableRow { + public static final class Row extends InternalRow { protected int rowId; private final ColumnarBatch parent; private final int fixedLenRowSize; @@ -129,7 +128,7 @@ public void markFiltered() { * Revisit this. This is expensive. This is currently only used in test paths. */ public InternalRow copy() { - GenericMutableRow row = new GenericMutableRow(columns.length); + GenericInternalRow row = new GenericInternalRow(columns.length); for (int i = 0; i < numFields(); i++) { if (isNullAt(i)) { row.setNullAt(i); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 63da501f18cca..d22bb17934ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.language.implicitConversions -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} @@ -59,6 +59,7 @@ private[sql] object Column { * * @since 1.6.0 */ +@InterfaceStability.Stable class TypedColumn[-T, U]( expr: Expression, private[sql] val encoder: ExpressionEncoder[U]) @@ -124,6 +125,7 @@ class TypedColumn[-T, U]( * * @since 1.3.0 */ +@InterfaceStability.Stable class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { @@ -1185,6 +1187,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.3.0 */ @Experimental +@InterfaceStability.Evolving class ColumnName(name: String) extends Column(name) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ad00966a917ad..65a9c008f9650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -21,7 +21,7 @@ import java.{lang => jl} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -34,6 +34,7 @@ import org.apache.spark.sql.types._ * @since 1.3.1 */ @Experimental +@InterfaceStability.Evolving final class DataFrameNaFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b84fb2fb95914..a716a916b7f7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,11 +24,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.Partition +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} +import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.InferSchema import org.apache.spark.sql.types.StructType @@ -38,6 +39,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.4.0 */ +@InterfaceStability.Stable class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** @@ -229,13 +231,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { table: String, parts: Array[Partition], connectionProperties: Properties): DataFrame = { - val props = new Properties() - extraOptions.foreach { case (key, value) => - props.put(key, value) - } - // connectionProperties should override settings in extraOptions - props.putAll(connectionProperties) - val relation = JDBCRelation(url, table, parts, props)(sparkSession) + // connectionProperties should override settings in extraOptions. + val params = extraOptions.toMap ++ connectionProperties.asScala.toMap + val options = new JDBCOptions(url, table, params) + val relation = JDBCRelation(parts, options)(sparkSession) sparkSession.baseRelationToDataFrame(relation) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index d69be36917360..a212bb6205328 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -21,7 +21,7 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.types._ @@ -34,6 +34,7 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} * @since 1.4.0 */ @Experimental +@InterfaceStability.Evolving final class DataFrameStatFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7374a8e045035..35ef050dcb169 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -21,12 +21,12 @@ import java.util.Properties import scala.collection.JavaConverters._ +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, CreateTable, DataSource, HadoopFsRelation} -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.StructType /** @@ -35,6 +35,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.4.0 */ +@InterfaceStability.Stable final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private val df = ds.toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9cfbdffd02582..e59a483075c94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.api.python.{PythonRDD, SerDeUtil} @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} -import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand} +import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython @@ -149,9 +149,10 @@ private[sql] object Dataset { * * @since 1.6.0 */ +@InterfaceStability.Stable class Dataset[T] private[sql]( @transient val sparkSession: SparkSession, - @DeveloperApi @transient val queryExecution: QueryExecution, + @DeveloperApi @InterfaceStability.Unstable @transient val queryExecution: QueryExecution, encoder: Encoder[T]) extends Serializable { @@ -369,6 +370,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) /** @@ -477,6 +479,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def isStreaming: Boolean = logicalPlan.isStreaming /** @@ -798,6 +801,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, // etc. @@ -869,6 +873,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { joinWith(other, condition, "inner") } @@ -1071,6 +1076,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, @@ -1105,6 +1111,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] @@ -1116,6 +1123,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def select[U1, U2, U3]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1130,6 +1138,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def select[U1, U2, U3, U4]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1145,6 +1154,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def select[U1, U2, U3, U4, U5]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1315,6 +1325,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def reduce(func: (T, T) => T): T = rdd.reduce(func) /** @@ -1327,6 +1338,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) /** @@ -1338,6 +1350,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, inputPlan) @@ -1360,6 +1373,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = groupByKey(func.call(_))(encoder) @@ -1878,17 +1892,25 @@ class Dataset[T] private[sql]( def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { val resolver = sparkSession.sessionState.analyzer.resolver val allColumns = queryExecution.analyzed.output - val groupCols = colNames.map { colName => - allColumns.find(col => resolver(col.name, colName)).getOrElse( + val groupCols = colNames.flatMap { colName => + // It is possibly there are more than one columns with the same name, + // so we call filter instead of find. + val cols = allColumns.filter(col => resolver(col.name, colName)) + if (cols.isEmpty) { throw new AnalysisException( - s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")) + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") + } + cols } val groupColExprIds = groupCols.map(_.exprId) val aggCols = logicalPlan.output.map { attr => if (groupColExprIds.contains(attr.exprId)) { attr } else { - Alias(new First(attr).toAggregateExpression(), attr.name)() + // Removing duplicate rows should not change output attributes. We should keep + // the original exprId of the attribute. Otherwise, to select a column in original + // dataset will cause analysis exception due to unresolved attribute. + Alias(new First(attr).toAggregateExpression(), attr.name)(exprId = attr.exprId) } } Aggregate(groupCols, aggCols, logicalPlan) @@ -2028,6 +2050,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -2041,6 +2064,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -2054,6 +2078,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { MapElements[T, U](func, logicalPlan) } @@ -2067,6 +2092,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder withTypedPlan(MapElements[T, U](func, logicalPlan)) @@ -2081,6 +2107,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, @@ -2097,6 +2124,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala mapPartitions(func)(encoder) @@ -2127,6 +2155,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) @@ -2140,6 +2169,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (T) => Iterator[U] = x => f.call(x).asScala flatMap(func)(encoder) @@ -2433,9 +2463,13 @@ class Dataset[T] private[sql]( } /** - * Creates a temporary view using the given name. The lifetime of this + * Creates a local temporary view using the given name. The lifetime of this * temporary view is tied to the [[SparkSession]] that was used to create this Dataset. * + * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that + * created it, i.e. it will be automatically dropped when the session terminates. It's not + * tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. + * * @throws AnalysisException if the view name already exists * * @group basic @@ -2443,21 +2477,46 @@ class Dataset[T] private[sql]( */ @throws[AnalysisException] def createTempView(viewName: String): Unit = withPlan { - createViewCommand(viewName, replace = false) + createTempViewCommand(viewName, replace = false, global = false) } + + /** - * Creates a temporary view using the given name. The lifetime of this + * Creates a local temporary view using the given name. The lifetime of this * temporary view is tied to the [[SparkSession]] that was used to create this Dataset. * * @group basic * @since 2.0.0 */ def createOrReplaceTempView(viewName: String): Unit = withPlan { - createViewCommand(viewName, replace = true) + createTempViewCommand(viewName, replace = true, global = false) } - private def createViewCommand(viewName: String, replace: Boolean): CreateViewCommand = { + /** + * Creates a global temporary view using the given name. The lifetime of this + * temporary view is tied to this Spark application. + * + * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application, + * i.e. it will be automatically dropped when the application terminates. It's tied to a system + * preserved database `_global_temp`, and we must use the qualified name to refer a global temp + * view, e.g. `SELECT * FROM _global_temp.view1`. + * + * @throws AnalysisException if the view name already exists + * + * @group basic + * @since 2.1.0 + */ + @throws[AnalysisException] + def createGlobalTempView(viewName: String): Unit = withPlan { + createTempViewCommand(viewName, replace = false, global = true) + } + + private def createTempViewCommand( + viewName: String, + replace: Boolean, + global: Boolean): CreateViewCommand = { + val viewType = if (global) GlobalTempView else LocalTempView CreateViewCommand( name = sparkSession.sessionState.sqlParser.parseTableIdentifier(viewName), userSpecifiedColumns = Nil, @@ -2467,17 +2526,15 @@ class Dataset[T] private[sql]( child = logicalPlan, allowExisting = false, replace = replace, - isTemporary = true) + viewType = viewType) } /** - * :: Experimental :: * Interface for saving the content of the non-streaming Dataset out into external storage. * * @group basic * @since 1.6.0 */ - @Experimental def write: DataFrameWriter[T] = { if (isStreaming) { logicalPlan.failAnalysis( @@ -2494,6 +2551,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def writeStream: DataStreamWriter[T] = { if (!isStreaming) { logicalPlan.failAnalysis( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 47b81c17a31dc..18bccee98f610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.annotation.InterfaceStability + /** * A container for a [[Dataset]], used for implicit conversions in Scala. * @@ -27,6 +29,7 @@ package org.apache.spark.sql * * @since 1.6.0 */ +@InterfaceStability.Stable case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { // This is declared with parentheses to prevent the Scala compiler from treating diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index a435734b0caef..1e8ba51e59e33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * @since 1.3.0 */ @Experimental +@InterfaceStability.Unstable class ExperimentalMethods private[sql]() { /** @@ -41,10 +42,8 @@ class ExperimentalMethods private[sql]() { * * @since 1.3.0 */ - @Experimental @volatile var extraStrategies: Seq[Strategy] = Nil - @Experimental @volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index f56b25b5576f1..1163035e315fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.streaming.StreamingQuery /** @@ -68,8 +68,11 @@ import org.apache.spark.sql.streaming.StreamingQuery * @since 2.0.0 */ @Experimental +@InterfaceStability.Evolving abstract class ForeachWriter[T] extends Serializable { + // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. + /** * Called when starting to process one partition of new data in the executor. The `version` is * for data deduplication when there are failures. When recovering from a failure, some data may diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index cea16fba76e47..828eb94efe598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} @@ -36,6 +36,7 @@ import org.apache.spark.sql.expressions.ReduceAggregator * @since 2.0.0 */ @Experimental +@InterfaceStability.Evolving class KeyValueGroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], vEncoder: Encoder[V], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 6c3fe07709fa3..f019d1e9daceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.api.r.SQLUtils._ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot} @@ -43,6 +42,7 @@ import org.apache.spark.sql.types.StructType * * @since 2.0.0 */ +@InterfaceStability.Stable class RelationalGroupedDataset protected[sql]( df: DataFrame, groupingExprs: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index 7e07e0cb84a87..9108d19d0a0c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} /** @@ -28,6 +29,7 @@ import org.apache.spark.sql.internal.SQLConf * * @since 2.0.0 */ +@InterfaceStability.Stable class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { /** @@ -36,6 +38,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * @since 2.0.0 */ def set(key: String, value: String): Unit = { + requireNonStaticConf(key) sqlConf.setConfString(key, value) } @@ -45,6 +48,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * @since 2.0.0 */ def set(key: String, value: Boolean): Unit = { + requireNonStaticConf(key) set(key, value.toString) } @@ -54,6 +58,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * @since 2.0.0 */ def set(key: String, value: Long): Unit = { + requireNonStaticConf(key) set(key, value.toString) } @@ -122,6 +127,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * @since 2.0.0 */ def unset(key: String): Unit = { + requireNonStaticConf(key) sqlConf.unsetConf(key) } @@ -132,4 +138,9 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { sqlConf.contains(key) } + private def requireNonStaticConf(key: String): Unit = { + if (StaticSQLConf.globalConfKeys.contains(key)) { + throw new AnalysisException(s"Cannot modify the value of a static config: $key") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 2edf2e1972053..3c5cf037c578d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,7 +24,7 @@ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigEntry @@ -55,6 +55,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager * @groupname Ungrouped Support functions for language integrated queries * @since 1.0.0 */ +@InterfaceStability.Stable class SQLContext private[sql](val sparkSession: SparkSession) extends Logging with Serializable { @@ -95,6 +96,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * that listen for execution metrics. */ @Experimental + @InterfaceStability.Evolving def listenerManager: ExecutionListenerManager = sparkSession.listenerManager /** @@ -166,6 +168,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) */ @Experimental @transient + @InterfaceStability.Unstable def experimental: ExperimentalMethods = sparkSession.experimental /** @@ -261,6 +264,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental + @InterfaceStability.Evolving object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = self } @@ -274,6 +278,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental + @InterfaceStability.Evolving def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { sparkSession.createDataFrame(rdd) } @@ -286,6 +291,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental + @InterfaceStability.Evolving def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { sparkSession.createDataFrame(data) } @@ -333,6 +339,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rowRDD, schema) } @@ -376,6 +383,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataset */ @Experimental + @InterfaceStability.Evolving def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { sparkSession.createDataset(data) } @@ -413,6 +421,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataset */ @Experimental + @InterfaceStability.Evolving def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { sparkSession.createDataset(data) } @@ -436,6 +445,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rowRDD, schema) } @@ -450,6 +460,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.6.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rows, schema) } @@ -515,6 +526,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def readStream: DataStreamReader = sparkSession.readStream @@ -632,6 +644,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental + @InterfaceStability.Evolving def range(end: Long): DataFrame = sparkSession.range(end).toDF() /** @@ -643,6 +656,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long): DataFrame = sparkSession.range(start, end).toDF() /** @@ -654,6 +668,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long, step: Long): DataFrame = { sparkSession.range(start, end, step).toDF() } @@ -668,6 +683,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { sparkSession.range(start, end, step, numPartitions).toDF() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 440952572d8c4..73d16d8a10fd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -28,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder * * @since 1.6.0 */ +@InterfaceStability.Evolving abstract class SQLImplicits { protected def _sqlContext: SQLContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 6d7ac0f6c1bb2..137c426b4b88d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -26,10 +26,9 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog @@ -41,6 +40,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, LongType, StructType} @@ -68,6 +68,7 @@ import org.apache.spark.util.Utils * .getOrCreate() * }}} */ +@InterfaceStability.Stable class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState]) @@ -137,6 +138,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def listenerManager: ExecutionListenerManager = sessionState.listenerManager /** @@ -147,6 +149,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Unstable def experimental: ExperimentalMethods = sessionState.experimentalMethods /** @@ -190,6 +193,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Unstable def streams: StreamingQueryManager = sessionState.streamingQueryManager /** @@ -229,6 +233,7 @@ class SparkSession private( * @return 2.0.0 */ @Experimental + @InterfaceStability.Evolving def emptyDataset[T: Encoder]: Dataset[T] = { val encoder = implicitly[Encoder[T]] new Dataset(self, LocalRelation(encoder.schema.toAttributes), encoder) @@ -241,6 +246,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { SparkSession.setActiveSession(this) val encoder = Encoders.product[A] @@ -254,6 +260,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { SparkSession.setActiveSession(this) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] @@ -293,6 +300,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema, needsConversion = true) } @@ -306,6 +314,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD.rdd, schema) } @@ -319,6 +328,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) } @@ -410,6 +420,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes @@ -428,6 +439,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { Dataset[T](self, ExternalRDD(data, self)) } @@ -449,6 +461,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { createDataset(data.asScala) } @@ -461,6 +474,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def range(end: Long): Dataset[java.lang.Long] = range(0, end) /** @@ -471,6 +485,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long): Dataset[java.lang.Long] = { range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) } @@ -483,6 +498,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { range(start, end, step, numPartitions = sparkContext.defaultParallelism) } @@ -496,6 +512,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { new Dataset(self, Range(start, end, step, numPartitions), Encoders.LONG) } @@ -596,6 +613,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def readStream: DataStreamReader = new DataStreamReader(self) @@ -614,6 +632,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext } @@ -670,11 +689,13 @@ class SparkSession private( } +@InterfaceStability.Stable object SparkSession { /** * Builder for [[SparkSession]]. */ + @InterfaceStability.Stable class Builder extends Logging { private[this] val options = new scala.collection.mutable.HashMap[String, String] @@ -791,7 +812,7 @@ object SparkSession { // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { - options.foreach { case (k, v) => session.conf.set(k, v) } + options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } if (options.nonEmpty) { logWarning("Use an existing SparkSession, some configuration may not take effect.") } @@ -803,7 +824,7 @@ object SparkSession { // If the current thread does not have an active session, get it from the global session. session = defaultSession.get() if ((session ne null) && !session.sparkContext.isStopped) { - options.foreach { case (k, v) => session.conf.set(k, v) } + options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } if (options.nonEmpty) { logWarning("Use an existing SparkSession, some configuration may not take effect.") } @@ -829,7 +850,7 @@ object SparkSession { sc } session = new SparkSession(sparkContext) - options.foreach { case (k, v) => session.conf.set(k, v) } + options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } defaultSession.set(session) // Register a successfully instantiated context to the singleton. This should be at the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index b006236481a29..617a14793697b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.reflect.runtime.universe.TypeTag import scala.util.Try +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry @@ -36,6 +37,7 @@ import org.apache.spark.sql.types.DataType * * @since 1.3.0 */ +@InterfaceStability.Stable class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging { protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 7d8ea03a27910..9de6510c634b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -28,11 +28,11 @@ import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.execution.command.ShowTablesCommand +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ private[sql] object SQLUtils extends Logging { @@ -64,7 +64,7 @@ private[sql] object SQLUtils extends Logging { spark: SparkSession, sparkConfigMap: JMap[Object, Object]): Unit = { for ((name, value) <- sparkConfigMap.asScala) { - spark.conf.set(name.toString, value.toString) + spark.sessionState.conf.setConfString(name.toString, value.toString) } for ((name, value) <- sparkConfigMap.asScala) { spark.sparkContext.conf.set(name.toString, value.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 7f2762c7dac92..18cba8ce28b4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -262,13 +262,36 @@ abstract class Catalog { options: Map[String, String]): DataFrame /** - * Drops the temporary view with the given view name in the catalog. + * Drops the local temporary view with the given view name in the catalog. * If the view has been cached before, then it will also be uncached. * + * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that + * created it, i.e. it will be automatically dropped when the session terminates. It's not + * tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. + * + * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean + * in Spark 2.1. + * * @param viewName the name of the view to be dropped. + * @return true if the view is dropped successfully, false otherwise. * @since 2.0.0 */ - def dropTempView(viewName: String): Unit + def dropTempView(viewName: String): Boolean + + /** + * Drops the global temporary view with the given view name in the catalog. + * If the view has been cached before, then it will also be uncached. + * + * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application, + * i.e. it will be automatically dropped when the application terminates. It's tied to a system + * preserved database `_global_temp`, and we must use the qualified name to refer a global temp + * view, e.g. `SELECT * FROM _global_temp.view1`. + * + * @param viewName the name of the view to be dropped. + * @return true if the view is dropped successfully, false otherwise. + * @since 2.1.0 + */ + def dropGlobalTempView(viewName: String): Boolean /** * Returns true if the table is currently cached in-memory. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 6c4248c60e893..d3a22228623e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -32,7 +32,7 @@ object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { data.mapPartitions { iterator => val numColumns = outputTypes.length - val mutableRow = new GenericMutableRow(numColumns) + val mutableRow = new GenericInternalRow(numColumns) val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) iterator.map { r => var i = 0 @@ -52,7 +52,7 @@ object RDDConversions { def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[InternalRow] = { data.mapPartitions { iterator => val numColumns = outputTypes.length - val mutableRow = new GenericMutableRow(numColumns) + val mutableRow = new GenericInternalRow(numColumns) val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) iterator.map { r => var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 383b3a233fc27..cb45a6d78b9b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -21,15 +21,14 @@ import java.nio.charset.StandardCharsets import java.sql.Timestamp import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec} +import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils @@ -125,6 +124,9 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .mkString("\t") } } + // SHOW TABLES in Hive only output table names, while ours outputs database, table name, isTemp. + case command: ExecutedCommandExec if command.cmd.isInstanceOf[ShowTablesCommand] => + command.executeCollect().map(_.getString(1)) case command: ExecutedCommandExec => command.executeCollect().map(_.getString(0)) case other => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 085bb9fc3c6cc..be2eddbb0e423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, ScriptInputOutputSchema} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing, _} +import org.apache.spark.sql.execution.datasources.{CreateTable, _} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. @@ -340,7 +340,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (provider.toLowerCase == "hive") { throw new AnalysisException("Cannot create hive serde table with CREATE TABLE USING") } - val schema = Option(ctx.colTypeList()).map(createStructType) + val schema = Option(ctx.colTypeList()).map(createSchema) val partitionColumnNames = Option(ctx.partitionColumnNames) .map(visitIdentifierList(_).toArray) @@ -385,7 +385,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " + "CREATE TEMPORARY VIEW ... USING ... instead") - CreateTempViewUsing(table, schema, replace = true, provider, options) + CreateTempViewUsing(table, schema, replace = true, global = false, provider, options) } else { CreateTable(tableDesc, mode, None) } @@ -399,8 +399,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx: CreateTempViewUsingContext): LogicalPlan = withOrigin(ctx) { CreateTempViewUsing( tableIdent = visitTableIdentifier(ctx.tableIdentifier()), - userSpecifiedSchema = Option(ctx.colTypeList()).map(createStructType), + userSpecifiedSchema = Option(ctx.colTypeList()).map(createSchema), replace = ctx.REPLACE != null, + global = ctx.GLOBAL != null, provider = ctx.tableProvider.qualifiedName.getText, options = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) } @@ -1269,7 +1270,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * * For example: * {{{ - * CREATE [OR REPLACE] [TEMPORARY] VIEW [IF NOT EXISTS] [db_name.]view_name + * CREATE [OR REPLACE] [[GLOBAL] TEMPORARY] VIEW [IF NOT EXISTS] [db_name.]view_name * [(column_name [COMMENT column_comment], ...) ] * [COMMENT view_comment] * [TBLPROPERTIES (property_name = property_value, ...)] @@ -1286,6 +1287,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } } + val viewType = if (ctx.TEMPORARY == null) { + PersistedView + } else if (ctx.GLOBAL != null) { + GlobalTempView + } else { + LocalTempView + } + CreateViewCommand( name = visitTableIdentifier(ctx.tableIdentifier), userSpecifiedColumns = userSpecifiedColumns, @@ -1295,7 +1304,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { child = plan(ctx.query), allowExisting = ctx.EXISTS != null, replace = ctx.REPLACE != null, - isTemporary = ctx.TEMPORARY != null) + viewType = viewType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index f335912ba2c32..7c11fdb9792e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -153,7 +153,7 @@ abstract class AggregationIterator( protected def generateProcessRow( expressions: Seq[AggregateExpression], functions: Seq[AggregateFunction], - inputAttributes: Seq[Attribute]): (MutableRow, InternalRow) => Unit = { + inputAttributes: Seq[Attribute]): (InternalRow, InternalRow) => Unit = { val joinedRow = new JoinedRow if (expressions.nonEmpty) { val mergeExpressions = functions.zipWithIndex.flatMap { @@ -168,9 +168,9 @@ abstract class AggregationIterator( case (ae: ImperativeAggregate, i) => expressions(i).mode match { case Partial | Complete => - (buffer: MutableRow, row: InternalRow) => ae.update(buffer, row) + (buffer: InternalRow, row: InternalRow) => ae.update(buffer, row) case PartialMerge | Final => - (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row) + (buffer: InternalRow, row: InternalRow) => ae.merge(buffer, row) } }.toArray // This projection is used to merge buffer values for all expression-based aggregates. @@ -178,7 +178,7 @@ abstract class AggregationIterator( val updateProjection = newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes) - (currentBuffer: MutableRow, row: InternalRow) => { + (currentBuffer: InternalRow, row: InternalRow) => { // Process all expression-based aggregate functions. updateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) // Process all imperative aggregate functions. @@ -190,11 +190,11 @@ abstract class AggregationIterator( } } else { // Grouping only. - (currentBuffer: MutableRow, row: InternalRow) => {} + (currentBuffer: InternalRow, row: InternalRow) => {} } } - protected val processRow: (MutableRow, InternalRow) => Unit = + protected val processRow: (InternalRow, InternalRow) => Unit = generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes) protected val groupingProjection: UnsafeProjection = @@ -202,7 +202,7 @@ abstract class AggregationIterator( protected val groupingAttributes = groupingExpressions.map(_.toAttribute) // Initializing the function used to generate the output row. - protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + protected def generateResultProjection(): (UnsafeRow, InternalRow) => UnsafeRow = { val joinedRow = new JoinedRow val modes = aggregateExpressions.map(_.mode).distinct val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) @@ -211,14 +211,14 @@ abstract class AggregationIterator( case ae: DeclarativeAggregate => ae.evaluateExpression case agg: AggregateFunction => NoOp } - val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType)) + val aggregateResult = new SpecificInternalRow(aggregateAttributes.map(_.dataType)) val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes) expressionAggEvalProjection.target(aggregateResult) val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes) - (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { // Generate results for all expression-based aggregate functions. expressionAggEvalProjection(currentBuffer) // Generate results for all imperative aggregate functions. @@ -244,7 +244,7 @@ abstract class AggregationIterator( } } - (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { // Serializes the generic object stored in aggregation buffer var i = 0 while (i < typedImperativeAggregates.length) { @@ -256,17 +256,17 @@ abstract class AggregationIterator( } else { // Grouping-only: we only output values based on grouping expressions. val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) - (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { resultProjection(currentGroupingKey) } } } - protected val generateOutput: (UnsafeRow, MutableRow) => UnsafeRow = + protected val generateOutput: (UnsafeRow, InternalRow) => UnsafeRow = generateResultProjection() /** Initializes buffer values for all aggregate functions. */ - protected def initializeBuffer(buffer: MutableRow): Unit = { + protected def initializeBuffer(buffer: InternalRow): Unit = { expressionAggInitialProjection.target(buffer)(EmptyRow) var i = 0 while (i < allImperativeAggregateFunctions.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index c2b1ef0fe3c2c..bea2dce1a7657 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -49,11 +49,11 @@ class SortBasedAggregationIterator( * Creates a new aggregation buffer and initializes buffer values * for all aggregate functions. */ - private def newBuffer: MutableRow = { + private def newBuffer: InternalRow = { val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length - val genericMutableBuffer = new GenericMutableRow(bufferRowSize) + val genericMutableBuffer = new GenericInternalRow(bufferRowSize) val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) val buffer = if (useUnsafeBuffer) { @@ -84,7 +84,7 @@ class SortBasedAggregationIterator( private[this] var sortedInputHasNewGroup: Boolean = false // The aggregation buffer used by the sort-based aggregation. - private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + private[this] val sortBasedAggregationBuffer: InternalRow = newBuffer // This safe projection is used to turn the input row into safe row. This is necessary // because the input row may be produced by unsafe projection in child operator and all the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 4e072a92cc772..2988161ee5e7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -118,7 +118,7 @@ class TungstenAggregationIterator( private def createNewAggregationBuffer(): UnsafeRow = { val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType)) - .apply(new GenericMutableRow(bufferSchema.length)) + .apply(new GenericInternalRow(bufferSchema.length)) // Initialize declarative aggregates' buffer values expressionAggInitialProjection.target(buffer)(EmptyRow) // Initialize imperative aggregates' buffer values @@ -127,7 +127,7 @@ class TungstenAggregationIterator( } // Creates a function used to generate output rows. - override protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + override protected def generateResultProjection(): (UnsafeRow, InternalRow) => UnsafeRow = { val modes = aggregateExpressions.map(_.mode).distinct if (modes.nonEmpty && !modes.contains(Final) && !modes.contains(Complete)) { // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection @@ -137,7 +137,7 @@ class TungstenAggregationIterator( val bufferSchema = StructType.fromAttributes(bufferAttributes) val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) - (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { unsafeRowJoiner.join(currentGroupingKey, currentBuffer.asInstanceOf[UnsafeRow]) } } else { @@ -300,7 +300,7 @@ class TungstenAggregationIterator( private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() // The function used to process rows in a group - private[this] var sortBasedProcessRow: (MutableRow, InternalRow) => Unit = null + private[this] var sortBasedProcessRow: (InternalRow, InternalRow) => Unit = null // Processes rows in the current group. It will stop when it find a new group. private def processCurrentSortedGroup(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 586e1456ac69e..67760f334e406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow, _} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, _} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} @@ -96,18 +96,18 @@ sealed trait BufferSetterGetterUtils { getters } - def createSetters(schema: StructType): Array[((MutableRow, Int, Any) => Unit)] = { + def createSetters(schema: StructType): Array[((InternalRow, Int, Any) => Unit)] = { val dataTypes = schema.fields.map(_.dataType) - val setters = new Array[(MutableRow, Int, Any) => Unit](dataTypes.length) + val setters = new Array[(InternalRow, Int, Any) => Unit](dataTypes.length) var i = 0 while (i < setters.length) { setters(i) = dataTypes(i) match { case NullType => - (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal) + (row: InternalRow, ordinal: Int, value: Any) => row.setNullAt(ordinal) case b: BooleanType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setBoolean(ordinal, value.asInstanceOf[Boolean]) } else { @@ -115,7 +115,7 @@ sealed trait BufferSetterGetterUtils { } case ByteType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setByte(ordinal, value.asInstanceOf[Byte]) } else { @@ -123,7 +123,7 @@ sealed trait BufferSetterGetterUtils { } case ShortType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setShort(ordinal, value.asInstanceOf[Short]) } else { @@ -131,7 +131,7 @@ sealed trait BufferSetterGetterUtils { } case IntegerType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setInt(ordinal, value.asInstanceOf[Int]) } else { @@ -139,7 +139,7 @@ sealed trait BufferSetterGetterUtils { } case LongType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setLong(ordinal, value.asInstanceOf[Long]) } else { @@ -147,7 +147,7 @@ sealed trait BufferSetterGetterUtils { } case FloatType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setFloat(ordinal, value.asInstanceOf[Float]) } else { @@ -155,7 +155,7 @@ sealed trait BufferSetterGetterUtils { } case DoubleType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setDouble(ordinal, value.asInstanceOf[Double]) } else { @@ -164,13 +164,13 @@ sealed trait BufferSetterGetterUtils { case dt: DecimalType => val precision = dt.precision - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => // To make it work with UnsafeRow, we cannot use setNullAt. // Please see the comment of UnsafeRow's setDecimal. row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) case DateType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setInt(ordinal, value.asInstanceOf[Int]) } else { @@ -178,7 +178,7 @@ sealed trait BufferSetterGetterUtils { } case TimestampType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setLong(ordinal, value.asInstanceOf[Long]) } else { @@ -186,7 +186,7 @@ sealed trait BufferSetterGetterUtils { } case other => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.update(ordinal, value) } else { @@ -209,7 +209,7 @@ private[aggregate] class MutableAggregationBufferImpl( toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, - var underlyingBuffer: MutableRow) + var underlyingBuffer: InternalRow) extends MutableAggregationBuffer with BufferSetterGetterUtils { private[this] val offsets: Array[Int] = { @@ -413,13 +413,13 @@ case class ScalaUDAF( null) } - override def initialize(buffer: MutableRow): Unit = { + override def initialize(buffer: InternalRow): Unit = { mutableAggregateBuffer.underlyingBuffer = buffer udaf.initialize(mutableAggregateBuffer) } - override def update(buffer: MutableRow, input: InternalRow): Unit = { + override def update(buffer: InternalRow, input: InternalRow): Unit = { mutableAggregateBuffer.underlyingBuffer = buffer udaf.update( @@ -427,7 +427,7 @@ case class ScalaUDAF( inputToScalaConverters(inputProjection(input)).asInstanceOf[Row]) } - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { mutableAggregateBuffer.underlyingBuffer = buffer1 inputAggregateBuffer.underlyingInputBuffer = buffer2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 7cde04b62619e..6241b79d9affc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -21,15 +21,16 @@ import java.nio.{ByteBuffer, ByteOrder} import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.types._ /** * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is * extracted from the buffer, instead of directly returning it, the value is set into some field of - * a [[MutableRow]]. In this way, boxing cost can be avoided by leveraging the setter methods - * for primitive values provided by [[MutableRow]]. + * a [[InternalRow]]. In this way, boxing cost can be avoided by leveraging the setter methods + * for primitive values provided by [[InternalRow]]. */ private[columnar] trait ColumnAccessor { initialize() @@ -38,7 +39,7 @@ private[columnar] trait ColumnAccessor { def hasNext: Boolean - def extractTo(row: MutableRow, ordinal: Int): Unit + def extractTo(row: InternalRow, ordinal: Int): Unit protected def underlyingBuffer: ByteBuffer } @@ -52,11 +53,11 @@ private[columnar] abstract class BasicColumnAccessor[JvmType]( override def hasNext: Boolean = buffer.hasRemaining - override def extractTo(row: MutableRow, ordinal: Int): Unit = { + override def extractTo(row: InternalRow, ordinal: Int): Unit = { extractSingle(row, ordinal) } - def extractSingle(row: MutableRow, ordinal: Int): Unit = { + def extractSingle(row: InternalRow, ordinal: Int): Unit = { columnType.extract(buffer, row, ordinal) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index d27d8c362dd9a..703bde25316df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -92,7 +92,7 @@ private[columnar] sealed abstract class ColumnType[JvmType] { * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever * possible. */ - def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { setField(row, ordinal, extract(buffer)) } @@ -125,13 +125,13 @@ private[columnar] sealed abstract class ColumnType[JvmType] { * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing * costs whenever possible. */ - def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit + def setField(row: InternalRow, ordinal: Int, value: JvmType): Unit /** * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid * boxing/unboxing costs whenever possible. */ - def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int): Unit = { setField(to, toOrdinal, getField(from, fromOrdinal)) } @@ -149,7 +149,7 @@ private[columnar] object NULL extends ColumnType[Any] { override def defaultSize: Int = 0 override def append(v: Any, buffer: ByteBuffer): Unit = {} override def extract(buffer: ByteBuffer): Any = null - override def setField(row: MutableRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal) + override def setField(row: InternalRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal) override def getField(row: InternalRow, ordinal: Int): Any = null } @@ -177,18 +177,18 @@ private[columnar] object INT extends NativeColumnType(IntegerType, 4) { ByteBufferHelper.getInt(buffer) } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setInt(ordinal, ByteBufferHelper.getInt(buffer)) } - override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Int): Unit = { row.setInt(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setInt(toOrdinal, from.getInt(fromOrdinal)) } } @@ -206,17 +206,17 @@ private[columnar] object LONG extends NativeColumnType(LongType, 8) { ByteBufferHelper.getLong(buffer) } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) } - override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Long): Unit = { row.setLong(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setLong(toOrdinal, from.getLong(fromOrdinal)) } } @@ -234,17 +234,17 @@ private[columnar] object FLOAT extends NativeColumnType(FloatType, 4) { ByteBufferHelper.getFloat(buffer) } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setFloat(ordinal, ByteBufferHelper.getFloat(buffer)) } - override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Float): Unit = { row.setFloat(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) } } @@ -262,17 +262,17 @@ private[columnar] object DOUBLE extends NativeColumnType(DoubleType, 8) { ByteBufferHelper.getDouble(buffer) } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setDouble(ordinal, ByteBufferHelper.getDouble(buffer)) } - override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Double): Unit = { row.setDouble(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) } } @@ -288,17 +288,17 @@ private[columnar] object BOOLEAN extends NativeColumnType(BooleanType, 1) { override def extract(buffer: ByteBuffer): Boolean = buffer.get() == 1 - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setBoolean(ordinal, buffer.get() == 1) } - override def setField(row: MutableRow, ordinal: Int, value: Boolean): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Boolean): Unit = { row.setBoolean(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) } } @@ -316,17 +316,17 @@ private[columnar] object BYTE extends NativeColumnType(ByteType, 1) { buffer.get() } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setByte(ordinal, buffer.get()) } - override def setField(row: MutableRow, ordinal: Int, value: Byte): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Byte): Unit = { row.setByte(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setByte(toOrdinal, from.getByte(fromOrdinal)) } } @@ -344,17 +344,17 @@ private[columnar] object SHORT extends NativeColumnType(ShortType, 2) { buffer.getShort() } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setShort(ordinal, buffer.getShort()) } - override def setField(row: MutableRow, ordinal: Int, value: Short): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Short): Unit = { row.setShort(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setShort(toOrdinal, from.getShort(fromOrdinal)) } } @@ -366,7 +366,7 @@ private[columnar] object SHORT extends NativeColumnType(ShortType, 2) { private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { // copy the bytes from ByteBuffer to UnsafeRow - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { if (row.isInstanceOf[MutableUnsafeRow]) { val numBytes = buffer.getInt val cursor = buffer.position() @@ -407,7 +407,7 @@ private[columnar] object STRING UTF8String.fromBytes(buffer.array(), buffer.arrayOffset() + cursor, length) } - override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: UTF8String): Unit = { if (row.isInstanceOf[MutableUnsafeRow]) { row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value) } else { @@ -419,7 +419,7 @@ private[columnar] object STRING row.getUTF8String(ordinal) } - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { setField(to, toOrdinal, getField(from, fromOrdinal)) } @@ -433,7 +433,7 @@ private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) Decimal(ByteBufferHelper.getLong(buffer), precision, scale) } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { if (row.isInstanceOf[MutableUnsafeRow]) { // copy it as Long row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) @@ -459,11 +459,11 @@ private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) row.getDecimal(ordinal, precision, scale) } - override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Decimal): Unit = { row.setDecimal(ordinal, value, precision) } - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { setField(to, toOrdinal, getField(from, fromOrdinal)) } } @@ -497,7 +497,7 @@ private[columnar] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def dataType: DataType = BinaryType - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Array[Byte]): Unit = { row.update(ordinal, value) } @@ -522,7 +522,7 @@ private[columnar] case class LARGE_DECIMAL(precision: Int, scale: Int) row.getDecimal(ordinal, precision, scale) } - override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Decimal): Unit = { row.setDecimal(ordinal, value, precision) } @@ -553,7 +553,7 @@ private[columnar] case class STRUCT(dataType: StructType) override def defaultSize: Int = 20 - override def setField(row: MutableRow, ordinal: Int, value: UnsafeRow): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: UnsafeRow): Unit = { row.update(ordinal, value) } @@ -591,7 +591,7 @@ private[columnar] case class ARRAY(dataType: ArrayType) override def defaultSize: Int = 28 - override def setField(row: MutableRow, ordinal: Int, value: UnsafeArrayData): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: UnsafeArrayData): Unit = { row.update(ordinal, value) } @@ -630,7 +630,7 @@ private[columnar] case class MAP(dataType: MapType) override def defaultSize: Int = 68 - override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: UnsafeMapData): Unit = { row.update(ordinal, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 96bd338f092e5..14024d6c10558 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -36,8 +36,7 @@ abstract class ColumnarIterator extends Iterator[InternalRow] { * * WARNING: These setter MUST be called in increasing order of ordinals. */ -class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) { - +class MutableUnsafeRow(val writer: UnsafeRowWriter) extends BaseGenericInternalRow { override def isNullAt(i: Int): Boolean = writer.isNullAt(i) override def setNullAt(i: Int): Unit = writer.setNullAt(i) @@ -55,6 +54,9 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(nu override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException // all other methods inherited from GenericMutableRow are not need + override protected def genericGet(ordinal: Int): Any = throw new UnsupportedOperationException + override def numFields: Int = throw new UnsupportedOperationException + override def copy(): InternalRow = throw new UnsupportedOperationException } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 2465633162c4e..2f09757aa341c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.catalyst.expressions.MutableRow +import org.apache.spark.sql.catalyst.InternalRow private[columnar] trait NullableColumnAccessor extends ColumnAccessor { private var nullsBuffer: ByteBuffer = _ @@ -39,7 +39,7 @@ private[columnar] trait NullableColumnAccessor extends ColumnAccessor { super.initialize() } - abstract override def extractTo(row: MutableRow, ordinal: Int): Unit = { + abstract override def extractTo(row: InternalRow, ordinal: Int): Unit = { if (pos == nextNullIndex) { seenNulls += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index 6579b5068e65a..e1d13ad0e94e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.columnar.compression -import org.apache.spark.sql.catalyst.expressions.MutableRow +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} import org.apache.spark.sql.types.AtomicType @@ -33,7 +33,7 @@ private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends Colu abstract override def hasNext: Boolean = super.hasNext || decoder.hasNext - override def extractSingle(row: MutableRow, ordinal: Int): Unit = { + override def extractSingle(row: InternalRow, ordinal: Int): Unit = { decoder.next(row, ordinal) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index b90d00b15b180..6e4f1c5b80684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} import org.apache.spark.sql.types.AtomicType @@ -39,7 +38,7 @@ private[columnar] trait Encoder[T <: AtomicType] { } private[columnar] trait Decoder[T <: AtomicType] { - def next(row: MutableRow, ordinal: Int): Unit + def next(row: InternalRow, ordinal: Int): Unit def hasNext: Boolean } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 941f03b745a07..ee99c90a751d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types._ @@ -56,7 +56,7 @@ private[columnar] case object PassThrough extends CompressionScheme { class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { columnType.extract(buffer, row, ordinal) } @@ -86,7 +86,7 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme { private var _compressedSize = 0 // Using `MutableRow` to store the last value to avoid boxing/unboxing cost. - private val lastValue = new SpecificMutableRow(Seq(columnType.dataType)) + private val lastValue = new SpecificInternalRow(Seq(columnType.dataType)) private var lastRun = 0 override def uncompressedSize: Int = _uncompressedSize @@ -117,9 +117,9 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme { to.putInt(RunLengthEncoding.typeId) if (from.hasRemaining) { - val currentValue = new SpecificMutableRow(Seq(columnType.dataType)) + val currentValue = new SpecificInternalRow(Seq(columnType.dataType)) var currentRun = 1 - val value = new SpecificMutableRow(Seq(columnType.dataType)) + val value = new SpecificInternalRow(Seq(columnType.dataType)) columnType.extract(from, currentValue, 0) @@ -156,7 +156,7 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme { private var valueCount = 0 private var currentValue: T#InternalType = _ - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { if (valueCount == run) { currentValue = columnType.extract(buffer) run = ByteBufferHelper.getInt(buffer) @@ -273,7 +273,7 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any]) } - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { columnType.setField(row, ordinal, dictionary(buffer.getShort()).asInstanceOf[T#InternalType]) } @@ -356,7 +356,7 @@ private[columnar] case object BooleanBitSet extends CompressionScheme { private var visited: Int = 0 - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { val bit = visited % BITS_PER_LONG visited += 1 @@ -443,7 +443,7 @@ private[columnar] case object IntDelta extends CompressionScheme { override def hasNext: Boolean = buffer.hasRemaining - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { val delta = buffer.get() prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getInt(buffer) row.setInt(ordinal, prev) @@ -523,7 +523,7 @@ private[columnar] case object LongDelta extends CompressionScheme { override def hasNext: Boolean = buffer.hasRemaining - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { val delta = buffer.get() prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getLong(buffer) row.setLong(ordinal, prev) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 01ac89868d100..45fa293e58951 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -183,17 +183,20 @@ case class DropTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view - // issue an exception. - catalog.getTableMetadataOption(tableName).map(_.tableType match { - case CatalogTableType.VIEW if !isView => - throw new AnalysisException( - "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") - case o if o != CatalogTableType.VIEW && isView => - throw new AnalysisException( - s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") - case _ => - }) + + if (!catalog.isTemporaryTable(tableName) && catalog.tableExists(tableName)) { + // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view + // issue an exception. + catalog.getTableMetadata(tableName).tableType match { + case CatalogTableType.VIEW if !isView => + throw new AnalysisException( + "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") + case o if o != CatalogTableType.VIEW && isView => + throw new AnalysisException( + s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") + case _ => + } + } try { sparkSession.sharedState.cacheManager.uncacheQuery( sparkSession.table(tableName.quotedString)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 08de6cd4242c5..424ef58d76c5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -579,9 +579,10 @@ case class ShowTablesCommand( databaseName: Option[String], tableIdentifierPattern: Option[String]) extends RunnableCommand { - // The result of SHOW TABLES has two columns, tableName and isTemporary. + // The result of SHOW TABLES has three columns: database, tableName and isTemporary. override val output: Seq[Attribute] = { - AttributeReference("tableName", StringType, nullable = false)() :: + AttributeReference("database", StringType, nullable = false)() :: + AttributeReference("tableName", StringType, nullable = false)() :: AttributeReference("isTemporary", BooleanType, nullable = false)() :: Nil } @@ -592,9 +593,9 @@ case class ShowTablesCommand( val db = databaseName.getOrElse(catalog.getCurrentDatabase) val tables = tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) - tables.map { t => - val isTemp = t.database.isEmpty - Row(t.table, isTemp) + tables.map { tableIdent => + val isTemp = catalog.isTemporaryTable(tableIdent) + Row(tableIdent.database.getOrElse(""), tableIdent.table, isTemp) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 15340ee921f68..bbcd9c4ef564c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -19,13 +19,46 @@ package org.apache.spark.sql.execution.command import scala.util.control.NonFatal -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.{SQLBuilder, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation} +import org.apache.spark.sql.types.{MetadataBuilder, StructType} + + +/** + * ViewType is used to specify the expected view type when we want to create or replace a view in + * [[CreateViewCommand]]. + */ +sealed trait ViewType + +/** + * LocalTempView means session-scoped local temporary views. Its lifetime is the lifetime of the + * session that created it, i.e. it will be automatically dropped when the session terminates. It's + * not tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. + */ +object LocalTempView extends ViewType + +/** + * GlobalTempView means cross-session global temporary views. Its lifetime is the lifetime of the + * Spark application, i.e. it will be automatically dropped when the application terminates. It's + * tied to a system preserved database `_global_temp`, and we must use the qualified name to refer a + * global temp view, e.g. SELECT * FROM _global_temp.view1. + */ +object GlobalTempView extends ViewType + +/** + * PersistedView means cross-session persisted views. Persisted views stay until they are + * explicitly dropped by user command. It's always tied to a database, default to the current + * database if not specified. + * + * Note that, Existing persisted view with the same name are not visible to the current session + * while the local temporary view exists, unless the view name is qualified by database. + */ +object PersistedView extends ViewType /** @@ -46,10 +79,7 @@ import org.apache.spark.sql.types.StructType * already exists, throws analysis exception. * @param replace if true, and if the view already exists, updates it; if false, and if the view * already exists, throws analysis exception. - * @param isTemporary if true, the view is created as a temporary view. Temporary views are dropped - * at the end of current Spark session. Existing permanent relations with the same - * name are not visible to the current session while the temporary view exists, - * unless they are specified with full qualified table name with database prefix. + * @param viewType the expected view type to be created with this command. */ case class CreateViewCommand( name: TableIdentifier, @@ -60,20 +90,21 @@ case class CreateViewCommand( child: LogicalPlan, allowExisting: Boolean, replace: Boolean, - isTemporary: Boolean) + viewType: ViewType) extends RunnableCommand { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) - if (!isTemporary) { - require(originalText.isDefined, - "The table to created with CREATE VIEW must have 'originalText'.") + if (viewType == PersistedView) { + require(originalText.isDefined, "'originalText' must be provided to create permanent view") } if (allowExisting && replace) { throw new AnalysisException("CREATE VIEW with both IF NOT EXISTS and REPLACE is not allowed.") } + private def isTemporary = viewType == LocalTempView || viewType == GlobalTempView + // Disallows 'CREATE TEMPORARY VIEW IF NOT EXISTS' to be consistent with 'CREATE TEMPORARY TABLE' if (allowExisting && isTemporary) { throw new AnalysisException( @@ -99,72 +130,53 @@ case class CreateViewCommand( s"(num: `${analyzedPlan.output.length}`) does not match the number of column names " + s"specified by CREATE VIEW (num: `${userSpecifiedColumns.length}`).") } - val sessionState = sparkSession.sessionState - - if (isTemporary) { - createTemporaryView(sparkSession, analyzedPlan) - } else { - // Adds default database for permanent table if it doesn't exist, so that tableExists() - // only check permanent tables. - val database = name.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val qualifiedName = name.copy(database = Option(database)) - - if (sessionState.catalog.tableExists(qualifiedName)) { - val tableMetadata = sessionState.catalog.getTableMetadata(qualifiedName) - if (allowExisting) { - // Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view - // already exists. - } else if (tableMetadata.tableType != CatalogTableType.VIEW) { - throw new AnalysisException(s"$qualifiedName is not a view") - } else if (replace) { - // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` - sessionState.catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) - } else { - // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already - // exists. - throw new AnalysisException( - s"View $qualifiedName already exists. If you want to update the view definition, " + - "please use ALTER VIEW AS or CREATE OR REPLACE VIEW AS") - } - } else { - // Create the view if it doesn't exist. - sessionState.catalog.createTable( - prepareTable(sparkSession, analyzedPlan), ignoreIfExists = false) - } - } - Seq.empty[Row] - } - - private def createTemporaryView(sparkSession: SparkSession, analyzedPlan: LogicalPlan): Unit = { - val catalog = sparkSession.sessionState.catalog - // Projects column names to alias names - val logicalPlan = if (userSpecifiedColumns.isEmpty) { + val aliasedPlan = if (userSpecifiedColumns.isEmpty) { analyzedPlan } else { val projectList = analyzedPlan.output.zip(userSpecifiedColumns).map { - case (attr, (colName, _)) => Alias(attr, colName)() + case (attr, (colName, None)) => Alias(attr, colName)() + case (attr, (colName, Some(colComment))) => + val meta = new MetadataBuilder().putString("comment", colComment).build() + Alias(attr, colName)(explicitMetadata = Some(meta)) } sparkSession.sessionState.executePlan(Project(projectList, analyzedPlan)).analyzed } - catalog.createTempView(name.table, logicalPlan, replace) + val catalog = sparkSession.sessionState.catalog + if (viewType == LocalTempView) { + catalog.createTempView(name.table, aliasedPlan, overrideIfExists = replace) + } else if (viewType == GlobalTempView) { + catalog.createGlobalTempView(name.table, aliasedPlan, overrideIfExists = replace) + } else if (catalog.tableExists(name)) { + val tableMetadata = catalog.getTableMetadata(name) + if (allowExisting) { + // Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view + // already exists. + } else if (tableMetadata.tableType != CatalogTableType.VIEW) { + throw new AnalysisException(s"$name is not a view") + } else if (replace) { + // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` + catalog.alterTable(prepareTable(sparkSession, aliasedPlan)) + } else { + // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already + // exists. + throw new AnalysisException( + s"View $name already exists. If you want to update the view definition, " + + "please use ALTER VIEW AS or CREATE OR REPLACE VIEW AS") + } + } else { + // Create the view if it doesn't exist. + catalog.createTable(prepareTable(sparkSession, aliasedPlan), ignoreIfExists = false) + } + Seq.empty[Row] } /** * Returns a [[CatalogTable]] that can be used to save in the catalog. This comment canonicalize * SQL based on the analyzed plan, and also creates the proper schema for the view. */ - private def prepareTable(sparkSession: SparkSession, analyzedPlan: LogicalPlan): CatalogTable = { - val aliasedPlan = if (userSpecifiedColumns.isEmpty) { - analyzedPlan - } else { - val projectList = analyzedPlan.output.zip(userSpecifiedColumns).map { - case (attr, (colName, _)) => Alias(attr, colName)() - } - sparkSession.sessionState.executePlan(Project(projectList, analyzedPlan)).analyzed - } - + private def prepareTable(sparkSession: SparkSession, aliasedPlan: LogicalPlan): CatalogTable = { val viewSQL: String = new SQLBuilder(aliasedPlan).toSQL // Validate the view SQL - make sure we can parse it and analyze it. @@ -176,19 +188,11 @@ case class CreateViewCommand( throw new RuntimeException(s"Failed to analyze the canonicalized SQL: $viewSQL", e) } - val viewSchema = if (userSpecifiedColumns.isEmpty) { - aliasedPlan.schema - } else { - StructType(aliasedPlan.schema.zip(userSpecifiedColumns).map { - case (field, (_, comment)) => comment.map(field.withComment).getOrElse(field) - }) - } - CatalogTable( identifier = name, tableType = CatalogTableType.VIEW, storage = CatalogStorageFormat.empty, - schema = viewSchema, + schema = aliasedPlan.schema, properties = properties, viewOriginalText = originalText, viewText = Some(viewSQL), @@ -222,8 +226,8 @@ case class AlterViewAsCommand( qe.assertAnalyzed() val analyzedPlan = qe.analyzed - if (session.sessionState.catalog.isTemporaryTable(name)) { - session.sessionState.catalog.createTempView(name.table, analyzedPlan, overrideIfExists = true) + if (session.sessionState.catalog.alterTempViewDefinition(name, analyzedPlan)) { + // a local/global temp view has been altered, we are done. } else { alterPermanentView(session, analyzedPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 693b4c4d0e5e9..6f9ed50a02b09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -273,7 +273,7 @@ object DataSourceStrategy extends Strategy with Logging { // Get the bucket ID based on the bucketing values. // Restriction: Bucket pruning works iff the bucketing column has one and only one column. def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { - val mutableRow = new SpecificMutableRow(Seq(bucketColumn.dataType)) + val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null) val bucketIdGeneration = UnsafeProjection.create( HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index c66da3a83198d..89944570df662 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.io.IOException + import scala.collection.mutable import org.apache.spark.{Partition => RDDPartition, TaskContext} @@ -25,6 +27,7 @@ import org.apache.spark.rdd.{InputFileNameHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.util.NextIterator /** * A part (i.e. "block") of a single file that should be read, along with partition column values @@ -62,6 +65,8 @@ class FileScanRDD( @transient val filePartitions: Seq[FilePartition]) extends RDD[InternalRow](sparkSession.sparkContext, Nil) { + private val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { val iterator = new Iterator[Object] with AutoCloseable { private val inputMetrics = context.taskMetrics().inputMetrics @@ -119,7 +124,30 @@ class FileScanRDD( InputFileNameHolder.setInputFileName(currentFile.filePath) try { - currentIterator = readFunction(currentFile) + if (ignoreCorruptFiles) { + currentIterator = new NextIterator[Object] { + private val internalIter = readFunction(currentFile) + + override def getNext(): AnyRef = { + try { + if (internalIter.hasNext) { + internalIter.next() + } else { + finished = true + null + } + } catch { + case e: IOException => + finished = true + null + } + } + + override def close(): Unit = {} + } + } else { + currentIterator = readFunction(currentFile) + } } catch { case e: java.io.FileNotFoundException => throw new java.io.FileNotFoundException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 4e662a52a7bb7..a3691158ee758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -59,14 +59,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val rdd = baseRdd(sparkSession, csvOptions, paths) val firstLine = findFirstLine(csvOptions, rdd) val firstRow = new CsvReader(csvOptions).parseLine(firstLine) - - val header = if (csvOptions.headerFlag) { - firstRow.zipWithIndex.map { case (value, index) => - if (value == null || value.isEmpty || value == csvOptions.nullValue) s"_c$index" else value - } - } else { - firstRow.zipWithIndex.map { case (value, index) => s"_c$index" } - } + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, csvOptions, caseSensitive) val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths) val schema = if (csvOptions.inferSchemaFlag) { @@ -74,13 +68,51 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } else { // By default fields are assumed to be StringType val schemaFields = header.map { fieldName => - StructField(fieldName.toString, StringType, nullable = true) + StructField(fieldName, StringType, nullable = true) } StructType(schemaFields) } Some(schema) } + /** + * Generates a header from the given row which is null-safe and duplicate-safe. + */ + private def makeSafeHeader( + row: Array[String], + options: CSVOptions, + caseSensitive: Boolean): Array[String] = { + if (options.headerFlag) { + val duplicates = { + val headerNames = row.filter(_ != null) + .map(name => if (caseSensitive) name else name.toLowerCase) + headerNames.diff(headerNames.distinct).distinct + } + + row.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == options.nullValue) { + // When there are empty strings or the values set in `nullValue`, put the + // index as the suffix. + s"_c$index" + } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // When there are case-insensitive duplicates, put the index as the suffix. + s"$value$index" + } else if (duplicates.contains(value)) { + // When there are duplicates, put the index as the suffix. + s"$value$index" + } else { + value + } + } + } else { + row.zipWithIndex.map { case (_, index) => + // Uses default column names, "_c#" where # is its position of fields + // when header option is disabled. + s"_c$index" + } + } + } + override def prepareWrite( sparkSession: SparkSession, job: Job, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 33b170bc31f62..55cb26d6513af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile, WriterContainer} import org.apache.spark.sql.types._ @@ -88,7 +88,7 @@ object CSVRelation extends Logging { case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index } val requiredSize = requiredFields.length - val row = new GenericMutableRow(requiredSize) + val row = new GenericInternalRow(requiredSize) (tokens: Array[String], numMalformedRows) => { if (params.dropMalformed && schemaFields.length != tokens.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index fa95af2648cf9..59fb48ffea598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -40,16 +40,20 @@ case class CreateTable( override def innerChildren: Seq[QueryPlan[_]] = query.toSeq } +/** + * Create or replace a local/global temporary view with given data source. + */ case class CreateTempViewUsing( tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], replace: Boolean, + global: Boolean, provider: String, options: Map[String, String]) extends RunnableCommand { if (tableIdent.database.isDefined) { throw new AnalysisException( - s"Temporary table '$tableIdent' should not have specified a database") + s"Temporary view '$tableIdent' should not have specified a database") } def run(sparkSession: SparkSession): Seq[Row] = { @@ -58,10 +62,16 @@ case class CreateTempViewUsing( userSpecifiedSchema = userSpecifiedSchema, className = provider, options = options) - sparkSession.sessionState.catalog.createTempView( - tableIdent.table, - Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan, - replace) + + val catalog = sparkSession.sessionState.catalog + val viewDefinition = Dataset.ofRows( + sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan + + if (global) { + catalog.createGlobalTempView(tableIdent.table, viewDefinition, replace) + } else { + catalog.createTempView(tableIdent.table, viewDefinition, replace) + } Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index bcf65e53afa73..fcd7409159def 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.execution.datasources.jdbc +import java.sql.{Connection, DriverManager} +import java.util.Properties + +import scala.collection.mutable.ArrayBuffer + /** * Options for the JDBC data source. */ @@ -24,40 +29,115 @@ class JDBCOptions( @transient private val parameters: Map[String, String]) extends Serializable { + import JDBCOptions._ + + def this(url: String, table: String, parameters: Map[String, String]) = { + this(parameters ++ Map( + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> table)) + } + + val asConnectionProperties: Properties = { + val properties = new Properties() + // We should avoid to pass the options into properties. See SPARK-17776. + parameters.filterKeys(!jdbcOptionNames.contains(_)) + .foreach { case (k, v) => properties.setProperty(k, v) } + properties + } + // ------------------------------------------------------------ // Required parameters // ------------------------------------------------------------ - require(parameters.isDefinedAt("url"), "Option 'url' is required.") - require(parameters.isDefinedAt("dbtable"), "Option 'dbtable' is required.") + require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") + require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.") // a JDBC URL - val url = parameters("url") + val url = parameters(JDBC_URL) // name of table - val table = parameters("dbtable") + val table = parameters(JDBC_TABLE_NAME) + + // ------------------------------------------------------------ + // Optional parameters + // ------------------------------------------------------------ + val driverClass = { + val userSpecifiedDriverClass = parameters.get(JDBC_DRIVER_CLASS) + userSpecifiedDriverClass.foreach(DriverRegistry.register) + + // Performing this part of the logic on the driver guards against the corner-case where the + // driver returned for a URL is different on the driver and executors due to classpath + // differences. + userSpecifiedDriverClass.getOrElse { + DriverManager.getDriver(url).getClass.getCanonicalName + } + } // ------------------------------------------------------------ - // Optional parameter list + // Optional parameters only for reading // ------------------------------------------------------------ // the column used to partition - val partitionColumn = parameters.getOrElse("partitionColumn", null) + val partitionColumn = parameters.getOrElse(JDBC_PARTITION_COLUMN, null) // the lower bound of partition column - val lowerBound = parameters.getOrElse("lowerBound", null) + val lowerBound = parameters.getOrElse(JDBC_LOWER_BOUND, null) // the upper bound of the partition column - val upperBound = parameters.getOrElse("upperBound", null) + val upperBound = parameters.getOrElse(JDBC_UPPER_BOUND, null) // the number of partitions - val numPartitions = parameters.getOrElse("numPartitions", null) - + val numPartitions = parameters.getOrElse(JDBC_NUM_PARTITIONS, null) require(partitionColumn == null || (lowerBound != null && upperBound != null && numPartitions != null), - "If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + - " and 'numPartitions' are required.") + s"If '$JDBC_PARTITION_COLUMN' is specified then '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND'," + + s" and '$JDBC_NUM_PARTITIONS' are required.") + val fetchSize = { + val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt + require(size >= 0, + s"Invalid value `${size.toString}` for parameter " + + s"`$JDBC_BATCH_FETCH_SIZE`. The minimum value is 0. When the value is 0, " + + "the JDBC driver ignores the value and does the estimates.") + size + } // ------------------------------------------------------------ - // The options for DataFrameWriter + // Optional parameters only for writing // ------------------------------------------------------------ // if to truncate the table from the JDBC database - val isTruncate = parameters.getOrElse("truncate", "false").toBoolean + val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options - val createTableOptions = parameters.getOrElse("createTableOptions", "") + val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "") + val batchSize = { + val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt + require(size >= 1, + s"Invalid value `${size.toString}` for parameter " + + s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.") + size + } + val isolationLevel = + parameters.getOrElse(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match { + case "NONE" => Connection.TRANSACTION_NONE + case "READ_UNCOMMITTED" => Connection.TRANSACTION_READ_UNCOMMITTED + case "READ_COMMITTED" => Connection.TRANSACTION_READ_COMMITTED + case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ + case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE + } +} + +object JDBCOptions { + private val jdbcOptionNames = ArrayBuffer.empty[String] + + private def newOption(name: String): String = { + jdbcOptionNames += name + name + } + + val JDBC_URL = newOption("url") + val JDBC_TABLE_NAME = newOption("dbtable") + val JDBC_DRIVER_CLASS = newOption("driver") + val JDBC_PARTITION_COLUMN = newOption("partitionColumn") + val JDBC_LOWER_BOUND = newOption("lowerBound") + val JDBC_UPPER_BOUND = newOption("upperBound") + val JDBC_NUM_PARTITIONS = newOption("numPartitions") + val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") + val JDBC_TRUNCATE = newOption("truncate") + val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") + val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") + val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index f10615ebe4bcf..c0fabc81e42a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp} -import java.util.Properties import scala.util.control.NonFatal @@ -46,17 +45,18 @@ object JDBCRDD extends Logging { * Takes a (schema, table) specification and returns the table's Catalyst * schema. * - * @param url - The JDBC url to fetch information from. - * @param table - The table name of the desired table. This may also be a - * SQL query wrapped in parentheses. + * @param options - JDBC options that contains url, table and other information. * * @return A StructType giving the table's Catalyst schema. * @throws SQLException if the table specification is garbage. * @throws SQLException if the table contains an unsupported type. */ - def resolveTable(url: String, table: String, properties: Properties): StructType = { + def resolveTable(options: JDBCOptions): StructType = { + val url = options.url + val table = options.table + val properties = options.asConnectionProperties val dialect = JdbcDialects.get(url) - val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)() + val conn: Connection = JdbcUtils.createConnectionFactory(options)() try { val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) try { @@ -143,43 +143,38 @@ object JDBCRDD extends Logging { }) } - - /** * Build and return JDBCRDD from the given information. * * @param sc - Your SparkContext. * @param schema - The Catalyst schema of the underlying database table. - * @param url - The JDBC url to connect to. - * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use. * @param requiredColumns - The names of the columns to SELECT. * @param filters - The filters to include in all WHERE clauses. * @param parts - An array of JDBCPartitions specifying partition ids and * per-partition WHERE clauses. + * @param options - JDBC options that contains url, table and other information. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ def scanTable( sc: SparkContext, schema: StructType, - url: String, - properties: Properties, - fqTable: String, requiredColumns: Array[String], filters: Array[Filter], - parts: Array[Partition]): RDD[InternalRow] = { + parts: Array[Partition], + options: JDBCOptions): RDD[InternalRow] = { + val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD( sc, - JdbcUtils.createConnectionFactory(url, properties), + JdbcUtils.createConnectionFactory(options), pruneSchema(schema, requiredColumns), - fqTable, quotedColumns, filters, parts, url, - properties) + options) } } @@ -192,12 +187,11 @@ private[jdbc] class JDBCRDD( sc: SparkContext, getConnection: () => Connection, schema: StructType, - fqTable: String, columns: Array[String], filters: Array[Filter], partitions: Array[Partition], url: String, - properties: Properties) + options: JDBCOptions) extends RDD[InternalRow](sc, Nil) { /** @@ -211,7 +205,7 @@ private[jdbc] class JDBCRDD( private val columnList: String = { val sb = new StringBuilder() columns.foreach(x => sb.append(",").append(x)) - if (sb.length == 0) "1" else sb.substring(1) + if (sb.isEmpty) "1" else sb.substring(1) } /** @@ -286,7 +280,7 @@ private[jdbc] class JDBCRDD( conn = getConnection() val dialect = JdbcDialects.get(url) import scala.collection.JavaConverters._ - dialect.beforeFetch(conn, properties.asScala.toMap) + dialect.beforeFetch(conn, options.asConnectionProperties.asScala.toMap) // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to @@ -294,15 +288,10 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) - val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" + val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt - require(fetchSize >= 0, - s"Invalid value `${fetchSize.toString}` for parameter " + - s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " + - "the JDBC driver ignores the value and does the estimates.") - stmt.setFetchSize(fetchSize) + stmt.setFetchSize(options.fetchSize) rs = stmt.executeQuery() val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 11613dd912eca..672c21c6ac734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.util.Properties - import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging @@ -102,10 +100,7 @@ private[sql] object JDBCRelation extends Logging { } private[sql] case class JDBCRelation( - url: String, - table: String, - parts: Array[Partition], - properties: Properties = new Properties())(@transient val sparkSession: SparkSession) + parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedFilteredScan with InsertableRelation { @@ -114,7 +109,7 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) + override val schema: StructType = JDBCRDD.resolveTable(jdbcOptions) // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { @@ -126,15 +121,16 @@ private[sql] case class JDBCRelation( JDBCRDD.scanTable( sparkSession.sparkContext, schema, - url, - properties, - table, requiredColumns, filters, - parts).asInstanceOf[RDD[Row]] + parts, + jdbcOptions).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { + val url = jdbcOptions.url + val table = jdbcOptions.table + val properties = jdbcOptions.asConnectionProperties data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) .jdbc(url, table, properties) @@ -142,6 +138,6 @@ private[sql] case class JDBCRelation( override def toString: String = { // credentials should not be included in the plan output, table information is sufficient. - s"JDBCRelation(${table})" + s"JDBCRelation(${jdbcOptions.table})" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index ae04af2479c8d..4420b3b18a907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.util.Properties - -import scala.collection.JavaConverters.mapAsJavaMapConverter - import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._ import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} class JdbcRelationProvider extends CreatableRelationProvider @@ -45,72 +42,53 @@ class JdbcRelationProvider extends CreatableRelationProvider partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt) } val parts = JDBCRelation.columnPartition(partitionInfo) - val properties = new Properties() // Additional properties that we will pass to getConnection - parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) + JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession) } - /* - * The following structure applies to this code: - * | tableExists | !tableExists - *------------------------------------------------------------------------------------ - * Ignore | BaseRelation | CreateTable, saveTable, BaseRelation - * ErrorIfExists | ERROR | CreateTable, saveTable, BaseRelation - * Overwrite* | (DropTable, CreateTable,) | CreateTable, saveTable, BaseRelation - * | saveTable, BaseRelation | - * Append | saveTable, BaseRelation | CreateTable, saveTable, BaseRelation - * - * *Overwrite & tableExists with truncate, will not drop & create, but instead truncate - */ override def createRelation( sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], - data: DataFrame): BaseRelation = { + df: DataFrame): BaseRelation = { val jdbcOptions = new JDBCOptions(parameters) val url = jdbcOptions.url val table = jdbcOptions.table + val createTableOptions = jdbcOptions.createTableOptions + val isTruncate = jdbcOptions.isTruncate - val props = new Properties() - props.putAll(parameters.asJava) - val conn = JdbcUtils.createConnectionFactory(url, props)() - + val conn = JdbcUtils.createConnectionFactory(jdbcOptions)() try { val tableExists = JdbcUtils.tableExists(conn, url, table) + if (tableExists) { + mode match { + case SaveMode.Overwrite => + if (isTruncate && isCascadingTruncateTable(url) == Some(false)) { + // In this case, we should truncate table and then load. + truncateTable(conn, table) + saveTable(df, url, table, jdbcOptions) + } else { + // Otherwise, do not truncate the table, instead drop and recreate it + dropTable(conn, table) + createTable(df.schema, url, table, createTableOptions, conn) + saveTable(df, url, table, jdbcOptions) + } - val (doCreate, doSave) = (mode, tableExists) match { - case (SaveMode.Ignore, true) => (false, false) - case (SaveMode.ErrorIfExists, true) => throw new AnalysisException( - s"Table or view '$table' already exists, and SaveMode is set to ErrorIfExists.") - case (SaveMode.Overwrite, true) => - if (jdbcOptions.isTruncate && JdbcUtils.isCascadingTruncateTable(url) == Some(false)) { - JdbcUtils.truncateTable(conn, table) - (false, true) - } else { - JdbcUtils.dropTable(conn, table) - (true, true) - } - case (SaveMode.Append, true) => (false, true) - case (_, true) => throw new IllegalArgumentException(s"Unexpected SaveMode, '$mode'," + - " for handling existing tables.") - case (_, false) => (true, true) - } + case SaveMode.Append => + saveTable(df, url, table, jdbcOptions) + + case SaveMode.ErrorIfExists => + throw new AnalysisException( + s"Table or view '$table' already exists. SaveMode: ErrorIfExists.") - if (doCreate) { - val schema = JdbcUtils.schemaString(data, url) - // To allow certain options to append when create a new table, which can be - // table_options or partition_options. - // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" - val createtblOptions = jdbcOptions.createTableOptions - val sql = s"CREATE TABLE $table ($schema) $createtblOptions" - val statement = conn.createStatement - try { - statement.executeUpdate(sql) - } finally { - statement.close() + case SaveMode.Ignore => + // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected + // to not save the contents of the DataFrame and to not change the existing data. + // Therefore, it is okay to do nothing here and then just return the relation below. } + } else { + createTable(df.schema, url, table, createTableOptions, conn) + saveTable(df, url, table, jdbcOptions) } - if (doSave) JdbcUtils.saveTable(data, url, table, props) } finally { conn.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 3db1d1f109fb7..e32db73bd6c6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} -import java.util.Properties import scala.collection.JavaConverters._ import scala.util.Try @@ -30,7 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ @@ -41,27 +40,13 @@ import org.apache.spark.util.NextIterator * Util functions for JDBC tables. */ object JdbcUtils extends Logging { - - // the property names are case sensitive - val JDBC_BATCH_FETCH_SIZE = "fetchsize" - val JDBC_BATCH_INSERT_SIZE = "batchsize" - val JDBC_TXN_ISOLATION_LEVEL = "isolationLevel" - /** * Returns a factory for creating connections to the given JDBC URL. * - * @param url the JDBC url to connect to. - * @param properties JDBC connection properties. + * @param options - JDBC options that contains url, table and other information. */ - def createConnectionFactory(url: String, properties: Properties): () => Connection = { - val userSpecifiedDriverClass = Option(properties.getProperty("driver")) - userSpecifiedDriverClass.foreach(DriverRegistry.register) - // Performing this part of the logic on the driver guards against the corner-case where the - // driver returned for a URL is different on the driver and executors due to classpath - // differences. - val driverClass: String = userSpecifiedDriverClass.getOrElse { - DriverManager.getDriver(url).getClass.getCanonicalName - } + def createConnectionFactory(options: JDBCOptions): () => Connection = { + val driverClass: String = options.driverClass () => { DriverRegistry.register(driverClass) val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { @@ -71,7 +56,7 @@ object JdbcUtils extends Logging { throw new IllegalStateException( s"Did not find registered driver with class $driverClass") } - driver.connect(url, properties) + driver.connect(options.url, options.asConnectionProperties) } } @@ -283,7 +268,7 @@ object JdbcUtils extends Logging { new NextIterator[InternalRow] { private[this] val rs = resultSet private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema) - private[this] val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) + private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) override protected def close(): Unit = { try { @@ -314,22 +299,22 @@ object JdbcUtils extends Logging { // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field // for `MutableRow`. The last argument `Int` means the index for the value to be set in // the row and also used for the value in `ResultSet`. - private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit + private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit /** * Creates `JDBCValueGetter`s according to [[StructType]], which can set - * each value from `ResultSet` to each field of [[MutableRow]] correctly. + * each value from `ResultSet` to each field of [[InternalRow]] correctly. */ private def makeGetters(schema: StructType): Array[JDBCValueGetter] = schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match { case BooleanType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1)) case DateType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. val dateVal = rs.getDate(pos + 1) if (dateVal != null) { @@ -347,25 +332,25 @@ object JdbcUtils extends Logging { // retrieve it, you will get wrong result 199.99. // So it is needed to set precision and scale for Decimal based on JDBC metadata. case DecimalType.Fixed(p, s) => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => val decimal = nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s)) row.update(pos, decimal) case DoubleType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setDouble(pos, rs.getDouble(pos + 1)) case FloatType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setFloat(pos, rs.getFloat(pos + 1)) case IntegerType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setInt(pos, rs.getInt(pos + 1)) case LongType if metadata.contains("binarylong") => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => val bytes = rs.getBytes(pos + 1) var ans = 0L var j = 0 @@ -376,20 +361,20 @@ object JdbcUtils extends Logging { row.setLong(pos, ans) case LongType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setLong(pos, rs.getLong(pos + 1)) case ShortType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setShort(pos, rs.getShort(pos + 1)) case StringType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 row.update(pos, UTF8String.fromString(rs.getString(pos + 1))) case TimestampType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => val t = rs.getTimestamp(pos + 1) if (t != null) { row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t)) @@ -398,7 +383,7 @@ object JdbcUtils extends Logging { } case BinaryType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.update(pos, rs.getBytes(pos + 1)) case ArrayType(et, _) => @@ -437,7 +422,7 @@ object JdbcUtils extends Logging { case _ => (array: Object) => array.asInstanceOf[Array[Any]] } - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => val array = nullSafeConvert[Object]( rs.getArray(pos + 1).getArray, array => new GenericArrayData(elementConversion.apply(array))) @@ -550,10 +535,6 @@ object JdbcUtils extends Logging { batchSize: Int, dialect: JdbcDialect, isolationLevel: Int): Iterator[Byte] = { - require(batchSize >= 1, - s"Invalid value `${batchSize.toString}` for parameter " + - s"`${JdbcUtils.JDBC_BATCH_INSERT_SIZE}`. The minimum value is 1.") - val conn = getConnection() var committed = false @@ -657,10 +638,10 @@ object JdbcUtils extends Logging { /** * Compute the schema string for this RDD. */ - def schemaString(df: DataFrame, url: String): String = { + def schemaString(schema: StructType, url: String): String = { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => + schema.fields foreach { field => val name = dialect.quoteIdentifier(field.name) val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" @@ -676,25 +657,41 @@ object JdbcUtils extends Logging { df: DataFrame, url: String, table: String, - properties: Properties) { + options: JDBCOptions) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => getJdbcType(field.dataType, dialect).jdbcNullType } val rddSchema = df.schema - val getConnection: () => Connection = createConnectionFactory(url, properties) - val batchSize = properties.getProperty(JDBC_BATCH_INSERT_SIZE, "1000").toInt - val isolationLevel = - properties.getProperty(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match { - case "NONE" => Connection.TRANSACTION_NONE - case "READ_UNCOMMITTED" => Connection.TRANSACTION_READ_UNCOMMITTED - case "READ_COMMITTED" => Connection.TRANSACTION_READ_COMMITTED - case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ - case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE - } + val getConnection: () => Connection = createConnectionFactory(options) + val batchSize = options.batchSize + val isolationLevel = options.isolationLevel df.foreachPartition(iterator => savePartition( getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel) ) } + + /** + * Creates a table with a given schema. + */ + def createTable( + schema: StructType, + url: String, + table: String, + createTableOptions: String, + conn: Connection): Unit = { + val strSchema = schemaString(schema, url) + // Create the table if the table does not exist. + // To allow certain options to append when create a new table, which can be + // table_options or partition_options. + // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" + val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions" + val statement = conn.createStatement + try { + statement.executeUpdate(sql) + } finally { + statement.close() + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 9ffc2b5dd8a56..33dcf2f3fd167 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -40,7 +40,7 @@ import org.apache.spark.unsafe.types.UTF8String /** * A [[ParentContainerUpdater]] is used by a Parquet converter to set converted values to some * corresponding parent container. For example, a converter for a `StructType` field may set - * converted values to a [[MutableRow]]; or a converter for array elements may append converted + * converted values to a [[InternalRow]]; or a converter for array elements may append converted * values to an [[ArrayBuffer]]. */ private[parquet] trait ParentContainerUpdater { @@ -155,7 +155,7 @@ private[parquet] class ParquetRowConverter( * Updater used together with field converters within a [[ParquetRowConverter]]. It propagates * converted filed values to the `ordinal`-th cell in `currentRow`. */ - private final class RowUpdater(row: MutableRow, ordinal: Int) extends ParentContainerUpdater { + private final class RowUpdater(row: InternalRow, ordinal: Int) extends ParentContainerUpdater { override def set(value: Any): Unit = row(ordinal) = value override def setBoolean(value: Boolean): Unit = row.setBoolean(ordinal, value) override def setByte(value: Byte): Unit = row.setByte(ordinal, value) @@ -166,7 +166,7 @@ private[parquet] class ParquetRowConverter( override def setFloat(value: Float): Unit = row.setFloat(ordinal, value) } - private val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) + private val currentRow = new SpecificInternalRow(catalystType.map(_.dataType)) private val unsafeProjection = UnsafeProjection.create(catalystType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 43cdce7de8c7f..bfe7e3dea45df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -119,7 +119,7 @@ case class BroadcastNestedLoopJoinExec( streamed.execute().mapPartitionsInternal { streamedIter => val buildRows = relation.value val joinedRow = new JoinedRow - val nulls = new GenericMutableRow(broadcast.output.size) + val nulls = new GenericInternalRow(broadcast.output.size) // Returns an iterator to avoid copy the rows. new Iterator[InternalRow] { @@ -205,14 +205,14 @@ case class BroadcastNestedLoopJoinExec( val joinedRow = new JoinedRow if (condition.isDefined) { - val resultRow = new GenericMutableRow(Array[Any](null)) + val resultRow = new GenericInternalRow(Array[Any](null)) streamedIter.map { row => val result = buildRows.exists(r => boundCondition(joinedRow(row, r))) resultRow.setBoolean(0, result) joinedRow(row, resultRow) } } else { - val resultRow = new GenericMutableRow(Array[Any](buildRows.nonEmpty)) + val resultRow = new GenericInternalRow(Array[Any](buildRows.nonEmpty)) streamedIter.map { row => joinedRow(row, resultRow) } @@ -293,7 +293,7 @@ case class BroadcastNestedLoopJoinExec( } val notMatchedBroadcastRows: Seq[InternalRow] = { - val nulls = new GenericMutableRow(streamed.output.size) + val nulls = new GenericInternalRow(streamed.output.size) val buf: CompactBuffer[InternalRow] = new CompactBuffer() val joinedRow = new JoinedRow joinedRow.withLeft(nulls) @@ -311,7 +311,7 @@ case class BroadcastNestedLoopJoinExec( val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => val buildRows = relation.value val joinedRow = new JoinedRow - val nulls = new GenericMutableRow(broadcast.output.size) + val nulls = new GenericInternalRow(broadcast.output.size) streamedIter.flatMap { streamedRow => var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index fb6bfa7b2735c..05c5e2f4cd77b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -63,45 +63,16 @@ trait HashJoin { protected lazy val (buildKeys, streamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") - val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) - val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) + val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) + val rkeys = HashJoin.rewriteKeyExpr(rightKeys) + .map(BindReferences.bindReference(_, right.output)) buildSide match { case BuildLeft => (lkeys, rkeys) case BuildRight => (rkeys, lkeys) } } - /** - * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. - * - * If not, returns the original expressions. - */ - private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { - var keyExpr: Expression = null - var width = 0 - keys.foreach { e => - e.dataType match { - case dt: IntegralType if dt.defaultSize <= 8 - width => - if (width == 0) { - if (e.dataType != LongType) { - keyExpr = Cast(e, LongType) - } else { - keyExpr = e - } - width = dt.defaultSize - } else { - val bits = dt.defaultSize * 8 - keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) - width -= bits - } - // TODO: support BooleanType, DateType and TimestampType - case other => - return keys - } - } - keyExpr :: Nil - } + protected def buildSideKeyGenerator(): Projection = UnsafeProjection.create(buildKeys) @@ -192,7 +163,7 @@ trait HashJoin { streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { val joinKeys = streamSideKeyGenerator() - val result = new GenericMutableRow(Array[Any](null)) + val result = new GenericInternalRow(Array[Any](null)) val joinedRow = new JoinedRow streamIter.map { current => val key = joinKeys(current) @@ -247,3 +218,31 @@ trait HashJoin { } } } + +object HashJoin { + /** + * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. + * + * If not, returns the original expressions. + */ + private[joins] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + assert(keys.nonEmpty) + // TODO: support BooleanType, DateType and TimestampType + if (keys.exists(!_.dataType.isInstanceOf[IntegralType]) + || keys.map(_.dataType.defaultSize).sum > 8) { + return keys + } + + var keyExpr: Expression = if (keys.head.dataType != LongType) { + Cast(keys.head, LongType) + } else { + keys.head + } + keys.tail.foreach { e => + val bits = e.dataType.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + } + keyExpr :: Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 81b3e1d224ab6..ecf7cf289f034 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -275,7 +275,7 @@ case class SortMergeJoinExec( case j: ExistenceJoin => new RowIterator { private[this] var currentLeftRow: InternalRow = _ - private[this] val result: MutableRow = new GenericMutableRow(Array[Any](null)) + private[this] val result: InternalRow = new GenericInternalRow(Array[Any](null)) private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index c7e267152b5cd..2acc5110e8950 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -141,7 +141,7 @@ object ObjectOperator { def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = { val proj = GenerateUnsafeProjection.generate(serializer) val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head - val objRow = new SpecificMutableRow(objType :: Nil) + val objRow = new SpecificInternalRow(objType :: Nil) (o: Any) => { objRow(0) = o proj(objRow) @@ -149,7 +149,7 @@ object ObjectOperator { } def wrapObjectToRow(objType: DataType): Any => InternalRow = { - val outputRow = new SpecificMutableRow(objType :: Nil) + val outputRow = new SpecificInternalRow(objType :: Nil) (o: Any) => { outputRow(0) = o outputRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index d9bf4d3ccf698..dcaf2c76d479d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -17,18 +17,21 @@ package org.apache.spark.sql.execution.python +import java.io.File + import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import net.razorvine.pickle.{Pickler, Unpickler} -import org.apache.spark.TaskContext +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils /** @@ -37,9 +40,25 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} * Python evaluation works by sending the necessary (projected) input data via a socket to an * external Python process, and combine the result from the Python process with the original row. * - * For each row we send to Python, we also put it in a queue. For each output row from Python, + * For each row we send to Python, we also put it in a queue first. For each output row from Python, * we drain the queue to find the original input row. Note that if the Python process is way too - * slow, this could lead to the queue growing unbounded and eventually run out of memory. + * slow, this could lead to the queue growing unbounded and spill into disk when run out of memory. + * + * Here is a diagram to show how this works: + * + * Downstream (for parent) + * / \ + * / socket (output of UDF) + * / \ + * RowQueue Python + * \ / + * \ socket (input of UDF) + * \ / + * upstream (from child) + * + * The rows sent to and received from Python are packed into batches (100 rows) and serialized, + * there should be always some rows buffered in the socket or Python process, so the pulling from + * RowQueue ALWAYS happened after pushing into it. */ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends SparkPlan { @@ -70,7 +89,11 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi // The queue used to buffer input rows so we can drain it to // combine input with output from Python. - val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + TaskContext.get().addTaskCompletionListener({ ctx => + queue.close() + }) val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip @@ -98,7 +121,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi // For each row, add it to the queue. val inputIterator = iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { inputRow => - queue.add(inputRow) + queue.add(inputRow.asInstanceOf[UnsafeRow]) val row = projection(inputRow) if (needConversion) { EvaluatePython.toJava(row, schema) @@ -124,7 +147,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler - val mutableRow = new GenericMutableRow(1) + val mutableRow = new GenericInternalRow(1) val joined = new JoinedRow val resultType = if (udfs.length == 1) { udfs.head.dataType @@ -132,7 +155,6 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) } val resultProj = UnsafeProjection.create(output, output) - outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala @@ -144,7 +166,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi } else { EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] } - resultProj(joined(queue.poll(), row)) + resultProj(joined(queue.remove(), row)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala new file mode 100644 index 0000000000000..422a3f862d96f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -0,0 +1,280 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import java.io._ + +import com.google.common.io.Closeables + +import org.apache.spark.SparkException +import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.memory.MemoryBlock + +/** + * A RowQueue is an FIFO queue for UnsafeRow. + * + * This RowQueue is ONLY designed and used for Python UDF, which has only one writer and only one + * reader, the reader ALWAYS ran behind the writer. See the doc of class [[BatchEvalPythonExec]] + * on how it works. + */ +private[python] trait RowQueue { + + /** + * Add a row to the end of it, returns true iff the row has been added to the queue. + */ + def add(row: UnsafeRow): Boolean + + /** + * Retrieve and remove the first row, returns null if it's empty. + * + * It can only be called after add is called, otherwise it will fail (NPE). + */ + def remove(): UnsafeRow + + /** + * Cleanup all the resources. + */ + def close(): Unit +} + +/** + * A RowQueue that is based on in-memory page. UnsafeRows are appended into it until it's full. + * Another thread could read from it at the same time (behind the writer). + * + * The format of UnsafeRow in page: + * [4 bytes to hold length of record (N)] [N bytes to hold record] [...] + * + * -1 length means end of page. + */ +private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields: Int) + extends RowQueue { + private val base: AnyRef = page.getBaseObject + private val endOfPage: Long = page.getBaseOffset + page.size + // the first location where a new row would be written + private var writeOffset = page.getBaseOffset + // points to the start of the next row to read + private var readOffset = page.getBaseOffset + private val resultRow = new UnsafeRow(numFields) + + def add(row: UnsafeRow): Boolean = synchronized { + val size = row.getSizeInBytes + if (writeOffset + 4 + size > endOfPage) { + // if there is not enough space in this page to hold the new record + if (writeOffset + 4 <= endOfPage) { + // if there's extra space at the end of the page, store a special "end-of-page" length (-1) + Platform.putInt(base, writeOffset, -1) + } + false + } else { + Platform.putInt(base, writeOffset, size) + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, base, writeOffset + 4, size) + writeOffset += 4 + size + true + } + } + + def remove(): UnsafeRow = synchronized { + assert(readOffset <= writeOffset, "reader should not go beyond writer") + if (readOffset + 4 > endOfPage || Platform.getInt(base, readOffset) < 0) { + null + } else { + val size = Platform.getInt(base, readOffset) + resultRow.pointTo(base, readOffset + 4, size) + readOffset += 4 + size + resultRow + } + } +} + +/** + * A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any + * reader has begun reading from the queue. + */ +private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueue { + private var out = new DataOutputStream( + new BufferedOutputStream(new FileOutputStream(file.toString))) + private var unreadBytes = 0L + + private var in: DataInputStream = _ + private val resultRow = new UnsafeRow(fields) + + def add(row: UnsafeRow): Boolean = synchronized { + if (out == null) { + // Another thread is reading, stop writing this one + return false + } + out.writeInt(row.getSizeInBytes) + out.write(row.getBytes) + unreadBytes += 4 + row.getSizeInBytes + true + } + + def remove(): UnsafeRow = synchronized { + if (out != null) { + out.close() + out = null + in = new DataInputStream(new BufferedInputStream(new FileInputStream(file.toString))) + } + + if (unreadBytes > 0) { + val size = in.readInt() + val bytes = new Array[Byte](size) + in.readFully(bytes) + unreadBytes -= 4 + size + resultRow.pointTo(bytes, size) + resultRow + } else { + null + } + } + + def close(): Unit = synchronized { + Closeables.close(out, true) + out = null + Closeables.close(in, true) + in = null + if (file.exists()) { + file.delete() + } + } +} + +/** + * A RowQueue that has a list of RowQueues, which could be in memory or disk. + * + * HybridRowQueue could be safely appended in one thread, and pulled in another thread in the same + * time. + */ +private[python] case class HybridRowQueue( + memManager: TaskMemoryManager, + tempDir: File, + numFields: Int) + extends MemoryConsumer(memManager) with RowQueue { + + // Each buffer should have at least one row + private var queues = new java.util.LinkedList[RowQueue]() + + private var writing: RowQueue = _ + private var reading: RowQueue = _ + + // exposed for testing + private[python] def numQueues(): Int = queues.size() + + def spill(size: Long, trigger: MemoryConsumer): Long = { + if (trigger == this) { + // When it's triggered by itself, it should write upcoming rows into disk instead of copying + // the rows already in the queue. + return 0L + } + var released = 0L + synchronized { + // poll out all the buffers and add them back in the same order to make sure that the rows + // are in correct order. + val newQueues = new java.util.LinkedList[RowQueue]() + while (!queues.isEmpty) { + val queue = queues.remove() + val newQueue = if (!queues.isEmpty && queue.isInstanceOf[InMemoryRowQueue]) { + val diskQueue = createDiskQueue() + var row = queue.remove() + while (row != null) { + diskQueue.add(row) + row = queue.remove() + } + released += queue.asInstanceOf[InMemoryRowQueue].page.size() + queue.close() + diskQueue + } else { + queue + } + newQueues.add(newQueue) + } + queues = newQueues + } + released + } + + private def createDiskQueue(): RowQueue = { + DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields) + } + + private def createNewQueue(required: Long): RowQueue = { + val page = try { + allocatePage(required) + } catch { + case _: OutOfMemoryError => + null + } + val buffer = if (page != null) { + new InMemoryRowQueue(page, numFields) { + override def close(): Unit = { + freePage(page) + } + } + } else { + createDiskQueue() + } + + synchronized { + queues.add(buffer) + } + buffer + } + + def add(row: UnsafeRow): Boolean = { + if (writing == null || !writing.add(row)) { + writing = createNewQueue(4 + row.getSizeInBytes) + if (!writing.add(row)) { + throw new SparkException(s"failed to push a row into $writing") + } + } + true + } + + def remove(): UnsafeRow = { + var row: UnsafeRow = null + if (reading != null) { + row = reading.remove() + } + if (row == null) { + if (reading != null) { + reading.close() + } + synchronized { + reading = queues.remove() + } + assert(reading != null, s"queue should not be empty") + row = reading.remove() + assert(row != null, s"$reading should have at least one row") + } + row + } + + def close(): Unit = { + if (reading != null) { + reading.close() + reading = null + } + synchronized { + while (!queues.isEmpty) { + queues.remove().close() + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 822f49ecab47b..c02b15498748f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Cast, GenericInternalRow} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.functions._ @@ -186,7 +186,7 @@ object StatFunctions extends Logging { require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => - val countsRow = new GenericMutableRow(columnSize + 1) + val countsRow = new GenericInternalRow(columnSize + 1) rows.foreach { (row: Row) => // row.get(0) is column 1 // row.get(1) is column 2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 027b5bbfab8d6..c14feea91ed7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.streaming -import java.io.IOException +import java.io.{InputStream, IOException, OutputStream} import java.nio.charset.StandardCharsets.UTF_8 +import scala.io.{Source => IOSource} import scala.reflect.ClassTag import org.apache.hadoop.fs.{Path, PathFilter} @@ -93,20 +94,25 @@ abstract class CompactibleFileStreamLog[T: ClassTag]( } } - override def serialize(logData: Array[T]): Array[Byte] = { - (metadataLogVersion +: logData.map(serializeData)).mkString("\n").getBytes(UTF_8) + override def serialize(logData: Array[T], out: OutputStream): Unit = { + // called inside a try-finally where the underlying stream is closed in the caller + out.write(metadataLogVersion.getBytes(UTF_8)) + logData.foreach { data => + out.write('\n') + out.write(serializeData(data).getBytes(UTF_8)) + } } - override def deserialize(bytes: Array[Byte]): Array[T] = { - val lines = new String(bytes, UTF_8).split("\n") - if (lines.length == 0) { + override def deserialize(in: InputStream): Array[T] = { + val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() + if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file") } - val version = lines(0) + val version = lines.next() if (version != metadataLogVersion) { throw new IllegalStateException(s"Unknown log version: ${version}") } - lines.slice(1, lines.length).map(deserializeData) + lines.map(deserializeData).toArray } override def add(batchId: Long, logs: Array[T]): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 39a0f3341389c..c7235320fd6bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.streaming -import java.io.{FileNotFoundException, IOException} -import java.nio.ByteBuffer +import java.io.{FileNotFoundException, InputStream, IOException, OutputStream} import java.util.{ConcurrentModificationException, EnumSet, UUID} import scala.reflect.ClassTag @@ -29,7 +28,6 @@ import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.SparkSession import org.apache.spark.util.UninterruptibleThread @@ -88,12 +86,16 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) } } - protected def serialize(metadata: T): Array[Byte] = { - JavaUtils.bufferToArray(serializer.serialize(metadata)) + protected def serialize(metadata: T, out: OutputStream): Unit = { + // called inside a try-finally where the underlying stream is closed in the caller + val outStream = serializer.serializeStream(out) + outStream.writeObject(metadata) } - protected def deserialize(bytes: Array[Byte]): T = { - serializer.deserialize[T](ByteBuffer.wrap(bytes)) + protected def deserialize(in: InputStream): T = { + // called inside a try-finally where the underlying stream is closed in the caller + val inStream = serializer.deserializeStream(in) + inStream.readObject[T]() } /** @@ -114,7 +116,7 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) // Only write metadata when the batch has not yet been written Thread.currentThread match { case ut: UninterruptibleThread => - ut.runUninterruptibly { writeBatch(batchId, serialize(metadata)) } + ut.runUninterruptibly { writeBatch(batchId, metadata, serialize) } case _ => throw new IllegalStateException( "HDFSMetadataLog.add() must be executed on a o.a.spark.util.UninterruptibleThread") @@ -129,7 +131,7 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a * valid behavior, we still need to prevent it from destroying the files. */ - private def writeBatch(batchId: Long, bytes: Array[Byte]): Unit = { + private def writeBatch(batchId: Long, metadata: T, writer: (T, OutputStream) => Unit): Unit = { // Use nextId to create a temp file var nextId = 0 while (true) { @@ -137,9 +139,9 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) try { val output = fileManager.create(tempPath) try { - output.write(bytes) + writer(metadata, output) } finally { - output.close() + IOUtils.closeQuietly(output) } try { // Try to commit the batch @@ -193,10 +195,9 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) if (fileManager.exists(batchMetadataFile)) { val input = fileManager.open(batchMetadataFile) try { - val bytes = IOUtils.toByteArray(input) - Some(deserialize(bytes)) + Some(deserialize(input)) } finally { - input.close() + IOUtils.closeQuietly(input) } } else { logDebug(s"Unable to find batch $batchMetadataFile") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index d3a46d020dbbf..c9f5d3b3d92d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -123,7 +123,7 @@ private[window] final class AggregateProcessor( private[this] val join = new JoinedRow private[this] val numImperatives = imperatives.length - private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType)) + private[this] val buffer = new SpecificInternalRow(bufferSchema.toSeq.map(_.dataType)) initialProjection.target(buffer) updateProjection.target(buffer) @@ -154,6 +154,6 @@ private[window] final class AggregateProcessor( } /** Evaluate buffer. */ - def evaluate(target: MutableRow): Unit = + def evaluate(target: InternalRow): Unit = evaluateProjection.target(target)(buffer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 7a6a30f120386..1dd281ebf1034 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -204,7 +204,7 @@ case class WindowExec( val factory = key match { // Offset Frame case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => - target: MutableRow => + target: InternalRow => new OffsetWindowFunctionFrame( target, ordinal, @@ -217,7 +217,7 @@ case class WindowExec( // Growing Frame. case ("AGGREGATE", frameType, None, Some(high)) => - target: MutableRow => { + target: InternalRow => { new UnboundedPrecedingWindowFunctionFrame( target, processor, @@ -226,7 +226,7 @@ case class WindowExec( // Shrinking Frame. case ("AGGREGATE", frameType, Some(low), None) => - target: MutableRow => { + target: InternalRow => { new UnboundedFollowingWindowFunctionFrame( target, processor, @@ -235,7 +235,7 @@ case class WindowExec( // Moving Frame. case ("AGGREGATE", frameType, Some(low), Some(high)) => - target: MutableRow => { + target: InternalRow => { new SlidingWindowFunctionFrame( target, processor, @@ -245,7 +245,7 @@ case class WindowExec( // Entire Partition Frame. case ("AGGREGATE", frameType, None, None) => - target: MutableRow => { + target: InternalRow => { new UnboundedWindowFunctionFrame(target, processor) } } @@ -312,7 +312,7 @@ case class WindowExec( val inputFields = child.output.length var sorter: UnsafeExternalSorter = null var rowBuffer: RowBuffer = null - val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) + val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length private[this] def fetchNextPartition() { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 2ab9faab7a59b..70efc0f78ddb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -56,7 +56,7 @@ private[window] abstract class WindowFunctionFrame { * @param offset by which rows get moved within a partition. */ private[window] final class OffsetWindowFunctionFrame( - target: MutableRow, + target: InternalRow, ordinal: Int, expressions: Array[OffsetWindowFunction], inputSchema: Seq[Attribute], @@ -136,7 +136,7 @@ private[window] final class OffsetWindowFunctionFrame( * @param ubound comparator used to identify the upper bound of an output row. */ private[window] final class SlidingWindowFunctionFrame( - target: MutableRow, + target: InternalRow, processor: AggregateProcessor, lbound: BoundOrdering, ubound: BoundOrdering) @@ -217,7 +217,7 @@ private[window] final class SlidingWindowFunctionFrame( * @param processor to calculate the row values with. */ private[window] final class UnboundedWindowFunctionFrame( - target: MutableRow, + target: InternalRow, processor: AggregateProcessor) extends WindowFunctionFrame { @@ -255,7 +255,7 @@ private[window] final class UnboundedWindowFunctionFrame( * @param ubound comparator used to identify the upper bound of an output row. */ private[window] final class UnboundedPrecedingWindowFunctionFrame( - target: MutableRow, + target: InternalRow, processor: AggregateProcessor, ubound: BoundOrdering) extends WindowFunctionFrame { @@ -317,7 +317,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( * @param lbound comparator used to identify the lower bound of an output row. */ private[window] final class UnboundedFollowingWindowFunctionFrame( - target: MutableRow, + target: InternalRow, processor: AggregateProcessor, lbound: BoundOrdering) extends WindowFunctionFrame { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index c29ec6f426789..3c1f6e897ea62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.expressions._ * * {{{ * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * Window.partitionBy("country").orderBy("date") + * .rowsBetween(Window.unboundedPreceding, Window.currentRow) * * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) @@ -42,7 +43,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { spec.partitionBy(colName, colNames : _*) } @@ -51,7 +52,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { spec.partitionBy(cols : _*) } @@ -60,7 +61,7 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { spec.orderBy(colName, colNames : _*) } @@ -69,11 +70,92 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { spec.orderBy(cols : _*) } + /** + * Value representing the last row in the partition, equivalent to "UNBOUNDED PRECEDING" in SQL. + * This can be used to specify the frame boundaries: + * + * {{{ + * Window.rowsBetween(Window.unboundedPreceding, Window.currentRow) + * }}} + * + * @since 2.1.0 + */ + def unboundedPreceding: Long = Long.MinValue + + /** + * Value representing the last row in the partition, equivalent to "UNBOUNDED FOLLOWING" in SQL. + * This can be used to specify the frame boundaries: + * + * {{{ + * Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + * }}} + * + * @since 2.1.0 + */ + def unboundedFollowing: Long = Long.MaxValue + + /** + * Value representing the current row. This can be used to specify the frame boundaries: + * + * {{{ + * Window.rowsBetween(Window.unboundedPreceding, Window.currentRow) + * }}} + * + * @since 2.1.0 + */ + def currentRow: Long = 0 + + /** + * Creates a [[WindowSpec]] with the frame boundaries defined, + * from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative positions from the current row. For example, "0" means + * "current row", while "-1" means the row before the current row, and "5" means the fifth row + * after the current row. + * + * We recommend users use [[Window.unboundedPreceding]], [[Window.unboundedFollowing]], + * and [[Window.currentRow]] to specify special boundary values, rather than using integral + * values directly. + * + * @param start boundary start, inclusive. The frame is unbounded if this is + * the minimum long value ([[Window.unboundedPreceding]]). + * @param end boundary end, inclusive. The frame is unbounded if this is the + * maximum long value ([[Window.unboundedFollowing]]). + * @since 2.1.0 + */ + // Note: when updating the doc for this method, also update WindowSpec.rowsBetween. + def rowsBetween(start: Long, end: Long): WindowSpec = { + spec.rowsBetween(start, end) + } + + /** + * Creates a [[WindowSpec]] with the frame boundaries defined, + * from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * while "-1" means one off before the current row, and "5" means the five off after the + * current row. + * + * We recommend users use [[Window.unboundedPreceding]], [[Window.unboundedFollowing]], + * and [[Window.currentRow]] to specify special boundary values, rather than using integral + * values directly. + * + * @param start boundary start, inclusive. The frame is unbounded if this is + * the minimum long value ([[Window.unboundedPreceding]]). + * @param end boundary end, inclusive. The frame is unbounded if this is the + * maximum long value ([[Window.unboundedFollowing]]). + * @since 2.1.0 + */ + // Note: when updating the doc for this method, also update WindowSpec.rangeBetween. + def rangeBetween(start: Long, end: Long): WindowSpec = { + spec.rangeBetween(start, end) + } + private[sql] def spec: WindowSpec = { new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index d716da2668675..8ebed399bf2d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -39,7 +39,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { partitionBy((colName +: colNames).map(Column(_)): _*) } @@ -48,7 +48,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { new WindowSpec(cols.map(_.expr), orderSpec, frame) } @@ -57,7 +57,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { orderBy((colName +: colNames).map(Column(_)): _*) } @@ -66,7 +66,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { val sortOrder: Seq[SortOrder] = cols.map { col => col.expr match { @@ -86,12 +86,17 @@ class WindowSpec private[sql]( * "current row", while "-1" means the row before the current row, and "5" means the fifth row * after the current row. * - * @param start boundary start, inclusive. - * The frame is unbounded if this is the minimum long value. - * @param end boundary end, inclusive. - * The frame is unbounded if this is the maximum long value. + * We recommend users use [[Window.unboundedPreceding]], [[Window.unboundedFollowing]], + * and [[Window.currentRow]] to specify special boundary values, rather than using integral + * values directly. + * + * @param start boundary start, inclusive. The frame is unbounded if this is + * the minimum long value ([[Window.unboundedPreceding]]). + * @param end boundary end, inclusive. The frame is unbounded if this is the + * maximum long value ([[Window.unboundedFollowing]]). * @since 1.4.0 */ + // Note: when updating the doc for this method, also update Window.rowsBetween. def rowsBetween(start: Long, end: Long): WindowSpec = { between(RowFrame, start, end) } @@ -103,12 +108,17 @@ class WindowSpec private[sql]( * while "-1" means one off before the current row, and "5" means the five off after the * current row. * - * @param start boundary start, inclusive. - * The frame is unbounded if this is the minimum long value. - * @param end boundary end, inclusive. - * The frame is unbounded if this is the maximum long value. + * We recommend users use [[Window.unboundedPreceding]], [[Window.unboundedFollowing]], + * and [[Window.currentRow]] to specify special boundary values, rather than using integral + * values directly. + * + * @param start boundary start, inclusive. The frame is unbounded if this is + * the minimum long value ([[Window.unboundedPreceding]]). + * @param end boundary end, inclusive. The frame is unbounded if this is the + * maximum long value ([[Window.unboundedFollowing]]). * @since 1.4.0 */ + // Note: when updating the doc for this method, also update Window.rangeBetween. def rangeBetween(start: Long, end: Long): WindowSpec = { between(RangeFrame, start, end) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index eac658c6176cb..5417a0e481158 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -106,7 +106,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def apply(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( @@ -120,7 +120,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { * Creates a [[Column]] for this UDAF using the distinct values of the given * [[Column]]s as input arguments. */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3bc1c5b90031d..de4943152720c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -38,7 +38,7 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * Functions available for [[DataFrame]]. + * Functions available for DataFrame operations. * * @groupname udf_funcs UDF functions * @groupname agg_funcs Aggregate functions @@ -54,6 +54,7 @@ import org.apache.spark.util.Utils * @since 1.3.0 */ @Experimental +@InterfaceStability.Evolving // scalastyle:off object functions { // scalastyle:on @@ -182,13 +183,43 @@ object functions { // Aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * @group agg_funcs + * @since 1.3.0 + */ + @deprecated("Use approx_count_distinct", "2.1.0") + def approxCountDistinct(e: Column): Column = approx_count_distinct(e) + + /** + * @group agg_funcs + * @since 1.3.0 + */ + @deprecated("Use approx_count_distinct", "2.1.0") + def approxCountDistinct(columnName: String): Column = approx_count_distinct(columnName) + + /** + * @group agg_funcs + * @since 1.3.0 + */ + @deprecated("Use approx_count_distinct", "2.1.0") + def approxCountDistinct(e: Column, rsd: Double): Column = approx_count_distinct(e, rsd) + + /** + * @group agg_funcs + * @since 1.3.0 + */ + @deprecated("Use approx_count_distinct", "2.1.0") + def approxCountDistinct(columnName: String, rsd: Double): Column = { + approx_count_distinct(Column(columnName), rsd) + } + /** * Aggregate function: returns the approximate number of distinct items in a group. * * @group agg_funcs - * @since 1.3.0 + * @since 2.1.0 */ - def approxCountDistinct(e: Column): Column = withAggregateFunction { + def approx_count_distinct(e: Column): Column = withAggregateFunction { HyperLogLogPlusPlus(e.expr) } @@ -196,9 +227,9 @@ object functions { * Aggregate function: returns the approximate number of distinct items in a group. * * @group agg_funcs - * @since 1.3.0 + * @since 2.1.0 */ - def approxCountDistinct(columnName: String): Column = approxCountDistinct(column(columnName)) + def approx_count_distinct(columnName: String): Column = approx_count_distinct(column(columnName)) /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -206,9 +237,9 @@ object functions { * @param rsd maximum estimation error allowed (default = 0.05) * * @group agg_funcs - * @since 1.3.0 + * @since 2.1.0 */ - def approxCountDistinct(e: Column, rsd: Double): Column = withAggregateFunction { + def approx_count_distinct(e: Column, rsd: Double): Column = withAggregateFunction { HyperLogLogPlusPlus(e.expr, rsd, 0, 0) } @@ -218,10 +249,10 @@ object functions { * @param rsd maximum estimation error allowed (default = 0.05) * * @group agg_funcs - * @since 1.3.0 + * @since 2.1.0 */ - def approxCountDistinct(columnName: String, rsd: Double): Column = { - approxCountDistinct(Column(columnName), rsd) + def approx_count_distinct(columnName: String, rsd: Double): Column = { + approx_count_distinct(Column(columnName), rsd) } /** @@ -1949,37 +1980,65 @@ object functions { */ def tanh(columnName: String): Column = tanh(Column(columnName)) + /** + * @group math_funcs + * @since 1.4.0 + */ + @deprecated("Use degrees", "2.1.0") + def toDegrees(e: Column): Column = degrees(e) + + /** + * @group math_funcs + * @since 1.4.0 + */ + @deprecated("Use degrees", "2.1.0") + def toDegrees(columnName: String): Column = degrees(Column(columnName)) + /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * * @group math_funcs - * @since 1.4.0 + * @since 2.1.0 */ - def toDegrees(e: Column): Column = withExpr { ToDegrees(e.expr) } + def degrees(e: Column): Column = withExpr { ToDegrees(e.expr) } /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @group math_funcs + * @since 2.1.0 + */ + def degrees(columnName: String): Column = degrees(Column(columnName)) + + /** + * @group math_funcs + * @since 1.4.0 + */ + @deprecated("Use radians", "2.1.0") + def toRadians(e: Column): Column = radians(e) + + /** * @group math_funcs * @since 1.4.0 */ - def toDegrees(columnName: String): Column = toDegrees(Column(columnName)) + @deprecated("Use radians", "2.1.0") + def toRadians(columnName: String): Column = radians(Column(columnName)) /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * * @group math_funcs - * @since 1.4.0 + * @since 2.1.0 */ - def toRadians(e: Column): Column = withExpr { ToRadians(e.expr) } + def radians(e: Column): Column = withExpr { ToRadians(e.expr) } /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * * @group math_funcs - * @since 1.4.0 + * @since 2.1.0 */ - def toRadians(columnName: String): Column = toRadians(Column(columnName)) + def radians(columnName: String): Column = radians(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// // Misc functions @@ -2672,6 +2731,7 @@ object functions { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def window( timeColumn: Column, windowDuration: String, @@ -2725,6 +2785,7 @@ object functions { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def window(timeColumn: Column, windowDuration: String, slideDuration: String): Column = { window(timeColumn, windowDuration, slideDuration, "0 second") } @@ -2763,6 +2824,7 @@ object functions { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def window(timeColumn: Column, windowDuration: String): Column = { window(timeColumn, windowDuration, windowDuration, "0 second") } @@ -3096,5 +3158,4 @@ object functions { def callUDF(udfName: String, cols: Column*): Column = withExpr { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index e412e1b4b302a..f6c297e91b7c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -94,20 +94,19 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database does not exist") override def listTables(dbName: String): Dataset[Table] = { - requireDatabaseExists(dbName) val tables = sessionCatalog.listTables(dbName).map(makeTable) CatalogImpl.makeDataset(tables, sparkSession) } private def makeTable(tableIdent: TableIdentifier): Table = { val metadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) - val database = metadata.identifier.database + val isTemp = sessionCatalog.isTemporaryTable(tableIdent) new Table( name = tableIdent.table, - database = database.orNull, + database = metadata.identifier.database.orNull, description = metadata.comment.orNull, - tableType = if (database.isEmpty) "TEMPORARY" else metadata.tableType.name, - isTemporary = database.isEmpty) + tableType = if (isTemp) "TEMPORARY" else metadata.tableType.name, + isTemporary = isTemp) } /** @@ -365,20 +364,35 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Drops the temporary view with the given view name in the catalog. + * Drops the local temporary view with the given view name in the catalog. * If the view has been cached/persisted before, it's also unpersisted. * * @param viewName the name of the view to be dropped. * @group ddl_ops * @since 2.0.0 */ - override def dropTempView(viewName: String): Unit = { - sparkSession.sessionState.catalog.getTempView(viewName).foreach { tempView => + override def dropTempView(viewName: String): Boolean = { + sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView => sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView)) sessionCatalog.dropTempView(viewName) } } + /** + * Drops the global temporary view with the given view name in the catalog. + * If the view has been cached/persisted before, it's also unpersisted. + * + * @param viewName the name of the view to be dropped. + * @group ddl_ops + * @since 2.1.0 + */ + override def dropGlobalTempView(viewName: String): Boolean = { + sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => + sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef)) + sessionCatalog.dropGlobalTempView(viewName) + } + } + /** * Returns true if the table is currently cached in-memory. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fecdf792fd14a..192083e2ea5f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -41,7 +41,7 @@ object SQLConf { private val sqlConfEntries = java.util.Collections.synchronizedMap( new java.util.HashMap[String, ConfigEntry[_]]()) - private def register(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized { + private[sql] def register(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized { require(!sqlConfEntries.containsKey(entry.key), s"Duplicate SQLConfigEntry. ${entry.key} has been registered") sqlConfEntries.put(entry.key, entry) @@ -326,18 +326,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - // This is used to control the when we will split a schema's JSON string to multiple pieces - // in order to fit the JSON string in metastore's table property (by default, the value has - // a length restriction of 4000 characters). We will split the JSON string of a schema - // to its length exceeds the threshold. - val SCHEMA_STRING_LENGTH_THRESHOLD = - SQLConfigBuilder("spark.sql.sources.schemaStringLengthThreshold") - .doc("The maximum length allowed in a single cell when " + - "storing additional schema information in Hive's metastore.") - .internal() - .intConf - .createWithDefault(4000) - val PARTITION_COLUMN_TYPE_INFERENCE = SQLConfigBuilder("spark.sql.sources.partitionColumnTypeInference.enabled") .doc("When true, automatically infer the data types for partitioned columns.") @@ -588,6 +576,12 @@ object SQLConf { .doubleConf .createWithDefault(0.05) + val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles") + .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + + "encountering corrupt files and contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -736,10 +730,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) - // Do not use a value larger than 4000 as the default value of this property. - // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. - def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) - def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = @@ -759,6 +749,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def warehousePath: String = new Path(getConf(WAREHOUSE_PATH)).toString + def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) + override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) @@ -886,3 +878,46 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { } } +/** + * Static SQL configuration is a cross-session, immutable Spark configuration. External users can + * see the static sql configs via `SparkSession.conf`, but can NOT set/unset them. + */ +object StaticSQLConf { + val globalConfKeys = java.util.Collections.synchronizedSet(new java.util.HashSet[String]()) + + private def buildConf(key: String): ConfigBuilder = { + ConfigBuilder(key).onCreate { entry => + globalConfKeys.add(entry.key) + SQLConf.register(entry) + } + } + + val CATALOG_IMPLEMENTATION = buildConf("spark.sql.catalogImplementation") + .internal() + .stringConf + .checkValues(Set("hive", "in-memory")) + .createWithDefault("in-memory") + + val GLOBAL_TEMP_DATABASE = buildConf("spark.sql.globalTempDatabase") + .internal() + .stringConf + .createWithDefault("global_temp") + + // This is used to control when we will split a schema's JSON string to multiple pieces + // in order to fit the JSON string in metastore's table property (by default, the value has + // a length restriction of 4000 characters, so do not use a value larger than 4000 as the default + // value of this property). We will split the JSON string of a schema to its length exceeds the + // threshold. Note that, this conf is only read in HiveExternalCatalog which is cross-session, + // that's why this conf has to be a static SQL conf. + val SCHEMA_STRING_LENGTH_THRESHOLD = buildConf("spark.sql.sources.schemaStringLengthThreshold") + .doc("The maximum length allowed in a single cell when " + + "storing additional schema information in Hive's metastore.") + .internal() + .intConf + .createWithDefault(4000) + + val DEBUG_MODE = buildConf("spark.sql.debug") + .internal() + .booleanConf + .createWithDefault(false) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 9f7d0019c6b92..8759dfe39ce1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -95,6 +95,7 @@ private[sql] class SessionState(sparkSession: SparkSession) { */ lazy val catalog = new SessionCatalog( sparkSession.sharedState.externalCatalog, + sparkSession.sharedState.globalTempViewManager, functionResourceLoader, functionRegistry, conf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 6387f0150631c..c6083b372a2db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -22,13 +22,14 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog} +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, GlobalTempViewManager, InMemoryCatalog} import org.apache.spark.sql.execution.CacheManager import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} +import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.util.{MutableURLClassLoader, Utils} @@ -37,39 +38,14 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} */ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { - /** - * Class for caching query results reused in future executions. - */ - val cacheManager: CacheManager = new CacheManager - - /** - * A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s. - */ - val listener: SQLListener = createListenerAndUI(sparkContext) - + // Load hive-site.xml into hadoopConf and determine the warehouse path we want to use, based on + // the config from both hive and Spark SQL. Finally set the warehouse config value to sparkConf. { val configFile = Utils.getContextOrSparkClassLoader.getResource("hive-site.xml") if (configFile != null) { sparkContext.hadoopConfiguration.addResource(configFile) } - } - - /** - * A catalog that interacts with external systems. - */ - lazy val externalCatalog: ExternalCatalog = - SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( - SharedState.externalCatalogClassName(sparkContext.conf), - sparkContext.conf, - sparkContext.hadoopConfiguration) - - /** - * A classloader used to load all user-added jar. - */ - val jarClassLoader = new NonClosableMutableURLClassLoader( - org.apache.spark.util.Utils.getContextOrSparkClassLoader) - { // Set the Hive metastore warehouse path to the one we use val tempConf = new SQLConf sparkContext.conf.getAll.foreach { case (k, v) => tempConf.setConfString(k, v) } @@ -93,6 +69,48 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { logInfo(s"Warehouse path is '${tempConf.warehousePath}'.") } + /** + * Class for caching query results reused in future executions. + */ + val cacheManager: CacheManager = new CacheManager + + /** + * A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s. + */ + val listener: SQLListener = createListenerAndUI(sparkContext) + + /** + * A catalog that interacts with external systems. + */ + val externalCatalog: ExternalCatalog = + SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( + SharedState.externalCatalogClassName(sparkContext.conf), + sparkContext.conf, + sparkContext.hadoopConfiguration) + + /** + * A manager for global temporary views. + */ + val globalTempViewManager = { + // System preserved database should not exists in metastore. However it's hard to guarantee it + // for every session, because case-sensitivity differs. Here we always lowercase it to make our + // life easier. + val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase + if (externalCatalog.databaseExists(globalTempDB)) { + throw new SparkException( + s"$globalTempDB is a system preserved database, please rename your existing database " + + "to resolve the name conflict, or set a different value for " + + s"${GLOBAL_TEMP_DATABASE.key}, and launch your Spark application again.") + } + new GlobalTempViewManager(globalTempDB) + } + + /** + * A classloader used to load all user-added jar. + */ + val jarClassLoader = new NonClosableMutableURLClassLoader( + org.apache.spark.util.Utils.getContextOrSparkClassLoader) + /** * Create a SQLListener then add it into SparkContext, and create a SQLTab if there is SparkUI. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 3f540d6258a0d..4f61a328f47ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types._ @@ -94,7 +94,7 @@ private object PostgresDialect extends JdbcDialect { // // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor // - if (properties.getOrElse(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { + if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 28d8bc3de68b8..161e0102f0b43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -17,8 +17,8 @@ package org.apache.spark -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.annotation.{DeveloperApi, InterfaceStability} +import org.apache.spark.sql.execution.SparkStrategy /** * Allows the execution of relational queries, including those expressed in SQL using Spark. @@ -40,6 +40,7 @@ package object sql { * [[org.apache.spark.sql.sources]] */ @DeveloperApi + @InterfaceStability.Unstable type Strategy = SparkStrategy type DataFrame = Dataset[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 864a9cd3eb89d..87b73062180e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.types.StructType @@ -283,6 +283,37 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo */ def text(path: String): DataFrame = format("text").load(path) + /** + * Loads text file(s) and returns a [[Dataset]] of String. The underlying schema of the Dataset + * contains a single string column named "value". + * + * If the directory structure of the text files contains partitioning information, those are + * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. + * + * Each line in the text file is a new element in the resulting Dataset. For example: + * {{{ + * // Scala: + * spark.readStream.textFile("/path/to/spark/README.md") + * + * // Java: + * spark.readStream().textFile("/path/to/spark/README.md") + * }}} + * + * You can set the following text-specific options to deal with text files: + *
    + *
  • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be + * considered in every trigger.
  • + *
+ * + * @param path input path + * @since 2.1.0 + */ + def textFile(path: String): Dataset[String] = { + if (userSpecifiedSchema.nonEmpty) { + throw new AnalysisException("User specified schema not supported with `textFile`") + } + text(path).select("value").as[String](sparkSession.implicits.newStringEncoder) + } /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 3cae5355eecc6..5e93fc469a41f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable.ListBuffer import scala.util.control.NonFatal -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.QueryExecution * multiple different threads. */ @Experimental +@InterfaceStability.Evolving trait QueryExecutionListener { /** @@ -68,6 +69,7 @@ trait QueryExecutionListener { * Manager for [[QueryExecutionListener]]. See [[org.apache.spark.sql.SQLContext.listenerManager]]. */ @Experimental +@InterfaceStability.Evolving class ExecutionListenerManager private[sql] () extends Logging { /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala similarity index 95% rename from sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index c6f8c3ad3fc93..1255c49104718 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -22,7 +22,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DataType, LongType, StructType} -class DataFrameWindowSuite extends QueryTest with SharedSQLContext { +/** + * Window function testing for DataFrame API. + */ +class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("reuse window partitionBy") { @@ -47,6 +50,16 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } + test("Window.rowsBetween") { + val df = Seq(("one", 1), ("two", 2)).toDF("key", "value") + // Running (cumulative) sum + checkAnswer( + df.select('key, sum("value").over( + Window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), + Row("one", 1) :: Row("two", 3) :: Nil + ) + } + test("lead") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.createOrReplaceTempView("window_table") @@ -144,9 +157,11 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { df.select( $"key", last("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), + Window.partitionBy($"value").orderBy($"key") + .rowsBetween(Window.currentRow, Window.unboundedFollowing)), last("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), + Window.partitionBy($"value").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.currentRow)), last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), Row(4, 4, 4, 4))) @@ -228,7 +243,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { $"key", var_pop($"value").over(window), var_samp($"value").over(window), - approxCountDistinct($"value").over(window)), + approx_count_distinct($"value").over(window)), Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2)) ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3243f352a5337..5fce9b4fe97ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -872,6 +872,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 1), ("a", 2), ("b", 1)) } + test("dropDuplicates: columns with same column name") { + val ds1 = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() + val ds2 = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() + // The dataset joined has two columns of the same name "_2". + val joined = ds1.join(ds2, "_1").select(ds1("_2").as[Int], ds2("_2").as[Int]) + checkDataset( + joined.dropDuplicates(), + (1, 2), (1, 1), (2, 1), (2, 2)) + } + + test("dropDuplicates should not change child plan output") { + val ds = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() + checkDataset( + ds.dropDuplicates("_1").select(ds("_1").as[String], ds("_2").as[Int]), + ("a", 1), ("b", 1)) + } + test("SPARK-16097: Encoders.tuple should handle null object correctly") { val enc = Encoders.tuple(Encoders.tuple(Encoders.STRING, Encoders.STRING), Encoders.STRING) val data = Seq((("a", "b"), "c"), (null, "d")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 0de7f2321f398..6944c6f848179 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -148,19 +148,19 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { testOneToOneMathFunction(tanh, math.tanh) } - test("toDegrees") { - testOneToOneMathFunction(toDegrees, math.toDegrees) + test("degrees") { + testOneToOneMathFunction(degrees, math.toDegrees) checkAnswer( sql("SELECT degrees(0), degrees(1), degrees(1.5)"), - Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) + Seq((1, 2)).toDF().select(degrees(lit(0)), degrees(lit(1)), degrees(lit(1.5))) ) } - test("toRadians") { - testOneToOneMathFunction(toRadians, math.toRadians) + test("radians") { + testOneToOneMathFunction(radians, math.toRadians) checkAnswer( sql("SELECT radians(0), radians(1), radians(1.5)"), - Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) + Seq((1, 2)).toDF().select(radians(lit(0)), radians(lit(1)), radians(lit(1.5))) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 34936b38fb5d4..7516be315dd2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -27,7 +27,7 @@ class RowSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("create row") { - val expected = new GenericMutableRow(4) + val expected = new GenericInternalRow(4) expected.setInt(0, 2147483647) expected.update(1, UTF8String.fromString("this is a string")) expected.setBoolean(2, false) @@ -49,7 +49,7 @@ class RowSuite extends SparkFunSuite with SharedSQLContext { } test("SpecificMutableRow.update with null") { - val row = new SpecificMutableRow(Seq(IntegerType)) + val row = new SpecificInternalRow(Seq(IntegerType)) row(0) = null assert(row.isNullAt(0)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 001c1a1d85313..2b35db411e2ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -88,11 +88,11 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { df.createOrReplaceTempView("listtablessuitetable") assert( sqlContext.tables().filter("tableName = 'listtablessuitetable'").collect().toSeq == - Row("listtablessuitetable", true) :: Nil) + Row("", "listtablessuitetable", true) :: Nil) assert( sqlContext.sql("SHOW tables").filter("tableName = 'listtablessuitetable'").collect().toSeq == - Row("listtablessuitetable", true) :: Nil) + Row("", "listtablessuitetable", true) :: Nil) sqlContext.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true, purge = false) @@ -105,11 +105,11 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { df.createOrReplaceTempView("listtablessuitetable") assert( sqlContext.tables("default").filter("tableName = 'listtablessuitetable'").collect().toSeq == - Row("listtablessuitetable", true) :: Nil) + Row("", "listtablessuitetable", true) :: Nil) assert( sqlContext.sql("show TABLES in default").filter("tableName = 'listtablessuitetable'") - .collect().toSeq == Row("listtablessuitetable", true) :: Nil) + .collect().toSeq == Row("", "listtablessuitetable", true) :: Nil) sqlContext.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true, purge = false) @@ -122,7 +122,8 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { df.createOrReplaceTempView("listtablessuitetable") val expectedSchema = StructType( - StructField("tableName", StringType, false) :: + StructField("database", StringType, false) :: + StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) Seq(sqlContext.tables(), sqlContext.sql("SHOW TABLes")).foreach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index b5eb16b6f650b..ffa26f1f8250f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -21,7 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.spark.sql.expressions.Window @@ -64,7 +64,7 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { assert(agg.eval(mergeBuffer) == data.map(_._1).max) // Tests low level eval(row: InternalRow) API. - val row = new GenericMutableRow(Array(mergeBuffer): Array[Any]) + val row = new GenericInternalRow(Array(mergeBuffer): Array[Any]) // Evaluates directly on row consist of aggregation buffer object. assert(agg.eval(row) == data.map(_._1).max) @@ -73,7 +73,7 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { test("supports SpecificMutableRow as mutable row") { val aggregationBufferSchema = Seq(IntegerType, LongType, BinaryType, IntegerType) val aggBufferOffset = 2 - val buffer = new SpecificMutableRow(aggregationBufferSchema) + val buffer = new SpecificInternalRow(aggregationBufferSchema) val agg = new TypedMax(BoundReference(ordinal = 1, dataType = IntegerType, nullable = false)) .withNewMutableAggBufferOffset(aggBufferOffset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala new file mode 100644 index 0000000000000..391bcb8b35d02 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalog.Table +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class GlobalTempViewSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + override protected def beforeAll(): Unit = { + super.beforeAll() + globalTempDB = spark.sharedState.globalTempViewManager.database + } + + private var globalTempDB: String = _ + + test("basic semantic") { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") + + // If there is no database in table name, we should try local temp view first, if not found, + // try table/view in current database, which is "default" in this case. So we expect + // NoSuchTableException here. + intercept[NoSuchTableException](spark.table("src")) + + // Use qualified name to refer to the global temp view explicitly. + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + + // Table name without database will never refer to a global temp view. + intercept[NoSuchTableException](sql("DROP VIEW src")) + + sql(s"DROP VIEW $globalTempDB.src") + // The global temp view should be dropped successfully. + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + + // We can also use Dataset API to create global temp view + Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + + // Use qualified name to rename a global temp view. + sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2") + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a")) + + // Use qualified name to alter a global temp view. + sql(s"ALTER VIEW $globalTempDB.src2 AS SELECT 2, 'b'") + checkAnswer(spark.table(s"$globalTempDB.src2"), Row(2, "b")) + + // We can also use Catalog API to drop global temp view + spark.catalog.dropGlobalTempView("src2") + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src2")) + } + + test("global temp view is shared among all sessions") { + try { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, 2)) + val newSession = spark.newSession() + checkAnswer(newSession.table(s"$globalTempDB.src"), Row(1, 2)) + } finally { + spark.catalog.dropGlobalTempView("src") + } + } + + test("global temp view database should be preserved") { + val e = intercept[AnalysisException](sql(s"CREATE DATABASE $globalTempDB")) + assert(e.message.contains("system preserved database")) + + val e2 = intercept[AnalysisException](sql(s"USE $globalTempDB")) + assert(e2.message.contains("system preserved database")) + } + + test("CREATE GLOBAL TEMP VIEW USING") { + withTempPath { path => + try { + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) + sql(s"CREATE GLOBAL TEMP VIEW src USING parquet OPTIONS (PATH '${path.getAbsolutePath}')") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + sql(s"INSERT INTO $globalTempDB.src SELECT 2, 'b'") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a") :: Row(2, "b") :: Nil) + } finally { + spark.catalog.dropGlobalTempView("src") + } + } + } + + test("CREATE TABLE LIKE should work for global temp view") { + try { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") + sql(s"CREATE TABLE cloned LIKE ${globalTempDB}.src") + val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) + assert(tableMeta.schema == new StructType().add("a", "int", false).add("b", "string", false)) + } finally { + spark.catalog.dropGlobalTempView("src") + sql("DROP TABLE default.cloned") + } + } + + test("list global temp views") { + try { + sql("CREATE GLOBAL TEMP VIEW v1 AS SELECT 3, 4") + sql("CREATE TEMP VIEW v2 AS SELECT 1, 2") + + checkAnswer(sql(s"SHOW TABLES IN $globalTempDB"), + Row(globalTempDB, "v1", true) :: + Row("", "v2", true) :: Nil) + + assert(spark.catalog.listTables(globalTempDB).collect().toSeq.map(_.name) == Seq("v1", "v2")) + } finally { + spark.catalog.dropTempView("v1") + spark.catalog.dropGlobalTempView("v2") + } + } + + test("should lookup global temp view if and only if global temp db is specified") { + try { + sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") + sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") + + checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) + + // we never lookup global temp views if database is not specified in table name + spark.catalog.dropTempView("same_name") + intercept[AnalysisException](sql("SELECT * FROM same_name")) + + // Use qualified name to lookup a global temp view. + checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) + } finally { + spark.catalog.dropTempView("same_name") + spark.catalog.dropGlobalTempView("same_name") + } + } + + test("public Catalog should recognize global temp view") { + try { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") + + assert(spark.catalog.tableExists(globalTempDB, "src")) + assert(spark.catalog.getTable(globalTempDB, "src").toString == new Table( + name = "src", + database = globalTempDB, + description = null, + tableType = "TEMPORARY", + isTemporary = true).toString) + } finally { + spark.catalog.dropGlobalTempView("src") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 6712d32924890..679150e9ae4c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.{DescribeFunctionCommand, DescribeTableCommand, ShowFunctionsCommand} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing} +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} /** * Parser test cases for rules defined in [[SparkSqlParser]]. @@ -35,8 +39,23 @@ class SparkSqlParserSuite extends PlanTest { private lazy val parser = new SparkSqlParser(new SQLConf) + /** + * Normalizes plans: + * - CreateTable the createTime in tableDesc will replaced by -1L. + */ + private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + plan match { + case CreateTable(tableDesc, mode, query) => + val newTableDesc = tableDesc.copy(createTime = -1L) + CreateTable(newTableDesc, mode, query) + case _ => plan // Don't transform + } + } + private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - comparePlans(parser.parsePlan(sqlCommand), plan) + val normalized1 = normalizePlan(parser.parsePlan(sqlCommand)) + val normalized2 = normalizePlan(plan) + comparePlans(normalized1, normalized2) } private def intercept(sqlCommand: String, messages: String*): Unit = { @@ -68,9 +87,124 @@ class SparkSqlParserSuite extends PlanTest { DescribeFunctionCommand(FunctionIdentifier("bar", database = None), isExtended = true)) assertEqual("describe function foo.bar", DescribeFunctionCommand( - FunctionIdentifier("bar", database = Option("foo")), isExtended = false)) + FunctionIdentifier("bar", database = Some("foo")), isExtended = false)) assertEqual("describe function extended f.bar", - DescribeFunctionCommand(FunctionIdentifier("bar", database = Option("f")), isExtended = true)) + DescribeFunctionCommand(FunctionIdentifier("bar", database = Some("f")), isExtended = true)) + } + + private def createTableUsing( + table: String, + database: Option[String] = None, + tableType: CatalogTableType = CatalogTableType.MANAGED, + storage: CatalogStorageFormat = CatalogStorageFormat.empty, + schema: StructType = new StructType, + provider: Option[String] = Some("parquet"), + partitionColumnNames: Seq[String] = Seq.empty, + bucketSpec: Option[BucketSpec] = None, + mode: SaveMode = SaveMode.ErrorIfExists, + query: Option[LogicalPlan] = None): CreateTable = { + CreateTable( + CatalogTable( + identifier = TableIdentifier(table, database), + tableType = tableType, + storage = storage, + schema = schema, + provider = provider, + partitionColumnNames = partitionColumnNames, + bucketSpec = bucketSpec + ), mode, query + ) + } + + private def createTable( + table: String, + database: Option[String] = None, + tableType: CatalogTableType = CatalogTableType.MANAGED, + storage: CatalogStorageFormat = CatalogStorageFormat.empty.copy( + inputFormat = HiveSerDe.sourceToSerDe("textfile").get.inputFormat, + outputFormat = HiveSerDe.sourceToSerDe("textfile").get.outputFormat), + schema: StructType = new StructType, + provider: Option[String] = Some("hive"), + partitionColumnNames: Seq[String] = Seq.empty, + comment: Option[String] = None, + mode: SaveMode = SaveMode.ErrorIfExists, + query: Option[LogicalPlan] = None): CreateTable = { + CreateTable( + CatalogTable( + identifier = TableIdentifier(table, database), + tableType = tableType, + storage = storage, + schema = schema, + provider = provider, + partitionColumnNames = partitionColumnNames, + comment = comment + ), mode, query + ) + } + + test("create table - schema") { + assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING)", + createTable( + table = "my_tab", + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + ) + ) + assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) " + + "PARTITIONED BY (c INT, d STRING COMMENT 'test2')", + createTable( + table = "my_tab", + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + .add("c", IntegerType) + .add("d", StringType, nullable = true, "test2"), + partitionColumnNames = Seq("c", "d") + ) + ) + assertEqual("CREATE TABLE my_tab(id BIGINT, nested STRUCT)", + createTable( + table = "my_tab", + schema = (new StructType) + .add("id", LongType) + .add("nested", (new StructType) + .add("col1", StringType) + .add("col2", IntegerType) + ) + ) + ) + // Partitioned by a StructType should be accepted by `SparkSqlParser` but will fail an analyze + // rule in `AnalyzeCreateTable`. + assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) " + + "PARTITIONED BY (nested STRUCT)", + createTable( + table = "my_tab", + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + .add("nested", (new StructType) + .add("col1", StringType) + .add("col2", IntegerType) + ), + partitionColumnNames = Seq("nested") + ) + ) + intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING)", + "no viable alternative at input") + } + + test("create table using - schema") { + assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet", + createTableUsing( + table = "my_tab", + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + ) + ) + intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet", + "no viable alternative at input") } test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 805b5667287ea..5f2a3aaff634c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ @@ -54,7 +54,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { expected: Int): Unit = { assertResult(expected, s"Wrong actualSize for $columnType") { - val row = new GenericMutableRow(1) + val row = new GenericInternalRow(1) row.update(0, CatalystTypeConverters.convertToCatalyst(value)) val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) columnType.actualSize(proj(row), 0) @@ -101,14 +101,15 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = { - val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE).order(ByteOrder.nativeOrder()) val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy()) + val totalSize = seq.map(_.getSizeInBytes).sum + val bufferSize = Math.max(DEFAULT_BUFFER_SIZE, totalSize) test(s"$columnType append/extract") { - buffer.rewind() - seq.foreach(columnType.append(_, 0, buffer)) + val buffer = ByteBuffer.allocate(bufferSize).order(ByteOrder.nativeOrder()) + seq.foreach(r => columnType.append(columnType.getField(r, 0), buffer)) buffer.rewind() seq.foreach { row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index 1529313dfbd51..686c8fa6f5fa9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -21,14 +21,14 @@ import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { - def makeNullRow(length: Int): GenericMutableRow = { - val row = new GenericMutableRow(length) + def makeNullRow(length: Int): GenericInternalRow = { + val row = new GenericInternalRow(length) (0 until length).foreach(row.setNullAt) row } @@ -86,7 +86,7 @@ object ColumnarTestUtils { tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { - val row = new GenericMutableRow(columnTypes.length) + val row = new GenericInternalRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value } @@ -95,11 +95,11 @@ object ColumnarTestUtils { def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], - count: Int): (Seq[T#InternalType], Seq[GenericMutableRow]) = { + count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => - val row = new GenericMutableRow(1) + val row = new GenericInternalRow(1) row(0) = value row } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala index dc22d3e8e4d3a..8f4ca3cea77a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( @@ -72,7 +72,7 @@ class NullableColumnAccessorSuite extends SparkFunSuite { } val accessor = TestNullableColumnAccessor(builder.build(), columnType) - val row = new GenericMutableRow(1) + val row = new GenericInternalRow(1) val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) (0 until 4).foreach { _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala index cdd4551d64b50..b2b6e92e9a056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) @@ -94,7 +94,7 @@ class NullableColumnBuilderSuite extends SparkFunSuite { (1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt())) // For non-null values - val actual = new GenericMutableRow(new Array[Any](1)) + val actual = new GenericInternalRow(new Array[Any](1)) (0 until 4).foreach { _ => columnType.extract(buffer, actual, 0) assert(converter(actual.get(0, dataType)) === converter(randomRow.get(0, dataType)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index f67e9c7dae278..d01bf911e3a77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ @@ -72,7 +72,7 @@ class BooleanBitSetSuite extends SparkFunSuite { buffer.rewind().position(headerSize + 4) val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) - val mutableRow = new GenericMutableRow(1) + val mutableRow = new GenericInternalRow(1) if (values.nonEmpty) { values.foreach { assert(decoder.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index babf944e6aa8e..9005ec93e786e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import org.apache.commons.lang3.RandomStringUtils import org.apache.commons.math3.distribution.LogNormalDistribution -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar.{BOOLEAN, INT, LONG, NativeColumnType, SHORT, STRING} import org.apache.spark.sql.types.AtomicType import org.apache.spark.util.Benchmark @@ -111,7 +111,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { input.rewind() benchmark.addCase(label)({ i: Int => - val rowBuf = new GenericMutableRow(1) + val rowBuf = new GenericInternalRow(1) for (n <- 0L until iters) { compressedBuf.rewind.position(4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index 830ca0294e1b8..67139b13d7882 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType @@ -97,7 +97,7 @@ class DictionaryEncodingSuite extends SparkFunSuite { buffer.rewind().position(headerSize + 4) val decoder = DictionaryEncoding.decoder(buffer, columnType) - val mutableRow = new GenericMutableRow(1) + val mutableRow = new GenericInternalRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index a530e270746c5..411d31fa0e29b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType @@ -48,7 +48,7 @@ class IntegralDeltaSuite extends SparkFunSuite { } input.foreach { value => - val row = new GenericMutableRow(1) + val row = new GenericInternalRow(1) columnType.setField(row, 0, value) builder.appendFrom(row, 0) } @@ -95,7 +95,7 @@ class IntegralDeltaSuite extends SparkFunSuite { buffer.rewind().position(headerSize + 4) val decoder = scheme.decoder(buffer, columnType) - val mutableRow = new GenericMutableRow(1) + val mutableRow = new GenericInternalRow(1) if (input.nonEmpty) { input.foreach{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index 95642e93ae9f0..dffa9b364ebfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType @@ -80,7 +80,7 @@ class RunLengthEncodingSuite extends SparkFunSuite { buffer.rewind().position(headerSize + 4) val decoder = RunLengthEncoding.decoder(buffer, columnType) - val mutableRow = new GenericMutableRow(1) + val mutableRow = new GenericInternalRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index b5499f2884c61..097dc2441351f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach -import org.apache.spark.internal.config._ import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} @@ -31,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -642,7 +642,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val csvFile = Thread.currentThread().getContextClassLoader.getResource("test-data/cars.csv").toString withView("testview") { - sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1: String, c2: String) USING " + + sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + s"OPTIONS (PATH '$csvFile')") @@ -969,17 +969,17 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { """.stripMargin) checkAnswer( sql("SHOW TABLES IN default 'show1*'"), - Row("show1a", true) :: Nil) + Row("", "show1a", true) :: Nil) checkAnswer( sql("SHOW TABLES IN default 'show1*|show2*'"), - Row("show1a", true) :: - Row("show2b", true) :: Nil) + Row("", "show1a", true) :: + Row("", "show2b", true) :: Nil) checkAnswer( sql("SHOW TABLES 'show1*|show2*'"), - Row("show1a", true) :: - Row("show2b", true) :: Nil) + Row("", "show1a", true) :: + Row("", "show2b", true) :: Nil) assert( sql("SHOW TABLES").count() >= 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 45411fa0656cd..c5deb31fec183 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.datasources -import java.io.File +import java.io._ import java.util.concurrent.atomic.AtomicInteger +import java.util.zip.GZIPOutputStream import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} @@ -441,6 +442,40 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("spark.files.ignoreCorruptFiles should work in SQL") { + val inputFile = File.createTempFile("input-", ".gz") + try { + // Create a corrupt gzip file + val byteOutput = new ByteArrayOutputStream() + val gzip = new GZIPOutputStream(byteOutput) + try { + gzip.write(Array[Byte](1, 2, 3, 4)) + } finally { + gzip.close() + } + val bytes = byteOutput.toByteArray + val o = new FileOutputStream(inputFile) + try { + // It's corrupt since we only write half of bytes into the file. + o.write(bytes.take(bytes.length / 2)) + } finally { + o.close() + } + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val e = intercept[SparkException] { + spark.read.text(inputFile.toURI.toString).collect() + } + assert(e.getCause.isInstanceOf[EOFException]) + assert(e.getCause.getMessage === "Unexpected end of input stream") + } + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { + assert(spark.read.text(inputFile.toURI.toString).collect().isEmpty) + } + } finally { + inputFile.delete() + } + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 29aac9def6924..f7c22c6c93f7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -856,4 +857,36 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(stringTimestampsWithFormat, expectedStringTimestampsWithFormat) } } + + test("load duplicated field names consistently with null or empty strings - case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + Seq("a,a,c,A,b,B").toDF().write.text(path.getAbsolutePath) + val actualSchema = spark.read + .format("csv") + .option("header", true) + .load(path.getAbsolutePath) + .schema + val fields = Seq("a0", "a1", "c", "A", "b", "B").map(StructField(_, StringType, true)) + val expectedSchema = StructType(fields) + assert(actualSchema == expectedSchema) + } + } + } + + test("load duplicated field names consistently with null or empty strings - case insensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + Seq("a,A,c,A,b,B").toDF().write.text(path.getAbsolutePath) + val actualSchema = spark.read + .format("csv") + .option("header", true) + .load(path.getAbsolutePath) + .schema + val fields = Seq("a0", "A1", "c", "A3", "b4", "B5").map(StructField(_, StringType, true)) + val expectedSchema = StructType(fields) + assert(actualSchema == expectedSchema) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index dae92f626c225..51832a13cfe0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.sql.{Date, Timestamp} -import java.text.SimpleDateFormat import java.util.Locale import org.apache.spark.SparkFunSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 3161a630af0f1..580eade4b1412 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -38,7 +38,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -716,7 +716,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { dataTypes.zip(constantValues).foreach { case (dt, v) => val schema = StructType(StructField("pcol", dt) :: Nil) val vectorizedReader = new VectorizedParquetRecordReader - val partitionValues = new GenericMutableRow(Array(v)) + val partitionValues = new GenericInternalRow(Array(v)) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2621c12655d1c..1a8d92868febf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT, SingleElement} import org.apache.spark.sql.internal.SQLConf @@ -729,7 +729,7 @@ object TestingUDT { .add("c", DoubleType, nullable = false) override def serialize(n: NestedStruct): Any = { - val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + val row = new SpecificInternalRow(sqlType.asInstanceOf[StructType].map(_.dataType)) row.setInt(0, n.a) row.setLong(1, n.b) row.setDouble(2, n.c) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 97adffa8ce101..83db81ea3f1c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -21,11 +21,13 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{LongType, ShortType} /** * Test various broadcast join operators. @@ -153,4 +155,49 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { cases.foreach(assertBroadcastJoin) } } + + test("join key rewritten") { + val l = Literal(1L) + val i = Literal(2) + val s = Literal.create(3, ShortType) + val ss = Literal("hello") + + assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil) + + assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) === + BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)), + BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil) + + assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) === + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) === + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) === + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) === + s :: s :: s :: s :: s :: Nil) + + assert(HashJoin.rewriteKeyExpr(ss :: Nil) === ss :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala new file mode 100644 index 0000000000000..ffda33cf906c5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.util.Utils + +class RowQueueSuite extends SparkFunSuite { + + test("in-memory queue") { + val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) + val queue = new InMemoryRowQueue(page, 1) { + override def close() {} + } + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](16), 16) + val n = page.size() / (4 + row.getSizeInBytes) + var i = 0 + while (i < n) { + row.setLong(0, i) + assert(queue.add(row), "fail to add") + i += 1 + } + assert(!queue.add(row), "should not add more") + i = 0 + while (i < n) { + val row = queue.remove() + assert(row != null, "fail to poll") + assert(row.getLong(0) == i, "does not match") + i += 1 + } + assert(queue.remove() == null, "should be empty") + queue.close() + } + + test("disk queue") { + val dir = Utils.createTempDir().getCanonicalFile + dir.mkdirs() + val queue = DiskRowQueue(new File(dir, "buffer"), 1) + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](16), 16) + val n = 1000 + var i = 0 + while (i < n) { + row.setLong(0, i) + assert(queue.add(row), "fail to add") + i += 1 + } + val first = queue.remove() + assert(first != null, "first should not be null") + assert(first.getLong(0) == 0, "first should be 0") + assert(!queue.add(row), "should not add more") + i = 1 + while (i < n) { + val row = queue.remove() + assert(row != null, "fail to poll") + assert(row.getLong(0) == i, "does not match") + i += 1 + } + assert(queue.remove() == null, "should be empty") + queue.close() + } + + test("hybrid queue") { + val mem = new TestMemoryManager(new SparkConf()) + mem.limit(4<<10) + val taskM = new TaskMemoryManager(mem, 0) + val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1) + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](16), 16) + val n = (4<<10) / 16 * 3 + var i = 0 + while (i < n) { + row.setLong(0, i) + assert(queue.add(row), "fail to add") + i += 1 + } + assert(queue.numQueues() > 1, "should have more than one queue") + queue.spill(1<<20, null) + i = 0 + while (i < n) { + val row = queue.remove() + assert(row != null, "fail to poll") + assert(row.getLong(0) == i, "does not match") + i += 1 + } + + // fill again and spill + i = 0 + while (i < n) { + row.setLong(0, i) + assert(queue.add(row), "fail to add") + i += 1 + } + assert(queue.numQueues() > 1, "should have more than one queue") + queue.spill(1<<20, null) + assert(queue.numQueues() > 1, "should have more than one queue") + i = 0 + while (i < n) { + val row = queue.remove() + assert(row != null, "fail to poll") + assert(row.getLong(0) == i, "does not match") + i += 1 + } + queue.close() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index 41a8cc2400dff..e1bc674a28071 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.charset.StandardCharsets.UTF_8 import org.apache.spark.SparkFunSuite @@ -133,9 +134,12 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin // scalastyle:on - assert(expected === new String(sinkLog.serialize(logs), UTF_8)) - - assert(VERSION === new String(sinkLog.serialize(Array()), UTF_8)) + val baos = new ByteArrayOutputStream() + sinkLog.serialize(logs, baos) + assert(expected === baos.toString(UTF_8.name())) + baos.reset() + sinkLog.serialize(Array(), baos) + assert(VERSION === baos.toString(UTF_8.name())) } } @@ -174,9 +178,9 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { blockSize = 30000L, action = FileStreamSinkLog.ADD_ACTION)) - assert(expected === sinkLog.deserialize(logs.getBytes(UTF_8))) + assert(expected === sinkLog.deserialize(new ByteArrayInputStream(logs.getBytes(UTF_8)))) - assert(Nil === sinkLog.deserialize(VERSION.getBytes(UTF_8))) + assert(Nil === sinkLog.deserialize(new ByteArrayInputStream(VERSION.getBytes(UTF_8)))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 6e60b0e4fad15..19b6d2603129c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -23,6 +23,7 @@ import org.mockito.Mockito.mock import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config import org.apache.spark.rdd.RDD import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SparkSession} @@ -446,7 +447,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { val conf = new SparkConf() .setMaster("local") .setAppName("test") - .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly + .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly val sc = new SparkContext(conf) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 3c60b233c2b04..df640ffab91de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -19,11 +19,15 @@ package org.apache.spark.sql.internal import org.apache.hadoop.fs.Path -import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext} +import org.apache.spark.SparkContext +import org.apache.spark.sql._ import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} class SQLConfSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + private val testKey = "test.key.0" private val testVal = "test.val.0" @@ -250,4 +254,25 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } } } + + test("static SQL conf comes from SparkConf") { + val previousValue = sparkContext.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) + try { + sparkContext.conf.set(SCHEMA_STRING_LENGTH_THRESHOLD, 2000) + val newSession = new SparkSession(sparkContext) + assert(newSession.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) == 2000) + checkAnswer( + newSession.sql(s"SET ${SCHEMA_STRING_LENGTH_THRESHOLD.key}"), + Row(SCHEMA_STRING_LENGTH_THRESHOLD.key, "2000")) + } finally { + sparkContext.conf.set(SCHEMA_STRING_LENGTH_THRESHOLD, previousValue) + } + } + + test("cannot set/unset static SQL conf") { + val e1 = intercept[AnalysisException](sql(s"SET ${SCHEMA_STRING_LENGTH_THRESHOLD.key}=10")) + assert(e1.message.contains("Cannot modify the value of a static config")) + val e2 = intercept[AnalysisException](spark.conf.unset(SCHEMA_STRING_LENGTH_THRESHOLD.key)) + assert(e2.message.contains("Cannot modify the value of a static config")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 10f15ca280689..71cf5e6a22916 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -29,8 +29,7 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JdbcUtils} import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -84,7 +83,7 @@ class JDBCSuite extends SparkFunSuite |CREATE TEMPORARY TABLE fetchtwo |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', - | ${JdbcUtils.JDBC_BATCH_FETCH_SIZE} '2') + | ${JDBCOptions.JDBC_BATCH_FETCH_SIZE} '2') """.stripMargin.replaceAll("\n", " ")) sql( @@ -354,8 +353,8 @@ class JDBCSuite extends SparkFunSuite test("Basic API with illegal fetchsize") { val properties = new Properties() - properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "-1") - val e = intercept[SparkException] { + properties.setProperty(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "-1") + val e = intercept[IllegalArgumentException] { spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties).collect() }.getMessage assert(e.contains("Invalid value `-1` for parameter `fetchsize`")) @@ -364,7 +363,7 @@ class JDBCSuite extends SparkFunSuite test("Basic API with FetchSize") { (0 to 4).foreach { size => val properties = new Properties() - properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, size.toString) + properties.setProperty(JDBCOptions.JDBC_BATCH_FETCH_SIZE, size.toString) assert(spark.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } @@ -788,7 +787,7 @@ class JDBCSuite extends SparkFunSuite test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") { val df = spark.createDataset(Seq("a", "b", "c")).toDF("order") - val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") + val schema = JdbcUtils.schemaString(df.schema, "jdbc:mysql://localhost:3306/temp") assert(schema.contains("`order` TEXT")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 506971362f867..96540ec92da73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.{Row, SaveMode} -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -113,8 +113,8 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { (-1 to 0).foreach { size => val properties = new Properties() - properties.setProperty(JdbcUtils.JDBC_BATCH_INSERT_SIZE, size.toString) - val e = intercept[SparkException] { + properties.setProperty(JDBCOptions.JDBC_BATCH_INSERT_SIZE, size.toString) + val e = intercept[IllegalArgumentException] { df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties) }.getMessage assert(e.contains(s"Invalid value `$size` for parameter `batchsize`")) @@ -126,12 +126,25 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { (1 to 3).foreach { size => val properties = new Properties() - properties.setProperty(JdbcUtils.JDBC_BATCH_INSERT_SIZE, size.toString) + properties.setProperty(JDBCOptions.JDBC_BATCH_INSERT_SIZE, size.toString) df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties) assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).count()) } } + test("CREATE with ignore") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + + df.write.mode(SaveMode.Ignore).jdbc(url1, "TEST.DROPTEST", properties) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count()) + assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + + df2.write.mode(SaveMode.Ignore).jdbc(url1, "TEST.DROPTEST", properties) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count()) + assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + } + test("CREATE with overwrite") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 3157afe5a56c0..7f9c981a4e9c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -342,6 +342,24 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("read from textfile") { + withTempDirs { case (src, tmp) => + val textStream = spark.readStream.textFile(src.getCanonicalPath) + val filtered = textStream.filter(_.contains("keep")) + + testStream(filtered)( + AddTextFileData("drop1\nkeep2\nkeep3", src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddTextFileData("drop4\nkeep5\nkeep6", src, tmp), + StartStream(), + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddTextFileData("drop7\nkeep8\nkeep9", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + } + } + test("SPARK-17165 should not track the list of seen files indefinitely") { // This test works by: // 1. Create a file diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java index 37e4845cceb9e..341a7fdbb59b8 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java @@ -37,11 +37,15 @@ import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.server.TServlet; +import org.eclipse.jetty.server.AbstractConnectionFactory; +import org.eclipse.jetty.server.ConnectionFactory; +import org.eclipse.jetty.server.HttpConnectionFactory; import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletHolder; import org.eclipse.jetty.util.ssl.SslContextFactory; import org.eclipse.jetty.util.thread.ExecutorThreadPool; +import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler; public class ThriftHttpCLIService extends ThriftCLIService { @@ -70,7 +74,8 @@ public void run() { httpServer = new org.eclipse.jetty.server.Server(threadPool); // Connector configs - ServerConnector connector = new ServerConnector(httpServer); + + ConnectionFactory[] connectionFactories; boolean useSsl = hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_USE_SSL); String schemeName = useSsl ? "https" : "http"; // Change connector if SSL is used @@ -90,8 +95,21 @@ public void run() { Arrays.toString(sslContextFactory.getExcludeProtocols())); sslContextFactory.setKeyStorePath(keyStorePath); sslContextFactory.setKeyStorePassword(keyStorePassword); - connector = new ServerConnector(httpServer, sslContextFactory); + connectionFactories = AbstractConnectionFactory.getFactories( + sslContextFactory, new HttpConnectionFactory()); + } else { + connectionFactories = new ConnectionFactory[] { new HttpConnectionFactory() }; } + ServerConnector connector = new ServerConnector( + httpServer, + null, + // Call this full constructor to set this, which forces daemon threads: + new ScheduledExecutorScheduler("HiveServer2-HttpHandler-JettyScheduler", true), + null, + -1, + -1, + connectionFactories); + connector.setPort(portNum); // Linux:yes, Windows:no connector.setReuseAddress(!Shell.WINDOWS); diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 261cc6feff090..237b829da882f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.types.{DataType, StructType} @@ -111,6 +112,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat s"as table property keys may not start with '$DATASOURCE_PREFIX' or '$STATISTICS_PREFIX':" + s" ${invalidKeys.mkString("[", ", ", "]")}") } + // External users are not allowed to set/switch the table type. In Hive metastore, the table + // type can be switched by changing the value of a case-sensitive table property `EXTERNAL`. + if (table.properties.contains("EXTERNAL")) { + throw new AnalysisException("Cannot set or change the preserved property key: 'EXTERNAL'") + } } // -------------------------------------------------------------------------- @@ -201,11 +207,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Serialized JSON schema string may be too long to be stored into a single metastore table // property. In this case, we split the JSON string and store each part as a separate table // property. - // TODO: the threshold should be set by `spark.sql.sources.schemaStringLengthThreshold`, - // however the current SQLConf is session isolated, which is not applicable to external - // catalog. We should re-enable this conf instead of hard code the value here, after we have - // global SQLConf. - val threshold = 4000 + val threshold = conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) val schemaJsonString = tableDefinition.schema.json // Split the JSON string. val parts = schemaJsonString.grouped(threshold).toSeq @@ -464,13 +466,18 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } else { table.storage } + val tableProps = if (conf.get(DEBUG_MODE)) { + table.properties + } else { + getOriginalTableProperties(table) + } table.copy( storage = storage, schema = getSchemaFromTableProperties(table), provider = Some(provider), partitionColumnNames = getPartitionColumnsFromTableProperties(table), bucketSpec = getBucketSpecFromTableProperties(table), - properties = getOriginalTableProperties(table)) + properties = tableProps) } getOrElse { table.copy(provider = Some("hive")) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index fe34caa0a3e48..1625116803505 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -688,25 +688,25 @@ private[hive] trait HiveInspectors { * @return A function that performs in-place updating of a MutableRow. * Use the overloaded ObjectInspector version for assignments. */ - def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit = + def unwrapperFor(field: HiveStructField): (Any, InternalRow, Int) => Unit = field.getFieldObjectInspector match { case oi: BooleanObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) case oi: ByteObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) case oi: ShortObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) case oi: IntObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) case oi: LongObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) case oi: FloatObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) case oi: DoubleObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) case oi => val unwrapper = unwrapperFor(oi) - (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value) + (value: Any, row: InternalRow, ordinal: Int) => row(ordinal) = unwrapper(value) } def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 85c509847d8ef..85ecf0ce70756 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.rules.Rule @@ -41,6 +41,7 @@ import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, + globalTempViewManager: GlobalTempViewManager, sparkSession: SparkSession, functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, @@ -48,6 +49,7 @@ private[sql] class HiveSessionCatalog( hadoopConf: Configuration) extends SessionCatalog( externalCatalog, + globalTempViewManager, functionResourceLoader, functionRegistry, conf, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index eb10c11382e83..6d4fe1a941a98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -45,6 +45,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) override lazy val catalog = { new HiveSessionCatalog( sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog], + sparkSession.sharedState.globalTempViewManager, sparkSession, functionResourceLoader, functionRegistry, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 39d71e164bf51..a5ef8723c8b6f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -23,7 +23,6 @@ import java.nio.charset.StandardCharsets import java.sql.Timestamp import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import scala.language.implicitConversions @@ -36,11 +35,11 @@ import org.apache.hadoop.util.VersionInfo import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.sql._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index ec7e53efc87f9..2a54163a04e9b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -120,7 +120,7 @@ class HadoopTableReader( val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) val attrsWithIndex = attributes.zipWithIndex - val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHadoopConf.value.value @@ -215,7 +215,7 @@ class HadoopTableReader( val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHadoopConf val localDeserializer = partDeserializer - val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) // Splits all attributes into two groups, partition key attributes and those that are not. // Attached indices indicate the position of each attribute in the output schema. @@ -224,7 +224,7 @@ class HadoopTableReader( relation.partitionKeys.contains(attr) } - def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow): Unit = { + def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = { partitionKeyAttrs.foreach { case (attr, ordinal) => val partOrdinal = relation.partitionKeys.indexOf(attr) row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) @@ -360,7 +360,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { iterator: Iterator[Writable], rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow, + mutableRow: InternalRow, tableDeser: Deserializer): Iterator[InternalRow] = { val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { @@ -381,43 +381,43 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { * Builds specific unwrappers ahead of time according to object inspector * types to avoid pattern matching and branching costs per row. */ - val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { + val unwrappers: Seq[(Any, InternalRow, Int) => Unit] = fieldRefs.map { _.getFieldObjectInspector match { case oi: BooleanObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) case oi: ByteObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) case oi: ShortObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) case oi: IntObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) case oi: LongObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) case oi: FloatObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) case oi: DoubleObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) case oi: HiveVarcharObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveCharObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) case oi: TimestampObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value))) case oi: DateObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.setInt(ordinal, DateTimeUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) case oi: BinaryObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.update(ordinal, oi.getPrimitiveJavaObject(value)) case oi => val unwrapper = unwrapperFor(oi) - (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value) + (value: Any, row: InternalRow, ordinal: Int) => row(ordinal) = unwrapper(value) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index c553c03a9b708..1025b8f70d9ff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -124,7 +124,7 @@ case class ScriptTransformation( } else { null } - val mutableRow = new SpecificMutableRow(output.map(_.dataType)) + val mutableRow = new SpecificInternalRow(output.map(_.dataType)) @transient lazy val unwrappers = outputSoi.getAllStructFieldRefs.asScala.map(unwrapperFor) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index d54913518bb33..42033080dc34b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -329,17 +329,17 @@ private[hive] case class HiveUDAFFunction( // buffer for it. override def aggBufferSchema: StructType = StructType(Nil) - override def update(_buffer: MutableRow, input: InternalRow): Unit = { + override def update(_buffer: InternalRow, input: InternalRow): Unit = { val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes)) } - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { throw new UnsupportedOperationException( "Hive UDAF doesn't support partial aggregate") } - override def initialize(_buffer: MutableRow): Unit = { + override def initialize(_buffer: InternalRow): Unit = { buffer = function.getNewAggregationBuffer } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 15b72d8d2179f..e94f49ea81177 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -281,7 +281,7 @@ private[orc] object OrcRelation extends HiveInspectors { maybeStructOI: Option[StructObjectInspector], iterator: Iterator[Writable]): Iterator[InternalRow] = { val deserializer = new OrcSerde - val mutableRow = new SpecificMutableRow(dataSchema.map(_.dataType)) + val mutableRow = new SpecificInternalRow(dataSchema.map(_.dataType)) val unsafeProjection = UnsafeProjection.create(dataSchema) def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 163f210802b53..6eb571b91ffab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder @@ -40,6 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.internal.{SharedState, SQLConf} +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.{ShutdownHookManager, Utils} // SPARK-3729: Test key required to check for initialization errors with config. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 9ac1e86fc82cb..c7f10e569fa4d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -45,7 +45,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { // Used for generating new query answer files by saving private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" - private val goldenSQLPath = "src/test/resources/sqlgen/" + private val goldenSQLPath = getTestResourcePath("sqlgen") protected override def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala index 57363b7259c61..939fd71b4f1ed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala @@ -87,11 +87,11 @@ class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEac assert( hc.sql("SELECT * FROM moo_table order by name").collect().toSeq == df.collect().toSeq.sortBy(_.getString(0))) - val tables = hc.sql("SHOW TABLES IN mee_db").collect().map(_.getString(0)) + val tables = hc.sql("SHOW TABLES IN mee_db").select("tableName").collect().map(_.getString(0)) assert(tables.toSet == Set("moo_table", "mee_table")) hc.sql("DROP TABLE moo_table") hc.sql("DROP TABLE mee_table") - val tables2 = hc.sql("SHOW TABLES IN mee_db").collect().map(_.getString(0)) + val tables2 = hc.sql("SHOW TABLES IN mee_db").select("tableName").collect().map(_.getString(0)) assert(tables2.isEmpty) hc.sql("USE default") hc.sql("DROP DATABASE mee_db CASCADE") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 54e27b6f73502..9ce3338647398 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -243,7 +243,7 @@ class HiveDDLCommandSuite extends PlanTest { .asInstanceOf[ScriptTransformation].copy(ioschema = null) val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan3 = parser.parsePlan("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e") + val plan3 = parser.parsePlan("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e") .asInstanceOf[ScriptTransformation].copy(ioschema = null) val p = ScriptTransformation( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 29317e2887861..d3873cf6c8231 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -152,7 +152,8 @@ class HiveSparkSubmitSuite case v if v.startsWith("2.10") || v.startsWith("2.11") => v.substring(0, 4) case x => throw new Exception(s"Unsupported Scala Version: $x") } - val testJar = s"sql/hive/src/test/resources/regression-test-SPARK-8489/test-$version.jar" + val jarDir = getTestResourcePath("regression-test-SPARK-8489") + val testJar = s"$jarDir/test-$version.jar" val args = Seq( "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 6eeb67510c735..15ba61646d03f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -58,10 +58,10 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft // We are using default DB. checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), - Row("listtablessuitetable", true)) + Row("", "listtablessuitetable", true)) checkAnswer( allTables.filter("tableName = 'hivelisttablessuitetable'"), - Row("hivelisttablessuitetable", false)) + Row("default", "hivelisttablessuitetable", false)) assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0) } } @@ -71,11 +71,11 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft case allTables => checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), - Row("listtablessuitetable", true)) + Row("", "listtablessuitetable", true)) assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hiveindblisttablessuitetable'"), - Row("hiveindblisttablessuitetable", false)) + Row("listtablessuitedb", "hiveindblisttablessuitetable", false)) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 8ae6868c9848a..7cc6179d44977 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,15 +23,16 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path +import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.hive.HiveExternalCatalog._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -699,28 +700,27 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("SPARK-6024 wide schema support") { - withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD.key -> "4000") { - withTable("wide_schema") { - withTempDir { tempDir => - // We will need 80 splits for this schema if the threshold is 4000. - val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType))) - - val tableDesc = CatalogTable( - identifier = TableIdentifier("wide_schema"), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - properties = Map("path" -> tempDir.getCanonicalPath) - ), - schema = schema, - provider = Some("json") - ) - spark.sessionState.catalog.createTable(tableDesc, ignoreIfExists = false) - - sessionState.refreshTable("wide_schema") - - val actualSchema = table("wide_schema").schema - assert(schema === actualSchema) - } + assert(spark.sparkContext.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) == 4000) + withTable("wide_schema") { + withTempDir { tempDir => + // We will need 80 splits for this schema if the threshold is 4000. + val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType))) + + val tableDesc = CatalogTable( + identifier = TableIdentifier("wide_schema"), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat.empty.copy( + properties = Map("path" -> tempDir.getCanonicalPath) + ), + schema = schema, + provider = Some("json") + ) + spark.sessionState.catalog.createTable(tableDesc, ignoreIfExists = false) + + sessionState.refreshTable("wide_schema") + + val actualSchema = table("wide_schema").schema + assert(schema === actualSchema) } } } @@ -984,7 +984,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv checkAnswer( spark.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), - Row("ttt3", false)) + Row("testdb8156", "ttt3", false)) spark.sql("""use default""") spark.sql("""drop database if exists testdb8156 CASCADE""") } @@ -1325,4 +1325,18 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv hiveClient.dropTable("default", "t", ignoreIfNotExists = true, purge = true) } } + + test("should keep data source entries in table properties when debug mode is on") { + val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) + try { + sparkSession.sparkContext.conf.set(DEBUG_MODE, true) + val newSession = sparkSession.newSession() + newSession.sql("CREATE TABLE abc(i int) USING json") + val tableMeta = newSession.sessionState.catalog.getTableMetadata(TableIdentifier("abc")) + assert(tableMeta.properties(DATASOURCE_SCHEMA_NUMPARTS).toInt == 1) + assert(tableMeta.properties(DATASOURCE_PROVIDER) == "json") + } finally { + sparkSession.sparkContext.conf.set(DEBUG_MODE, previousValue) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala index b2103b3bfc36c..2c772ce2155ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -94,15 +94,15 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("CREATE TABLE show2b(c2 int)") checkAnswer( sql("SHOW TABLES IN default 'show1*'"), - Row("show1a", false) :: Nil) + Row("default", "show1a", false) :: Nil) checkAnswer( sql("SHOW TABLES IN default 'show1*|show2*'"), - Row("show1a", false) :: - Row("show2b", false) :: Nil) + Row("default", "show1a", false) :: + Row("default", "show2b", false) :: Nil) checkAnswer( sql("SHOW TABLES 'show1*|show2*'"), - Row("show1a", false) :: - Row("show2b", false) :: Nil) + Row("default", "show1a", false) :: + Row("default", "show2b", false) :: Nil) assert( sql("SHOW TABLES").count() >= 2) assert( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 751e976c7b908..3d1712e4354c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach -import org.apache.spark.internal.config._ import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} @@ -32,6 +31,7 @@ import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils class HiveDDLSuite @@ -315,6 +315,38 @@ class HiveDDLSuite assert(message.contains("Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) } + test("create table - SET TBLPROPERTIES EXTERNAL to TRUE") { + val tabName = "tab1" + withTable(tabName) { + val message = intercept[AnalysisException] { + sql(s"CREATE TABLE $tabName (height INT, length INT) TBLPROPERTIES('EXTERNAL'='TRUE')") + }.getMessage + assert(message.contains("Cannot set or change the preserved property key: 'EXTERNAL'")) + } + } + + test("alter table - SET TBLPROPERTIES EXTERNAL to TRUE") { + val tabName = "tab1" + withTable(tabName) { + val catalog = spark.sessionState.catalog + sql(s"CREATE TABLE $tabName (height INT, length INT)") + assert( + catalog.getTableMetadata(TableIdentifier(tabName)).tableType == CatalogTableType.MANAGED) + val message = intercept[AnalysisException] { + sql(s"ALTER TABLE $tabName SET TBLPROPERTIES ('EXTERNAL' = 'TRUE')") + }.getMessage + assert(message.contains("Cannot set or change the preserved property key: 'EXTERNAL'")) + // The table type is not changed to external + assert( + catalog.getTableMetadata(TableIdentifier(tabName)).tableType == CatalogTableType.MANAGED) + // The table property is case sensitive. Thus, external is allowed + sql(s"ALTER TABLE $tabName SET TBLPROPERTIES ('external' = 'TRUE')") + // The table type is not changed to external + assert( + catalog.getTableMetadata(TableIdentifier(tabName)).tableType == CatalogTableType.MANAGED) + } + } + test("alter views and alter table - misuse") { val tabName = "tab1" withTable(tabName) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6c77a0deb52a4..6f2a16662bf10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -66,13 +66,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import spark.implicits._ test("script") { + val scriptFilePath = getTestResourcePath("test_script.sh") if (testCommandAvailable("bash") && testCommandAvailable("echo | sed")) { val df = Seq(("x1", "y1", "z1"), ("x2", "y2", "z2")).toDF("c1", "c2", "c3") df.createOrReplaceTempView("script_table") val query1 = sql( - """ + s""" |SELECT col1 FROM (from(SELECT c1, c2, c3 FROM script_table) tempt_table - |REDUCE c1, c2, c3 USING 'bash src/test/resources/test_script.sh' AS + |REDUCE c1, c2, c3 USING 'bash $scriptFilePath' AS |(col1 STRING, col2 STRING)) script_test_table""".stripMargin) checkAnswer(query1, Row("x1_y1") :: Row("x2_y2") :: Nil) } @@ -1290,11 +1291,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .selectExpr("id AS a", "id AS b") .createOrReplaceTempView("test") + val scriptFilePath = getTestResourcePath("data") checkAnswer( sql( - """FROM( + s"""FROM( | FROM test SELECT TRANSFORM(a, b) - | USING 'python src/test/resources/data/scripts/test_transform.py "\t"' + | USING 'python $scriptFilePath/scripts/test_transform.py "\t"' | AS (c STRING, d STRING) |) t |SELECT c @@ -1308,12 +1310,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .selectExpr("id AS a", "id AS b") .createOrReplaceTempView("test") + val scriptFilePath = getTestResourcePath("data") val df = sql( - """FROM test + s"""FROM test |SELECT TRANSFORM(a, b) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |WITH SERDEPROPERTIES('field.delim' = '|') - |USING 'python src/test/resources/data/scripts/test_transform.py "|"' + |USING 'python $scriptFilePath/scripts/test_transform.py "|"' |AS (c STRING, d STRING) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |WITH SERDEPROPERTIES('field.delim' = '|') diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala index f5c605fe5e2fa..2af935da689c9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala @@ -62,15 +62,15 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { var e = intercept[AnalysisException] { sql("CREATE OR REPLACE VIEW tab1 AS SELECT * FROM jt") }.getMessage - assert(e.contains("`default`.`tab1` is not a view")) + assert(e.contains("`tab1` is not a view")) e = intercept[AnalysisException] { sql("CREATE VIEW tab1 AS SELECT * FROM jt") }.getMessage - assert(e.contains("`default`.`tab1` is not a view")) + assert(e.contains("`tab1` is not a view")) e = intercept[AnalysisException] { sql("ALTER VIEW tab1 AS SELECT * FROM jt") }.getMessage - assert(e.contains("`default`.`tab1` is not a view")) + assert(e.contains("`tab1` is not a view")) } }