Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4ff613c
Add progress
ashwinvaidya17 Oct 23, 2025
12e5c52
Merge branch 'feature/geti-inspect' into ashwin/feat/progress_bar_sse
ashwinvaidya17 Oct 23, 2025
66b2f46
Merge branch 'feature/geti-inspect' into ashwin/feat/progress_bar_sse
ashwinvaidya17 Oct 31, 2025
f3855a2
Merge fixes
ashwinvaidya17 Oct 31, 2025
12130d2
Fix progress bar
ashwinvaidya17 Oct 31, 2025
4e05737
Update application/backend/src/services/job_service.py
ashwinvaidya17 Oct 31, 2025
004e880
Update application/backend/src/pydantic_models/job.py
ashwinvaidya17 Nov 3, 2025
9776eb4
Use finally block
ashwinvaidya17 Nov 3, 2025
e29cb09
Add Mark's changes
ashwinvaidya17 Nov 4, 2025
89ebcc8
Use job.message for informing training stage
ashwinvaidya17 Nov 4, 2025
965f0fc
Merge branch 'feature/geti-inspect' into ashwin/feat/progress_bar_sse
ashwinvaidya17 Nov 5, 2025
0cc0844
Restore callback
ashwinvaidya17 Nov 5, 2025
5e25b9a
Update application/backend/src/pydantic_models/job.py
ashwinvaidya17 Nov 5, 2025
873e5de
Add check for synchronization task
ashwinvaidya17 Nov 5, 2025
00a9b71
Merge branch 'feature/geti-inspect' into ashwin/feat/progress_bar_sse
ashwinvaidya17 Nov 5, 2025
6791180
Fix conflicts
ashwinvaidya17 Nov 5, 2025
eb057e2
Fix prek issues
ashwinvaidya17 Nov 5, 2025
b5a87d6
Fix mypy issues
ashwinvaidya17 Nov 5, 2025
48d762a
Fix tests
ashwinvaidya17 Nov 6, 2025
ddc1152
Cosmetics
ashwinvaidya17 Nov 6, 2025
422a38c
Update playwright in action
ashwinvaidya17 Nov 6, 2025
b19bfb0
Revert playwright changes
ashwinvaidya17 Nov 6, 2025
db05d27
Target package-lock to feature branch
ashwinvaidya17 Nov 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions application/backend/src/api/endpoints/job_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from typing import Annotated
from uuid import UUID

from fastapi import APIRouter, Body, Depends
from fastapi import APIRouter, Body, Depends, status
from sse_starlette import EventSourceResponse

from api.dependencies import get_job_id, get_job_service
from api.endpoints import API_PREFIX
from pydantic_models import JobList
from pydantic_models.job import JobSubmitted, TrainJobPayload
from pydantic_models.job import JobCancelled, JobSubmitted, TrainJobPayload
from services import JobService

job_api_prefix_url = API_PREFIX + "/jobs"
Expand Down Expand Up @@ -42,3 +42,21 @@ async def get_job_logs(
) -> EventSourceResponse:
"""Endpoint to get the logs of a job by its ID"""
return EventSourceResponse(job_service.stream_logs(job_id=job_id))


@job_router.get("/{job_id}/progress")
async def get_job_progress(
job_id: Annotated[UUID, Depends(get_job_id)],
job_service: Annotated[JobService, Depends(get_job_service)],
) -> EventSourceResponse:
"""Endpoint to get the progress of a job by its ID"""
return EventSourceResponse(job_service.stream_progress(job_id=job_id))


@job_router.post("/{job_id}:cancel", status_code=status.HTTP_202_ACCEPTED)
async def cancel_job(
job_id: Annotated[UUID, Depends(get_job_id)],
job_service: Annotated[JobService, Depends(get_job_service)],
) -> JobCancelled:
"""Endpoint to cancel a job by its ID"""
return await job_service.cancel_job(job_id=job_id)
10 changes: 9 additions & 1 deletion application/backend/src/pydantic_models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any
from uuid import UUID

from pydantic import BaseModel, Field, field_serializer
from pydantic import BaseModel, Field, computed_field, field_serializer

from pydantic_models.base import BaseIDModel

Expand Down Expand Up @@ -46,6 +46,14 @@ class JobSubmitted(BaseModel):
job_id: UUID


class JobCancelled(BaseModel):
job_id: UUID

@computed_field
def message(self) -> str:
return f"Job with ID `{self.job_id}` marked as cancelled."


class TrainJobPayload(BaseModel):
project_id: UUID = Field(exclude=True)
model_name: str
Expand Down
52 changes: 42 additions & 10 deletions application/backend/src/services/job_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import datetime
import json
import logging
from collections.abc import AsyncGenerator
from uuid import UUID

Expand All @@ -12,9 +14,11 @@
from db import get_async_db_session_ctx
from exceptions import DuplicateJobException, ResourceNotFoundException
from pydantic_models import Job, JobList, JobType
from pydantic_models.job import JobStatus, JobSubmitted, TrainJobPayload
from pydantic_models.job import JobCancelled, JobStatus, JobSubmitted, TrainJobPayload
from repositories import JobRepository

logger = logging.getLogger(__name__)


class JobService:
@staticmethod
Expand All @@ -24,7 +28,7 @@ async def get_job_list(extra_filters: dict | None = None) -> JobList:
return JobList(jobs=await repo.get_all(extra_filters=extra_filters))

@staticmethod
async def get_job_by_id(job_id: UUID) -> Job | None:
async def get_job_by_id(job_id: UUID | str) -> Job | None:
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function signature changed to accept UUID | str but there's no validation or conversion of string inputs to UUID. This could lead to runtime errors if a string is passed that isn't a valid UUID.

Suggested change
async def get_job_by_id(job_id: UUID | str) -> Job | None:
async def get_job_by_id(job_id: UUID | str) -> Job | None:
if isinstance(job_id, str):
try:
job_id = UUID(job_id)
except (ValueError, AttributeError, TypeError):
raise ResourceNotFoundException(resource_id=job_id, resource_name="job")

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the parameter type from UUID to UUID | str is an API change. Ensure this is intentional and that all callers properly handle string UUIDs, as this could lead to validation issues if invalid strings are passed.

Copilot uses AI. Check for mistakes.
async with get_async_db_session_ctx() as session:
repo = JobRepository(session)
return await repo.get_by_id(job_id)
Expand Down Expand Up @@ -56,7 +60,10 @@ async def get_pending_train_job() -> Job | None:

@staticmethod
async def update_job_status(
job_id: UUID, status: JobStatus, message: str | None = None, progress: int | None = None
job_id: UUID,
status: JobStatus,
message: str | None = None,
progress: int | None = None,
) -> None:
async with get_async_db_session_ctx() as session:
repo = JobRepository(session)
Expand All @@ -75,6 +82,13 @@ async def update_job_status(
updates["progress"] = progress_
await repo.update(job, updates)

@classmethod
async def is_job_still_running(cls, job_id: UUID | str) -> bool:
job = await cls.get_job_by_id(job_id=job_id)
if job is None:
raise ResourceNotFoundException(resource_id=job_id, resource_name="job")
return job.status == JobStatus.RUNNING

@classmethod
async def stream_logs(cls, job_id: UUID | str) -> AsyncGenerator[ServerSentEvent]:
from core.logging.utils import get_job_logs_path # noqa: PLC0415
Expand All @@ -83,12 +97,6 @@ async def stream_logs(cls, job_id: UUID | str) -> AsyncGenerator[ServerSentEvent
if not await anyio.Path(log_file).exists():
raise ResourceNotFoundException(resource_id=job_id, resource_name="job_logs")

async def is_job_still_running():
job = await cls.get_job_by_id(job_id=job_id)
if job is None:
raise ResourceNotFoundException(resource_id=job_id, resource_name="job")
return job.status == JobStatus.RUNNING

# Cache job status and only check every 2 seconds
status_check_interval = 2.0 # seconds
last_status_check = 0.0
Expand All @@ -101,7 +109,7 @@ async def is_job_still_running():
now = loop.time()
# Only check job status every status_check_interval seconds
if now - last_status_check > status_check_interval:
cached_still_running = await is_job_still_running()
cached_still_running = await cls.is_job_still_running(job_id=job_id)
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_job_still_running method now raises ResourceNotFoundException if the job is not found, but this is called within a streaming context where the job might be deleted. Consider handling this exception or documenting the expected behavior when a job is deleted mid-stream.

Suggested change
cached_still_running = await cls.is_job_still_running(job_id=job_id)
try:
cached_still_running = await cls.is_job_still_running(job_id=job_id)
except ResourceNotFoundException:
# Job was deleted mid-stream; terminate gracefully
break

Copilot uses AI. Check for mistakes.
last_status_check = now
still_running = cached_still_running
if not line:
Expand All @@ -113,3 +121,27 @@ async def is_job_still_running():
else:
break
yield ServerSentEvent(data=line.rstrip())

@classmethod
async def stream_progress(cls, job_id: UUID | str) -> AsyncGenerator[ServerSentEvent]:
"""Stream the progress of a job by its ID"""
still_running = True
while still_running:
job = await cls.get_job_by_id(job_id=job_id)
if job is None:
raise ResourceNotFoundException(resource_id=job_id, resource_name="job")
yield ServerSentEvent(data=json.dumps({"progress": job.progress, "message": job.message}))
still_running = job.status in {JobStatus.RUNNING, JobStatus.PENDING}
await asyncio.sleep(0.5)

@classmethod
async def cancel_job(cls, job_id: UUID | str) -> JobCancelled:
"""Cancel a job by its ID"""
async with get_async_db_session_ctx() as session:
repo = JobRepository(session)
job = await repo.get_by_id(job_id)
if job is None:
raise ResourceNotFoundException(resource_id=job_id, resource_name="job")

await repo.update(job, {"status": JobStatus.CANCELED})
return JobCancelled(job_id=job.id)
54 changes: 51 additions & 3 deletions application/backend/src/services/training_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
from contextlib import redirect_stdout
from uuid import UUID

from anomalib.data import Folder
from anomalib.data.utils import TestSplitMode
Expand All @@ -15,6 +16,7 @@
from repositories.binary_repo import ImageBinaryRepository, ModelBinaryRepository
from services import ModelService
from services.job_service import JobService
from utils.callbacks import GetiInspectProgressCallback, ProgressSyncParams
from utils.devices import Devices
from utils.experiment_loggers import TrackioLogger

Expand Down Expand Up @@ -70,12 +72,24 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model:
name=str(model_name),
train_job_id=job.id,
)
synchronization_parameters = ProgressSyncParams()
logger.info(f"Training model `{model_name}` for job `{job.id}`")

synchronization_task: asyncio.Task[None] | None = None
try:
synchronization_task = asyncio.create_task(
cls._sync_progress_with_db(
job_service=job_service, job_id=job.id, synchronization_parameters=synchronization_parameters
)
)
# Use asyncio.to_thread to keep event loop responsive
# TODO: Consider ProcessPoolExecutor for true parallelism with multiple jobs
trained_model = await asyncio.to_thread(cls._train_model, model=model, device=device)
trained_model = await asyncio.to_thread(
cls._train_model,
model=model,
device=device,
synchronization_parameters=synchronization_parameters,
)
if trained_model is None:
raise ValueError("Training failed - model is None")

Expand All @@ -94,9 +108,15 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model:
await model_binary_repo.delete_model_folder()
await model_service.delete_model(project_id=project_id, model_id=model.id)
raise e
finally:
logger.debug("Syncing progress with db stopped")
if synchronization_task is not None and not synchronization_task.done():
synchronization_task.cancel()

@staticmethod
def _train_model(model: Model, device: str | None = None) -> Model | None:
def _train_model(
model: Model, synchronization_parameters: ProgressSyncParams, device: str | None = None
) -> Model | None:
"""
Execute CPU-intensive model training using anomalib.

Expand All @@ -106,6 +126,7 @@ def _train_model(model: Model, device: str | None = None) -> Model | None:

Args:
model: Model object with training configuration
synchronization_parameters: Parameters for synchronization between the main process and the training process
device: Device to train on

Returns:
Expand Down Expand Up @@ -145,7 +166,9 @@ def _train_model(model: Model, device: str | None = None) -> Model | None:
engine = Engine(
default_root_dir=model.export_path,
logger=[trackio, tensorboard],
devices=[0], # Only single GPU training is supported for now
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded device index [0] should be derived from the device parameter or made configurable. This creates inconsistency with the device parameter handling.

Copilot uses AI. Check for mistakes.
Comment on lines 145 to +169
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoding devices=[0] overrides the device parameter logic that was previously used. This comment indicates a limitation but the implementation ignores the device parameter that gets passed to this function, which could confuse users who expect that parameter to work.

Suggested change
engine = Engine(
default_root_dir=model.export_path,
logger=[trackio, tensorboard],
devices=[0], # Only single GPU training is supported for now
# Use training_device to set devices dynamically
# If training_device is 'cpu', set devices=['cpu'], else set devices=[0] or as appropriate
devices = ['cpu'] if training_device == 'cpu' else [0]
engine = Engine(
default_root_dir=model.export_path,
logger=[trackio, tensorboard],
devices=devices, # Devices now set based on training_device

Copilot uses AI. Check for mistakes.
max_epochs=10,
callbacks=[GetiInspectProgressCallback(synchronization_parameters)],
accelerator=training_device,
)

Expand All @@ -154,7 +177,7 @@ def _train_model(model: Model, device: str | None = None) -> Model | None:

# Capture pytorch stdout logs into logger
with redirect_stdout(LoggerStdoutWriter()): # type: ignore[type-var]
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change from engine.train() to engine.fit() suggests an API update in the anomalib library. Consider adding a comment explaining this API change or updating related documentation to help future maintainers understand the transition.

Suggested change
with redirect_stdout(LoggerStdoutWriter()): # type: ignore[type-var]
with redirect_stdout(LoggerStdoutWriter()): # type: ignore[type-var]
# Note: anomalib API changed from `engine.train()` to `engine.fit()`.
# See anomalib release notes for details. This ensures compatibility with the latest version.

Copilot uses AI. Check for mistakes.
engine.train(model=anomalib_model, datamodule=datamodule)
engine.fit(model=anomalib_model, datamodule=datamodule)
Copy link

Copilot AI Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed from engine.train() to engine.fit() - this appears to be an API update but should be verified that this method exists and provides the same functionality.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed from engine.train() to engine.fit(). Verify that this API change is intentional and that both methods exist in the anomalib library, as this could be a breaking change if train() was the correct method.

Copilot uses AI. Check for mistakes.

# Find and set threshold metric
for callback in engine.trainer.callbacks: # type: ignore[attr-defined]
Expand All @@ -172,6 +195,31 @@ def _train_model(model: Model, device: str | None = None) -> Model | None:
model.is_ready = True
return model

@classmethod
async def _sync_progress_with_db(
cls,
job_service: JobService,
job_id: UUID,
synchronization_parameters: ProgressSyncParams,
) -> None:
try:
while True:
progress: int = synchronization_parameters.progress
message = synchronization_parameters.message
if not await job_service.is_job_still_running(job_id=job_id):
logger.debug("Job cancelled, stopping progress sync")
synchronization_parameters.set_cancel_training_event()
break
logger.debug(f"Syncing progress with db: {progress}% - {message}")
await job_service.update_job_status(
job_id=job_id, status=JobStatus.RUNNING, progress=progress, message=message
)
await asyncio.sleep(0.5)
except Exception as e:
logger.exception("Failed to sync progress with db: %s", e)
await job_service.update_job_status(job_id=job_id, status=JobStatus.FAILED, message="Training failed")
raise

@staticmethod
async def abort_orphan_jobs() -> None:
"""
Expand Down
Loading
Loading