Skip to content

Commit fc22d6f

Browse files
bmartelmcanuWesley Limahlomziknick-skriabin
authored
chore: FIT-671: Add state transitions to the current application flows (#8534)
Co-authored-by: Marcel Canu <[email protected]> Co-authored-by: bmartel <[email protected]> Co-authored-by: Wesley Lima <[email protected]> Co-authored-by: Andrew <[email protected]> Co-authored-by: hlomzik <[email protected]> Co-authored-by: Nick Skriabin <[email protected]> Co-authored-by: robot-ci-heartex <[email protected]> Co-authored-by: nick-skriabin <[email protected]>
1 parent 6884c9d commit fc22d6f

35 files changed

+4108
-568
lines changed
Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from threading import local
2+
from typing import Any
23

34
from django.core.signals import request_finished
45
from django.dispatch import receiver
@@ -7,18 +8,104 @@
78
_thread_locals = local()
89

910

11+
class CurrentContext:
12+
@classmethod
13+
def set(cls, key: str, value: Any, shared: bool = True) -> None:
14+
if not hasattr(_thread_locals, 'data'):
15+
_thread_locals.data = {}
16+
if not hasattr(_thread_locals, 'job_data'):
17+
_thread_locals.job_data = {}
18+
19+
if shared:
20+
_thread_locals.job_data[key] = value
21+
else:
22+
_thread_locals.data[key] = value
23+
24+
@classmethod
25+
def get(cls, key: str, default=None):
26+
return getattr(_thread_locals, 'job_data', {}).get(key, getattr(_thread_locals, 'data', {}).get(key, default))
27+
28+
@classmethod
29+
def set_request(cls, request):
30+
_thread_locals.request = request
31+
if request.user:
32+
cls.set_user(request.user)
33+
34+
@classmethod
35+
def get_organization_id(cls):
36+
return cls.get('organization_id')
37+
38+
@classmethod
39+
def set_organization_id(cls, organization_id: int):
40+
cls.set('organization_id', organization_id)
41+
42+
@classmethod
43+
def get_user(cls):
44+
return cls.get('user')
45+
46+
@classmethod
47+
def set_user(cls, user):
48+
cls.set('user', user)
49+
if getattr(user, 'active_organization_id', None):
50+
cls.set_organization_id(user.active_organization_id)
51+
52+
@classmethod
53+
def set_fsm_disabled(cls, disabled: bool):
54+
"""
55+
Temporarily disable/enable FSM for the current thread.
56+
57+
This is useful for test cleanup and bulk operations where FSM state
58+
tracking is not needed and would cause performance issues.
59+
60+
Args:
61+
disabled: True to disable FSM, False to enable it
62+
"""
63+
cls.set('fsm_disabled', disabled)
64+
65+
@classmethod
66+
def is_fsm_disabled(cls) -> bool:
67+
"""
68+
Check if FSM is disabled for the current thread.
69+
70+
Returns:
71+
True if FSM is disabled, False otherwise
72+
"""
73+
return cls.get('fsm_disabled', False)
74+
75+
@classmethod
76+
def get_job_data(cls) -> dict:
77+
"""
78+
This data will be shared to jobs spawned by the current thread.
79+
"""
80+
return getattr(_thread_locals, 'job_data', {})
81+
82+
@classmethod
83+
def clear(cls) -> None:
84+
if hasattr(_thread_locals, 'data'):
85+
delattr(_thread_locals, 'data')
86+
87+
if hasattr(_thread_locals, 'job_data'):
88+
delattr(_thread_locals, 'job_data')
89+
90+
if hasattr(_thread_locals, 'request'):
91+
del _thread_locals.request
92+
93+
@classmethod
94+
def get_request(cls):
95+
return getattr(_thread_locals, 'request', None)
96+
97+
1098
def get_current_request():
1199
"""returns the request object for this thread"""
12-
result = getattr(_thread_locals, 'request', None)
100+
result = CurrentContext.get_request()
13101
return result
14102

15103

16104
class ThreadLocalMiddleware(CommonMiddleware):
17105
def process_request(self, request):
18-
_thread_locals.request = request
106+
CurrentContext.set_request(request)
19107

20108

21109
@receiver(request_finished)
22110
def clean_request(sender, **kwargs):
23-
if hasattr(_thread_locals, 'request'):
24-
del _thread_locals.request
111+
CurrentContext.clear()

label_studio/core/redis.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import sys
55
from datetime import timedelta
66
from functools import partial
7+
from typing import Any
78

89
import django_rq
910
import redis
11+
from core.current_request import CurrentContext
1012
from django.conf import settings
1113
from django_rq import get_connection
1214
from rq.command import send_stop_job_command
@@ -80,6 +82,40 @@ def redis_connected():
8082
return redis_healthcheck()
8183

8284

85+
def _is_serializable(value: Any) -> bool:
86+
"""Check if a value can be serialized for job context."""
87+
return isinstance(value, (str, int, float, bool, list, dict, type(None)))
88+
89+
90+
def _capture_context() -> dict:
91+
"""
92+
Capture the current context for passing to a job.
93+
Returns a dictionary of context data that can be serialized.
94+
"""
95+
context_data = {}
96+
97+
# Get user information
98+
if user := CurrentContext.get_user():
99+
context_data['user_id'] = user.id
100+
101+
# Get organization if set separately
102+
if org_id := CurrentContext.get_organization_id():
103+
context_data['organization_id'] = org_id
104+
105+
# If organization_id is not set, try to get it from the user, this ensures that we have an organization_id for the job
106+
# And it prefers the original requesting user's organization_id over the current active organization_id of the user which could change during async jobs
107+
if not org_id and user and hasattr(user, 'active_organization_id') and user.active_organization_id:
108+
context_data['organization_id'] = user.active_organization_id
109+
110+
# Get any custom context values (exclude non-serializable objects)
111+
job_data = CurrentContext.get_job_data()
112+
for key, value in job_data.items():
113+
if key not in ['user', 'request'] and _is_serializable(value):
114+
context_data[key] = value
115+
116+
return context_data
117+
118+
83119
def redis_get(key):
84120
if not redis_healthcheck():
85121
return
@@ -112,7 +148,9 @@ def redis_delete(key):
112148

113149
def start_job_async_or_sync(job, *args, in_seconds=0, **kwargs):
114150
"""
115-
Start job async with redis or sync if redis is not connected
151+
Start job async with redis or sync if redis is not connected.
152+
Automatically preserves context for async jobs and clears it after completion.
153+
116154
:param job: Job function
117155
:param args: Function arguments
118156
:param in_seconds: Job will be delayed for in_seconds
@@ -122,28 +160,29 @@ def start_job_async_or_sync(job, *args, in_seconds=0, **kwargs):
122160

123161
redis = redis_connected() and kwargs.get('redis', True)
124162
queue_name = kwargs.get('queue_name', 'default')
163+
125164
if 'queue_name' in kwargs:
126165
del kwargs['queue_name']
127166
if 'redis' in kwargs:
128167
del kwargs['redis']
168+
129169
job_timeout = None
130170
if 'job_timeout' in kwargs:
131171
job_timeout = kwargs['job_timeout']
132172
del kwargs['job_timeout']
173+
133174
if redis:
134-
# Auto-capture request_id from thread local and pass it via job meta
175+
# Async execution with Redis - wrap job for context management
135176
try:
136-
from label_studio.core.current_request import _thread_locals
177+
context_data = _capture_context()
137178

138-
request_id = getattr(_thread_locals, 'request_id', None)
139-
if request_id:
140-
# Store in job meta for worker access
179+
if context_data:
141180
meta = kwargs.get('meta', {})
142-
meta['request_id'] = request_id
181+
# Store context data in job meta for worker access
182+
meta.update(context_data)
143183
kwargs['meta'] = meta
144184
except Exception:
145-
# Fail silently if no request context
146-
pass
185+
logger.info(f'Failed to capture context for job {job.__name__} on queue {queue_name}')
147186

148187
try:
149188
args_info = _truncate_args_for_logging(args, kwargs)
@@ -154,6 +193,7 @@ def start_job_async_or_sync(job, *args, in_seconds=0, **kwargs):
154193
enqueue_method = queue.enqueue
155194
if in_seconds > 0:
156195
enqueue_method = partial(queue.enqueue_in, timedelta(seconds=in_seconds))
196+
157197
job = enqueue_method(
158198
job,
159199
*args,
@@ -164,8 +204,10 @@ def start_job_async_or_sync(job, *args, in_seconds=0, **kwargs):
164204
return job
165205
else:
166206
on_failure = kwargs.pop('on_failure', None)
207+
167208
try:
168-
return job(*args, **kwargs)
209+
result = job(*args, **kwargs)
210+
return result
169211
except Exception:
170212
exc_info = sys.exc_info()
171213
if on_failure:

label_studio/core/settings/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@
216216
'rest_framework.authtoken',
217217
'rest_framework_simplejwt.token_blacklist',
218218
'drf_generators',
219+
'fsm', # MUST be before apps that register FSM transitions (projects, tasks)
219220
'core',
220221
'users',
221222
'organizations',
@@ -232,7 +233,6 @@
232233
'ml_model_providers',
233234
'jwt_auth',
234235
'session_policy',
235-
'fsm',
236236
]
237237

238238
MIDDLEWARE = [

label_studio/core/tests/test_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ def _assert_delete_and_restore_equal(self, drow, original):
1818
original_dict.pop('_state')
1919
original_created_at = original_dict.pop('created_at')
2020
original_updated_at = original_dict.pop('updated_at')
21+
# Pop _original_values - this is an internal FSM field that's recreated on __init__
22+
# and shouldn't be compared
23+
original_dict.pop('_original_values', None)
2124
original.delete()
2225

2326
for deserialized_object in serializers.deserialize('json', json.dumps([drow.data])):
@@ -28,6 +31,9 @@ def _assert_delete_and_restore_equal(self, drow, original):
2831
new_dict.pop('_state')
2932
new_created_at = new_dict.pop('created_at')
3033
new_updated_at = new_dict.pop('updated_at')
34+
# Pop _original_values - this is an internal FSM field that's recreated on __init__
35+
# and shouldn't be compared
36+
new_dict.pop('_original_values', None)
3137

3238
assert new_dict == original_dict
3339
# Datetime loses microsecond precision, so we can't compare them directly

label_studio/data_manager/actions/basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def delete_tasks_annotations(project, queryset, **kwargs):
9999
drafts = drafts.filter(user=int(annotator_id))
100100
project.summary.remove_created_drafts_and_labels(drafts)
101101

102-
count, _ = annotations.delete()
102+
# count before delete to return the number of deleted items, not including cascade deletions
103+
count = annotations.count()
104+
annotations.delete()
103105
drafts.delete() # since task-level annotation drafts will not have been deleted by CASCADE
104106
emit_webhooks_for_instance(project.organization, project, WebhookAction.ANNOTATIONS_DELETED, annotations_ids)
105107
request = kwargs['request']

label_studio/fsm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class OrderStateChoices(models.TextChoices):
5656
### 2. Create State Model
5757

5858
```python
59-
from fsm.models import BaseState
59+
from fsm.state_models import BaseState
6060
from fsm.registry import register_state_model
6161

6262
@register_state_model('order')

0 commit comments

Comments
 (0)