Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,6 @@ class BuildPlugin implements Plugin<Project> {

// Main variant needs the least configuration on its own, since it is the default publication created above.
sparkVariants.defaultVariant { SparkVariant variant ->
project.publishing.publications.main.setAlias(true)
updateVariantArtifactId(project, project.publishing.publications.main, variant)
}

Expand Down Expand Up @@ -651,6 +650,7 @@ class BuildPlugin implements Plugin<Project> {
}
configurePom(project, variantPublication)
updateVariantArtifactId(project, variantPublication, variant)
variantPublication.setAlias(true)
}
}
if (signingKey.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@

public class SparkVariantPlugin implements Plugin<Project> {

public static final String ITEST_SOURCE_SET_NAME = "itest";

public static class SparkVariant {

private final CharSequence name;
Expand Down Expand Up @@ -311,6 +313,23 @@ private static void configureDefaultVariant(Project project, SparkVariant sparkV
runtimeElements.getOutgoing().capability(capability);

configureScalaJarClassifiers(project, sparkVariant);
// Extend main and test source set for the main variant - this enables the possibility of having diverging code between variants
SourceSetContainer sourceSets = javaPluginExtension.getSourceSets();
ScalaSourceDirectorySet scalaSourceSet = getScalaSourceSet(sourceSets.getByName(MAIN_SOURCE_SET_NAME));
scalaSourceSet.setSrcDirs(Arrays.asList(
"src/" + MAIN_SOURCE_SET_NAME + "/scala",
"src/" + MAIN_SOURCE_SET_NAME + "/" + sparkVariant.getName()
));
ScalaSourceDirectorySet scalaTestSourceSet = getScalaSourceSet(sourceSets.getByName(TEST_SOURCE_SET_NAME));
scalaTestSourceSet.setSrcDirs(Arrays.asList(
"src/" + TEST_SOURCE_SET_NAME + "/scala",
"src/" + TEST_SOURCE_SET_NAME + "/" + sparkVariant.getName()
));
ScalaSourceDirectorySet scalaITestSourceSet = getScalaSourceSet(sourceSets.getByName(ITEST_SOURCE_SET_NAME));
scalaITestSourceSet.setSrcDirs(Arrays.asList(
"src/" + ITEST_SOURCE_SET_NAME + "/scala",
"src/" + ITEST_SOURCE_SET_NAME + "/" + sparkVariant.getName()
));
}

private static void configureVariant(Project project, SparkVariant sparkVariant, JavaPluginExtension javaPluginExtension) {
Expand Down
1 change: 1 addition & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ spark20Version = 2.3.0
spark22Version = 2.2.3
spark24Version = 2.4.4
spark30Version = 3.4.3
spark35Version = 3.5.6

# same as Spark's
scala210Version = 2.10.7
Expand Down
27 changes: 27 additions & 0 deletions spark/core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ sparkVariants {
// We should maybe move these to a separate config file that can be read from both this file and the pipeline script in the future if it creates issues
setCoreDefaultVariant "spark30scala213", spark30Version, scala213Version
addCoreFeatureVariant "spark30scala212", spark30Version, scala212Version
addCoreFeatureVariant "spark35scala212", spark35Version, scala212Version
addCoreFeatureVariant "spark35scala213", spark35Version, scala213Version

all { SparkVariantPlugin.SparkVariant variant ->

Expand Down Expand Up @@ -44,6 +46,11 @@ sparkVariants {
add(variant.configuration('api'), "org.apache.spark:spark-core_${variant.scalaMajorVersion}:${variant.sparkVersion}") {
exclude group: 'org.apache.hadoop'
}
if (variant.sparkVersion >= "3.5.0") {
add(variant.configuration('implementation'), "org.apache.spark:spark-common-utils_${variant.scalaMajorVersion}:$variant.sparkVersion") {
exclude group: 'org.apache.hadoop'
}
}

add(variant.configuration('implementation'), project(":elasticsearch-hadoop-mr"))
add(variant.configuration('implementation'), "commons-logging:commons-logging:1.1.1")
Expand Down Expand Up @@ -126,3 +133,23 @@ tasks.withType(ScalaCompile) { ScalaCompile task ->
task.targetCompatibility = project.ext.minimumRuntimeVersion
task.options.forkOptions.executable = new File(project.ext.runtimeJavaHome, 'bin/java').absolutePath
}

tasks.register('copyPoms', Copy) {
from(tasks.named('generatePomFileForMainPublication')) {
rename 'pom-default.xml', "elasticsearch-spark-30_2.13-${project.getVersion()}.pom"
}
from(tasks.named('generatePomFileForSpark30scala212Publication')) {
rename 'pom-default.xml', "elasticsearch-spark-30_2.12-${project.getVersion()}.pom"
}
from(tasks.named('generatePomFileForSpark35scala212Publication')) {
rename 'pom-default.xml', "elasticsearch-spark-35_2.12-${project.getVersion()}.pom"
}
from(tasks.named('generatePomFileForSpark35scala213Publication')) {
rename 'pom-default.xml', "elasticsearch-spark-35_2.13-${project.getVersion()}.pom"
}
into(new File(project.buildDir, 'distributions'))
}

tasks.named('distribution').configure {
dependsOn 'copyPoms'
}
20 changes: 18 additions & 2 deletions spark/sql-30/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ apply plugin: 'spark.variants'

sparkVariants {
capabilityGroup 'org.elasticsearch.spark.sql.variant'
setDefaultVariant "spark30scala213", spark30Version, scala213Version
addFeatureVariant "spark30scala212", spark30Version, scala212Version
setCoreDefaultVariant "spark30scala213", spark30Version, scala213Version
addCoreFeatureVariant "spark30scala212", spark30Version, scala212Version
addCoreFeatureVariant "spark35scala213", spark35Version, scala213Version
addCoreFeatureVariant "spark35scala212", spark35Version, scala212Version

all { SparkVariantPlugin.SparkVariant variant ->
String scalaCompileTaskName = project.sourceSets
Expand Down Expand Up @@ -58,6 +60,14 @@ sparkVariants {
exclude group: 'javax.servlet'
exclude group: 'org.apache.hadoop'
}
if (variant.sparkVersion >= "3.5.0") {
add(variant.configuration('implementation'), "org.apache.spark:spark-common-utils_${variant.scalaMajorVersion}:$variant.sparkVersion") {
exclude group: 'org.apache.hadoop'
}
add(variant.configuration('implementation'), "org.apache.spark:spark-sql-api_${variant.scalaMajorVersion}:$variant.sparkVersion") {
exclude group: 'org.apache.hadoop'
}
}

add(variant.configuration('implementation'), "org.apache.spark:spark-sql_${variant.scalaMajorVersion}:$variant.sparkVersion") {
exclude group: 'org.apache.hadoop'
Expand Down Expand Up @@ -198,6 +208,12 @@ tasks.register('copyPoms', Copy) {
from(tasks.named('generatePomFileForSpark30scala212Publication')) {
rename 'pom-default.xml', "elasticsearch-spark-30_2.12-${project.getVersion()}.pom"
}
from(tasks.named('generatePomFileForSpark35scala212Publication')) {
rename 'pom-default.xml', "elasticsearch-spark-35_2.12-${project.getVersion()}.pom"
}
from(tasks.named('generatePomFileForSpark35scala213Publication')) {
rename 'pom-default.xml', "elasticsearch-spark-35_2.13-${project.getVersion()}.pom"
}
into(new File(project.buildDir, 'distributions'))
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package org.elasticsearch.spark.sql.streaming

import org.apache.spark.sql.streaming.StreamingQueryListener
import org.elasticsearch.hadoop.util.unit.TimeValue
import org.junit.Assert

import java.util.UUID
import java.util.concurrent.{CountDownLatch, TimeUnit}

// Listener to 1) ensure no more than a single stream is running at a time, 2) know when we're done processing inputs
// and 3) to capture any Exceptions encountered during the execution of the stream.
class StreamingQueryLifecycleListener extends StreamingQueryListener {

private var uuid: Option[UUID] = None

private var inputsRequired = 0L
private var inputsSeen = 0L

private var expectingToThrow: Option[Class[_]] = None
private var foundExpectedException: Boolean = false
private var encounteredException: Option[String] = None

private var latch = new CountDownLatch(1) // expects just a single batch

def incrementExpected(): Unit = inputsRequired = inputsRequired + 1

def setExpectedException(clazz: Class[_]): Unit = {
expectingToThrow match {
case Some(cls) => throw new IllegalArgumentException(s"Already expecting exception of type [$cls]!")
case None => expectingToThrow = Some(clazz)
}
}

// Make sure we only ever watch one query at a time.
private def captureQueryID(eventId: UUID): Unit = {
uuid match {
case Some(id) if eventId != id => throw new IllegalStateException("Multiple queries are not supported")
case None => uuid = Some(eventId)
case _ => // No problem for now
}
}

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
captureQueryID(event.id)
}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
captureQueryID(event.progress.id)

// keep track of the number of input rows seen. When we reach the number of expected inputs,
// wait for two 0 values to pass before signaling to close

val rows = event.progress.numInputRows
inputsSeen = inputsSeen + rows

if (inputsSeen == inputsRequired) {
if (rows == 0) {
// Don't close after meeting the first input level. Wait to make sure we get
// one last pass with no new rows processed before signalling.
latch.countDown()
}
} else if (inputsSeen > inputsRequired) {
throw new IllegalStateException("Too many inputs encountered. Expected [" + inputsRequired +
"] but found [" + inputsSeen + "]")
}
}

protected def onQueryIdle(eventId: UUID): Unit = {
captureQueryID(eventId)
if (inputsSeen == inputsRequired) {
latch.countDown()
}
}

override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
try {
captureQueryID(event.id)

encounteredException = event.exception match {
case Some(value) =>
// This is a whole trace, get everything after the enclosing SparkException (i.e. the original exception + trace)
val messageParts = value.split("\\): ")
if (messageParts.size > 1) {
val nonSparkMessage = messageParts(1)
// Ditch the full trace after the first newline
val removedNewLine = nonSparkMessage.substring(0, nonSparkMessage.indexOf("\n"))
// Split the class name from the exception message and take the class name
Some(removedNewLine.substring(0, removedNewLine.indexOf(":")))
} else {
// Return the original framework error
Some(value.substring(0, value.indexOf(":")))
}
case None => None
}

val expectedExceptionName = expectingToThrow.map(_.getCanonicalName).getOrElse("None")

foundExpectedException = encounteredException.exists(_.equals(expectedExceptionName))
} finally {
// signal no matter what to avoid deadlock
latch.countDown()
}
}

def waitOnComplete(timeValue: TimeValue): Boolean = latch.await(timeValue.millis, TimeUnit.MILLISECONDS)

def expectAnotherBatch(): Unit = {
latch = new CountDownLatch(1)
}

def assertExpectedExceptions(message: Option[String]): Unit = {
expectingToThrow match {
case Some(exceptionClass) =>
if (!foundExpectedException) {
encounteredException match {
case Some(s) => Assert.fail(s"Expected ${exceptionClass.getCanonicalName} but got $s")
case None => Assert.fail(message.getOrElse(s"Expected ${exceptionClass.getCanonicalName} but no Exceptions were thrown"))
}
}
case None =>
encounteredException match {
case Some(exception) => Assert.fail(s"Expected no exceptions but got $exception")
case None => ()
}
}
}
}
Loading