Skip to content

Commit 2ece06e

Browse files
support tls for grpc websocket (#378)
Support tls E2E Production test: 1. to server which does not support tls a. mode = disabled (default) b. enable tls or mtls will error --> trigger retry and fallback logic 2. to server which requires tls 1. tls with valid certs -> connect 2. mtls with valid client certs -> connect other cases fallback to use insecure channel which ultimately will fail and trigger retry and fallback logic Added unit test for error handling behavior more e2e integration test to be found here statsig-io/long-running-sdk#29
1 parent 475a643 commit 2ece06e

File tree

5 files changed

+190
-23
lines changed

5 files changed

+190
-23
lines changed

src/main/kotlin/com/statsig/sdk/SpecUpdater.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ internal class SpecUpdater(
3535
private inline fun <reified T> Gson.fromJson(json: String) = fromJson<T>(json, object : TypeToken<T>() {}.type)
3636

3737
fun initialize() {
38+
transport.setStreamingFallback(NetworkEndpoint.DOWNLOAD_CONFIG_SPECS) {
39+
this.transport.downloadConfigSpecsFromStatsig(
40+
this.lastUpdateTime,
41+
).first
42+
}
3843
transport.initialize()
3944
}
4045

@@ -72,11 +77,6 @@ internal class SpecUpdater(
7277
idListFlow.collect { idListCallback(it) }
7378
}
7479
}
75-
transport.setStreamingFallback(NetworkEndpoint.DOWNLOAD_CONFIG_SPECS) {
76-
this.transport.downloadConfigSpecsFromStatsig(
77-
this.lastUpdateTime,
78-
).first
79-
}
8080
}
8181

8282
fun getInitializeOrder(): List<DataSource> {

src/main/kotlin/com/statsig/sdk/StatsigOptions.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import com.statsig.sdk.datastore.IDataStore
55
import com.statsig.sdk.network.STATSIG_API_URL_BASE
66
import com.statsig.sdk.persistent_storage.IUserPersistentStorage
77
import com.statsig.sdk.persistent_storage.UserPersistedValues
8+
import java.io.InputStream
89
import java.time.Instant
910
import java.time.format.DateTimeFormatter
1011
import java.time.temporal.ChronoUnit
@@ -208,6 +209,17 @@ enum class NetworkProtocol {
208209
GRPC_WEBSOCKET,
209210
}
210211

212+
enum class AuthenticationMode {
213+
@SerializedName("disabled")
214+
DISABLED,
215+
216+
@SerializedName("tls")
217+
TLS,
218+
219+
@SerializedName("mTls")
220+
MTLS,
221+
}
222+
211223
data class ForwardProxyConfig(
212224
@SerializedName("proxyAddress") var proxyAddress: String,
213225
@SerializedName("protocol") val proxyProtocol: NetworkProtocol,
@@ -216,6 +228,12 @@ data class ForwardProxyConfig(
216228
@SerializedName("retry_backoff_multiplier") val retryBackoffMultiplier: Int? = null,
217229
@SerializedName("retry_backoff_base_ms") val retryBackoffBaseMs: Long? = null,
218230
@SerializedName("push_worker_failover_threshold") val pushWorkerFailoverThreshold: Int? = null,
231+
232+
// TLS Certification
233+
@SerializedName("authentication_mode") var authenticationMode: AuthenticationMode = AuthenticationMode.DISABLED,
234+
@SerializedName("tls_cert_chain") var tlsCertChain: InputStream? = null,
235+
@SerializedName("tls_private_key") var tlsPrivateKey: InputStream? = null,
236+
@SerializedName("tls_private_key_password") var tlsPrivateKeyPassword: InputStream? = null,
219237
)
220238

221239
data class ProxyConfig @JvmOverloads constructor(

src/main/kotlin/com/statsig/sdk/StatsigServer.kt

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -286,19 +286,25 @@ private class StatsigServerImpl() :
286286
private var setupStartTime = 0L
287287

288288
override fun setup(serverSecret: String, options: StatsigOptions) {
289-
Thread.setDefaultUncaughtExceptionHandler(MainThreadExceptionHandler(this, Thread.currentThread()))
290-
setupStartTime = System.currentTimeMillis()
291-
errorBoundary = ErrorBoundary(serverSecret, options, statsigMetadata)
292-
coroutineExceptionHandler = CoroutineExceptionHandler { _, ex ->
293-
// no-op - supervisor job should not throw when a child fails
294-
errorBoundary.logException("coroutineExceptionHandler", ex)
295-
}
296-
statsigJob = SupervisorJob()
297-
statsigScope = CoroutineScope(statsigJob + coroutineExceptionHandler)
298-
transport = StatsigTransport(serverSecret, options, statsigMetadata, statsigScope, errorBoundary, sdkConfigs)
299-
logger = StatsigLogger(statsigScope, transport, statsigMetadata, options, sdkConfigs)
300-
options.customLogger.also { outputLogger = it }
301-
this.options = options
289+
try {
290+
Thread.setDefaultUncaughtExceptionHandler(MainThreadExceptionHandler(this, Thread.currentThread()))
291+
setupStartTime = System.currentTimeMillis()
292+
errorBoundary = ErrorBoundary(serverSecret, options, statsigMetadata)
293+
coroutineExceptionHandler = CoroutineExceptionHandler { _, ex ->
294+
// no-op - supervisor job should not throw when a child fails
295+
errorBoundary.logException("coroutineExceptionHandler", ex)
296+
}
297+
statsigJob = SupervisorJob()
298+
statsigScope = CoroutineScope(statsigJob + coroutineExceptionHandler)
299+
transport = StatsigTransport(serverSecret, options, statsigMetadata, statsigScope, errorBoundary, sdkConfigs)
300+
logger = StatsigLogger(statsigScope, transport, statsigMetadata, options, sdkConfigs)
301+
options.customLogger.also { outputLogger = it }
302+
this.options = options
303+
} catch (e: Throwable) {
304+
// noop swallow and let other part handle error
305+
options.customLogger.warn("[STATSIG]Failed to setup sdk")
306+
options.customLogger.warn(e.stackTraceToString())
307+
}
302308
}
303309

304310
override suspend fun initialize(serverSecret: String, options: StatsigOptions): InitializationDetails? {

src/main/kotlin/com/statsig/sdk/network/GRPCWebsocketWorker.kt

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,34 @@
11
package com.statsig.sdk.network
22

3-
import com.statsig.sdk.*
3+
import com.statsig.sdk.AuthenticationMode
4+
import com.statsig.sdk.Diagnostics
5+
import com.statsig.sdk.ErrorBoundary
6+
import com.statsig.sdk.FailureDetails
7+
import com.statsig.sdk.FailureReason
8+
import com.statsig.sdk.ForwardProxyConfig
9+
import com.statsig.sdk.KeyType
10+
import com.statsig.sdk.NetworkProtocol
11+
import com.statsig.sdk.StatsigEvent
12+
import com.statsig.sdk.StatsigOptions
413
import grpc.generated.statsig_forward_proxy.StatsigForwardProxyGrpc
514
import grpc.generated.statsig_forward_proxy.StatsigForwardProxyOuterClass.ConfigSpecRequest
615
import grpc.generated.statsig_forward_proxy.StatsigForwardProxyOuterClass.ConfigSpecResponse
716
import io.grpc.Channel
17+
import io.grpc.Grpc
818
import io.grpc.ManagedChannelBuilder
19+
import io.grpc.TlsChannelCredentials
920
import io.grpc.stub.StreamObserver
10-
import kotlinx.coroutines.*
21+
import kotlinx.coroutines.CoroutineScope
22+
import kotlinx.coroutines.Dispatchers
23+
import kotlinx.coroutines.Job
1124
import kotlinx.coroutines.channels.BufferOverflow
25+
import kotlinx.coroutines.delay
1226
import kotlinx.coroutines.flow.Flow
1327
import kotlinx.coroutines.flow.MutableSharedFlow
1428
import kotlinx.coroutines.flow.asSharedFlow
1529
import kotlinx.coroutines.flow.first
30+
import kotlinx.coroutines.launch
31+
import kotlinx.coroutines.withTimeoutOrNull
1632

1733
private const val RETRY_LIMIT = 10
1834
private const val INITIAL_RETRY_BACKOFF_MS: Long = 10 * 1000
@@ -30,8 +46,6 @@ internal class GRPCWebsocketWorker(
3046
override val isPullWorker: Boolean = false
3147

3248
private var diagnostics: Diagnostics? = null
33-
private val channel: Channel = ManagedChannelBuilder.forTarget(proxyConfig.proxyAddress).usePlaintext().build()
34-
private val stub = StatsigForwardProxyGrpc.newStub(channel)
3549
private val logger = options.customLogger
3650

3751
private val observer = object : StreamObserver<ConfigSpecResponse> {
@@ -64,6 +78,59 @@ internal class GRPCWebsocketWorker(
6478
private var connected: Boolean = true
6579
private var downloadConfigsJob: Job? = null
6680
var streamingFallback: StreamingFallback? = null
81+
private val stub: StatsigForwardProxyGrpc.StatsigForwardProxyStub
82+
83+
private var channel: Channel
84+
init {
85+
try {
86+
channel = setupChannel()
87+
} catch (e: Exception) {
88+
options.customLogger.warn("Failed setup channel falling back to insecure channel")
89+
channel = ManagedChannelBuilder.forTarget(proxyConfig.proxyAddress).usePlaintext().build()
90+
}
91+
92+
stub = StatsigForwardProxyGrpc.newStub(channel)
93+
}
94+
95+
private fun setupChannel(): Channel {
96+
when (proxyConfig.authenticationMode) {
97+
AuthenticationMode.DISABLED ->
98+
return ManagedChannelBuilder.forTarget(proxyConfig.proxyAddress).usePlaintext().build()
99+
AuthenticationMode.MTLS -> {
100+
val tlsBuilder = TlsChannelCredentials.newBuilder()
101+
proxyConfig.tlsCertChain.let {
102+
if (it != null) {
103+
tlsBuilder.trustManager(it)
104+
} else {
105+
options.customLogger.warn("Failed to get cert chain for tls")
106+
}
107+
proxyConfig.tlsPrivateKey.let { privateKey ->
108+
val privateKeyPassword = proxyConfig.tlsPrivateKeyPassword
109+
if (privateKey != null && privateKeyPassword != null) {
110+
tlsBuilder.keyManager(privateKey, privateKeyPassword)
111+
return Grpc.newChannelBuilder(proxyConfig.proxyAddress, tlsBuilder.build()).build()
112+
} else {
113+
options.customLogger.warn("Failed to get private key and password for tls")
114+
return Grpc.newChannelBuilder(proxyConfig.proxyAddress, tlsBuilder.build()).build()
115+
}
116+
}
117+
}
118+
}
119+
AuthenticationMode.TLS -> {
120+
val tlsBuilder = TlsChannelCredentials.newBuilder()
121+
proxyConfig.tlsCertChain.let {
122+
if (it != null) {
123+
tlsBuilder.trustManager(it)
124+
return Grpc.newChannelBuilder(proxyConfig.proxyAddress, tlsBuilder.build()).build()
125+
} else {
126+
options.customLogger.warn("Failed to get cert chain for tls")
127+
}
128+
}
129+
}
130+
}
131+
options.customLogger.warn("Falling back to insecure channel")
132+
return ManagedChannelBuilder.forTarget(proxyConfig.proxyAddress).usePlaintext().build()
133+
}
67134

68135
private val dcsFlowBacker = MutableSharedFlow<String>(
69136
replay = 1,

src/test/java/com/statsig/sdk/GRPCTest.kt

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
package com.statsig.sdk
22

3+
import com.statsig.sdk.network.GRPCWebsocketWorker
34
import com.statsig.sdk.network.GRPCWorker
45
import grpc.generated.statsig_forward_proxy.StatsigForwardProxyGrpc
56
import grpc.generated.statsig_forward_proxy.StatsigForwardProxyOuterClass.ConfigSpecRequest
67
import grpc.generated.statsig_forward_proxy.StatsigForwardProxyOuterClass.ConfigSpecResponse
78
import io.grpc.stub.StreamObserver
89
import io.grpc.testing.GrpcServerRule
910
import io.mockk.mockk
11+
import kotlinx.coroutines.CoroutineScope
12+
import kotlinx.coroutines.Dispatchers
13+
import kotlinx.coroutines.delay
1014
import kotlinx.coroutines.runBlocking
1115
import org.junit.Assert.assertEquals
1216
import org.junit.Rule
1317
import org.junit.Test
18+
import java.io.ByteArrayInputStream
1419

1520
class GRPCTest {
1621
@Rule
@@ -21,12 +26,28 @@ class GRPCTest {
2126
@Rule
2227
val retry = RetryRule(3)
2328

24-
private fun setupGRPC(spec: ConfigSpecResponse) {
29+
private fun setupGRPC(spec: ConfigSpecResponse, shouldThrowOnStreaming: Boolean = false, blockStreamTime: Long? = null) {
2530
val service = object : StatsigForwardProxyGrpc.StatsigForwardProxyImplBase() {
2631
override fun getConfigSpec(request: ConfigSpecRequest?, responseObserver: StreamObserver<ConfigSpecResponse>?) {
2732
responseObserver?.onNext(spec)
2833
responseObserver?.onCompleted()
2934
}
35+
36+
override fun streamConfigSpec(
37+
request: ConfigSpecRequest?,
38+
responseObserver: StreamObserver<ConfigSpecResponse>?,
39+
) {
40+
if (shouldThrowOnStreaming) {
41+
responseObserver?.onError(Exception("io exception"))
42+
}
43+
if (blockStreamTime != null) {
44+
runBlocking {
45+
delay(blockStreamTime)
46+
}
47+
}
48+
responseObserver?.onNext(spec)
49+
responseObserver?.onCompleted()
50+
}
3051
}
3152
grpcServerRule.serviceRegistry.addService(service)
3253
}
@@ -46,4 +67,59 @@ class GRPCTest {
4667

4768
assertEquals(Pair("spec-1", null), worker.downloadConfigSpecs(0))
4869
}
70+
71+
@Test
72+
fun testTLSConfigurationError() = runBlocking {
73+
setupGRPC(ConfigSpecResponse.newBuilder().setSpec("spec-1").setLastUpdated(123).build())
74+
val options = StatsigOptions()
75+
val boundary = mockk<ErrorBoundary>()
76+
val scope = CoroutineScope(Dispatchers.Default)
77+
val proxyConfig = ForwardProxyConfig(
78+
"proxy:8000",
79+
NetworkProtocol.GRPC_WEBSOCKET,
80+
retryBackoffMultiplier = 1,
81+
retryBackoffBaseMs = 10,
82+
authenticationMode = AuthenticationMode.TLS,
83+
// We will try catch this invalid cert chain and fallback to insecure channel
84+
tlsCertChain = ByteArrayInputStream("invalid".toByteArray()),
85+
)
86+
val worker = GRPCWebsocketWorker("sdk", options, scope, boundary, proxyConfig)
87+
val stub = StatsigForwardProxyGrpc.newStub(grpcServerRule.channel)
88+
89+
val stubField = GRPCWebsocketWorker::class.java.getDeclaredField("stub")
90+
stubField.isAccessible = true
91+
stubField.set(worker, stub)
92+
worker.initializeFlows()
93+
// In this test, stream will be successful, because we mock
94+
assertEquals(Pair("spec-1", null), worker.downloadConfigSpecs(0))
95+
}
96+
97+
/*
98+
* Test when server returns error we start retry
99+
* This error can be because of error internal unavailability, authentication failed
100+
* */
101+
@Test
102+
fun testTimeout() = runBlocking {
103+
setupGRPC(ConfigSpecResponse.newBuilder().setSpec("spec-1").setLastUpdated(123).build(), blockStreamTime = 500)
104+
val options = StatsigOptions(initTimeoutMs = 100)
105+
val boundary = mockk<ErrorBoundary>()
106+
val scope = CoroutineScope(Dispatchers.Default)
107+
val proxyConfig = ForwardProxyConfig(
108+
"proxy:8000",
109+
NetworkProtocol.GRPC_WEBSOCKET,
110+
retryBackoffMultiplier = 1,
111+
retryBackoffBaseMs = 10,
112+
authenticationMode = AuthenticationMode.DISABLED,
113+
)
114+
val worker = GRPCWebsocketWorker("sdk", options, scope, boundary, proxyConfig)
115+
val stub = StatsigForwardProxyGrpc.newStub(grpcServerRule.channel)
116+
117+
val stubField = GRPCWebsocketWorker::class.java.getDeclaredField("stub")
118+
stubField.isAccessible = true
119+
stubField.set(worker, stub)
120+
worker.initializeFlows()
121+
val details = worker.downloadConfigSpecs(0)
122+
assert(details.first == null)
123+
assert(details.second?.exception?.message?.contains("failed to receive config spec within init timeout") == true)
124+
}
49125
}

0 commit comments

Comments
 (0)