Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ addition to the properties dictated by the underlying librdkafka C library:
where ``expiry_time`` is the time in seconds since the epoch as a floating point number.
This callback is useful only when ``sasl.mechanisms=OAUTHBEARER`` is set and
is served to get the initial token before a successful broker connection can be made.
The callback can be triggered by calling ``client.poll()`` or ``producer.flush()``.
The callback is asynchronously triggered by the background thread to maintain token validity.``.

* ``on_delivery(kafka.KafkaError, kafka.Message)`` (**Producer**): value is a Python function reference
that is called once for each produced message to indicate the final
Expand Down
12 changes: 12 additions & 0 deletions src/confluent_kafka/src/Admin.c
Original file line number Diff line number Diff line change
Expand Up @@ -5526,10 +5526,22 @@ static int Admin_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) {
return -1;
}

/* Enable SASL callbacks on background thread for AdminClient since
* applications typically don't call poll() regularly on AdminClient. */
if (self->oauth_cb) {
rd_kafka_sasl_background_callbacks_enable(self->rk);
}

/* Forward log messages to poll queue */
if (self->logger)
rd_kafka_set_log_queue(self->rk, NULL);


/* Wait for the background thread to set the token */
if (self->oauth_cb) {
return wait_for_oauth_token_set(self);
}

return 0;
}

Expand Down
12 changes: 12 additions & 0 deletions src/confluent_kafka/src/Consumer.c
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,12 @@ static int Consumer_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) {
return -1;
}

/* Enable Token Refresh to be handled by background thread if OAuth
* callback is provided */
if (self->oauth_cb) {
rd_kafka_sasl_background_callbacks_enable(self->rk);
}

/* Forward log messages to main queue which is then forwarded
* to the consumer queue */
if (self->logger)
Expand All @@ -1674,6 +1680,12 @@ static int Consumer_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) {
self->u.Consumer.rkqu = rd_kafka_queue_get_consumer(self->rk);
assert(self->u.Consumer.rkqu);


/* Wait for the background thread to set the token */
if (self->oauth_cb) {
return wait_for_oauth_token_set(self);
}

return 0;
}

Expand Down
11 changes: 11 additions & 0 deletions src/confluent_kafka/src/Producer.c
Original file line number Diff line number Diff line change
Expand Up @@ -1331,10 +1331,21 @@ static int Producer_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) {
return -1;
}

/* Enable Token Refresh to be handled by background thread if OAuth
* callback is provided */
if (self->oauth_cb) {
rd_kafka_sasl_background_callbacks_enable(self->rk);
}

/* Forward log messages to poll queue */
if (self->logger)
rd_kafka_set_log_queue(self->rk, NULL);

/* Wait for the background thread to set the token */
if (self->oauth_cb) {
return wait_for_oauth_token_set(self);
}

return 0;
}

Expand Down
61 changes: 56 additions & 5 deletions src/confluent_kafka/src/confluent_kafka.c
Original file line number Diff line number Diff line change
Expand Up @@ -2037,11 +2037,59 @@ static int py_extensions_to_c(char **extensions,
return 1;
}


/**
* @brief Waits for OAuth callback to set the token
*
* Useful during client init as we want to ensure we have the token before we
* return back
*
* Returns 0 if token was set within the timeout period, -1 otherwise.
*/
int wait_for_oauth_token_set(Handle *h) {

if (!h->oauth_cb)
return 0;

int max_wait_sec = 10;
int retry_interval_sec = 1; /* Check every 1 sec */
int elapsed_sec = 0;
while (!h->oauth_token_set && elapsed_sec < max_wait_sec) {
CallState cs;
CallState_begin(h, &cs);
sleep(retry_interval_sec);
CallState_end(h, &cs);
elapsed_sec += retry_interval_sec;
}

if (!h->oauth_token_set) {
/* Token not set within timeout */
cfl_PyErr_Format(
RD_KAFKA_RESP_ERR_SASL_AUTHENTICATION_FAILED,
"OAuth token not set within %d seconds timeout",
max_wait_sec);
CallState cs;
CallState_begin(h, &cs);
rd_kafka_destroy(h->rk);
h->rk = NULL;
CallState_end(h, &cs);
return -1;
}
return 0;
}

/**
* @brief Callback invoked when a OAuth token needs to be refreshed.
*
* Note that this callback will be invoked by the background thread as
* all client types have been configured to use background threads for sasl
* events.
*/
static void
oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) {
Handle *h = opaque;
PyObject *eo, *result;
CallState *cs;
PyGILState_STATE gstate;
const char *token;
double expiry;
const char *principal = "";
Expand All @@ -2051,7 +2099,7 @@ oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) {
char err_msg[2048];
rd_kafka_resp_err_t err_code;

cs = CallState_get(h);
gstate = PyGILState_Ensure();

eo = Py_BuildValue("s", oauthbearer_config);
result = PyObject_CallFunctionObjArgs(h->oauth_cb, eo, NULL);
Expand Down Expand Up @@ -2103,6 +2151,7 @@ oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) {
PyErr_Format(PyExc_ValueError, "%s", err_msg);
goto fail;
}
h->oauth_token_set = 1;
goto done;

fail:
Expand All @@ -2116,10 +2165,10 @@ oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) {
PyErr_Clear();
goto done;
err:
CallState_crash(cs);
PyGILState_Release(gstate);
rd_kafka_yield(h->rk);
done:
CallState_resume(cs);
PyGILState_Release(gstate);
}

/****************************************************************************
Expand Down Expand Up @@ -2649,8 +2698,10 @@ rd_kafka_conf_t *common_conf_setup(rd_kafka_type_t ktype,
rd_kafka_conf_set_log_cb(conf, log_cb);
}

if (h->oauth_cb)
if (h->oauth_cb) {
rd_kafka_conf_set_oauthbearer_token_refresh_cb(conf, oauth_cb);
rd_kafka_conf_enable_sasl_queue(conf, 1);
}

rd_kafka_conf_set_opaque(conf, h);

Expand Down
2 changes: 2 additions & 0 deletions src/confluent_kafka/src/confluent_kafka.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ typedef struct {

PyObject *logger;
PyObject *oauth_cb;
int oauth_token_set;

union {
/**
Expand Down Expand Up @@ -465,6 +466,7 @@ PyObject *c_topic_partition_result_to_py_dict(
PyObject *list_topics(Handle *self, PyObject *args, PyObject *kwargs);
PyObject *list_groups(Handle *self, PyObject *args, PyObject *kwargs);
PyObject *set_sasl_credentials(Handle *self, PyObject *args, PyObject *kwargs);
int wait_for_oauth_token_set(Handle *self);


extern const char list_topics_doc[];
Expand Down
91 changes: 77 additions & 14 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

import confluent_kafka
from confluent_kafka import Consumer, Producer
from confluent_kafka import Consumer, KafkaException, Producer
from confluent_kafka.admin import AdminClient
from tests.common import TestConsumer

Expand Down Expand Up @@ -140,9 +140,7 @@ def oauth_cb(oauth_config):
}

kc = TestConsumer(conf)

while not seen_oauth_cb:
kc.poll(timeout=0.1)
assert seen_oauth_cb # callback is expected to happen during client init
kc.close()


Expand All @@ -160,29 +158,86 @@ def oauth_cb(oauth_config):
'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'session.timeout.ms': 100, # Avoid close() blocking too long
'session.timeout.ms': 100,
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb,
}

kc = TestConsumer(conf)

while not seen_oauth_cb:
kc.poll(timeout=0.1)
assert seen_oauth_cb # callback is expected to happen during client init
kc.close()


def test_oauth_cb_failure():
"""Tests oauth_cb."""
"""
Tests oauth_cb for a case when it fails to return a token.
We expect the client init to fail
"""

def oauth_cb(oauth_config):
raise Exception

conf = {
'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'session.timeout.ms': 1000,
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb,
}

with pytest.raises(KafkaException):
TestConsumer(conf)


def test_oauth_cb_token_refresh_success():
"""
Tests whether oauth callback gets called multiple times by the background thread
"""
oauth_cb_count = 0

def oauth_cb(oauth_config):
nonlocal oauth_cb_count
oauth_cb_count += 1
assert oauth_config == 'oauth_cb'
return 'token', time.time() + 3 # token is returned with an expiry of 3 seconds

conf = {
'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'session.timeout.ms': 1000,
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb,
}

kc = TestConsumer(conf) # callback is expected to happen during client init
assert oauth_cb_count == 1

# Check every 1 second for up to 5 seconds for callback count to increase
max_wait_sec = 5
elapsed_sec = 0
while oauth_cb_count == 1 and elapsed_sec < max_wait_sec:
time.sleep(1)
elapsed_sec += 1

kc.close()
assert oauth_cb_count > 1


def test_oauth_cb_token_refresh_failure():
"""
Tests whether oauth callback gets called again if token refresh failed in one of the calls after init
"""
oauth_cb_count = 0

def oauth_cb(oauth_config):
nonlocal oauth_cb_count
oauth_cb_count += 1
assert oauth_config == 'oauth_cb'
if oauth_cb_count == 2:
return 'token', time.time() + 100.0, oauth_config, {"extthree": "extthreeval"}
raise Exception
raise Exception
return 'token', time.time() + 3 # token is returned with an expiry of 3 seconds

conf = {
'group.id': 'test',
Expand All @@ -193,11 +248,19 @@ def oauth_cb(oauth_config):
'oauth_cb': oauth_cb,
}

kc = TestConsumer(conf)
kc = TestConsumer(conf) # callback is expected to happen during client init
assert oauth_cb_count == 1

# Check every 1 second for up to 15 seconds for callback count to increase
# Call back failure causes a refresh attempt after 10 secs, so ideally 2 callbacks should happen within 15 secs
max_wait_sec = 15
elapsed_sec = 0
while oauth_cb_count <= 2 and elapsed_sec < max_wait_sec:
time.sleep(1)
elapsed_sec += 1

while oauth_cb_count < 2:
kc.poll(timeout=0.1)
kc.close()
assert oauth_cb_count > 2


def skip_interceptors():
Expand Down