From 1fe6ed1f6de4c49f608784cae0e87ad5955a65b2 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Mon, 30 Dec 2024 16:21:23 -0800 Subject: [PATCH] Improvement spark delta-sharing client: convert expires_in as string to int, if returned as string --- .../sharing/client/auth/OAuthClient.scala | 31 ++++++- .../client/auth/OAuthClientSuite.scala | 87 ++++++++++++------- 2 files changed, 86 insertions(+), 32 deletions(-) diff --git a/client/src/main/scala/io/delta/sharing/client/auth/OAuthClient.scala b/client/src/main/scala/io/delta/sharing/client/auth/OAuthClient.scala index 930bec6c0..5703a671e 100644 --- a/client/src/main/scala/io/delta/sharing/client/auth/OAuthClient.scala +++ b/client/src/main/scala/io/delta/sharing/client/auth/OAuthClient.scala @@ -77,6 +77,8 @@ private[client] class OAuthClient(httpClient: } private def parseOAuthTokenResponse(response: String): OAuthClientCredentials = { + // Parsing the response per oauth spec + // https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 if (response == null || response.isEmpty) { throw new RuntimeException("Empty response from OAuth token endpoint") } @@ -84,13 +86,38 @@ private[client] class OAuthClient(httpClient: if (!jsonNode.has("access_token") || !jsonNode.get("access_token").isTextual) { throw new RuntimeException("Missing 'access_token' field in OAuth token response") } - if (!jsonNode.has("expires_in") || !jsonNode.get("expires_in").isNumber) { + if (!jsonNode.has("expires_in")) { throw new RuntimeException("Missing 'expires_in' field in OAuth token response") } + // OAuth spec requires 'expires_in' to be an integer, e.g., 3600. + // See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 + // But some token endpoints return `expires_in` as a string e.g., "3600". + // This ensures that we support both integer and string values for 'expires_in' field. + // Example request resulting in 'expires_in' as a string: + // curl -X POST \ + // https://login.windows.net/$TENANT_ID/oauth2/token \ + // -H "Content-Type: application/x-www-form-urlencoded" \ + // -d "grant_type=client_credentials" \ + // -d "client_id=$CLIENT_ID" \ + // -d "client_secret=$CLIENT_SECRET" \ + // -d "scope=https://graph.microsoft.com/.default" + val expiresIn : Long = jsonNode.get("expires_in") match { + case n if n.isNumber => n.asLong() + case n if n.isTextual => + try { + n.asText().toLong + } catch { + case _: NumberFormatException => + throw new RuntimeException("Invalid 'expires_in' field in OAuth token response") + } + case _ => + throw new RuntimeException("Invalid 'expires_in' field in OAuth token response") + } + OAuthClientCredentials( jsonNode.get("access_token").asText(), - jsonNode.get("expires_in").asLong(), + expiresIn, System.currentTimeMillis() ) } diff --git a/client/src/test/scala/io/delta/sharing/client/auth/OAuthClientSuite.scala b/client/src/test/scala/io/delta/sharing/client/auth/OAuthClientSuite.scala index c57f75297..d6274e305 100644 --- a/client/src/test/scala/io/delta/sharing/client/auth/OAuthClientSuite.scala +++ b/client/src/test/scala/io/delta/sharing/client/auth/OAuthClientSuite.scala @@ -24,8 +24,9 @@ import org.apache.http.impl.bootstrap.{HttpServer, ServerBootstrap} import org.apache.http.impl.client.{CloseableHttpClient, HttpClients} import org.apache.http.protocol.{HttpContext, HttpRequestHandler} import org.apache.spark.SparkFunSuite +import org.scalatest.prop.TableDrivenPropertyChecks -class OAuthClientSuite extends SparkFunSuite { +class OAuthClientSuite extends SparkFunSuite with TableDrivenPropertyChecks { var server: HttpServer = _ def startServer(handler: HttpRequestHandler): Unit = { @@ -58,40 +59,66 @@ class OAuthClientSuite extends SparkFunSuite { throw new RuntimeException(s"Port $port is not released after $timeoutMillis milliseconds") } - test("OAuthClient should parse token response correctly") { - val handler = new HttpRequestHandler { - @throws[HttpException] - @throws[IOException] - override def handle(request: HttpRequest, - response: HttpResponse, - context: HttpContext): Unit = { - val responseBody = - """{ - | "access_token": "test-access-token", - | "expires_in": 3600, - | "token_type": "bearer" - |}""".stripMargin - response.setEntity(new StringEntity(responseBody, ContentType.APPLICATION_JSON)) - response.setStatusCode(200) + case class TokenExchangeSuccessScenario(responseBody: String, + expectedAccessToken: String, + expectedExpiresIn: Long) + + // OAuth spec requires 'expires_in' to be an integer, e.g., 3600. + // See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 + // But some token endpoints return `expires_in` as a string e.g., "3600". + // This test ensures the client can handle such cases. + // The test case ensures that we support both integer and string values for 'expires_in' field. + private val tokenExchangeSuccessScenarios = Table( + "testScenario", + TokenExchangeSuccessScenario( + responseBody = """{ + | "access_token": "test-access-token", + | "expires_in": 3600, + | "token_type": "bearer" + |}""".stripMargin, + expectedAccessToken = "test-access-token", + expectedExpiresIn = 3600 + ), + TokenExchangeSuccessScenario( + responseBody = """{ + | "access_token": "test-access-token", + | "expires_in": "3600", + | "token_type": "bearer" + |}""".stripMargin, + expectedAccessToken = "test-access-token", + expectedExpiresIn = 3600 + ) + ) + + forAll(tokenExchangeSuccessScenarios) { testScenario => + test("OAuthClient should parse token response correctly") { + val handler = new HttpRequestHandler { + @throws[HttpException] + @throws[IOException] + override def handle(request: HttpRequest, + response: HttpResponse, + context: HttpContext): Unit = { + response.setEntity( + new StringEntity(testScenario.responseBody, ContentType.APPLICATION_JSON)) + response.setStatusCode(200) + } } - } - startServer(handler) + startServer(handler) - val httpClient: CloseableHttpClient = HttpClients.createDefault() - val oauthClient = new OAuthClient(httpClient, AuthConfig(), - "http://localhost:1080/token", "client-id", "client-secret") - - val start = System.currentTimeMillis() + val httpClient: CloseableHttpClient = HttpClients.createDefault() + val oauthClient = new OAuthClient(httpClient, AuthConfig(), + "http://localhost:1080/token", "client-id", "client-secret") - val token = oauthClient.clientCredentials() + val start = System.currentTimeMillis() + val token = oauthClient.clientCredentials() + val end = System.currentTimeMillis() - val end = System.currentTimeMillis() + assert(token.accessToken == testScenario.expectedAccessToken) + assert(token.expiresIn == testScenario.expectedExpiresIn) + assert(token.creationTimestamp >= start && token.creationTimestamp <= end) - assert(token.accessToken == "test-access-token") - assert(token.expiresIn == 3600) - assert(token.creationTimestamp >= start && token.creationTimestamp <= end) - - stopServer() + stopServer() + } } test("OAuthClient should handle 401 Unauthorized response") {