diff --git a/application/backend/src/api/endpoints/job_endpoints.py b/application/backend/src/api/endpoints/job_endpoints.py index cd2841ba7e..bb01e187fb 100644 --- a/application/backend/src/api/endpoints/job_endpoints.py +++ b/application/backend/src/api/endpoints/job_endpoints.py @@ -4,13 +4,13 @@ from typing import Annotated from uuid import UUID -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Body, Depends, status from sse_starlette import EventSourceResponse from api.dependencies import get_job_id, get_job_service from api.endpoints import API_PREFIX from pydantic_models import JobList -from pydantic_models.job import JobSubmitted, TrainJobPayload +from pydantic_models.job import JobCancelled, JobSubmitted, TrainJobPayload from services import JobService job_api_prefix_url = API_PREFIX + "/jobs" @@ -42,3 +42,21 @@ async def get_job_logs( ) -> EventSourceResponse: """Endpoint to get the logs of a job by its ID""" return EventSourceResponse(job_service.stream_logs(job_id=job_id)) + + +@job_router.get("/{job_id}/progress") +async def get_job_progress( + job_id: Annotated[UUID, Depends(get_job_id)], + job_service: Annotated[JobService, Depends(get_job_service)], +) -> EventSourceResponse: + """Endpoint to get the progress of a job by its ID""" + return EventSourceResponse(job_service.stream_progress(job_id=job_id)) + + +@job_router.post("/{job_id}:cancel", status_code=status.HTTP_202_ACCEPTED) +async def cancel_job( + job_id: Annotated[UUID, Depends(get_job_id)], + job_service: Annotated[JobService, Depends(get_job_service)], +) -> JobCancelled: + """Endpoint to cancel a job by its ID""" + return await job_service.cancel_job(job_id=job_id) diff --git a/application/backend/src/pydantic_models/job.py b/application/backend/src/pydantic_models/job.py index 051119a453..d701c7900b 100644 --- a/application/backend/src/pydantic_models/job.py +++ b/application/backend/src/pydantic_models/job.py @@ -5,7 +5,7 @@ from typing import Any from uuid import UUID -from pydantic import BaseModel, Field, field_serializer +from pydantic import BaseModel, Field, computed_field, field_serializer from pydantic_models.base import BaseIDModel @@ -46,6 +46,14 @@ class JobSubmitted(BaseModel): job_id: UUID +class JobCancelled(BaseModel): + job_id: UUID + + @computed_field + def message(self) -> str: + return f"Job with ID `{self.job_id}` marked as cancelled." + + class TrainJobPayload(BaseModel): project_id: UUID = Field(exclude=True) model_name: str diff --git a/application/backend/src/services/job_service.py b/application/backend/src/services/job_service.py index 2ccd4b220d..d98203124f 100644 --- a/application/backend/src/services/job_service.py +++ b/application/backend/src/services/job_service.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio import datetime +import json +import logging from collections.abc import AsyncGenerator from uuid import UUID @@ -12,9 +14,11 @@ from db import get_async_db_session_ctx from exceptions import DuplicateJobException, ResourceNotFoundException from pydantic_models import Job, JobList, JobType -from pydantic_models.job import JobStatus, JobSubmitted, TrainJobPayload +from pydantic_models.job import JobCancelled, JobStatus, JobSubmitted, TrainJobPayload from repositories import JobRepository +logger = logging.getLogger(__name__) + class JobService: @staticmethod @@ -24,7 +28,7 @@ async def get_job_list(extra_filters: dict | None = None) -> JobList: return JobList(jobs=await repo.get_all(extra_filters=extra_filters)) @staticmethod - async def get_job_by_id(job_id: UUID) -> Job | None: + async def get_job_by_id(job_id: UUID | str) -> Job | None: async with get_async_db_session_ctx() as session: repo = JobRepository(session) return await repo.get_by_id(job_id) @@ -56,7 +60,10 @@ async def get_pending_train_job() -> Job | None: @staticmethod async def update_job_status( - job_id: UUID, status: JobStatus, message: str | None = None, progress: int | None = None + job_id: UUID, + status: JobStatus, + message: str | None = None, + progress: int | None = None, ) -> None: async with get_async_db_session_ctx() as session: repo = JobRepository(session) @@ -75,6 +82,13 @@ async def update_job_status( updates["progress"] = progress_ await repo.update(job, updates) + @classmethod + async def is_job_still_running(cls, job_id: UUID | str) -> bool: + job = await cls.get_job_by_id(job_id=job_id) + if job is None: + raise ResourceNotFoundException(resource_id=job_id, resource_name="job") + return job.status == JobStatus.RUNNING + @classmethod async def stream_logs(cls, job_id: UUID | str) -> AsyncGenerator[ServerSentEvent]: from core.logging.utils import get_job_logs_path # noqa: PLC0415 @@ -83,12 +97,6 @@ async def stream_logs(cls, job_id: UUID | str) -> AsyncGenerator[ServerSentEvent if not await anyio.Path(log_file).exists(): raise ResourceNotFoundException(resource_id=job_id, resource_name="job_logs") - async def is_job_still_running(): - job = await cls.get_job_by_id(job_id=job_id) - if job is None: - raise ResourceNotFoundException(resource_id=job_id, resource_name="job") - return job.status == JobStatus.RUNNING - # Cache job status and only check every 2 seconds status_check_interval = 2.0 # seconds last_status_check = 0.0 @@ -101,7 +109,7 @@ async def is_job_still_running(): now = loop.time() # Only check job status every status_check_interval seconds if now - last_status_check > status_check_interval: - cached_still_running = await is_job_still_running() + cached_still_running = await cls.is_job_still_running(job_id=job_id) last_status_check = now still_running = cached_still_running if not line: @@ -113,3 +121,27 @@ async def is_job_still_running(): else: break yield ServerSentEvent(data=line.rstrip()) + + @classmethod + async def stream_progress(cls, job_id: UUID | str) -> AsyncGenerator[ServerSentEvent]: + """Stream the progress of a job by its ID""" + still_running = True + while still_running: + job = await cls.get_job_by_id(job_id=job_id) + if job is None: + raise ResourceNotFoundException(resource_id=job_id, resource_name="job") + yield ServerSentEvent(data=json.dumps({"progress": job.progress, "message": job.message})) + still_running = job.status in {JobStatus.RUNNING, JobStatus.PENDING} + await asyncio.sleep(0.5) + + @classmethod + async def cancel_job(cls, job_id: UUID | str) -> JobCancelled: + """Cancel a job by its ID""" + async with get_async_db_session_ctx() as session: + repo = JobRepository(session) + job = await repo.get_by_id(job_id) + if job is None: + raise ResourceNotFoundException(resource_id=job_id, resource_name="job") + + await repo.update(job, {"status": JobStatus.CANCELED}) + return JobCancelled(job_id=job.id) diff --git a/application/backend/src/services/training_service.py b/application/backend/src/services/training_service.py index 21bff1f7ca..84ce5f5232 100644 --- a/application/backend/src/services/training_service.py +++ b/application/backend/src/services/training_service.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio from contextlib import redirect_stdout +from uuid import UUID from anomalib.data import Folder from anomalib.data.utils import TestSplitMode @@ -15,6 +16,7 @@ from repositories.binary_repo import ImageBinaryRepository, ModelBinaryRepository from services import ModelService from services.job_service import JobService +from utils.callbacks import GetiInspectProgressCallback, ProgressSyncParams from utils.devices import Devices from utils.experiment_loggers import TrackioLogger @@ -70,12 +72,24 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model: name=str(model_name), train_job_id=job.id, ) + synchronization_parameters = ProgressSyncParams() logger.info(f"Training model `{model_name}` for job `{job.id}`") + synchronization_task: asyncio.Task[None] | None = None try: + synchronization_task = asyncio.create_task( + cls._sync_progress_with_db( + job_service=job_service, job_id=job.id, synchronization_parameters=synchronization_parameters + ) + ) # Use asyncio.to_thread to keep event loop responsive # TODO: Consider ProcessPoolExecutor for true parallelism with multiple jobs - trained_model = await asyncio.to_thread(cls._train_model, model=model, device=device) + trained_model = await asyncio.to_thread( + cls._train_model, + model=model, + device=device, + synchronization_parameters=synchronization_parameters, + ) if trained_model is None: raise ValueError("Training failed - model is None") @@ -94,9 +108,15 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model: await model_binary_repo.delete_model_folder() await model_service.delete_model(project_id=project_id, model_id=model.id) raise e + finally: + logger.debug("Syncing progress with db stopped") + if synchronization_task is not None and not synchronization_task.done(): + synchronization_task.cancel() @staticmethod - def _train_model(model: Model, device: str | None = None) -> Model | None: + def _train_model( + model: Model, synchronization_parameters: ProgressSyncParams, device: str | None = None + ) -> Model | None: """ Execute CPU-intensive model training using anomalib. @@ -106,6 +126,7 @@ def _train_model(model: Model, device: str | None = None) -> Model | None: Args: model: Model object with training configuration + synchronization_parameters: Parameters for synchronization between the main process and the training process device: Device to train on Returns: @@ -145,7 +166,9 @@ def _train_model(model: Model, device: str | None = None) -> Model | None: engine = Engine( default_root_dir=model.export_path, logger=[trackio, tensorboard], + devices=[0], # Only single GPU training is supported for now max_epochs=10, + callbacks=[GetiInspectProgressCallback(synchronization_parameters)], accelerator=training_device, ) @@ -154,7 +177,7 @@ def _train_model(model: Model, device: str | None = None) -> Model | None: # Capture pytorch stdout logs into logger with redirect_stdout(LoggerStdoutWriter()): # type: ignore[type-var] - engine.train(model=anomalib_model, datamodule=datamodule) + engine.fit(model=anomalib_model, datamodule=datamodule) # Find and set threshold metric for callback in engine.trainer.callbacks: # type: ignore[attr-defined] @@ -172,6 +195,31 @@ def _train_model(model: Model, device: str | None = None) -> Model | None: model.is_ready = True return model + @classmethod + async def _sync_progress_with_db( + cls, + job_service: JobService, + job_id: UUID, + synchronization_parameters: ProgressSyncParams, + ) -> None: + try: + while True: + progress: int = synchronization_parameters.progress + message = synchronization_parameters.message + if not await job_service.is_job_still_running(job_id=job_id): + logger.debug("Job cancelled, stopping progress sync") + synchronization_parameters.set_cancel_training_event() + break + logger.debug(f"Syncing progress with db: {progress}% - {message}") + await job_service.update_job_status( + job_id=job_id, status=JobStatus.RUNNING, progress=progress, message=message + ) + await asyncio.sleep(0.5) + except Exception as e: + logger.exception("Failed to sync progress with db: %s", e) + await job_service.update_job_status(job_id=job_id, status=JobStatus.FAILED, message="Training failed") + raise + @staticmethod async def abort_orphan_jobs() -> None: """ diff --git a/application/backend/src/utils/callbacks.py b/application/backend/src/utils/callbacks.py new file mode 100644 index 0000000000..d241066d18 --- /dev/null +++ b/application/backend/src/utils/callbacks.py @@ -0,0 +1,226 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Lightning callback for sending progress to the frontend via the Plugin API.""" + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Any + +from lightning.pytorch.callbacks import Callback +from loguru import logger + +if TYPE_CHECKING: + from lightning.pytorch import LightningModule, Trainer + + +class ProgressSyncParams: + def __init__(self) -> None: + self._progress = 0 + self._message: str = "Initializing" + self._lock = threading.Lock() + self.cancel_training_event = threading.Event() + + @property + def message(self) -> str: + with self._lock: + return self._message + + @message.setter + def message(self, stage: str) -> None: + with self._lock: + self._message = f"Stage: {stage}" + logger.debug("Message updated: %s", self._message) + + @property + def progress(self) -> int: + with self._lock: + return self._progress + + @progress.setter + def progress(self, progress: int) -> None: + with self._lock: + self._progress = progress + logger.debug("Progress updated: %s", progress) + + def set_cancel_training_event(self) -> None: + with self._lock: + self.cancel_training_event.set() + logger.debug("Set cancel training event") + + +class GetiInspectProgressCallback(Callback): + """Callback for displaying training/validation/testing progress in the Geti Inspect UI. + + This callback sends progress events through a multiprocessing queue that the + main process polls and broadcasts via WebSocket to connected frontend clients. + + Args: + synchronization_parameters: Parameters for synchronization between the main process and the training process + + Example: + trainer = Trainer(callbacks=[GetiInspectProgressCallback(synchronization_parameters=ProgressSyncParams())]) + """ + + def __init__(self, synchronization_parameters: ProgressSyncParams) -> None: + """Initialize the callback with synchronization parameters. + Args: + synchronization_parameters: Parameters for synchronization between the main process and the training process + """ + self.synchronization_parameters = synchronization_parameters + + def _check_cancel_training(self, trainer: Trainer) -> None: + """Check if training should be canceled.""" + if self.synchronization_parameters.cancel_training_event.is_set(): + trainer.should_stop = True + + def _send_progress(self, progress: float, message: str) -> None: + """Send progress update to frontend via event queue. + Puts a generic event message into the multiprocessing queue which will + be picked up by the main process and broadcast via WebSocket. + Args: + progress: Progress value between 0.0 and 1.0 + message: The current training message + """ + # Convert progress to percentage (0-100) + progress_percent = int(progress * 100) + + try: + logger.debug("Sent progress: %s - %d%%", message, progress_percent) + self.synchronization_parameters.progress = progress_percent + self.synchronization_parameters.message = message + except Exception as e: + logger.warning("Failed to send progress to event queue: %s", e) + + # Training callbacks + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when training starts.""" + del pl_module # unused + if trainer.state.stage is not None: + self._send_progress(0, trainer.state.stage.value) + else: + self._send_progress(0, "Training started") + self._check_cancel_training(trainer) + + def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: + """Called when a training batch starts.""" + del pl_module, batch, batch_idx # unused + self._check_cancel_training(trainer) + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when a training epoch ends.""" + del pl_module # unused + # If max_epochs is not available, set progress to 0.5 + if trainer.state.stage is not None: + progress = ( + (trainer.current_epoch + 1) / trainer.max_epochs + if (trainer.max_epochs is not None and trainer.max_epochs > 0) + else 0.5 + ) + self._send_progress(progress, trainer.state.stage.value) + self._check_cancel_training(trainer) + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when training ends.""" + del pl_module # unused + if trainer.state.stage is not None: + self._send_progress(1.0, trainer.state.stage.value) + self._check_cancel_training(trainer) + + # Validation callbacks + def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when validation starts.""" + del pl_module # unused + self._check_cancel_training(trainer) + + def on_validation_batch_start( + self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Called when a validation batch starts.""" + del pl_module, batch, batch_idx, dataloader_idx # unused + self._check_cancel_training(trainer) + + def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when a validation epoch ends.""" + del pl_module # unused + self._check_cancel_training(trainer) + + def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when validation ends.""" + del pl_module # unused + self._check_cancel_training(trainer) + + # Test callbacks + def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when testing starts.""" + del pl_module # unused + if trainer.state.stage is not None: + self._send_progress(0, trainer.state.stage.value) + else: + self._send_progress(0, "Testing started") + self._check_cancel_training(trainer) + + def on_test_batch_start( + self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Called when a test batch starts.""" + del pl_module, batch, batch_idx, dataloader_idx # unused + self._check_cancel_training(trainer) + + def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when a test epoch ends.""" + del pl_module # unused + # If max_epochs is not available, set progress to 0.5 + if trainer.state.stage is not None: + progress = ( + (trainer.current_epoch + 1) / trainer.max_epochs + if (trainer.max_epochs is not None and trainer.max_epochs > 0) + else 0.5 + ) + self._send_progress(progress, trainer.state.stage.value) + self._check_cancel_training(trainer) + + def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when testing ends.""" + del pl_module # unused + if trainer.state.stage is not None: + self._send_progress(1.0, trainer.state.stage.value) + self._check_cancel_training(trainer) + + # Predict callbacks + def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when prediction starts.""" + del pl_module # unused + if trainer.state.stage is not None: + self._send_progress(0, trainer.state.stage.value) + else: + self._send_progress(0, "Prediction started") + self._check_cancel_training(trainer) + + def on_predict_batch_start( + self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Called when a prediction batch starts.""" + del pl_module, batch, batch_idx, dataloader_idx # unused + self._check_cancel_training(trainer) + + def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when a prediction epoch ends.""" + del pl_module # unused + # If max_epochs is not available, set progress to 0.5 + if trainer.state.stage is not None: + progress = ( + (trainer.current_epoch + 1) / trainer.max_epochs + if (trainer.max_epochs is not None and trainer.max_epochs > 0) + else 0.5 + ) + self._send_progress(progress, trainer.state.stage.value) + self._check_cancel_training(trainer) + + def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when prediction ends.""" + del pl_module # unused + if trainer.state.stage is not None: + self._send_progress(1.0, trainer.state.stage.value) + self._check_cancel_training(trainer) diff --git a/application/backend/tests/unit/services/test_training_service.py b/application/backend/tests/unit/services/test_training_service.py index f7ece5e8a8..6007751fda 100644 --- a/application/backend/tests/unit/services/test_training_service.py +++ b/application/backend/tests/unit/services/test_training_service.py @@ -8,6 +8,7 @@ from pydantic_models import JobStatus from repositories.binary_repo import ImageBinaryRepository, ModelBinaryRepository from services import TrainingService +from utils.callbacks import ProgressSyncParams @pytest.fixture @@ -191,7 +192,7 @@ def test_train_pending_job_cleanup_on_failure( with patch("services.training_service.asyncio.to_thread") as mock_to_thread: # Mock the training to succeed first, setting export_path, then fail - def mock_train_model(cls, model, device=None): + def mock_train_model(cls, model, synchronization_parameters: ProgressSyncParams, device=None): model.export_path = "/path/to/model" raise Exception("Training failed") @@ -221,7 +222,7 @@ def test_train_model_success( fxt_model_binary_repo.model_folder_path = "/path/to/model" # Call the method - result = TrainingService._train_model(fxt_model) + result = TrainingService._train_model(fxt_model, synchronization_parameters=ProgressSyncParams()) # Verify the result assert result == fxt_model @@ -240,7 +241,7 @@ def test_train_model_success( assert len(call_args[1]["logger"]) == 2 # trackio and tensorboard assert call_args[1]["max_epochs"] == 10 - fxt_mock_anomalib_components["engine"].train.assert_called_once_with( + fxt_mock_anomalib_components["engine"].fit.assert_called_once_with( model=fxt_mock_anomalib_components["anomalib_model"], datamodule=fxt_mock_anomalib_components["folder"] ) fxt_mock_anomalib_components["engine"].export.assert_called_once() diff --git a/application/ui/src/api/fetch-sse.ts b/application/ui/src/api/fetch-sse.ts new file mode 100644 index 0000000000..5b64d39eab --- /dev/null +++ b/application/ui/src/api/fetch-sse.ts @@ -0,0 +1,46 @@ +// Connect to an SSE endpoint and yield its messages +export function fetchSSE(url: string) { + return { + async *[Symbol.asyncIterator]() { + const eventSource = new EventSource(url); + + try { + let { promise, resolve, reject } = Promise.withResolvers(); + + eventSource.onmessage = (event) => { + if (event.data === 'DONE' || event.data.includes('COMPLETED')) { + eventSource.close(); + resolve('DONE'); + return; + } + resolve(event.data); + }; + + eventSource.onerror = (error) => { + eventSource.close(); + reject(new Error('EventSource failed: ' + error)); + }; + + // Keep yielding data as it comes in + while (true) { + const message = await promise; + + // If server sends 'DONE' message or similar, break the loop + if (message === 'DONE') { + break; + } + + try { + yield JSON.parse(message); + } catch { + console.error('Could not parse message:', message); + } + + ({ promise, resolve, reject } = Promise.withResolvers()); + } + } finally { + eventSource.close(); + } + }, + }; +} diff --git a/application/ui/src/features/inspect/footer/footer.component.tsx b/application/ui/src/features/inspect/footer/footer.component.tsx new file mode 100644 index 0000000000..7de19e4025 --- /dev/null +++ b/application/ui/src/features/inspect/footer/footer.component.tsx @@ -0,0 +1,174 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +import { Suspense } from 'react'; + +import { $api } from '@geti-inspect/api'; +import { SchemaJob as Job, SchemaJob } from '@geti-inspect/api/spec'; +import { useProjectIdentifier } from '@geti-inspect/hooks'; +import { Flex, ProgressBar, Text, View } from '@geti/ui'; +import { CanceledIcon, WaitingIcon } from '@geti/ui/icons'; +import { queryOptions, experimental_streamedQuery as streamedQuery, useQuery } from '@tanstack/react-query'; +import { fetchSSE } from 'src/api/fetch-sse'; + +const IdleItem = () => { + return ( + + + + Idle + + + ); +}; + +const getStyleForMessage = (message: string) => { + if (message.toLowerCase().includes('valid')) { + return { + backgroundColor: 'var(--spectrum-global-color-yellow-600)', + color: '#000', + }; + } else if (message.toLowerCase().includes('test')) { + return { + backgroundColor: 'var(--spectrum-global-color-green-600)', + color: '#fff', + }; + } else if (message.toLowerCase().includes('train') || message.toLowerCase().includes('fit')) { + return { + backgroundColor: 'var(--spectrum-global-color-blue-600)', + color: '#fff', + }; + } + + return { + backgroundColor: 'var(--spectrum-global-color-blue-600)', + color: '#fff', + }; +}; + +const TrainingStatusItem = ({ trainingJob }: { trainingJob: SchemaJob }) => { + // Cancel training job + const cancelJobMutation = $api.useMutation('post', '/api/jobs/{job_id}:cancel'); + const handleCancel = async () => { + try { + if (trainingJob.id === undefined) { + throw Error('TODO: jobs should always have an ID'); + } + + console.info('Cancel training'); + await cancelJobMutation.mutateAsync({ + params: { + path: { + job_id: trainingJob.id, + }, + }, + }); + console.info('Job cancelled successfully'); + } catch (error) { + console.error('Failed to cancel job:', error); + } + }; + + const progressQuery = useQuery( + queryOptions({ + queryKey: ['get', '/api/jobs/{job_id}/progress', trainingJob.id], + queryFn: streamedQuery({ + queryFn: () => fetchSSE(`/api/jobs/${trainingJob.id}/progress`), + maxChunks: 1, + }), + staleTime: Infinity, + }) + ); + + // Get the job progress and message from the last SSE message, or fallback + const lastJobProgress = progressQuery.data?.at(-1); + const progress = lastJobProgress?.progress ?? trainingJob.progress; + const message = lastJobProgress?.message ?? trainingJob.message; + + const { backgroundColor, color } = getStyleForMessage(message); + + return ( +
+ + + + {message} + + + +
+ ); +}; + +const useCurrentJob = () => { + const { data: jobsData } = $api.useSuspenseQuery('get', '/api/jobs', undefined, { + refetchInterval: 5000, + }); + + const { projectId } = useProjectIdentifier(); + const runningJob = jobsData.jobs.find( + (job: Job) => job.project_id === projectId && (job.status === 'running' || job.status === 'pending') + ); + + return runningJob; +}; + +export const ProgressBarItem = () => { + const trainingJob = useCurrentJob(); + + if (trainingJob !== undefined) { + return ; + } + + return ; +}; + +export const Footer = () => { + return ( + + + + + + ); +}; diff --git a/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx index 347bbc9253..b3ef175f2e 100644 --- a/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx +++ b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx @@ -17,56 +17,7 @@ import { } from '@geti/ui'; import { LogsIcon } from '@geti/ui/icons'; import { queryOptions, experimental_streamedQuery as streamedQuery, useQuery } from '@tanstack/react-query'; - -// Connect to an SSE endpoint and yield its messages -function fetchSSE(url: string) { - return { - async *[Symbol.asyncIterator]() { - const eventSource = new EventSource(url); - - try { - let { promise, resolve, reject } = Promise.withResolvers(); - - eventSource.onmessage = (event) => { - if (event.data === 'DONE' || event.data.includes('COMPLETED')) { - eventSource.close(); - resolve('DONE'); - return; - } - resolve(event.data); - }; - - eventSource.onerror = (error) => { - eventSource.close(); - reject(new Error('EventSource failed: ' + error)); - }; - - // Keep yielding data as it comes in - while (true) { - const message = await promise; - - // If server sends 'DONE' message or similar, break the loop - if (message === 'DONE') { - break; - } - - try { - const data = JSON.parse(message); - if (data['text']) { - yield data['text']; - } - } catch { - console.error('Could not parse message:', message); - } - - ({ promise, resolve, reject } = Promise.withResolvers()); - } - } finally { - eventSource.close(); - } - }, - }; -} +import { fetchSSE } from 'src/api/fetch-sse'; const JobLogsDialogContent = ({ jobId }: { jobId: string }) => { const query = useQuery( @@ -81,7 +32,7 @@ const JobLogsDialogContent = ({ jobId }: { jobId: string }) => { return ( - {query.data?.map((line, idx) => {line})} + {query.data?.map((line, idx) => {line.text})} ); }; diff --git a/application/ui/src/routes/inspect/inspect.tsx b/application/ui/src/routes/inspect/inspect.tsx index f391a051fa..c8789d8767 100644 --- a/application/ui/src/routes/inspect/inspect.tsx +++ b/application/ui/src/routes/inspect/inspect.tsx @@ -4,6 +4,7 @@ import { useProjectIdentifier } from '@geti-inspect/hooks'; import { Grid } from '@geti/ui'; +import { Footer } from '../../features/inspect/footer/footer.component'; import { InferenceProvider } from '../../features/inspect/inference-provider.component'; import { InferenceResult } from '../../features/inspect/inference-result.component'; import { SelectedMediaItemProvider } from '../../features/inspect/selected-media-item-provider.component'; @@ -15,8 +16,8 @@ export const Inspect = () => { return ( { +