88import pytest
99
1010import confluent_kafka
11- from confluent_kafka import Consumer , Producer
11+ from confluent_kafka import Consumer , KafkaException , Producer
1212from confluent_kafka .admin import AdminClient
1313from tests .common import TestConsumer
1414
@@ -140,9 +140,7 @@ def oauth_cb(oauth_config):
140140 }
141141
142142 kc = TestConsumer (conf )
143-
144- while not seen_oauth_cb :
145- kc .poll (timeout = 0.1 )
143+ assert seen_oauth_cb # callback is expected to happen during client init
146144 kc .close ()
147145
148146
@@ -160,29 +158,86 @@ def oauth_cb(oauth_config):
160158 'group.id' : 'test' ,
161159 'security.protocol' : 'sasl_plaintext' ,
162160 'sasl.mechanisms' : 'OAUTHBEARER' ,
163- 'session.timeout.ms' : 100 , # Avoid close() blocking too long
161+ 'session.timeout.ms' : 100 ,
164162 'sasl.oauthbearer.config' : 'oauth_cb' ,
165163 'oauth_cb' : oauth_cb ,
166164 }
167165
168166 kc = TestConsumer (conf )
169-
170- while not seen_oauth_cb :
171- kc .poll (timeout = 0.1 )
167+ assert seen_oauth_cb # callback is expected to happen during client init
172168 kc .close ()
173169
174170
175171def test_oauth_cb_failure ():
176- """Tests oauth_cb."""
172+ """
173+ Tests oauth_cb for a case when it fails to return a token.
174+ We expect the client init to fail
175+ """
176+
177+ def oauth_cb (oauth_config ):
178+ raise Exception
179+
180+ conf = {
181+ 'group.id' : 'test' ,
182+ 'security.protocol' : 'sasl_plaintext' ,
183+ 'sasl.mechanisms' : 'OAUTHBEARER' ,
184+ 'session.timeout.ms' : 1000 ,
185+ 'sasl.oauthbearer.config' : 'oauth_cb' ,
186+ 'oauth_cb' : oauth_cb ,
187+ }
188+
189+ with pytest .raises (KafkaException ):
190+ TestConsumer (conf )
191+
192+
193+ def test_oauth_cb_token_refresh_success ():
194+ """
195+ Tests whether oauth callback gets called multiple times by the background thread
196+ """
197+ oauth_cb_count = 0
198+
199+ def oauth_cb (oauth_config ):
200+ nonlocal oauth_cb_count
201+ oauth_cb_count += 1
202+ assert oauth_config == 'oauth_cb'
203+ return 'token' , time .time () + 3 # token is returned with an expiry of 3 seconds
204+
205+ conf = {
206+ 'group.id' : 'test' ,
207+ 'security.protocol' : 'sasl_plaintext' ,
208+ 'sasl.mechanisms' : 'OAUTHBEARER' ,
209+ 'session.timeout.ms' : 1000 ,
210+ 'sasl.oauthbearer.config' : 'oauth_cb' ,
211+ 'oauth_cb' : oauth_cb ,
212+ }
213+
214+ kc = TestConsumer (conf ) # callback is expected to happen during client init
215+ assert oauth_cb_count == 1
216+
217+ # Check every 1 second for up to 5 seconds for callback count to increase
218+ max_wait_sec = 5
219+ elapsed_sec = 0
220+ while oauth_cb_count == 1 and elapsed_sec < max_wait_sec :
221+ time .sleep (1 )
222+ elapsed_sec += 1
223+
224+ kc .close ()
225+ assert oauth_cb_count > 1
226+
227+
228+ def test_oauth_cb_token_refresh_failure ():
229+ """
230+ Tests whether oauth callback gets called again if token refresh failed in one of the calls after init
231+ """
177232 oauth_cb_count = 0
178233
179234 def oauth_cb (oauth_config ):
180235 nonlocal oauth_cb_count
181236 oauth_cb_count += 1
182237 assert oauth_config == 'oauth_cb'
183238 if oauth_cb_count == 2 :
184- return 'token' , time . time () + 100.0 , oauth_config , { "extthree" : "extthreeval" }
185- raise Exception
239+ raise Exception
240+ return 'token' , time . time () + 3 # token is returned with an expiry of 3 seconds
186241
187242 conf = {
188243 'group.id' : 'test' ,
@@ -193,11 +248,19 @@ def oauth_cb(oauth_config):
193248 'oauth_cb' : oauth_cb ,
194249 }
195250
196- kc = TestConsumer (conf )
251+ kc = TestConsumer (conf ) # callback is expected to happen during client init
252+ assert oauth_cb_count == 1
253+
254+ # Check every 1 second for up to 15 seconds for callback count to increase
255+ # Call back failure causes a refresh attempt after 10 secs, so ideally 2 callbacks should happen within 15 secs
256+ max_wait_sec = 15
257+ elapsed_sec = 0
258+ while oauth_cb_count <= 2 and elapsed_sec < max_wait_sec :
259+ time .sleep (1 )
260+ elapsed_sec += 1
197261
198- while oauth_cb_count < 2 :
199- kc .poll (timeout = 0.1 )
200262 kc .close ()
263+ assert oauth_cb_count > 2
201264
202265
203266def skip_interceptors ():
0 commit comments