Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions src/lib/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,53 @@

logger = logging.getLogger(__name__)

_user_id_context = contextvars.ContextVar('user_id', default='')
_job_id_context = contextvars.ContextVar('job_id', default='')


class UserLogContext:
"""
Log context for adding user_id to JSON log records.
"""

def __init__(self, user_id: str = ''):
self.user_id = user_id
self._token: contextvars.Token[str] | None = None

def __enter__(self):
# Always set the ContextVar — including to '' — so a nested context
# without a user_id masks any value from an outer scope rather than
# leaking it into logs emitted inside this block.
self._token = _user_id_context.set(self.user_id)
return self

def __exit__(self, ex_type, ex_value, ex_traceback):
# pylint: disable=unused-argument
if self._token is not None:
_user_id_context.reset(self._token)


class JobLogContext:
"""
Log context for adding job_id to JSON log records.
"""

def __init__(self, job_id: str = ''):
self.job_id = job_id
self._token: contextvars.Token[str] | None = None

def __enter__(self):
# Always set the ContextVar — including to '' — so a nested context
# without a job_id masks any value from an outer scope rather than
# leaking it into logs emitted inside this block.
self._token = _job_id_context.set(self.job_id)
return self

def __exit__(self, ex_type, ex_value, ex_traceback):
# pylint: disable=unused-argument
if self._token is not None:
_job_id_context.reset(self._token)


class LoggingLevel(enum.IntEnum):
"""
Expand Down Expand Up @@ -111,22 +158,28 @@ class WorkflowLogContext:
All logging, even within subfunctions, inside this context will have the workflow ID
attribute included with the log. Users should only use this for single threaded instances.
If 'extra' parameter is used when using logging inside the context, the workflow_uuid attribute
will not be overridden.
will not be overridden. user_id and job_id are emitted only by JsonServiceFormatter.
"""

def __init__(self, workflow_uuid: str):
def __init__(self, workflow_uuid: str, user_id: str = '', job_id: str = ''):
self.workflow_uuid = workflow_uuid
self._filter = WorkflowLogFilter(workflow_uuid)
self._user_context = UserLogContext(user_id)
self._job_context = JobLogContext(job_id)

def __enter__(self):
if self.workflow_uuid:
logging.getLogger().addFilter(self._filter)
self._user_context.__enter__()
self._job_context.__enter__()
return self

def __exit__(self, ex_type, ex_value, ex_traceback):
# pylint: disable=unused-argument
if self.workflow_uuid:
logging.getLogger().removeFilter(self._filter)
self._job_context.__exit__(ex_type, ex_value, ex_traceback)
self._user_context.__exit__(ex_type, ex_value, ex_traceback)


class LogFormat(str, enum.Enum):
Expand Down Expand Up @@ -235,6 +288,9 @@ class JsonServiceFormatter(logging.Formatter):
message: formatted message body
backend: set only for backend loggers (get_backend_logger)
workflow_uuid: set only when the WorkflowLogFilter or extra= adds it
user_id: set only in JSON logs when UserLogContext or extra= adds it
job_id: set only in JSON logs when JobLogContext or extra= adds it
status_code: set only in JSON logs when extra= adds it (HTTP responses)
exception: set only when exc_info is provided (formatted traceback)
stack: set only when stack_info is provided
"""
Expand Down Expand Up @@ -263,6 +319,19 @@ def format(self, record):
workflow_uuid = getattr(record, 'workflow_uuid', None)
if workflow_uuid:
payload['workflow_uuid'] = workflow_uuid
user_id = getattr(record, 'user_id', None)
if user_id is None:
user_id = _user_id_context.get()
if user_id:
payload['user_id'] = user_id
job_id = getattr(record, 'job_id', None)
if job_id is None:
job_id = _job_id_context.get()
if job_id:
payload['job_id'] = job_id
status_code = getattr(record, 'status_code', None)
if status_code is not None:
payload['status_code'] = status_code
if record.exc_info:
payload['exception'] = self.formatException(record.exc_info)
if record.stack_info:
Expand Down
70 changes: 70 additions & 0 deletions src/lib/utils/tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def _make_record(
level: int = logging.INFO,
message: str = 'hello world',
workflow_uuid: str | None = None,
user_id: str | None = None,
job_id: str | None = None,
status_code: int | None = None,
exc_info=None,
) -> logging.LogRecord:
record = logging.LogRecord(
Expand All @@ -42,6 +45,12 @@ def _make_record(
record.module = 'test_logging'
if workflow_uuid is not None:
record.workflow_uuid = workflow_uuid
if user_id is not None:
record.user_id = user_id
if job_id is not None:
record.job_id = job_id
if status_code is not None:
record.status_code = status_code
return record


Expand Down Expand Up @@ -92,6 +101,7 @@ def test_required_fields_emitted(self):
self.assertIn('timestamp', payload)
self.assertNotIn('backend', payload)
self.assertNotIn('workflow_uuid', payload)
self.assertNotIn('user_id', payload)
self.assertNotIn('exception', payload)

def test_backend_field_emitted_for_backend_loggers(self):
Expand All @@ -106,6 +116,66 @@ def test_workflow_uuid_propagated_when_present(self):
payload = json.loads(formatter.format(_make_record(workflow_uuid='wf-123')))
self.assertEqual(payload['workflow_uuid'], 'wf-123')

def test_user_id_propagated_when_present(self):
formatter = logging_utils.JsonServiceFormatter(service='osmo-test')
payload = json.loads(formatter.format(_make_record(user_id='alice@example.com')))
self.assertEqual(payload['user_id'], 'alice@example.com')

def test_user_id_context_propagated_to_json(self):
formatter = logging_utils.JsonServiceFormatter(service='osmo-test')
with logging_utils.UserLogContext('alice@example.com'):
payload = json.loads(formatter.format(_make_record()))
self.assertEqual(payload['user_id'], 'alice@example.com')

def test_empty_user_id_context_masks_outer_user_id(self):
formatter = logging_utils.JsonServiceFormatter(service='osmo-test')
with logging_utils.UserLogContext('alice@example.com'):
with logging_utils.UserLogContext(''):
payload = json.loads(formatter.format(_make_record()))
self.assertNotIn('user_id', payload)
# Outer user_id is restored after the inner context exits.
payload = json.loads(formatter.format(_make_record()))
self.assertEqual(payload['user_id'], 'alice@example.com')

def test_job_id_propagated_when_present(self):
formatter = logging_utils.JsonServiceFormatter(service='osmo-test')
payload = json.loads(formatter.format(_make_record(job_id='job-123')))
self.assertEqual(payload['job_id'], 'job-123')

def test_job_id_context_propagated_to_json(self):
formatter = logging_utils.JsonServiceFormatter(service='osmo-test')
with logging_utils.JobLogContext('job-123'):
payload = json.loads(formatter.format(_make_record()))
self.assertEqual(payload['job_id'], 'job-123')

def test_empty_job_id_context_masks_outer_job_id(self):
formatter = logging_utils.JsonServiceFormatter(service='osmo-test')
with logging_utils.JobLogContext('job-123'):
with logging_utils.JobLogContext(''):
payload = json.loads(formatter.format(_make_record()))
self.assertNotIn('job_id', payload)
payload = json.loads(formatter.format(_make_record()))
self.assertEqual(payload['job_id'], 'job-123')

def test_status_code_propagated_when_present(self):
formatter = logging_utils.JsonServiceFormatter(service='osmo-test')
payload = json.loads(formatter.format(_make_record(status_code=404)))
self.assertEqual(payload['status_code'], 404)

def test_status_code_absent_when_not_provided(self):
formatter = logging_utils.JsonServiceFormatter(service='osmo-test')
payload = json.loads(formatter.format(_make_record()))
self.assertNotIn('status_code', payload)

def test_workflow_log_context_propagates_job_id(self):
formatter = logging_utils.JsonServiceFormatter(service='osmo-test')
with logging_utils.WorkflowLogContext(
'wf-123', user_id='alice@example.com', job_id='job-123'
):
payload = json.loads(formatter.format(_make_record()))
self.assertEqual(payload['user_id'], 'alice@example.com')
self.assertEqual(payload['job_id'], 'job-123')

def test_exception_traceback_included(self):
formatter = logging_utils.JsonServiceFormatter(service='osmo-test')
try:
Expand Down
63 changes: 33 additions & 30 deletions src/operator/backend_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,36 +134,39 @@ async def run_job(self, job_spec: Dict, context: JobContext):

workflow_uuid = job.workflow_uuid if \
isinstance(job, backend_jobs.BackendWorkflowJob) else ''
logging.info('Starting job %s from the queue', job, extra={'workflow_uuid': workflow_uuid})
job_start_time = time.time()
try:
result = await asyncio.to_thread(
job.execute, context, self._progress_writer, self._progress_iter_freq)
if result.status != jobs_base.JobStatus.SUCCESS:
result.message = f'Backend execution failed: {result.message}'
message = backend_messages.MessageBody(
type=backend_messages.MessageType.JOB_STATUS, body=result)
except Exception as error: # pylint: disable=broad-except
error_message = f'{type(error).__name__}: {error}'
logging.exception('Fatal exception of type %s when running job %s',
error_message, job, extra={'workflow_uuid': workflow_uuid})
message = backend_messages.MessageBody(
type=backend_messages.MessageType.JOB_STATUS,
body=jobs_base.JobResult(
status=jobs_base.JobStatus.FAILED_NO_RETRY,
message=f'Got exception when running backend execute: {error_message}\n' + \
f'Traceback: {traceback.format_exc()}'))

await context.send_async_message(message)

logging.info('Completed job %s with status %s', job, message.body,
extra={'workflow_uuid': workflow_uuid})
job_duration = time.time() - job_start_time
self.backend_metrics.send_histogram(
name='backend_job_execution_time', value=job_duration, unit='seconds',
description=f'Job execution time for {job.job_type}',
tags={'job_type': job.job_type, 'namespace': self.config.namespace}
)
with src.lib.utils.logging.UserLogContext(job_spec.get('user', '')), \
src.lib.utils.logging.JobLogContext(job.job_id or ''):
logging.info(
'Starting job %s from the queue', job, extra={'workflow_uuid': workflow_uuid})
job_start_time = time.time()
try:
result = await asyncio.to_thread(
job.execute, context, self._progress_writer, self._progress_iter_freq)
if result.status != jobs_base.JobStatus.SUCCESS:
result.message = f'Backend execution failed: {result.message}'
message = backend_messages.MessageBody(
type=backend_messages.MessageType.JOB_STATUS, body=result)
except Exception as error: # pylint: disable=broad-except
error_message = f'{type(error).__name__}: {error}'
logging.exception('Fatal exception of type %s when running job %s',
error_message, job, extra={'workflow_uuid': workflow_uuid})
message = backend_messages.MessageBody(
type=backend_messages.MessageType.JOB_STATUS,
body=jobs_base.JobResult(
status=jobs_base.JobStatus.FAILED_NO_RETRY,
message=f'Got exception when running backend execute: {error_message}\n' + \
f'Traceback: {traceback.format_exc()}'))

await context.send_async_message(message)

logging.info('Completed job %s with status %s', job, message.body,
extra={'workflow_uuid': workflow_uuid})
job_duration = time.time() - job_start_time
self.backend_metrics.send_histogram(
name='backend_job_execution_time', value=job_duration, unit='seconds',
description=f'Job execution time for {job.job_type}',
tags={'job_type': job.job_type, 'namespace': self.config.namespace}
)
self._current_job = None

def _monitor_progress(self):
Expand Down
4 changes: 3 additions & 1 deletion src/service/agent/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ def run_job(self, job_spec: Dict, message: kombu.transport.virtual.base.Message)
extra = {}
if 'workflow_uuid' in job_spec:
extra['workflow_uuid'] = job_spec['workflow_uuid']
if 'user' in job_spec:
extra['user_id'] = job_spec['user']
logging.info('Completed job (type=%s, id=%s) with status %s',
job_spec['job_type'], job_spec['job_id'],
result,extra=extra)
result, extra=extra)

async def run_jobs(self, backend_name: str):
worker_thread = threading.Thread(target=self.run, daemon=True)
Expand Down
12 changes: 12 additions & 0 deletions src/service/core/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@

@app.middleware('http')
async def check_client_version(request: fastapi.Request, call_next):
user_id = request.headers.get(login.OSMO_USER_HEADER, '')
with src.lib.utils.logging.UserLogContext(user_id):
response = await _check_client_version(request, call_next)
logging.info(
'%s %s -> %d',
request.method, request.url.path, response.status_code,
extra={'status_code': response.status_code},
)
return response


async def _check_client_version(request: fastapi.Request, call_next):
client_version_str = request.headers.get(version.VERSION_HEADER)
token_name = request.headers.get(login.OSMO_TOKEN_NAME_HEADER)
if client_version_str is None:
Expand Down
1 change: 1 addition & 0 deletions src/service/core/workflow/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,7 @@ def convert_task_file_contents(curr_task_spec: Dict):
upload_spec_job = jobs.UploadWorkflowFiles(
workflow_id=workflow_id,
workflow_uuid=self.base32_id,
user=self.user,
files=files)
upload_spec_job.send_job_to_queue()

Expand Down
4 changes: 3 additions & 1 deletion src/service/logger/ctrl_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ async def run_websocket(websocket: fastapi.WebSocket, name: str, task_name: str,
workflow_obj = workflow.Workflow.fetch_from_db(database, name)
group_name = task.Task.fetch_group_name(database, workflow_obj.workflow_id, task_name)

with src.lib.utils.logging.WorkflowLogContext(workflow_obj.workflow_uuid):
with src.lib.utils.logging.WorkflowLogContext(
workflow_obj.workflow_uuid, user_id=workflow_obj.user
):

task_cred_values = task.TaskGroup.fetch_task_secrets(database,
workflow_obj.workflow_id,
Expand Down
6 changes: 5 additions & 1 deletion src/service/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def run_job(self, job_spec: Dict, message: kombu.transport.virtual.base.Message)
self._current_job = job

workflow_uuid = job.workflow_uuid if isinstance(job, jobs.WorkflowJob) else ''
with src.lib.utils.logging.WorkflowLogContext(workflow_uuid):
with src.lib.utils.logging.WorkflowLogContext(
workflow_uuid,
user_id=getattr(job, 'user', ''),
job_id=job.job_id or '',
):
logging.info('Starting job %s from the queue', job)
job_metadata = job.get_metadata()

Expand Down
8 changes: 7 additions & 1 deletion src/ui/src/lib/auth/user-context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

"use client";

import { createContext, useContext, type ReactNode } from "react";
import { createContext, useContext, useEffect, type ReactNode } from "react";
import { setLogUserId } from "@/lib/logger";

export interface User {
id: string;
Expand Down Expand Up @@ -56,6 +57,11 @@ export function UserProvider({ children, initialUser }: UserProviderProps) {
window.location.href = getLogoutUrl();
};

useEffect(() => {
setLogUserId(initialUser?.id ?? null);
return () => setLogUserId(null);
}, [initialUser?.id]);

return (
<UserContext.Provider value={{ user: initialUser, isLoading: false, logout }}>{children}</UserContext.Provider>
);
Expand Down
Loading
Loading