Skip to content

Commit

Permalink
Add support for http proxy (#450) (#451)
Browse files Browse the repository at this point in the history
Adds support for Http Proxy when reading data from storage

sample configuration:

```
spark.delta.sharing.network.proxyHost=1.2.3.4
spark.delta.sharing.network.proxyPort=3128
spark.delta.sharing.network.noProxyHosts=5.6.7.8,12.13.14.15
```
  • Loading branch information
moderakh authored Jan 11, 2024
1 parent 9512548 commit b3af8b9
Show file tree
Hide file tree
Showing 5 changed files with 418 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,69 @@ import java.util.concurrent.TimeUnit
import org.apache.hadoop.fs._
import org.apache.hadoop.fs.permission.FsPermission
import org.apache.hadoop.util.Progressable
import org.apache.http.{HttpHost, HttpRequest}
import org.apache.http.auth.{AuthScope, UsernamePasswordCredentials}
import org.apache.http.client.config.RequestConfig
import org.apache.http.impl.client.HttpClientBuilder
import org.apache.http.conn.routing.HttpRoute
import org.apache.http.impl.client.{BasicCredentialsProvider, HttpClientBuilder}
import org.apache.http.impl.conn.{DefaultRoutePlanner, DefaultSchemePortResolver}
import org.apache.http.protocol.HttpContext
import org.apache.spark.SparkEnv
import org.apache.spark.delta.sharing.{PreSignedUrlCache, PreSignedUrlFetcher}
import org.apache.spark.network.util.JavaUtils

import io.delta.sharing.client.model.FileAction
import io.delta.sharing.client.util.{ConfUtils, RetryUtils}
import io.delta.sharing.client.util.ConfUtils

/** Read-only file system for delta paths. */
private[sharing] class DeltaSharingFileSystem extends FileSystem {

import DeltaSharingFileSystem._

lazy private val numRetries = ConfUtils.numRetries(getConf)
lazy private val maxRetryDurationMillis = ConfUtils.maxRetryDurationMillis(getConf)
lazy private val timeoutInSeconds = ConfUtils.timeoutInSeconds(getConf)
lazy private val httpClient = createHttpClient()

lazy private val httpClient = {
private[sharing] def createHttpClient() = {
val proxyConfigOpt = ConfUtils.getProxyConfig(getConf)
val maxConnections = ConfUtils.maxConnections(getConf)
val config = RequestConfig.custom()
.setConnectTimeout(timeoutInSeconds * 1000)
.setConnectionRequestTimeout(timeoutInSeconds * 1000)
.setSocketTimeout(timeoutInSeconds * 1000).build()
HttpClientBuilder.create()

val clientBuilder = HttpClientBuilder.create()
.setMaxConnTotal(maxConnections)
.setMaxConnPerRoute(maxConnections)
.setDefaultRequestConfig(config)
// Disable the default retry behavior because we have our own retry logic.
// See `RetryUtils.runWithExponentialBackoff`.
.disableAutomaticRetries()
.build()

// Set proxy if provided.
proxyConfigOpt.foreach { proxyConfig =>

val proxy = new HttpHost(proxyConfig.host, proxyConfig.port)
clientBuilder.setProxy(proxy)

if (proxyConfig.noProxyHosts.nonEmpty) {
val routePlanner = new DefaultRoutePlanner(DefaultSchemePortResolver.INSTANCE) {
override def determineRoute(target: HttpHost,
request: HttpRequest,
context: HttpContext): HttpRoute = {
if (proxyConfig.noProxyHosts.contains(target.getHostName)) {
// Direct route (no proxy)
new HttpRoute(target)
} else {
// Route via proxy
new HttpRoute(target, proxy)
}
}
}
clientBuilder.setRoutePlanner(routePlanner)
}
}
clientBuilder.build()
}

private lazy val refreshThresholdMs = getConf.getLong(
Expand Down
38 changes: 38 additions & 0 deletions client/src/main/scala/io/delta/sharing/client/util/ConfUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,27 @@ object ConfUtils {
val LIMIT_PUSHDOWN_ENABLED_CONF = "spark.delta.sharing.limitPushdown.enabled"
val LIMIT_PUSHDOWN_ENABLED_DEFAULT = "true"

val PROXY_HOST = "spark.delta.sharing.network.proxyHost"
val PROXY_PORT = "spark.delta.sharing.network.proxyPort"
val NO_PROXY_HOSTS = "spark.delta.sharing.network.noProxyHosts"

def getProxyConfig(conf: Configuration): Option[ProxyConfig] = {
val proxyHost = conf.get(PROXY_HOST, null)
val proxyPortAsString = conf.get(PROXY_PORT, null)

if (proxyHost == null && proxyPortAsString == null) {
return None
}

validateNonEmpty(proxyHost, PROXY_HOST)
validateNonEmpty(proxyPortAsString, PROXY_PORT)
val proxyPort = proxyPortAsString.toInt
validatePortNumber(proxyPort, PROXY_PORT)

val noProxyList = conf.getTrimmedStrings(NO_PROXY_HOSTS).toSeq
Some(ProxyConfig(proxyHost, proxyPort, noProxyHosts = noProxyList))
}

def numRetries(conf: Configuration): Int = {
val numRetries = conf.getInt(NUM_RETRIES_CONF, NUM_RETRIES_DEFAULT)
validateNonNeg(numRetries, NUM_RETRIES_CONF)
Expand Down Expand Up @@ -170,4 +191,21 @@ object ConfUtils {
throw new IllegalArgumentException(conf + " must be positive")
}
}

private def validateNonEmpty(value: String, conf: String): Unit = {
if (value == null || value.isEmpty) {
throw new IllegalArgumentException(conf + " must be defined")
}
}

private def validatePortNumber(value: Int, conf: String): Unit = {
if (value <= 0 || value > 65535) {
throw new IllegalArgumentException(conf + " must be a valid port number")
}
}

case class ProxyConfig(host: String,
port: Int,
noProxyHosts: Seq[String] = Seq.empty
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@

package io.delta.sharing.client

import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.http.client.methods.HttpGet
import org.apache.http.util.EntityUtils
import org.apache.spark.SparkFunSuite
import org.sparkproject.jetty.server.Server
import org.sparkproject.jetty.servlet.{ServletHandler, ServletHolder}

import io.delta.sharing.client.model.{AddCDCFile, AddFile, AddFileForCDF, FileAction, RemoveFile}
import io.delta.sharing.client.model._
import io.delta.sharing.client.util.{ConfUtils, ProxyServer}

class DeltaSharingFileSystemSuite extends SparkFunSuite {
import DeltaSharingFileSystem._
Expand Down Expand Up @@ -58,4 +64,173 @@ class DeltaSharingFileSystemSuite extends SparkFunSuite {
assert(fs eq path.getFileSystem(conf))
})
}

test("traffic goes through a proxy when a proxy configured") {
// Create a local HTTP server.
val server = new Server(0)
val handler = new ServletHandler()
server.setHandler(handler)
handler.addServletWithMapping(new ServletHolder(new HttpServlet {
override def doGet(req: HttpServletRequest, resp: HttpServletResponse): Unit = {
resp.setContentType("text/plain")
resp.setStatus(HttpServletResponse.SC_OK)

// scalastyle:off println
resp.getWriter.println("Hello, World!")
// scalastyle:on println
}
}), "/*")
server.start()
do {
Thread.sleep(100)
} while (!server.isStarted())

// Create a local HTTP proxy server.
val proxyServer = new ProxyServer(0)
proxyServer.initialize()

try {

// Create a ProxyConfig with the host and port of the local proxy server.
val conf = new Configuration
conf.set(ConfUtils.PROXY_HOST, proxyServer.getHost())
conf.set(ConfUtils.PROXY_PORT, proxyServer.getPort().toString)

// Configure the httpClient to use the ProxyConfig.
val fs = new DeltaSharingFileSystem() {
override def getConf = {
conf
}
}

// Get http client instance.
val httpClient = fs.createHttpClient()

// Send a request to the local server through the httpClient.
val response = httpClient.execute(new HttpGet(server.getURI.toString))

// Assert that the request is successful.
assert(response.getStatusLine.getStatusCode == HttpServletResponse.SC_OK)
val content = EntityUtils.toString(response.getEntity)
assert(content.trim == "Hello, World!")

// Assert that the request is passed through proxy.
assert(proxyServer.getCapturedRequests().size == 1)
} finally {
server.stop()
proxyServer.stop()
}
}

test("traffic skips the proxy when a noProxyHosts configured") {
// Create a local HTTP server.
val server = new Server(0)
val handler = new ServletHandler()
server.setHandler(handler)
handler.addServletWithMapping(new ServletHolder(new HttpServlet {
override def doGet(req: HttpServletRequest, resp: HttpServletResponse): Unit = {
resp.setContentType("text/plain")
resp.setStatus(HttpServletResponse.SC_OK)

// scalastyle:off println
resp.getWriter.println("Hello, World!")
// scalastyle:on println
}
}), "/*")
server.start()
do {
Thread.sleep(100)
} while (!server.isStarted())

// Create a local HTTP proxy server.
val proxyServer = new ProxyServer(0)
proxyServer.initialize()
try {
// Create a ProxyConfig with the host and port of the local proxy server and noProxyHosts.
val conf = new Configuration
conf.set(ConfUtils.PROXY_HOST, proxyServer.getHost())
conf.set(ConfUtils.PROXY_PORT, proxyServer.getPort().toString)
conf.set(ConfUtils.NO_PROXY_HOSTS, server.getURI.getHost)

// Configure the httpClient to use the ProxyConfig.
val fs = new DeltaSharingFileSystem() {
override def getConf = {
conf
}
}

// Get http client instance.
val httpClient = fs.createHttpClient()

// Send a request to the local server through the httpClient.
val response = httpClient.execute(new HttpGet(server.getURI.toString))

// Assert that the request is successful.
assert(response.getStatusLine.getStatusCode == HttpServletResponse.SC_OK)
val content = EntityUtils.toString(response.getEntity)
assert(content.trim == "Hello, World!")

// Assert that the request is not passed through proxy.
assert(proxyServer.getCapturedRequests().isEmpty)
} finally {
server.stop()
proxyServer.stop()
}
}

test("traffic goes through the proxy when noProxyHosts does not include destination") {
// Create a local HTTP server.
val server = new Server(0)
val handler = new ServletHandler()
server.setHandler(handler)
handler.addServletWithMapping(new ServletHolder(new HttpServlet {
override def doGet(req: HttpServletRequest, resp: HttpServletResponse): Unit = {
resp.setContentType("text/plain")
resp.setStatus(HttpServletResponse.SC_OK)

// scalastyle:off println
resp.getWriter.println("Hello, World!")
// scalastyle:on println
}
}), "/*")
server.start()
do {
Thread.sleep(100)
} while (!server.isStarted())

// Create a local HTTP proxy server.
val proxyServer = new ProxyServer(0)
proxyServer.initialize()
try {
// Create a ProxyConfig with the host and port of the local proxy server and noProxyHosts.
val conf = new Configuration
conf.set(ConfUtils.PROXY_HOST, proxyServer.getHost())
conf.set(ConfUtils.PROXY_PORT, proxyServer.getPort().toString)
conf.set(ConfUtils.NO_PROXY_HOSTS, "1.2.3.4")

// Configure the httpClient to use the ProxyConfig.
val fs = new DeltaSharingFileSystem() {
override def getConf = {
conf
}
}

// Get http client instance.
val httpClient = fs.createHttpClient()

// Send a request to the local server through the httpClient.
val response = httpClient.execute(new HttpGet(server.getURI.toString))

// Assert that the request is successful.
assert(response.getStatusLine.getStatusCode == HttpServletResponse.SC_OK)
val content = EntityUtils.toString(response.getEntity)
assert(content.trim == "Hello, World!")

// Assert that the request is not passed through proxy.
assert(proxyServer.getCapturedRequests().size == 1)
} finally {
server.stop()
proxyServer.stop()
}
}
}
Loading

0 comments on commit b3af8b9

Please sign in to comment.