Skip to content

Commit 14e9c33

Browse files
authored
Handle OAuth Token Refreshes Using Background Thread For Admin, Producer and Consumer Clients (#2130)
* dev changes to handle oauth callbacks from background thread * temporarily comment out test case * add test cases for background oauth callbacks * remove test file * minor fixes in test assertions and lint * update documentation * increase timeout from 5s to 10s * fix code formatting in test_misc * fix formatting according to clang-format v18
1 parent 6af4e6f commit 14e9c33

File tree

7 files changed

+171
-20
lines changed

7 files changed

+171
-20
lines changed

docs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,7 @@ addition to the properties dictated by the underlying librdkafka C library:
10571057
where ``expiry_time`` is the time in seconds since the epoch as a floating point number.
10581058
This callback is useful only when ``sasl.mechanisms=OAUTHBEARER`` is set and
10591059
is served to get the initial token before a successful broker connection can be made.
1060-
The callback can be triggered by calling ``client.poll()`` or ``producer.flush()``.
1060+
The callback is asynchronously triggered by the background thread to maintain token validity.``.
10611061

10621062
* ``on_delivery(kafka.KafkaError, kafka.Message)`` (**Producer**): value is a Python function reference
10631063
that is called once for each produced message to indicate the final

src/confluent_kafka/src/Admin.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5526,10 +5526,22 @@ static int Admin_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) {
55265526
return -1;
55275527
}
55285528

5529+
/* Enable SASL callbacks on background thread for AdminClient since
5530+
* applications typically don't call poll() regularly on AdminClient. */
5531+
if (self->oauth_cb) {
5532+
rd_kafka_sasl_background_callbacks_enable(self->rk);
5533+
}
5534+
55295535
/* Forward log messages to poll queue */
55305536
if (self->logger)
55315537
rd_kafka_set_log_queue(self->rk, NULL);
55325538

5539+
5540+
/* Wait for the background thread to set the token */
5541+
if (self->oauth_cb) {
5542+
return wait_for_oauth_token_set(self);
5543+
}
5544+
55335545
return 0;
55345546
}
55355547

src/confluent_kafka/src/Consumer.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,6 +1664,12 @@ static int Consumer_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) {
16641664
return -1;
16651665
}
16661666

1667+
/* Enable Token Refresh to be handled by background thread if OAuth
1668+
* callback is provided */
1669+
if (self->oauth_cb) {
1670+
rd_kafka_sasl_background_callbacks_enable(self->rk);
1671+
}
1672+
16671673
/* Forward log messages to main queue which is then forwarded
16681674
* to the consumer queue */
16691675
if (self->logger)
@@ -1674,6 +1680,12 @@ static int Consumer_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) {
16741680
self->u.Consumer.rkqu = rd_kafka_queue_get_consumer(self->rk);
16751681
assert(self->u.Consumer.rkqu);
16761682

1683+
1684+
/* Wait for the background thread to set the token */
1685+
if (self->oauth_cb) {
1686+
return wait_for_oauth_token_set(self);
1687+
}
1688+
16771689
return 0;
16781690
}
16791691

src/confluent_kafka/src/Producer.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,10 +1331,21 @@ static int Producer_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) {
13311331
return -1;
13321332
}
13331333

1334+
/* Enable Token Refresh to be handled by background thread if OAuth
1335+
* callback is provided */
1336+
if (self->oauth_cb) {
1337+
rd_kafka_sasl_background_callbacks_enable(self->rk);
1338+
}
1339+
13341340
/* Forward log messages to poll queue */
13351341
if (self->logger)
13361342
rd_kafka_set_log_queue(self->rk, NULL);
13371343

1344+
/* Wait for the background thread to set the token */
1345+
if (self->oauth_cb) {
1346+
return wait_for_oauth_token_set(self);
1347+
}
1348+
13381349
return 0;
13391350
}
13401351

src/confluent_kafka/src/confluent_kafka.c

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2037,11 +2037,59 @@ static int py_extensions_to_c(char **extensions,
20372037
return 1;
20382038
}
20392039

2040+
2041+
/**
2042+
* @brief Waits for OAuth callback to set the token
2043+
*
2044+
* Useful during client init as we want to ensure we have the token before we
2045+
* return back
2046+
*
2047+
* Returns 0 if token was set within the timeout period, -1 otherwise.
2048+
*/
2049+
int wait_for_oauth_token_set(Handle *h) {
2050+
2051+
if (!h->oauth_cb)
2052+
return 0;
2053+
2054+
int max_wait_sec = 10;
2055+
int retry_interval_sec = 1; /* Check every 1 sec */
2056+
int elapsed_sec = 0;
2057+
while (!h->oauth_token_set && elapsed_sec < max_wait_sec) {
2058+
CallState cs;
2059+
CallState_begin(h, &cs);
2060+
sleep(retry_interval_sec);
2061+
CallState_end(h, &cs);
2062+
elapsed_sec += retry_interval_sec;
2063+
}
2064+
2065+
if (!h->oauth_token_set) {
2066+
/* Token not set within timeout */
2067+
cfl_PyErr_Format(
2068+
RD_KAFKA_RESP_ERR_SASL_AUTHENTICATION_FAILED,
2069+
"OAuth token not set within %d seconds timeout",
2070+
max_wait_sec);
2071+
CallState cs;
2072+
CallState_begin(h, &cs);
2073+
rd_kafka_destroy(h->rk);
2074+
h->rk = NULL;
2075+
CallState_end(h, &cs);
2076+
return -1;
2077+
}
2078+
return 0;
2079+
}
2080+
2081+
/**
2082+
* @brief Callback invoked when a OAuth token needs to be refreshed.
2083+
*
2084+
* Note that this callback will be invoked by the background thread as
2085+
* all client types have been configured to use background threads for sasl
2086+
* events.
2087+
*/
20402088
static void
20412089
oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) {
20422090
Handle *h = opaque;
20432091
PyObject *eo, *result;
2044-
CallState *cs;
2092+
PyGILState_STATE gstate;
20452093
const char *token;
20462094
double expiry;
20472095
const char *principal = "";
@@ -2051,7 +2099,7 @@ oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) {
20512099
char err_msg[2048];
20522100
rd_kafka_resp_err_t err_code;
20532101

2054-
cs = CallState_get(h);
2102+
gstate = PyGILState_Ensure();
20552103

20562104
eo = Py_BuildValue("s", oauthbearer_config);
20572105
result = PyObject_CallFunctionObjArgs(h->oauth_cb, eo, NULL);
@@ -2103,6 +2151,7 @@ oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) {
21032151
PyErr_Format(PyExc_ValueError, "%s", err_msg);
21042152
goto fail;
21052153
}
2154+
h->oauth_token_set = 1;
21062155
goto done;
21072156

21082157
fail:
@@ -2116,10 +2165,10 @@ oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) {
21162165
PyErr_Clear();
21172166
goto done;
21182167
err:
2119-
CallState_crash(cs);
2168+
PyGILState_Release(gstate);
21202169
rd_kafka_yield(h->rk);
21212170
done:
2122-
CallState_resume(cs);
2171+
PyGILState_Release(gstate);
21232172
}
21242173

21252174
/****************************************************************************
@@ -2649,8 +2698,10 @@ rd_kafka_conf_t *common_conf_setup(rd_kafka_type_t ktype,
26492698
rd_kafka_conf_set_log_cb(conf, log_cb);
26502699
}
26512700

2652-
if (h->oauth_cb)
2701+
if (h->oauth_cb) {
26532702
rd_kafka_conf_set_oauthbearer_token_refresh_cb(conf, oauth_cb);
2703+
rd_kafka_conf_enable_sasl_queue(conf, 1);
2704+
}
26542705

26552706
rd_kafka_conf_set_opaque(conf, h);
26562707

src/confluent_kafka/src/confluent_kafka.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ typedef struct {
239239

240240
PyObject *logger;
241241
PyObject *oauth_cb;
242+
int oauth_token_set;
242243

243244
union {
244245
/**
@@ -465,6 +466,7 @@ PyObject *c_topic_partition_result_to_py_dict(
465466
PyObject *list_topics(Handle *self, PyObject *args, PyObject *kwargs);
466467
PyObject *list_groups(Handle *self, PyObject *args, PyObject *kwargs);
467468
PyObject *set_sasl_credentials(Handle *self, PyObject *args, PyObject *kwargs);
469+
int wait_for_oauth_token_set(Handle *self);
468470

469471

470472
extern const char list_topics_doc[];

tests/test_misc.py

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99

1010
import confluent_kafka
11-
from confluent_kafka import Consumer, Producer
11+
from confluent_kafka import Consumer, KafkaException, Producer
1212
from confluent_kafka.admin import AdminClient
1313
from tests.common import TestConsumer
1414

@@ -140,9 +140,7 @@ def oauth_cb(oauth_config):
140140
}
141141

142142
kc = TestConsumer(conf)
143-
144-
while not seen_oauth_cb:
145-
kc.poll(timeout=0.1)
143+
assert seen_oauth_cb # callback is expected to happen during client init
146144
kc.close()
147145

148146

@@ -160,29 +158,86 @@ def oauth_cb(oauth_config):
160158
'group.id': 'test',
161159
'security.protocol': 'sasl_plaintext',
162160
'sasl.mechanisms': 'OAUTHBEARER',
163-
'session.timeout.ms': 100, # Avoid close() blocking too long
161+
'session.timeout.ms': 100,
164162
'sasl.oauthbearer.config': 'oauth_cb',
165163
'oauth_cb': oauth_cb,
166164
}
167165

168166
kc = TestConsumer(conf)
169-
170-
while not seen_oauth_cb:
171-
kc.poll(timeout=0.1)
167+
assert seen_oauth_cb # callback is expected to happen during client init
172168
kc.close()
173169

174170

175171
def test_oauth_cb_failure():
176-
"""Tests oauth_cb."""
172+
"""
173+
Tests oauth_cb for a case when it fails to return a token.
174+
We expect the client init to fail
175+
"""
176+
177+
def oauth_cb(oauth_config):
178+
raise Exception
179+
180+
conf = {
181+
'group.id': 'test',
182+
'security.protocol': 'sasl_plaintext',
183+
'sasl.mechanisms': 'OAUTHBEARER',
184+
'session.timeout.ms': 1000,
185+
'sasl.oauthbearer.config': 'oauth_cb',
186+
'oauth_cb': oauth_cb,
187+
}
188+
189+
with pytest.raises(KafkaException):
190+
TestConsumer(conf)
191+
192+
193+
def test_oauth_cb_token_refresh_success():
194+
"""
195+
Tests whether oauth callback gets called multiple times by the background thread
196+
"""
197+
oauth_cb_count = 0
198+
199+
def oauth_cb(oauth_config):
200+
nonlocal oauth_cb_count
201+
oauth_cb_count += 1
202+
assert oauth_config == 'oauth_cb'
203+
return 'token', time.time() + 3 # token is returned with an expiry of 3 seconds
204+
205+
conf = {
206+
'group.id': 'test',
207+
'security.protocol': 'sasl_plaintext',
208+
'sasl.mechanisms': 'OAUTHBEARER',
209+
'session.timeout.ms': 1000,
210+
'sasl.oauthbearer.config': 'oauth_cb',
211+
'oauth_cb': oauth_cb,
212+
}
213+
214+
kc = TestConsumer(conf) # callback is expected to happen during client init
215+
assert oauth_cb_count == 1
216+
217+
# Check every 1 second for up to 5 seconds for callback count to increase
218+
max_wait_sec = 5
219+
elapsed_sec = 0
220+
while oauth_cb_count == 1 and elapsed_sec < max_wait_sec:
221+
time.sleep(1)
222+
elapsed_sec += 1
223+
224+
kc.close()
225+
assert oauth_cb_count > 1
226+
227+
228+
def test_oauth_cb_token_refresh_failure():
229+
"""
230+
Tests whether oauth callback gets called again if token refresh failed in one of the calls after init
231+
"""
177232
oauth_cb_count = 0
178233

179234
def oauth_cb(oauth_config):
180235
nonlocal oauth_cb_count
181236
oauth_cb_count += 1
182237
assert oauth_config == 'oauth_cb'
183238
if oauth_cb_count == 2:
184-
return 'token', time.time() + 100.0, oauth_config, {"extthree": "extthreeval"}
185-
raise Exception
239+
raise Exception
240+
return 'token', time.time() + 3 # token is returned with an expiry of 3 seconds
186241

187242
conf = {
188243
'group.id': 'test',
@@ -193,11 +248,19 @@ def oauth_cb(oauth_config):
193248
'oauth_cb': oauth_cb,
194249
}
195250

196-
kc = TestConsumer(conf)
251+
kc = TestConsumer(conf) # callback is expected to happen during client init
252+
assert oauth_cb_count == 1
253+
254+
# Check every 1 second for up to 15 seconds for callback count to increase
255+
# Call back failure causes a refresh attempt after 10 secs, so ideally 2 callbacks should happen within 15 secs
256+
max_wait_sec = 15
257+
elapsed_sec = 0
258+
while oauth_cb_count <= 2 and elapsed_sec < max_wait_sec:
259+
time.sleep(1)
260+
elapsed_sec += 1
197261

198-
while oauth_cb_count < 2:
199-
kc.poll(timeout=0.1)
200262
kc.close()
263+
assert oauth_cb_count > 2
201264

202265

203266
def skip_interceptors():

0 commit comments

Comments
 (0)