Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Support for Custom Session+Proxy Configurations #644

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,30 @@

package io.delta.sharing.client

import java.io.File
import java.net.{URI, URLDecoder, URLEncoder}
import java.util.concurrent.TimeUnit

import org.apache.hadoop.fs._
import org.apache.hadoop.fs.permission.FsPermission
import org.apache.hadoop.util.Progressable
import org.apache.http.{HttpClientConnection, HttpHost, HttpRequest, HttpResponse}
import org.apache.http.auth.{AuthScope, UsernamePasswordCredentials}
import org.apache.http.{HttpClientConnection, HttpHost, HttpRequest, HttpRequestInterceptor, HttpResponse}
import org.apache.http.client.config.RequestConfig
import org.apache.http.client.utils.URIBuilder
import org.apache.http.conn.routing.HttpRoute
import org.apache.http.impl.client.{BasicCredentialsProvider, HttpClientBuilder, RequestWrapper}
import org.apache.http.conn.routing.{HttpRoute, HttpRoutePlanner}
import org.apache.http.conn.ssl.NoopHostnameVerifier
import org.apache.http.conn.ssl.TrustSelfSignedStrategy
import org.apache.http.impl.client.{CloseableHttpClient, HttpClientBuilder, RequestWrapper}
import org.apache.http.impl.conn.{DefaultRoutePlanner, DefaultSchemePortResolver}
import org.apache.http.protocol.{HttpContext, HttpRequestExecutor}
import org.apache.http.ssl.SSLContextBuilder
import org.apache.spark.SparkEnv
import org.apache.spark.delta.sharing.{PreSignedUrlCache, PreSignedUrlFetcher}
import org.apache.spark.internal.Logging

import io.delta.sharing.client.model.FileAction
import io.delta.sharing.client.util.ConfUtils
import io.delta.sharing.client.util.ConfUtils.ProxyConfig

/** Read-only file system for delta paths. */
private[sharing] class DeltaSharingFileSystem extends FileSystem with Logging {
Expand All @@ -44,74 +48,128 @@ private[sharing] class DeltaSharingFileSystem extends FileSystem with Logging {

lazy private val numRetries = ConfUtils.numRetries(getConf)
lazy private val maxRetryDurationMillis = ConfUtils.maxRetryDurationMillis(getConf)
lazy private val timeoutInSeconds = ConfUtils.timeoutInSeconds(getConf)
lazy private val httpClient = createHttpClient()

private[sharing] def createHttpClient() = {
val proxyConfigOpt = ConfUtils.getProxyConfig(getConf)
val maxConnections = ConfUtils.maxConnections(getConf)
val config = RequestConfig.custom()
.setConnectTimeout(timeoutInSeconds * 1000)
.setConnectionRequestTimeout(timeoutInSeconds * 1000)
.setSocketTimeout(timeoutInSeconds * 1000).build()
private[sharing] def createHttpClient(): CloseableHttpClient = {
val conf = getConf
val timeoutInMillis = ConfUtils.getTimeoutInMillis(conf)
val proxyConfigOpt = ConfUtils.getProxyConfig(conf)
val maxConnections = ConfUtils.maxConnections(conf)
val customHeadersOpt = ConfUtils.getCustomHeaders(conf)
val neverUseHttps = ConfUtils.getNeverUseHttps(conf)

val requestConfig = RequestConfig.custom()
.setConnectTimeout(timeoutInMillis)
.setConnectionRequestTimeout(timeoutInMillis)
.setSocketTimeout(timeoutInMillis)
.build()

logDebug(s"Creating HTTP client with timeoutInMillis: $timeoutInMillis")

logDebug(s"Creating delta sharing httpClient with timeoutInSeconds: $timeoutInSeconds.")
val clientBuilder = HttpClientBuilder.create()
.setMaxConnTotal(maxConnections)
.setMaxConnPerRoute(maxConnections)
.setDefaultRequestConfig(config)
// Disable the default retry behavior because we have our own retry logic.
// See `RetryUtils.runWithExponentialBackoff`.
.setDefaultRequestConfig(requestConfig)
.disableAutomaticRetries()

// Set proxy if provided.
proxyConfigOpt.foreach { proxyConfig =>
configureProxy(clientBuilder, proxyConfig, neverUseHttps)
}

customHeadersOpt.foreach { headers =>
addCustomHeaders(clientBuilder, headers)
}

clientBuilder.build()
}

private def configureProxy(clientBuilder: HttpClientBuilder, proxyConfig: ProxyConfig,
neverUseHttps: Boolean): Unit = {

val proxy = new HttpHost(proxyConfig.host, proxyConfig.port)
clientBuilder.setProxy(proxy)

val neverUseHttps = ConfUtils.getNeverUseHttps(getConf)
if (neverUseHttps) {
val httpRequestDowngradeExecutor = new HttpRequestExecutor {
override def execute(
request: HttpRequest,
connection: HttpClientConnection,
context: HttpContext): HttpResponse = {
try {
val modifiedUri: URI = {
new URIBuilder(request.getRequestLine.getUri).setScheme("http").build()
}
val wrappedRequest = new RequestWrapper(request)
wrappedRequest.setURI(modifiedUri)

return super.execute(wrappedRequest, connection, context)
} catch {
case e: Exception =>
logInfo("Failed to downgrade the request to http", e)
}
super.execute(request, connection, context)
}
val proxy = new HttpHost(proxyConfig.host, proxyConfig.port)
clientBuilder.setProxy(proxy)

proxyConfig.authToken.foreach { token =>
clientBuilder.addInterceptorFirst(new HttpRequestInterceptor {
override def process(request: HttpRequest, context: HttpContext): Unit = {
request.addHeader("Proxy-Authorization", s"Bearer $token")
}
clientBuilder.setRequestExecutor(httpRequestDowngradeExecutor)
})
}

configureSSL(clientBuilder, proxyConfig)

if (neverUseHttps) {
clientBuilder.setRequestExecutor(createHttpRequestDowngradeExecutor())
}

if (proxyConfig.noProxyHosts.nonEmpty || neverUseHttps) {
clientBuilder.setRoutePlanner(createRoutePlanner(proxy, proxyConfig.noProxyHosts))
}
}

private def configureSSL(clientBuilder: HttpClientBuilder, proxyConfig: ProxyConfig): Unit = {
if (proxyConfig.sslTrustAll) {
clientBuilder.setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE)
clientBuilder.setSSLContext(
new SSLContextBuilder()
.loadTrustMaterial(null, new TrustSelfSignedStrategy)
.build()
)
} else {
proxyConfig.caCertPath.foreach { path =>
clientBuilder.setSSLContext(
new SSLContextBuilder()
.loadTrustMaterial(new File(path), null)
.build()
)
}
if (proxyConfig.noProxyHosts.nonEmpty || neverUseHttps) {
val routePlanner = new DefaultRoutePlanner(DefaultSchemePortResolver.INSTANCE) {
override def determineRoute(target: HttpHost,
request: HttpRequest,
context: HttpContext): HttpRoute = {
if (proxyConfig.noProxyHosts.contains(target.getHostName)) {
// Direct route (no proxy)
new HttpRoute(target)
} else {
// Route via proxy
new HttpRoute(target, proxy)
}
}
}
}

private def createHttpRequestDowngradeExecutor(): HttpRequestExecutor = {
new HttpRequestExecutor() {
override def execute(request: HttpRequest, conn: HttpClientConnection,
context: HttpContext): HttpResponse = {
try {
val modifiedUri = new URIBuilder(request.getRequestLine.getUri)
.setScheme("http")
.build()
val wrappedRequest = new RequestWrapper(request)
wrappedRequest.setURI(modifiedUri)
super.execute(wrappedRequest, conn, context)
} catch {
case e: Exception =>
logInfo("Failed to downgrade the request to HTTP", e)
super.execute(request, conn, context)
}
clientBuilder.setRoutePlanner(routePlanner)
}
}
clientBuilder.build()
}

private def createRoutePlanner(proxy: HttpHost, noProxyHosts: Seq[String]): HttpRoutePlanner = {
new DefaultRoutePlanner(DefaultSchemePortResolver.INSTANCE) {
override def determineRoute(target: HttpHost, request: HttpRequest,
context: HttpContext): HttpRoute = {
if (noProxyHosts.contains(target.getHostName)) {
// Direct route (no proxy)
new HttpRoute(target)
} else {
// Route via proxy
new HttpRoute(target, proxy)
}
}
}
}

private def addCustomHeaders(clientBuilder: HttpClientBuilder,
headers: Map[String, String]): Unit = {
ConfUtils.validateCustomHeaders(headers)
clientBuilder.addInterceptorFirst(new HttpRequestInterceptor {
override def process(request: HttpRequest, context: HttpContext): Unit = {
headers.foreach { case (key, value) => request.addHeader(key, value) }
}
})
}

private lazy val refreshThresholdMs = getConf.getLong(
Expand All @@ -122,7 +180,7 @@ private[sharing] class DeltaSharingFileSystem extends FileSystem with Logging {

override def getScheme: String = SCHEME

override def getUri(): URI = URI.create(s"$SCHEME:///")
override def getUri: URI = URI.create(s"$SCHEME:///")

// open a file path with the format below:
// ```
Expand Down Expand Up @@ -204,7 +262,7 @@ private[sharing] class DeltaSharingFileSystem extends FileSystem with Logging {

private[sharing] object DeltaSharingFileSystem {

val SCHEME = "delta-sharing"
private val SCHEME = "delta-sharing"

case class DeltaSharingPath(tablePath: String, fileId: String, fileSize: Long) {

Expand Down
49 changes: 44 additions & 5 deletions client/src/main/scala/io/delta/sharing/client/util/ConfUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package io.delta.sharing.client.util

import com.fasterxml.jackson.databind.ObjectMapper
import java.util.concurrent.TimeUnit

import org.apache.hadoop.conf.Configuration
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -81,6 +81,10 @@ object ConfUtils {
val PROXY_PORT = "spark.delta.sharing.network.proxyPort"
val NO_PROXY_HOSTS = "spark.delta.sharing.network.noProxyHosts"

val CUSTOM_HEADERS = "spark.delta.sharing.network.customHeaders"
val PROXY_AUTH_TOKEN = "spark.delta.sharing.network.proxyAuthToken"
val CA_CERT_PATH = "spark.delta.sharing.network.caCertPath"

val OAUTH_RETRIES_CONF = "spark.delta.sharing.oauth.tokenExchangeMaxRetries"
val OAUTH_RETRIES_DEFAULT = 5

Expand Down Expand Up @@ -108,8 +112,40 @@ object ConfUtils {
val proxyPort = proxyPortAsString.toInt
validatePortNumber(proxyPort, PROXY_PORT)

val noProxyList = conf.getTrimmedStrings(NO_PROXY_HOSTS).toSeq
Some(ProxyConfig(proxyHost, proxyPort, noProxyHosts = noProxyList))
Some(ProxyConfig(
host = proxyHost,
port = proxyPort,
noProxyHosts = conf.getTrimmedStrings(NO_PROXY_HOSTS).toSeq,
authToken = Option(conf.get(PROXY_AUTH_TOKEN, null)),
caCertPath = Option(conf.get(CA_CERT_PATH, null)),
sslTrustAll = conf.getBoolean(SSL_TRUST_ALL_CONF, SSL_TRUST_ALL_DEFAULT.toBoolean)
))
}


def getCustomHeaders(conf: Configuration): Option[Map[String, String]] = {
val headersString = conf.get(CUSTOM_HEADERS, null)
if (headersString != null && headersString.nonEmpty) {
val mapper = new ObjectMapper()
val headers = mapper.readValue(headersString, classOf[Map[String, String]])
Some(headers)
} else {
None
}
}

def validateCustomHeaders(headers: Map[String, String]): Unit = {
headers.foreach { case (key, value) =>
require(key != null && key.nonEmpty, "Custom header name must not be null or empty")
require(value != null, s"Custom header value for '$key' must not be null")
}
}

def getTimeoutInMillis(conf: Configuration): Int = {
val timeoutStr = conf.get(TIMEOUT_CONF, TIMEOUT_DEFAULT)
val timeoutMillis = JavaUtils.timeStringAs(timeoutStr, TimeUnit.MILLISECONDS)
validateNonNeg(timeoutMillis, TIMEOUT_CONF)
timeoutMillis.toInt
}

def getNeverUseHttps(conf: Configuration): Boolean = {
Expand Down Expand Up @@ -325,7 +361,10 @@ object ConfUtils {
}

case class ProxyConfig(host: String,
port: Int,
noProxyHosts: Seq[String] = Seq.empty
port: Int,
noProxyHosts: Seq[String] = Seq.empty,
authToken: Option[String] = None,
caCertPath: Option[String] = None,
sslTrustAll: Boolean = false
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ class DeltaSharingFileSystemSuite extends SparkFunSuite {
proxyServer.initialize()

try {

// Create a ProxyConfig with the host and port of the local proxy server.
val conf = new Configuration
conf.set(ConfUtils.PROXY_HOST, proxyServer.getHost())
conf.set(ConfUtils.PROXY_PORT, proxyServer.getPort().toString)
conf.set(ConfUtils.PROXY_AUTH_TOKEN, "testAuthToken")
conf.set(ConfUtils.CA_CERT_PATH, "/path/to/ca_cert.pem")
conf.set(ConfUtils.SSL_TRUST_ALL_CONF, "true")

// Configure the httpClient to use the ProxyConfig.
val fs = new DeltaSharingFileSystem() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,19 @@ class ConfUtilsSuite extends SparkFunSuite {
val conf = newConf(Map(
PROXY_HOST -> "1.2.3.4",
PROXY_PORT -> "8080",
NO_PROXY_HOSTS -> "localhost,127.0.0.1"
NO_PROXY_HOSTS -> "localhost,127.0.0.1",
PROXY_AUTH_TOKEN -> "testAuthToken",
CA_CERT_PATH -> "/path/to/ca_cert.pem",
SSL_TRUST_ALL_CONF -> "true"
))
val proxyConfig = getProxyConfig(conf)
assert(proxyConfig.isDefined)
assert(proxyConfig.get.host == "1.2.3.4")
assert(proxyConfig.get.port == 8080)
assert(proxyConfig.get.noProxyHosts == Seq("localhost", "127.0.0.1"))
assert(proxyConfig.get.authToken.contains("testAuthToken"))
assert(proxyConfig.get.caCertPath.contains("/path/to/ca_cert.pem"))
assert(proxyConfig.get.sslTrustAll)
}

test("getProxyConfig with only host and port") {
Expand All @@ -132,6 +138,9 @@ class ConfUtilsSuite extends SparkFunSuite {
assert(proxyConfig.get.host == "1.2.3.4")
assert(proxyConfig.get.port == 8080)
assert(proxyConfig.get.noProxyHosts.isEmpty)
assert(proxyConfig.get.authToken.isEmpty)
assert(proxyConfig.get.caCertPath.isEmpty)
assert(!proxyConfig.get.sslTrustAll)
}

test("getProxyConfig with no proxy settings") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkFunSuite

import io.delta.sharing.client.util.{RetryUtils, UnexpectedHttpStatus}
import io.delta.sharing.client.util.RetryUtils._
import io.delta.sharing.spark.MissingEndStreamActionException

Expand Down
Loading
Loading