Skip to content

Commit ae14233

Browse files
🐛 Ensure proper Redis client shutdown in Celery (#8237)
1 parent 19b60cc commit ae14233

File tree

16 files changed

+212
-130
lines changed

16 files changed

+212
-130
lines changed

packages/celery-library/src/celery_library/backends/_redis.py renamed to packages/celery-library/src/celery_library/backends/redis.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
TaskMetadata,
1313
TaskUUID,
1414
)
15-
from servicelib.redis import RedisClientSDK
15+
from servicelib.redis import RedisClientSDK, handle_redis_returns_union_types
1616

1717
from ..utils import build_task_id_prefix
1818

@@ -41,18 +41,24 @@ async def create_task(
4141
expiry: timedelta,
4242
) -> None:
4343
task_key = _build_key(task_id)
44-
await self._redis_client_sdk.redis.hset(
45-
name=task_key,
46-
key=_CELERY_TASK_METADATA_KEY,
47-
value=task_metadata.model_dump_json(),
48-
) # type: ignore
44+
await handle_redis_returns_union_types(
45+
self._redis_client_sdk.redis.hset(
46+
name=task_key,
47+
key=_CELERY_TASK_METADATA_KEY,
48+
value=task_metadata.model_dump_json(),
49+
)
50+
)
4951
await self._redis_client_sdk.redis.expire(
5052
task_key,
5153
expiry,
5254
)
5355

5456
async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None:
55-
raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore
57+
raw_result = await handle_redis_returns_union_types(
58+
self._redis_client_sdk.redis.hget(
59+
_build_key(task_id), _CELERY_TASK_METADATA_KEY
60+
)
61+
)
5662
if not raw_result:
5763
return None
5864

@@ -65,7 +71,11 @@ async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None:
6571
return None
6672

6773
async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
68-
raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore
74+
raw_result = await handle_redis_returns_union_types(
75+
self._redis_client_sdk.redis.hget(
76+
_build_key(task_id), _CELERY_TASK_PROGRESS_KEY
77+
)
78+
)
6979
if not raw_result:
7080
return None
7181

@@ -121,11 +131,13 @@ async def remove_task(self, task_id: TaskID) -> None:
121131
await self._redis_client_sdk.redis.delete(_build_key(task_id))
122132

123133
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
124-
await self._redis_client_sdk.redis.hset(
125-
name=_build_key(task_id),
126-
key=_CELERY_TASK_PROGRESS_KEY,
127-
value=report.model_dump_json(),
128-
) # type: ignore
134+
await handle_redis_returns_union_types(
135+
self._redis_client_sdk.redis.hset(
136+
name=_build_key(task_id),
137+
key=_CELERY_TASK_PROGRESS_KEY,
138+
value=report.model_dump_json(),
139+
)
140+
)
129141

130142
async def task_exists(self, task_id: TaskID) -> bool:
131143
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))

packages/celery-library/src/celery_library/common.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,9 @@
22
from typing import Any
33

44
from celery import Celery # type: ignore[import-untyped]
5-
from servicelib.redis import RedisClientSDK
65
from settings_library.celery import CelerySettings
76
from settings_library.redis import RedisDatabase
87

9-
from .backends._redis import RedisTaskInfoStore
10-
from .task_manager import CeleryTaskManager
11-
128

139
def _celery_configure(celery_settings: CelerySettings) -> dict[str, Any]:
1410
base_config = {
@@ -36,22 +32,3 @@ def create_app(settings: CelerySettings) -> Celery:
3632
),
3733
**_celery_configure(settings),
3834
)
39-
40-
41-
async def create_task_manager(
42-
app: Celery, settings: CelerySettings
43-
) -> CeleryTaskManager:
44-
redis_client_sdk = RedisClientSDK(
45-
settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(
46-
RedisDatabase.CELERY_TASKS
47-
),
48-
client_name="celery_tasks",
49-
)
50-
await redis_client_sdk.setup()
51-
# GCR please address https://github.com/ITISFoundation/osparc-simcore/issues/8159
52-
53-
return CeleryTaskManager(
54-
app,
55-
settings,
56-
RedisTaskInfoStore(redis_client_sdk),
57-
)

packages/celery-library/src/celery_library/signals.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,15 @@
66
from celery.worker.worker import WorkController # type: ignore[import-untyped]
77
from servicelib.celery.app_server import BaseAppServer
88
from servicelib.logging_utils import log_context
9-
from settings_library.celery import CelerySettings
109

11-
from .common import create_task_manager
1210
from .utils import get_app_server, set_app_server
1311

1412
_logger = logging.getLogger(__name__)
1513

1614

1715
def on_worker_init(
18-
app_server: BaseAppServer,
19-
celery_settings: CelerySettings,
2016
sender: WorkController,
17+
app_server: BaseAppServer,
2118
**_kwargs,
2219
) -> None:
2320
startup_complete_event = threading.Event()
@@ -26,21 +23,14 @@ def _init(startup_complete_event: threading.Event) -> None:
2623
loop = asyncio.new_event_loop()
2724
asyncio.set_event_loop(loop)
2825

29-
async def _setup_task_manager():
30-
assert sender.app # nosec
31-
assert isinstance(sender.app, Celery) # nosec
32-
33-
app_server.task_manager = await create_task_manager(
34-
sender.app,
35-
celery_settings,
36-
)
26+
assert sender.app # nosec
27+
assert isinstance(sender.app, Celery) # nosec
3728

38-
set_app_server(sender.app, app_server)
29+
set_app_server(sender.app, app_server)
3930

4031
app_server.event_loop = loop
4132

42-
loop.run_until_complete(_setup_task_manager())
43-
loop.run_until_complete(app_server.lifespan(startup_complete_event))
33+
loop.run_until_complete(app_server.run_until_shutdown(startup_complete_event))
4434

4535
thread = threading.Thread(
4636
group=None,

packages/celery-library/tests/conftest.py

Lines changed: 67 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,30 @@
22
# pylint: disable=unused-argument
33

44
import datetime
5+
import logging
56
import threading
67
from collections.abc import AsyncIterator, Callable
7-
from functools import partial
88
from typing import Any
99

1010
import pytest
1111
from celery import Celery # type: ignore[import-untyped]
12-
from celery.contrib.testing.worker import TestWorkController, start_worker
12+
from celery.contrib.testing.worker import (
13+
TestWorkController,
14+
start_worker,
15+
)
1316
from celery.signals import worker_init, worker_shutdown
1417
from celery.worker.worker import WorkController
15-
from celery_library.common import create_task_manager
18+
from celery_library.backends.redis import RedisTaskInfoStore
1619
from celery_library.signals import on_worker_init, on_worker_shutdown
1720
from celery_library.task_manager import CeleryTaskManager
1821
from celery_library.types import register_celery_types
1922
from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict
2023
from pytest_simcore.helpers.typing_env import EnvVarsDict
2124
from servicelib.celery.app_server import BaseAppServer
25+
from servicelib.celery.task_manager import TaskManager
26+
from servicelib.redis import RedisClientSDK
2227
from settings_library.celery import CelerySettings
23-
from settings_library.redis import RedisSettings
28+
from settings_library.redis import RedisDatabase, RedisSettings
2429

2530
pytest_plugins = [
2631
"pytest_simcore.docker_compose",
@@ -33,11 +38,42 @@
3338
]
3439

3540

41+
_logger = logging.getLogger(__name__)
42+
43+
3644
class FakeAppServer(BaseAppServer):
37-
async def lifespan(self, startup_completed_event: threading.Event) -> None:
45+
def __init__(self, app: Celery, settings: CelerySettings):
46+
super().__init__(app)
47+
self._settings = settings
48+
self._task_manager: CeleryTaskManager | None = None
49+
50+
@property
51+
def task_manager(self) -> TaskManager:
52+
assert self._task_manager, "Task manager is not initialized"
53+
return self._task_manager
54+
55+
async def run_until_shutdown(
56+
self, startup_completed_event: threading.Event
57+
) -> None:
58+
redis_client_sdk = RedisClientSDK(
59+
self._settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(
60+
RedisDatabase.CELERY_TASKS
61+
),
62+
client_name="pytest_celery_tasks",
63+
)
64+
await redis_client_sdk.setup()
65+
66+
self._task_manager = CeleryTaskManager(
67+
self._app,
68+
self._settings,
69+
RedisTaskInfoStore(redis_client_sdk),
70+
)
71+
3872
startup_completed_event.set()
3973
await self.shutdown_event.wait() # wait for shutdown
4074

75+
await redis_client_sdk.shutdown()
76+
4177

4278
@pytest.fixture
4379
def register_celery_tasks() -> Callable[[Celery], None]:
@@ -51,17 +87,12 @@ def _(celery_app: Celery) -> None: ...
5187
@pytest.fixture
5288
def app_environment(
5389
monkeypatch: pytest.MonkeyPatch,
54-
redis_service: RedisSettings,
5590
env_devel_dict: EnvVarsDict,
5691
) -> EnvVarsDict:
5792
return setenvs_from_dict(
5893
monkeypatch,
5994
{
6095
**env_devel_dict,
61-
"REDIS_SECURE": redis_service.REDIS_SECURE,
62-
"REDIS_HOST": redis_service.REDIS_HOST,
63-
"REDIS_PORT": f"{redis_service.REDIS_PORT}",
64-
"REDIS_PASSWORD": redis_service.REDIS_PASSWORD.get_secret_value(),
6596
},
6697
)
6798

@@ -74,8 +105,8 @@ def celery_settings(
74105

75106

76107
@pytest.fixture
77-
def app_server() -> BaseAppServer:
78-
return FakeAppServer(app=None)
108+
def app_server(celery_app: Celery, celery_settings: CelerySettings) -> BaseAppServer:
109+
return FakeAppServer(app=celery_app, settings=celery_settings)
79110

80111

81112
@pytest.fixture(scope="session")
@@ -98,11 +129,10 @@ def celery_config() -> dict[str, Any]:
98129
async def with_celery_worker(
99130
celery_app: Celery,
100131
app_server: BaseAppServer,
101-
celery_settings: CelerySettings,
102132
register_celery_tasks: Callable[[Celery], None],
103133
) -> AsyncIterator[TestWorkController]:
104134
def _on_worker_init_wrapper(sender: WorkController, **_kwargs):
105-
return partial(on_worker_init, app_server, celery_settings)(sender, **_kwargs)
135+
return on_worker_init(sender, app_server, **_kwargs)
106136

107137
worker_init.connect(_on_worker_init_wrapper)
108138
worker_shutdown.connect(on_worker_shutdown)
@@ -111,24 +141,39 @@ def _on_worker_init_wrapper(sender: WorkController, **_kwargs):
111141

112142
with start_worker(
113143
celery_app,
114-
pool="threads",
115144
concurrency=1,
145+
pool="threads",
116146
loglevel="info",
117147
perform_ping_check=False,
118148
queues="default",
119149
) as worker:
120150
yield worker
121151

122152

153+
@pytest.fixture
154+
async def mock_celery_app(celery_config: dict[str, Any]) -> Celery:
155+
return Celery(**celery_config)
156+
157+
123158
@pytest.fixture
124159
async def celery_task_manager(
125-
celery_app: Celery,
160+
mock_celery_app: Celery,
126161
celery_settings: CelerySettings,
127-
with_celery_worker: TestWorkController,
128-
) -> CeleryTaskManager:
162+
use_in_memory_redis: RedisSettings,
163+
) -> AsyncIterator[CeleryTaskManager]:
129164
register_celery_types()
130165

131-
return await create_task_manager(
132-
celery_app,
133-
celery_settings,
134-
)
166+
try:
167+
redis_client_sdk = RedisClientSDK(
168+
use_in_memory_redis.build_redis_dsn(RedisDatabase.CELERY_TASKS),
169+
client_name="pytest_celery_tasks",
170+
)
171+
await redis_client_sdk.setup()
172+
173+
yield CeleryTaskManager(
174+
mock_celery_app,
175+
celery_settings,
176+
RedisTaskInfoStore(redis_client_sdk),
177+
)
178+
finally:
179+
await redis_client_sdk.shutdown()

packages/celery-library/tests/unit/test_tasks.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import pytest
1414
from celery import Celery, Task # pylint: disable=no-name-in-module
15+
from celery.worker.worker import WorkController # pylint: disable=no-name-in-module
1516
from celery_library.errors import TaskNotFoundError, TransferrableCeleryError
1617
from celery_library.task import register_task
1718
from celery_library.task_manager import CeleryTaskManager
@@ -33,6 +34,10 @@
3334
pytest_simcore_ops_services_selection = []
3435

3536

37+
class MyTaskFilter(TaskFilter):
38+
user_id: int
39+
40+
3641
async def _fake_file_processor(
3742
celery_app: Celery, task_name: str, task_id: str, files: list[str]
3843
) -> str:
@@ -91,8 +96,9 @@ def _(celery_app: Celery) -> None:
9196

9297
async def test_submitting_task_calling_async_function_results_with_success_state(
9398
celery_task_manager: CeleryTaskManager,
99+
with_celery_worker: WorkController,
94100
):
95-
task_filter = TaskFilter(user_id=42)
101+
task_filter = MyTaskFilter(user_id=42)
96102

97103
task_uuid = await celery_task_manager.submit_task(
98104
TaskMetadata(
@@ -121,8 +127,9 @@ async def test_submitting_task_calling_async_function_results_with_success_state
121127

122128
async def test_submitting_task_with_failure_results_with_error(
123129
celery_task_manager: CeleryTaskManager,
130+
with_celery_worker: WorkController,
124131
):
125-
task_filter = TaskFilter(user_id=42)
132+
task_filter = MyTaskFilter(user_id=42)
126133

127134
task_uuid = await celery_task_manager.submit_task(
128135
TaskMetadata(
@@ -149,8 +156,9 @@ async def test_submitting_task_with_failure_results_with_error(
149156

150157
async def test_cancelling_a_running_task_aborts_and_deletes(
151158
celery_task_manager: CeleryTaskManager,
159+
with_celery_worker: WorkController,
152160
):
153-
task_filter = TaskFilter(user_id=42)
161+
task_filter = MyTaskFilter(user_id=42)
154162

155163
task_uuid = await celery_task_manager.submit_task(
156164
TaskMetadata(
@@ -171,8 +179,9 @@ async def test_cancelling_a_running_task_aborts_and_deletes(
171179

172180
async def test_listing_task_uuids_contains_submitted_task(
173181
celery_task_manager: CeleryTaskManager,
182+
with_celery_worker: WorkController,
174183
):
175-
task_filter = TaskFilter(user_id=42)
184+
task_filter = MyTaskFilter(user_id=42)
176185

177186
task_uuid = await celery_task_manager.submit_task(
178187
TaskMetadata(
@@ -190,5 +199,5 @@ async def test_listing_task_uuids_contains_submitted_task(
190199
tasks = await celery_task_manager.list_tasks(task_filter)
191200
assert any(task.uuid == task_uuid for task in tasks)
192201

193-
tasks = await celery_task_manager.list_tasks(task_filter)
194-
assert any(task.uuid == task_uuid for task in tasks)
202+
tasks = await celery_task_manager.list_tasks(task_filter)
203+
assert any(task.uuid == task_uuid for task in tasks)

0 commit comments

Comments
 (0)