-
Notifications
You must be signed in to change notification settings - Fork 836
π feat(inspect): Add progressbar #3045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4ff613c
12e5c52
66b2f46
f3855a2
12130d2
4e05737
004e880
9776eb4
e29cb09
89ebcc8
965f0fc
0cc0844
5e25b9a
873e5de
00a9b71
6791180
eb057e2
b5a87d6
48d762a
ddc1152
422a38c
b19bfb0
db05d27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,6 +2,8 @@ | |||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||
| import asyncio | ||||||||||||||||
| import datetime | ||||||||||||||||
| import json | ||||||||||||||||
| import logging | ||||||||||||||||
| from collections.abc import AsyncGenerator | ||||||||||||||||
| from uuid import UUID | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -12,9 +14,11 @@ | |||||||||||||||
| from db import get_async_db_session_ctx | ||||||||||||||||
| from exceptions import DuplicateJobException, ResourceNotFoundException | ||||||||||||||||
| from pydantic_models import Job, JobList, JobType | ||||||||||||||||
| from pydantic_models.job import JobStatus, JobSubmitted, TrainJobPayload | ||||||||||||||||
| from pydantic_models.job import JobCancelled, JobStatus, JobSubmitted, TrainJobPayload | ||||||||||||||||
| from repositories import JobRepository | ||||||||||||||||
|
|
||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class JobService: | ||||||||||||||||
| @staticmethod | ||||||||||||||||
|
|
@@ -24,7 +28,7 @@ async def get_job_list(extra_filters: dict | None = None) -> JobList: | |||||||||||||||
| return JobList(jobs=await repo.get_all(extra_filters=extra_filters)) | ||||||||||||||||
|
|
||||||||||||||||
| @staticmethod | ||||||||||||||||
| async def get_job_by_id(job_id: UUID) -> Job | None: | ||||||||||||||||
| async def get_job_by_id(job_id: UUID | str) -> Job | None: | ||||||||||||||||
|
||||||||||||||||
| async def get_job_by_id(job_id: UUID | str) -> Job | None: | |
| async def get_job_by_id(job_id: UUID | str) -> Job | None: | |
| if isinstance(job_id, str): | |
| try: | |
| job_id = UUID(job_id) | |
| except (ValueError, AttributeError, TypeError): | |
| raise ResourceNotFoundException(resource_id=job_id, resource_name="job") |
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the parameter type from UUID to UUID | str is an API change. Ensure this is intentional and that all callers properly handle string UUIDs, as this could lead to validation issues if invalid strings are passed.
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The is_job_still_running method now raises ResourceNotFoundException if the job is not found, but this is called within a streaming context where the job might be deleted. Consider handling this exception or documenting the expected behavior when a job is deleted mid-stream.
| cached_still_running = await cls.is_job_still_running(job_id=job_id) | |
| try: | |
| cached_still_running = await cls.is_job_still_running(job_id=job_id) | |
| except ResourceNotFoundException: | |
| # Job was deleted mid-stream; terminate gracefully | |
| break |
ashwinvaidya17 marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||
| from contextlib import redirect_stdout | ||||||||||||||||||||||||
| from uuid import UUID | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| from anomalib.data import Folder | ||||||||||||||||||||||||
| from anomalib.data.utils import TestSplitMode | ||||||||||||||||||||||||
|
|
@@ -15,6 +16,7 @@ | |||||||||||||||||||||||
| from repositories.binary_repo import ImageBinaryRepository, ModelBinaryRepository | ||||||||||||||||||||||||
| from services import ModelService | ||||||||||||||||||||||||
| from services.job_service import JobService | ||||||||||||||||||||||||
| from utils.callbacks import GetiInspectProgressCallback, ProgressSyncParams | ||||||||||||||||||||||||
| from utils.devices import Devices | ||||||||||||||||||||||||
| from utils.experiment_loggers import TrackioLogger | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -70,12 +72,24 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model: | |||||||||||||||||||||||
| name=str(model_name), | ||||||||||||||||||||||||
| train_job_id=job.id, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| synchronization_parameters = ProgressSyncParams() | ||||||||||||||||||||||||
| logger.info(f"Training model `{model_name}` for job `{job.id}`") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| synchronization_task: asyncio.Task[None] | None = None | ||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||
| synchronization_task = asyncio.create_task( | ||||||||||||||||||||||||
| cls._sync_progress_with_db( | ||||||||||||||||||||||||
| job_service=job_service, job_id=job.id, synchronization_parameters=synchronization_parameters | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| # Use asyncio.to_thread to keep event loop responsive | ||||||||||||||||||||||||
| # TODO: Consider ProcessPoolExecutor for true parallelism with multiple jobs | ||||||||||||||||||||||||
| trained_model = await asyncio.to_thread(cls._train_model, model=model, device=device) | ||||||||||||||||||||||||
| trained_model = await asyncio.to_thread( | ||||||||||||||||||||||||
| cls._train_model, | ||||||||||||||||||||||||
| model=model, | ||||||||||||||||||||||||
| device=device, | ||||||||||||||||||||||||
| synchronization_parameters=synchronization_parameters, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| if trained_model is None: | ||||||||||||||||||||||||
| raise ValueError("Training failed - model is None") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -94,9 +108,15 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model: | |||||||||||||||||||||||
| await model_binary_repo.delete_model_folder() | ||||||||||||||||||||||||
| await model_service.delete_model(project_id=project_id, model_id=model.id) | ||||||||||||||||||||||||
| raise e | ||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||
| logger.debug("Syncing progress with db stopped") | ||||||||||||||||||||||||
| if synchronization_task is not None and not synchronization_task.done(): | ||||||||||||||||||||||||
| synchronization_task.cancel() | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||
| def _train_model(model: Model, device: str | None = None) -> Model | None: | ||||||||||||||||||||||||
| def _train_model( | ||||||||||||||||||||||||
| model: Model, synchronization_parameters: ProgressSyncParams, device: str | None = None | ||||||||||||||||||||||||
| ) -> Model | None: | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Execute CPU-intensive model training using anomalib. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -106,6 +126,7 @@ def _train_model(model: Model, device: str | None = None) -> Model | None: | |||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||
| model: Model object with training configuration | ||||||||||||||||||||||||
| synchronization_parameters: Parameters for synchronization between the main process and the training process | ||||||||||||||||||||||||
| device: Device to train on | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||
|
|
@@ -145,7 +166,9 @@ def _train_model(model: Model, device: str | None = None) -> Model | None: | |||||||||||||||||||||||
| engine = Engine( | ||||||||||||||||||||||||
| default_root_dir=model.export_path, | ||||||||||||||||||||||||
| logger=[trackio, tensorboard], | ||||||||||||||||||||||||
| devices=[0], # Only single GPU training is supported for now | ||||||||||||||||||||||||
ashwinvaidya17 marked this conversation as resolved.
Show resolved
Hide resolved
ashwinvaidya17 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
| engine = Engine( | |
| default_root_dir=model.export_path, | |
| logger=[trackio, tensorboard], | |
| devices=[0], # Only single GPU training is supported for now | |
| # Use training_device to set devices dynamically | |
| # If training_device is 'cpu', set devices=['cpu'], else set devices=[0] or as appropriate | |
| devices = ['cpu'] if training_device == 'cpu' else [0] | |
| engine = Engine( | |
| default_root_dir=model.export_path, | |
| logger=[trackio, tensorboard], | |
| devices=devices, # Devices now set based on training_device |
ashwinvaidya17 marked this conversation as resolved.
Show resolved
Hide resolved
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change from engine.train() to engine.fit() suggests an API update in the anomalib library. Consider adding a comment explaining this API change or updating related documentation to help future maintainers understand the transition.
| with redirect_stdout(LoggerStdoutWriter()): # type: ignore[type-var] | |
| with redirect_stdout(LoggerStdoutWriter()): # type: ignore[type-var] | |
| # Note: anomalib API changed from `engine.train()` to `engine.fit()`. | |
| # See anomalib release notes for details. This ensures compatibility with the latest version. |
MarkRedeman marked this conversation as resolved.
Show resolved
Hide resolved
ashwinvaidya17 marked this conversation as resolved.
Show resolved
Hide resolved
Copilot
AI
Nov 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed from engine.train() to engine.fit() - this appears to be an API update but should be verified that this method exists and provides the same functionality.
ashwinvaidya17 marked this conversation as resolved.
Show resolved
Hide resolved
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed from engine.train() to engine.fit(). Verify that this API change is intentional and that both methods exist in the anomalib library, as this could be a breaking change if train() was the correct method.
Uh oh!
There was an error while loading. Please reload this page.