diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index c9df105..2569bef 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -2,7 +2,7 @@ name: ML Pipeline - Train and Publish on: push: - branches: [ main, master, dev ] # Триггер на push в main/master/dev (новый датасет) + branches: [ main, dev ] # Триггер на push в main/master/dev (новый датасет) workflow_dispatch: # Позволяет запускать вручную через GitHub UI inputs: run_training: @@ -69,7 +69,7 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | - docker compose -f docker-compose.yml -f docker-compose.dev.yml run --rm cvc-dev pytest tests/ -v --tb=short --cov=commands_classifier --cov-report=term-missing + docker compose -f docker-compose.yml -f docker-compose.dev.yml run --rm cvc-dev pytest tests/ -v --tb=short --cov=app --cov-report=term-missing # Запускается только при push с меткой [retrain] в сообщении коммита или при ручном запуске (с опцией run_training) train-and-publish: @@ -262,7 +262,7 @@ jobs: disable_notification: true message: | *Пайплайн прошёл успешно* - Репо: `${{ github.repository }}` + Репозиторий: `${{ github.repository }}` Ветка: `${{ github.ref_name }}` [Открыть run](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) @@ -280,6 +280,6 @@ jobs: format: markdown message: | *Пайплайн упал* - Репо: `${{ github.repository }}` + Репозиторий: `${{ github.repository }}` Ветка: `${{ github.ref_name }}` [Открыть run](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..335ecd4 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,69 @@ +# Публикация образа в GHCR после успешного прохождения основного пайплайна на main + +name: Publish Docker image + +on: + workflow_run: + workflows: ["ML Pipeline - Train and Publish"] + types: [completed] + branches: [main] + +permissions: + contents: read + packages: write + +jobs: + publish: + if: github.event.workflow_run.conclusion == 'success' + name: Build and push to GHCR + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.event.workflow_run.head_sha }} + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to GHCR + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Set image name + run: echo "IMAGE_BASE=ghcr.io/${GITHUB_REPOSITORY_OWNER,,}/cvc-api" >> $GITHUB_ENV + + - name: Build and push + uses: docker/build-push-action@v6 + with: + context: . + file: ./Dockerfile + platforms: linux/amd64,linux/arm64 + tags: | + ${{ env.IMAGE_BASE }}:main + ${{ env.IMAGE_BASE }}:${{ github.event.workflow_run.head_sha }} + push: true + + notify-telegram-on-publish: + name: Notify Telegram on image published + if: always() && needs.publish.result == 'success' + needs: [publish] + runs-on: ubuntu-latest + steps: + - name: Send Telegram notification (silent) + uses: appleboy/telegram-action@v1.0.1 + continue-on-error: true + with: + to: ${{ secrets.TELEGRAM_TO }} + token: ${{ secrets.TELEGRAM_TOKEN }} + format: markdown + disable_notification: true + message: | + *Образ CVC успешно опубликован в GHCR* + Репозиторий: `${{ github.repository }}` + Ветка: `${{ github.event.workflow_run.head_branch }}` + Образ: `ghcr.io/${{ github.repository_owner }}/cvc-api:main` + [Открыть run](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) diff --git a/Dockerfile b/Dockerfile index 359c410..d9260de 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ RUN pip install --root-user-action=ignore --upgrade pip setuptools wheel COPY requirements-docker.txt . RUN pip install --root-user-action=ignore -r requirements-docker.txt -COPY commands_classifier/ ./commands_classifier/ +COPY app/ ./app/ COPY config.yaml . COPY pytest.ini . COPY data/ ./data/ @@ -24,4 +24,4 @@ RUN mkdir -p models checkpoints EXPOSE 20001 -CMD ["python", "-m", "commands_classifier.cli", "serve", "--host", "0.0.0.0", "--port", "20001"] +CMD ["python", "-m", "app.main", "--host", "0.0.0.0", "--port", "20001"] diff --git a/README.md b/README.md index 8322bc6..376b795 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ # CVC - Classification of Voice Commands -[![ML Pipeline](https://github.com/ShiWarai/CVC/actions/workflows/deploy.yml/badge.svg)](https://github.com/ShiWarai/CVC/actions/workflows/deploy.yml) +[![ML Pipeline](https://github.com/ShiWarai/CVC/actions/workflows/deploy.yml/badge.svg?branch=main)](https://github.com/ShiWarai/CVC/actions/workflows/deploy.yml) [![License: MIT](https://img.shields.io/github/license/ShiWarai/CVC)](https://opensource.org/licenses/MIT) ![Python Version](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue) ![Docker Ready](https://img.shields.io/badge/docker-ready-blue?logo=docker) [![CVC-Panda on Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20CVC--Panda-Model-yellow)](https://huggingface.co/ShiWarai/CVC-Panda) -Мини-сервис для классификации голосовых команд (SetFit). Обучает модель на малом датасете и классифицирует текстовые команды. Создан для использования в проекте навыка для Sber Salute. +Мини-сервис для классификации голосовых команд (SetFit). Обучает модель на малом датасете и классифицирует текстовые команды. Создан для использования в проекте навыка для команд роботу-собаке. ## Стек технологий @@ -29,7 +29,7 @@ | [Использование](#использование) | CLI, Python-клиент, библиотека | | [Конфигурация и API](#конфигурация-и-api) | config.yaml, эндпоинты | | [Данные](#данные) | Формат датасета, параметры обучения | -| [Разработка](#разработка) | Тесты, линт, структура проекта | +| [Разработка](#разработка) | Тесты, линт, архитектура, структура проекта | | [CI/CD](#cicd) | Пайплайн и ссылка на настройку | | [Лицензия](#лицензия) | MIT | @@ -73,10 +73,10 @@ HF_REPO_ID=your-username/model-name ### Локальный запуск ```bash -python -m commands_classifier.cli serve +python -m app.main ``` -Опции: `--host`, `--port`, `--config`. БД создаётся при первом запуске, данные из `data/` или CSV из `config.yaml`. +Опции: `--host`, `--port`, `--config`, `--reload`. БД создаётся при первом запуске, данные из `data/` или CSV из `config.yaml`. ## Использование @@ -85,19 +85,19 @@ python -m commands_classifier.cli serve После запуска сервера (Docker или локально): ```bash -python -m commands_classifier.client predict --text "равняйся" [--show-confidence] -python -m commands_classifier.client predict --file commands.txt -python -m commands_classifier.client train [--batch-size 32 --iterations 30] -python -m commands_classifier.client train-status -python -m commands_classifier.client examples list -python -m commands_classifier.client examples add --text "команда" --command "label" -python -m commands_classifier.client examples delete --id 1 -python -m commands_classifier.client health -python -m commands_classifier.client metrics -python -m commands_classifier.client reset -python -m commands_classifier.client load-from-hf [--repo-id "username/model-name"] -python -m commands_classifier.client load-from-hf-status -python -m commands_classifier.client command-feedback # репорт «исправить команду» из RDS-2P-Salute +python -m app.client predict --text "равняйся" [--show-confidence] +python -m app.client predict --file commands.txt +python -m app.client train [--batch-size 32 --iterations 30] +python -m app.client train-status +python -m app.client examples list +python -m app.client examples add --text "команда" --command "label" +python -m app.client examples delete --id 1 +python -m app.client health +python -m app.client metrics +python -m app.client reset +python -m app.client load-from-hf [--repo-id "username/model-name"] +python -m app.client load-from-hf-status +python -m app.client command-feedback # репорт «исправить команду» из RDS-2P-Salute ``` По умолчанию клиент подключается к `http://localhost:20001` (флаг `--url` для другого адреса). @@ -185,16 +185,35 @@ docker compose -f docker-compose.yml build cvc-api docker compose -f docker-compose.yml -f docker-compose.dev.yml build cvc-dev docker compose -f docker-compose.yml -f docker-compose.dev.yml run --rm cvc-dev ruff check . -docker compose -f docker-compose.yml -f docker-compose.dev.yml run --rm cvc-dev pytest tests/ -v --tb=short --cov=commands_classifier --cov-report=term-missing +docker compose -f docker-compose.yml -f docker-compose.dev.yml run --rm cvc-dev pytest tests/ -v --tb=short --cov=app --cov-report=term-missing ``` +### Архитектура + +Проект построен по принципам чистой архитектуры (слои не зависят от деталей доставки и инфраструктуры). + +| Слой | Назначение | +|------|------------| +| **domain** | Сущности (`Example`, `PredictionResult`, `TrainingStatus`), порты (`IClassifier`, `IExampleRepository`), утилиты (`text_utils`). Без внешних зависимостей. | +| **application** | Сценарии (use cases): предсказание (`PredictUseCase`), работа с примерами (`ExamplesUseCase`). Получают зависимости через конструктор. | +| **adapters** | Реализации портов: **persistence** — SQLite-репозиторий примеров; **ml** — SetFit-классификатор и retry для HF; **data_loading** — загрузка датасета из CSV/JSON. | +| **api** | FastAPI-приложение, роуты, глобальное состояние (state). В `init_app()` собираются use cases и адаптеры (composition root). | + +Точка входа сервера: `main.py` → `app.api.server`; CLI к API: `client.py`. + ### Структура проекта ``` CVC/ ├── config.yaml ├── requirements-docker.txt | requirements-cuda.txt | requirements-rocm.txt -├── commands_classifier/ # Код: model, dataset, db, cli, client, api/ +├── app/ # Точка входа: python -m app.main +│ ├── main.py # Запуск сервера +│ ├── domain/ # Сущности, порты, text_utils +│ ├── application/ # Use cases +│ ├── adapters/ # persistence (SQLite), ml (SetFit), data_loading +│ ├── api/ # FastAPI, роуты, state +│ └── client.py # HTTP-клиент и библиотека ├── data/ # CSV/JSON для миграции ├── models/ # Сохранённые модели ├── db/ # SQLite (training_data.db) diff --git a/commands_classifier/__init__.py b/app/__init__.py similarity index 100% rename from commands_classifier/__init__.py rename to app/__init__.py diff --git a/app/adapters/__init__.py b/app/adapters/__init__.py new file mode 100644 index 0000000..b88e852 --- /dev/null +++ b/app/adapters/__init__.py @@ -0,0 +1,38 @@ +"""Адаптеры: реализация портов domain (persistence, ml, загрузка данных).""" + +from app.adapters.data_loading import load_dataset +from app.adapters.ml import CommandsClassifier, retry_hf +from app.adapters.persistence import ( + SqliteExampleRepository, + add_example, + check_connection, + count_examples, + delete_example, + get_all_examples, + get_example_by_id, + get_examples_for_training, + get_trained_examples_by_labels, + get_training_stats, + init_db, + mark_examples_as_trained, + reset_training_status, +) + +__all__ = [ + "SqliteExampleRepository", + "init_db", + "add_example", + "get_all_examples", + "get_example_by_id", + "delete_example", + "count_examples", + "get_examples_for_training", + "get_trained_examples_by_labels", + "mark_examples_as_trained", + "get_training_stats", + "reset_training_status", + "check_connection", + "CommandsClassifier", + "retry_hf", + "load_dataset", +] diff --git a/app/adapters/data_loading/__init__.py b/app/adapters/data_loading/__init__.py new file mode 100644 index 0000000..d6d8096 --- /dev/null +++ b/app/adapters/data_loading/__init__.py @@ -0,0 +1,5 @@ +"""Загрузка датасетов с диска (CSV/JSON).""" + +from app.adapters.data_loading.dataset import load_dataset + +__all__ = ["load_dataset"] diff --git a/commands_classifier/dataset.py b/app/adapters/data_loading/dataset.py similarity index 85% rename from commands_classifier/dataset.py rename to app/adapters/data_loading/dataset.py index 1e2ad73..e199dd0 100644 --- a/commands_classifier/dataset.py +++ b/app/adapters/data_loading/dataset.py @@ -1,4 +1,4 @@ -"""Утилиты для загрузки и подготовки датасетов.""" +"""Загрузка датасета из CSV или JSON файла.""" import json from pathlib import Path @@ -21,30 +21,22 @@ def load_dataset(dataset_path: str) -> Tuple[List[str], List[str]]: ValueError: Если формат файла не поддерживается """ path = Path(dataset_path) - if not path.exists(): raise FileNotFoundError(f"Файл датасета не найден: {dataset_path}") if path.suffix.lower() == ".csv": df = pd.read_csv(dataset_path) - - # Проверяем наличие нужных колонок if "text" not in df.columns or "command" not in df.columns: raise ValueError( "CSV файл должен содержать колонки 'text' и 'command'. " f"Найдены колонки: {list(df.columns)}" ) - texts = df["text"].astype(str).tolist() labels = df["command"].astype(str).tolist() elif path.suffix.lower() == ".json": with open(dataset_path, "r", encoding="utf-8") as f: data = json.load(f) - - # Поддерживаем два формата JSON: - # 1. Список объектов: [{"text": "...", "command": "..."}, ...] - # 2. Объект с ключами: {"texts": [...], "commands": [...]} if isinstance(data, list): texts = [item["text"] for item in data] labels = [item["command"] for item in data] diff --git a/app/adapters/ml/__init__.py b/app/adapters/ml/__init__.py new file mode 100644 index 0000000..d2b3f3f --- /dev/null +++ b/app/adapters/ml/__init__.py @@ -0,0 +1,6 @@ +"""ML-адаптер: SetFit-классификатор и retry для Hugging Face.""" + +from app.adapters.ml.hf_retry import retry_hf +from app.adapters.ml.setfit_classifier import CommandsClassifier + +__all__ = ["retry_hf", "CommandsClassifier"] diff --git a/commands_classifier/hf_retry.py b/app/adapters/ml/hf_retry.py similarity index 90% rename from commands_classifier/hf_retry.py rename to app/adapters/ml/hf_retry.py index 887743b..3c191bf 100644 --- a/commands_classifier/hf_retry.py +++ b/app/adapters/ml/hf_retry.py @@ -5,7 +5,6 @@ T = TypeVar("T") -# Задержки в секундах: 1, 2, 4 DEFAULT_BACKOFF = (1.0, 2.0, 4.0) @@ -20,7 +19,7 @@ def retry_hf( Args: fn: Безаргументный callable (например, lambda: from_pretrained(...)). max_retries: Максимальное число попыток (включая первую). - backoff: Кортеж задержек в секундах между попытками (длина должна быть >= max_retries - 1). + backoff: Кортеж задержек в секундах между попытками. Returns: Результат вызова fn(). diff --git a/app/adapters/ml/setfit_classifier.py b/app/adapters/ml/setfit_classifier.py new file mode 100644 index 0000000..94fb112 --- /dev/null +++ b/app/adapters/ml/setfit_classifier.py @@ -0,0 +1,208 @@ +"""Классификатор команд на основе SetFit. Реализует IClassifier (domain.ports).""" + +import os +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np + +try: + import torch + if not hasattr(torch.distributed, "is_initialized"): + def _is_initialized(): + return False + torch.distributed.is_initialized = _is_initialized +except ImportError: + pass + +from datasets import Dataset +from setfit import SetFitModel, SetFitTrainer + +from app.adapters.ml.hf_retry import retry_hf + + +def _get_hf_token() -> Optional[str]: + return os.getenv("HF_TOKEN") + + +class CommandsClassifier: + """Классификатор команд на основе SetFit для few-shot learning. Реализует порт IClassifier.""" + + def __init__( + self, model_name: str, confidence_threshold: float = 0.5, cache_dir: Optional[str] = None + ): + self.model_name = model_name + self.model: Optional[SetFitModel] = None + self.is_trained = False + self.confidence_threshold = float(confidence_threshold) + self.cache_dir = cache_dir + + def train( + self, + texts: List[str], + labels: List[str], + num_iterations: int = 20, + num_epochs: int = 1, + batch_size: int = 16, + learning_rate: float = 2e-5, + device: Optional[str] = None, + ) -> None: + if len(texts) != len(labels): + raise ValueError( + f"Количество текстов ({len(texts)}) не совпадает с количеством меток ({len(labels)})" + ) + import torch + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + cache_dir_path = None + if self.cache_dir: + cache_dir_path = Path(self.cache_dir) + cache_dir_path.mkdir(parents=True, exist_ok=True) + cache_dir_path = str(cache_dir_path) + hf_token = _get_hf_token() + + def _load_base_model(): + try: + return SetFitModel.from_pretrained( + self.model_name, + cache_dir=cache_dir_path, + use_safetensors=True, + token=hf_token, + ) + except Exception: + return SetFitModel.from_pretrained( + self.model_name, + cache_dir=cache_dir_path, + token=hf_token, + ) + + self.model = retry_hf(_load_base_model) + self.model = self.model.to(device) + train_dataset = Dataset.from_dict({"text": texts, "label": labels}) + learning_rate_float = float(learning_rate) + trainer = SetFitTrainer( + model=self.model, + train_dataset=train_dataset, + num_iterations=num_iterations, + num_epochs=num_epochs, + batch_size=batch_size, + learning_rate=learning_rate_float, + column_mapping={"text": "text", "label": "label"}, + ) + trainer.train() + self.is_trained = True + + def predict(self, text: str, return_confidence: bool = False) -> str | Tuple[str, float]: + if not self.is_trained or self.model is None: + raise ValueError("Модель не обучена. Сначала вызовите метод train().") + predictions, confidences = self._predict_with_confidence([text]) + command = predictions[0] + confidence = confidences[0] + if confidence < self.confidence_threshold: + command = "unknown" + if return_confidence: + return command, confidence + return command + + def _predict_with_confidence(self, texts: List[str]) -> Tuple[List[str], List[float]]: + probs = self.model.predict_proba(texts) + preds = self.model.predict(texts) + predictions = [] + confidences = [] + if hasattr(probs, "tolist"): + probs = probs.tolist() + if hasattr(preds, "tolist"): + preds = preds.tolist() + else: + preds = list(preds) + for i, prob in enumerate(probs): + if isinstance(prob, (list, np.ndarray)): + max_idx = np.argmax(prob) + max_prob = float(prob[max_idx]) + else: + max_prob = float(prob) + predictions.append(str(preds[i])) + confidences.append(float(max_prob)) + return predictions, confidences + + def predict_batch( + self, texts: List[str], return_confidence: bool = False + ) -> List[str] | Tuple[List[str], List[float]]: + if not self.is_trained or self.model is None: + raise ValueError("Модель не обучена. Сначала вызовите метод train().") + predictions, confidences = self._predict_with_confidence(texts) + commands = [] + for pred, conf in zip(predictions, confidences): + if conf < self.confidence_threshold: + commands.append("unknown") + else: + commands.append(pred) + if return_confidence: + return commands, confidences + return commands + + def save(self, model_path: str): + if not self.is_trained or self.model is None: + raise ValueError("Модель не обучена. Нечего сохранять.") + import shutil + import tempfile + path = Path(model_path) + path.parent.mkdir(parents=True, exist_ok=True) + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) / path.name + self.model.save_pretrained(str(temp_path)) + if path.exists(): + shutil.rmtree(path) + shutil.move(str(temp_path), str(path)) + + def load(self, model_path: str, confidence_threshold: Optional[float] = None): + import warnings + path = Path(model_path) + if not path.exists(): + raise FileNotFoundError(f"Модель не найдена: {model_path}") + hf_token = _get_hf_token() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=".*mistral.*regex.*", category=UserWarning) + try: + self.model = SetFitModel.from_pretrained( + str(path), use_safetensors=True, token=hf_token + ) + except Exception: + self.model = SetFitModel.from_pretrained(str(path), token=hf_token) + self.is_trained = True + if confidence_threshold is not None: + self.confidence_threshold = float(confidence_threshold) + + def get_embeddings(self, texts: List[str]) -> List[List[float]]: + hf_token = _get_hf_token() + if self.model is None: + def _load_base(): + try: + return SetFitModel.from_pretrained( + self.model_name, + use_safetensors=True, + token=hf_token, + ) + except Exception: + return SetFitModel.from_pretrained(self.model_name, token=hf_token) + self.model = retry_hf(_load_base) + if hasattr(self.model, "model_body"): + embedding_model = self.model.model_body + elif hasattr(self.model, "model"): + embedding_model = self.model.model + else: + embedding_model = self.model + if hasattr(embedding_model, "encode"): + embeddings = embedding_model.encode(texts, convert_to_numpy=True) + else: + from sentence_transformers import SentenceTransformer + base_model = SentenceTransformer(self.model_name, token=hf_token) + embeddings = base_model.encode(texts, convert_to_numpy=True) + if hasattr(embeddings, "tolist"): + embeddings = embeddings.tolist() + result = [] + for emb in embeddings: + if isinstance(emb, (list, np.ndarray)): + result.append([float(x) for x in emb]) + else: + result.append([float(emb)]) + return result diff --git a/app/adapters/persistence/__init__.py b/app/adapters/persistence/__init__.py new file mode 100644 index 0000000..2546341 --- /dev/null +++ b/app/adapters/persistence/__init__.py @@ -0,0 +1,37 @@ +"""Persistence-адаптер: SQLite-репозиторий примеров.""" + +from app.adapters.persistence.sqlite_repository import ( + SqliteExampleRepository, + _default_repo, + _normalize_db_path, + add_example, + check_connection, + count_examples, + delete_example, + get_all_examples, + get_example_by_id, + get_examples_for_training, + get_trained_examples_by_labels, + get_training_stats, + init_db, + mark_examples_as_trained, + reset_training_status, +) + +__all__ = [ + "SqliteExampleRepository", + "_default_repo", + "_normalize_db_path", + "init_db", + "add_example", + "get_all_examples", + "get_example_by_id", + "delete_example", + "count_examples", + "get_examples_for_training", + "get_trained_examples_by_labels", + "mark_examples_as_trained", + "get_training_stats", + "reset_training_status", + "check_connection", +] diff --git a/app/adapters/persistence/sqlite_repository.py b/app/adapters/persistence/sqlite_repository.py new file mode 100644 index 0000000..d109ffd --- /dev/null +++ b/app/adapters/persistence/sqlite_repository.py @@ -0,0 +1,271 @@ +"""Реализация IExampleRepository для SQLite.""" + +import sqlite3 +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd + +from app.domain.text_utils import remove_punctuation + + +def _normalize_db_path(db_path: str) -> str: + path = Path(db_path) + if path.exists() and path.is_dir(): + try: + if not any(path.iterdir()): + path.rmdir() + return db_path + return str(path / "training_data.db") + except OSError: + return str(path / "training_data.db") + return db_path + + +def check_connection(db_path: str) -> bool: + path = _normalize_db_path(db_path) + try: + conn = sqlite3.connect(path, timeout=2.0) + conn.execute("SELECT 1") + conn.close() + return True + except Exception: + return False + + +def _example_exists(cursor: sqlite3.Cursor, text: str, command: str) -> bool: + cursor.execute("SELECT COUNT(*) FROM examples WHERE text = ? AND command = ?", (text, command)) + return cursor.fetchone()[0] > 0 + + +class SqliteExampleRepository: + """Реализация IExampleRepository для SQLite.""" + + def init(self, db_path: str, csv_path: Optional[str] = None) -> None: + db_path = _normalize_db_path(db_path) + path = Path(db_path) + path.parent.mkdir(parents=True, exist_ok=True) + conn = None + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS examples ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + text TEXT NOT NULL, + command TEXT NOT NULL, + is_trained INTEGER DEFAULT 0 + ) + """) + cursor.execute("PRAGMA table_info(examples)") + columns = [column[1] for column in cursor.fetchall()] + if "is_trained" not in columns: + cursor.execute("ALTER TABLE examples ADD COLUMN is_trained INTEGER DEFAULT 0") + cursor.execute("UPDATE examples SET is_trained = 0 WHERE is_trained IS NULL") + conn.commit() + if csv_path: + csv_path_obj = Path(csv_path) + if csv_path_obj.exists(): + csv_files = list(csv_path_obj.glob("*.csv")) if csv_path_obj.is_dir() else ( + [csv_path_obj] if csv_path_obj.suffix.lower() == ".csv" else [] + ) + if not csv_files and csv_path_obj.is_dir(): + print(f"В директории {csv_path} не найдено CSV файлов") + for csv_file in csv_files: + try: + df = pd.read_csv(csv_file) + if "text" in df.columns and "command" in df.columns: + for _, row in df.iterrows(): + cleaned_text = remove_punctuation(str(row["text"])) + command = str(row["command"]) + if not _example_exists(cursor, cleaned_text, command): + cursor.execute( + "INSERT INTO examples (text, command, is_trained) VALUES (?, ?, 0)", + (cleaned_text, command), + ) + conn.commit() + except Exception as e: + print(f"Ошибка при синхронизации {csv_file.name}: {e}") + except sqlite3.OperationalError as e: + raise RuntimeError( + f"Не удалось создать/открыть базу данных по пути: {db_path}\nОшибка: {e}" + ) from e + finally: + if conn: + conn.close() + + def get_all(self, db_path: str) -> List[Tuple[int, str, str]]: + db_path = _normalize_db_path(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute("SELECT id, text, command FROM examples ORDER BY id") + results = cursor.fetchall() + conn.close() + return results + + def get_by_id(self, db_path: str, example_id: int) -> Optional[Tuple[int, str, str]]: + db_path = _normalize_db_path(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute("SELECT id, text, command FROM examples WHERE id = ?", (example_id,)) + result = cursor.fetchone() + conn.close() + return result + + def add(self, db_path: str, text: str, command: str) -> int: + db_path = _normalize_db_path(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute( + "INSERT INTO examples (text, command, is_trained) VALUES (?, ?, 0)", (text, command) + ) + example_id = cursor.lastrowid + conn.commit() + conn.close() + return example_id + + def delete(self, db_path: str, example_id: int) -> bool: + db_path = _normalize_db_path(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute("DELETE FROM examples WHERE id = ?", (example_id,)) + deleted = cursor.rowcount > 0 + conn.commit() + conn.close() + return deleted + + def get_examples_for_training(self, db_path: str) -> Tuple[List[str], List[str], List[int]]: + db_path = _normalize_db_path(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute("SELECT id, text, command FROM examples WHERE is_trained = 0 ORDER BY id") + examples = cursor.fetchall() + conn.close() + return [ex[1] for ex in examples], [ex[2] for ex in examples], [ex[0] for ex in examples] + + def get_trained_examples_by_labels( + self, db_path: str, labels: List[str], limit_per_label: int + ) -> Tuple[List[str], List[str], List[int]]: + if not labels: + return [], [], [] + db_path = _normalize_db_path(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + texts, result_labels, ids = [], [], [] + for label in labels: + cursor.execute( + "SELECT id, text, command FROM examples WHERE command = ? AND is_trained = 1 ORDER BY id LIMIT ?", + (label, limit_per_label), + ) + for ex in cursor.fetchall(): + ids.append(ex[0]) + texts.append(ex[1]) + result_labels.append(ex[2]) + conn.close() + return texts, result_labels, ids + + def mark_as_trained(self, db_path: str, example_ids: List[int]) -> None: + if not example_ids: + return + validated_ids = [int(ex_id) for ex_id in example_ids] + db_path = _normalize_db_path(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + if len(validated_ids) > 10000: + conn.close() + raise ValueError("Слишком много ID для одной операции (максимум 10000)") + placeholders = ",".join("?" * len(validated_ids)) + cursor.execute(f"UPDATE examples SET is_trained = 1 WHERE id IN ({placeholders})", validated_ids) + conn.commit() + conn.close() + + def count(self, db_path: str) -> int: + db_path = _normalize_db_path(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM examples") + count = cursor.fetchone()[0] + conn.close() + return count + + def get_training_stats(self, db_path: str) -> Dict[str, Any]: + db_path = _normalize_db_path(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM examples") + total = cursor.fetchone()[0] + cursor.execute("SELECT COUNT(*) FROM examples WHERE is_trained = 1") + trained = cursor.fetchone()[0] + cursor.execute("SELECT COUNT(*) FROM examples WHERE is_trained = 0") + untrained = cursor.fetchone()[0] + conn.close() + return {"total": total, "trained": trained, "untrained": untrained} + + def reset_training_status(self, db_path: str) -> int: + db_path = _normalize_db_path(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute("UPDATE examples SET is_trained = 0 WHERE is_trained = 1") + reset_count = cursor.rowcount + conn.commit() + conn.close() + return reset_count + + def check_connection(self, db_path: str) -> bool: + path = _normalize_db_path(db_path) + try: + conn = sqlite3.connect(path, timeout=2.0) + conn.execute("SELECT 1") + conn.close() + return True + except Exception: + return False + + +_default_repo = SqliteExampleRepository() + + +def init_db(db_path: str, csv_path: Optional[str] = None) -> None: + _default_repo.init(db_path, csv_path) + + +def get_all_examples(db_path: str) -> List[Tuple[int, str, str]]: + return _default_repo.get_all(db_path) + + +def get_examples_for_training(db_path: str) -> Tuple[List[str], List[str], List[int]]: + return _default_repo.get_examples_for_training(db_path) + + +def get_trained_examples_by_labels( + db_path: str, labels: List[str], limit_per_label: int +) -> Tuple[List[str], List[str], List[int]]: + return _default_repo.get_trained_examples_by_labels(db_path, labels, limit_per_label) + + +def add_example(db_path: str, text: str, command: str) -> int: + return _default_repo.add(db_path, text, command) + + +def delete_example(db_path: str, example_id: int) -> bool: + return _default_repo.delete(db_path, example_id) + + +def count_examples(db_path: str) -> int: + return _default_repo.count(db_path) + + +def get_example_by_id(db_path: str, example_id: int) -> Optional[Tuple[int, str, str]]: + return _default_repo.get_by_id(db_path, example_id) + + +def mark_examples_as_trained(db_path: str, example_ids: List[int]) -> None: + _default_repo.mark_as_trained(db_path, example_ids) + + +def get_training_stats(db_path: str) -> Dict[str, Any]: + return _default_repo.get_training_stats(db_path) + + +def reset_training_status(db_path: str) -> int: + return _default_repo.reset_training_status(db_path) diff --git a/commands_classifier/api/__init__.py b/app/api/__init__.py similarity index 100% rename from commands_classifier/api/__init__.py rename to app/api/__init__.py diff --git a/app/api/routes/__init__.py b/app/api/routes/__init__.py new file mode 100644 index 0000000..5b27909 --- /dev/null +++ b/app/api/routes/__init__.py @@ -0,0 +1,17 @@ +"""Routes для API сервера CVC.""" + +from app.api.routes.command_feedback import router as command_feedback_router +from app.api.routes.examples import router as examples_router +from app.api.routes.health import router as health_router +from app.api.routes.load_from_hf import router as load_from_hf_router +from app.api.routes.predict import router as predict_router +from app.api.routes.training import router as training_router + +__all__ = [ + "predict_router", + "training_router", + "examples_router", + "load_from_hf_router", + "health_router", + "command_feedback_router", +] diff --git a/commands_classifier/api/routes/command_feedback.py b/app/api/routes/command_feedback.py similarity index 98% rename from commands_classifier/api/routes/command_feedback.py rename to app/api/routes/command_feedback.py index 187a6ae..a6b5b34 100644 --- a/commands_classifier/api/routes/command_feedback.py +++ b/app/api/routes/command_feedback.py @@ -6,7 +6,7 @@ import requests from fastapi import APIRouter, HTTPException -from commands_classifier.api.state import get_config +from app.api.state import get_config logger = logging.getLogger(__name__) diff --git a/commands_classifier/api/routes/examples.py b/app/api/routes/examples.py similarity index 72% rename from commands_classifier/api/routes/examples.py rename to app/api/routes/examples.py index 0e24fce..2281f10 100644 --- a/commands_classifier/api/routes/examples.py +++ b/app/api/routes/examples.py @@ -5,9 +5,9 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field, field_validator -from commands_classifier import db -from commands_classifier.api.state import get_config -from commands_classifier.api.utils import remove_punctuation +from app.adapters import persistence as db +from app.api.state import get_config, get_examples_use_case +from app.api.utils import remove_punctuation router = APIRouter(tags=["examples"]) @@ -44,9 +44,12 @@ async def get_examples(): Returns: Список всех примеров """ - config = get_config() - db_path = config["database"]["path"] - examples = db.get_all_examples(db_path) + examples_uc = get_examples_use_case() + if examples_uc is not None: + examples = examples_uc.get_all() + else: + config = get_config() + examples = db.get_all_examples(config["database"]["path"]) return [ExampleResponse(id=ex[0], text=ex[1], command=ex[2]) for ex in examples] @@ -61,13 +64,19 @@ async def add_example(request: ExampleRequest): Returns: Созданный пример с ID """ - config = get_config() - db_path = config["database"]["path"] + examples_uc = get_examples_use_case() try: - # Очищаем знаки препинания из текста перед сохранением + if examples_uc is not None: + cleaned_text = remove_punctuation(request.text) + if len(cleaned_text) == 0: + raise HTTPException( + status_code=400, detail="После очистки строка оказалась пустой" + ) + example_id = examples_uc.add(request.text, request.command) + return ExampleResponse(id=example_id, text=cleaned_text, command=request.command) + config = get_config() + db_path = config["database"]["path"] cleaned_text = remove_punctuation(request.text) - - # Проверяем, что после очистки текст не пустой if len(cleaned_text) == 0: raise HTTPException(status_code=400, detail="Текст после очистки не может быть пустым") @@ -94,9 +103,12 @@ async def delete_example(example_id: int): if example_id <= 0: raise HTTPException(status_code=400, detail="ID примера должен быть положительным числом") - config = get_config() - db_path = config["database"]["path"] - deleted = db.delete_example(db_path, example_id) + examples_uc = get_examples_use_case() + if examples_uc is not None: + deleted = examples_uc.delete(example_id) + else: + config = get_config() + deleted = db.delete_example(config["database"]["path"], example_id) if not deleted: raise HTTPException(status_code=404, detail=f"Пример с ID {example_id} не найден") return {"message": f"Пример {example_id} успешно удален"} @@ -117,9 +129,12 @@ async def get_example(example_id: int): if example_id <= 0: raise HTTPException(status_code=400, detail="ID примера должен быть положительным числом") - config = get_config() - db_path = config["database"]["path"] - example = db.get_example_by_id(db_path, example_id) + examples_uc = get_examples_use_case() + if examples_uc is not None: + example = examples_uc.get_by_id(example_id) + else: + config = get_config() + example = db.get_example_by_id(config["database"]["path"], example_id) if example is None: raise HTTPException(status_code=404, detail=f"Пример с ID {example_id} не найден") return ExampleResponse(id=example[0], text=example[1], command=example[2]) diff --git a/commands_classifier/api/routes/health.py b/app/api/routes/health.py similarity index 93% rename from commands_classifier/api/routes/health.py rename to app/api/routes/health.py index 4dd3bf2..bea10f8 100644 --- a/commands_classifier/api/routes/health.py +++ b/app/api/routes/health.py @@ -2,8 +2,8 @@ from fastapi import APIRouter, Response -from commands_classifier import db -from commands_classifier.api.state import get_classifier, get_config, get_training_manager +from app.adapters import persistence as db +from app.api.state import get_classifier, get_config, get_training_manager router = APIRouter(tags=["health"]) diff --git a/commands_classifier/api/routes/load_from_hf.py b/app/api/routes/load_from_hf.py similarity index 98% rename from commands_classifier/api/routes/load_from_hf.py rename to app/api/routes/load_from_hf.py index 1d0a60e..fa25597 100644 --- a/commands_classifier/api/routes/load_from_hf.py +++ b/app/api/routes/load_from_hf.py @@ -11,8 +11,8 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field, field_validator -from commands_classifier.api.state import get_config, get_training_manager, load_model -from commands_classifier.hf_retry import retry_hf +from app.adapters.ml import retry_hf +from app.api.state import get_config, get_training_manager, load_model router = APIRouter(tags=["load_from_hf"]) diff --git a/commands_classifier/api/routes/predict.py b/app/api/routes/predict.py similarity index 97% rename from commands_classifier/api/routes/predict.py rename to app/api/routes/predict.py index c9ceea5..e20ef81 100644 --- a/commands_classifier/api/routes/predict.py +++ b/app/api/routes/predict.py @@ -5,9 +5,9 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field, field_validator -from commands_classifier.api.state import get_classifier, get_config -from commands_classifier.api.utils import remove_punctuation -from commands_classifier.model import CommandsClassifier +from app.adapters.ml import CommandsClassifier +from app.api.state import get_classifier, get_config +from app.api.utils import remove_punctuation router = APIRouter(tags=["predict"]) diff --git a/commands_classifier/api/routes/training.py b/app/api/routes/training.py similarity index 97% rename from commands_classifier/api/routes/training.py rename to app/api/routes/training.py index 1e96613..07daee7 100644 --- a/commands_classifier/api/routes/training.py +++ b/app/api/routes/training.py @@ -7,8 +7,8 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field -from commands_classifier import db -from commands_classifier.api.state import get_config, get_training_manager, unload_classifier +from app.adapters import persistence as db +from app.api.state import get_config, get_training_manager, unload_classifier router = APIRouter(tags=["training"]) diff --git a/commands_classifier/api/server.py b/app/api/server.py similarity index 86% rename from commands_classifier/api/server.py rename to app/api/server.py index 6e163a1..d9cf92b 100644 --- a/commands_classifier/api/server.py +++ b/app/api/server.py @@ -8,8 +8,8 @@ import yaml from fastapi import FastAPI -from commands_classifier import db -from commands_classifier.api.routes import ( +from app.adapters import persistence as db +from app.api.routes import ( command_feedback_router, examples_router, health_router, @@ -17,14 +17,20 @@ predict_router, training_router, ) -from commands_classifier.api.state import ( +from app.api.state import ( + get_classifier, + get_config, get_default_device, load_model, set_config, set_default_device, + set_examples_use_case, + set_predict_use_case, set_training_manager, ) -from commands_classifier.api.training import TrainingManager +from app.api.training import TrainingManager +from app.application.examples_use_case import ExamplesUseCase +from app.application.predict_use_case import PredictUseCase # Настраиваем логирование logging.basicConfig( @@ -100,7 +106,13 @@ def init_app(): csv_path = config["database"].get("csv_migration_path") db.init_db(db_path, csv_path) - # Инициализируем менеджер обучения с callback для перезагрузки модели + # Сценарии (use cases) для роутов + set_predict_use_case(PredictUseCase(get_classifier)) + set_examples_use_case( + ExamplesUseCase(db._default_repo, lambda: get_config()["database"]["path"]) + ) + + # Менеджер обучения с callback для перезагрузки модели model_path = config["model"]["path"] model_name = config["model"]["name"] confidence_threshold = float(config["model"].get("confidence_threshold", 0.5)) @@ -114,6 +126,7 @@ def init_app(): on_training_complete=load_model, default_device=get_default_device(), cache_dir=cache_dir, + example_repository=db._default_repo, ) set_training_manager(training_manager) diff --git a/commands_classifier/api/state.py b/app/api/state.py similarity index 82% rename from commands_classifier/api/state.py rename to app/api/state.py index a9c3940..7b1b744 100644 --- a/commands_classifier/api/state.py +++ b/app/api/state.py @@ -3,14 +3,16 @@ from pathlib import Path from typing import Any, Dict, Optional -from commands_classifier.api.training import TrainingManager -from commands_classifier.model import CommandsClassifier +from app.adapters.ml import CommandsClassifier +from app.api.training import TrainingManager # Глобальные переменные состояния _classifier: Optional[CommandsClassifier] = None _training_manager: Optional[TrainingManager] = None _config: Dict[str, Any] = {} _default_device: str = "cpu" +_predict_use_case: Optional[Any] = None +_examples_use_case: Optional[Any] = None def get_classifier() -> Optional[CommandsClassifier]: @@ -97,6 +99,28 @@ def set_default_device(device: str) -> None: _default_device = device +def get_predict_use_case() -> Optional[Any]: + """Возвращает сценарий предсказания (PredictUseCase) или None.""" + return _predict_use_case + + +def set_predict_use_case(uc: Optional[Any]) -> None: + """Устанавливает сценарий предсказания.""" + global _predict_use_case + _predict_use_case = uc + + +def get_examples_use_case() -> Optional[Any]: + """Возвращает сценарий примеров (ExamplesUseCase) или None.""" + return _examples_use_case + + +def set_examples_use_case(uc: Optional[Any]) -> None: + """Устанавливает сценарий примеров.""" + global _examples_use_case + _examples_use_case = uc + + def load_model() -> bool: """Загружает модель из файла.""" config = get_config() diff --git a/commands_classifier/api/training.py b/app/api/training.py similarity index 88% rename from commands_classifier/api/training.py rename to app/api/training.py index 54404a4..d146d76 100644 --- a/commands_classifier/api/training.py +++ b/app/api/training.py @@ -4,24 +4,30 @@ import threading import uuid from datetime import datetime -from enum import Enum from pathlib import Path from typing import Any, Callable, Dict, Optional -from commands_classifier import db -from commands_classifier.model import CommandsClassifier +from app.adapters import persistence as db +from app.adapters.ml import CommandsClassifier +from app.domain.entities import TrainingStatus # Настраиваем логгер для обучения -logger = logging.getLogger("commands_classifier.training") +logger = logging.getLogger("app.training") -class TrainingStatus(str, Enum): - """Статусы обучения.""" +def _default_classifier_factory( + model_name: str, confidence_threshold: float, cache_dir: Optional[str] +): + """Фабрика по умолчанию: создаёт CommandsClassifier с заданными параметрами.""" - IDLE = "idle" - RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" + def factory(): + return CommandsClassifier( + model_name=model_name, + confidence_threshold=confidence_threshold, + cache_dir=cache_dir, + ) + + return factory class TrainingManager: @@ -36,6 +42,8 @@ def __init__( on_training_complete: Optional[Callable[[], None]] = None, default_device: str = "cpu", cache_dir: Optional[str] = None, + example_repository: Optional[Any] = None, + classifier_factory: Optional[Callable[[], Any]] = None, ): """ Инициализирует менеджер обучения. @@ -48,14 +56,18 @@ def __init__( on_training_complete: Callback функция, вызываемая после успешного обучения default_device: Устройство для обучения (определяется автоматически при старте) cache_dir: Путь для кэширования базовой модели (опционально) + example_repository: Реализация IExampleRepository (если None — app.adapters.persistence._default_repo) + classifier_factory: Callable[[], IClassifier] (если None — создаётся CommandsClassifier внутри) """ self.db_path = db_path self.model_path = model_path self.model_name = model_name - self.confidence_threshold = float(confidence_threshold) # Убеждаемся, что это float + self.confidence_threshold = float(confidence_threshold) self.on_training_complete = on_training_complete - self.default_device = default_device # Устройство определяется автоматически при старте + self.default_device = default_device self.cache_dir = cache_dir + self._example_repository = example_repository + self._classifier_factory = classifier_factory self.lock = threading.Lock() self.training_thread: Optional[threading.Thread] = None @@ -124,7 +136,12 @@ def _train_in_background( self.progress = 0.1 # Загружаем данные из БД (только необученные) - texts, labels, example_ids = db.get_examples_for_training(self.db_path) + if self._example_repository is not None: + texts, labels, example_ids = self._example_repository.get_examples_for_training( + self.db_path + ) + else: + texts, labels, example_ids = db.get_examples_for_training(self.db_path) if len(texts) == 0: raise ValueError("Нет необученных данных для обучения в базе данных") @@ -211,11 +228,14 @@ def _train_in_background( # Создаем и обучаем модель threshold_float = float(self.confidence_threshold) - classifier = CommandsClassifier( - model_name=self.model_name, - confidence_threshold=threshold_float, - cache_dir=self.cache_dir, - ) + if self._classifier_factory is not None: + classifier = self._classifier_factory() + else: + classifier = CommandsClassifier( + model_name=self.model_name, + confidence_threshold=threshold_float, + cache_dir=self.cache_dir, + ) # Обновляем прогресс: начало обучения self.progress = 0.3 @@ -291,7 +311,10 @@ def _train_in_background( # Отмечаем использованные строки как обученные if example_ids: - db.mark_examples_as_trained(self.db_path, example_ids) + if self._example_repository is not None: + self._example_repository.mark_as_trained(self.db_path, example_ids) + else: + db.mark_examples_as_trained(self.db_path, example_ids) # Обучение завершено успешно with self.lock: diff --git a/app/api/utils.py b/app/api/utils.py new file mode 100644 index 0000000..3478cca --- /dev/null +++ b/app/api/utils.py @@ -0,0 +1,5 @@ +"""Утилиты для API. Реэкспорт из domain для обратной совместимости.""" + +from app.domain.text_utils import remove_punctuation + +__all__ = ["remove_punctuation"] diff --git a/app/application/__init__.py b/app/application/__init__.py new file mode 100644 index 0000000..bc84281 --- /dev/null +++ b/app/application/__init__.py @@ -0,0 +1,6 @@ +"""Application layer: use cases and training orchestration.""" + +from app.application.examples_use_case import ExamplesUseCase +from app.application.predict_use_case import PredictUseCase + +__all__ = ["ExamplesUseCase", "PredictUseCase"] diff --git a/app/application/examples_use_case.py b/app/application/examples_use_case.py new file mode 100644 index 0000000..3156632 --- /dev/null +++ b/app/application/examples_use_case.py @@ -0,0 +1,37 @@ +"""Сценарий CRUD примеров обучения.""" + +from typing import Callable, List, Optional, Tuple + +from app.domain.ports import IExampleRepository +from app.domain.text_utils import remove_punctuation + + +class ExamplesUseCase: + """Сценарий работы с примерами: get_all, get_by_id, add, delete. Нормализация текста при add.""" + + def __init__( + self, + repository: IExampleRepository, + get_db_path: Callable[[], str], + normalizer: Callable[[str], str] = remove_punctuation, + ): + self._repo = repository + self._get_db_path = get_db_path + self._normalize = normalizer + + def get_all(self) -> List[Tuple[int, str, str]]: + """Возвращает все примеры (id, text, command).""" + return self._repo.get_all(self._get_db_path()) + + def get_by_id(self, example_id: int) -> Optional[Tuple[int, str, str]]: + """Возвращает пример по ID или None.""" + return self._repo.get_by_id(self._get_db_path(), example_id) + + def add(self, text: str, command: str) -> int: + """Добавляет пример после нормализации текста. Возвращает ID.""" + cleaned = self._normalize(text) + return self._repo.add(self._get_db_path(), cleaned, command) + + def delete(self, example_id: int) -> bool: + """Удаляет пример по ID. Возвращает True если удалён.""" + return self._repo.delete(self._get_db_path(), example_id) diff --git a/app/application/predict_use_case.py b/app/application/predict_use_case.py new file mode 100644 index 0000000..2eda2fa --- /dev/null +++ b/app/application/predict_use_case.py @@ -0,0 +1,48 @@ +"""Сценарий предсказания и эмбеддингов.""" + +from typing import Callable, List, Optional, Tuple, Union + +from app.domain.ports import IClassifier +from app.domain.text_utils import remove_punctuation + + +class PredictUseCase: + """Сценарий предсказания: нормализация текста и вызов классификатора.""" + + def __init__( + self, + get_classifier: Callable[[], Optional[IClassifier]], + normalizer: Callable[[str], str] = remove_punctuation, + ): + self._get_classifier = get_classifier + self._normalize = normalizer + + def execute_single( + self, text: str, return_confidence: bool = False + ) -> Union[str, Tuple[str, float]]: + """Классифицирует один текст.""" + classifier = self._get_classifier() + if classifier is None: + return ("unknown", 0.0) if return_confidence else "unknown" + cleaned = self._normalize(text) + return classifier.predict(cleaned, return_confidence=return_confidence) + + def execute_batch( + self, texts: List[str], return_confidence: bool = False + ) -> Union[List[str], Tuple[List[str], List[float]]]: + """Классифицирует список текстов.""" + classifier = self._get_classifier() + if classifier is None: + if return_confidence: + return (["unknown"] * len(texts), [0.0] * len(texts)) + return ["unknown"] * len(texts) + cleaned = [self._normalize(t) for t in texts] + return classifier.predict_batch(cleaned, return_confidence=return_confidence) + + def get_embeddings(self, texts: List[str]) -> List[List[float]]: + """Возвращает эмбеддинги для текстов. При отсутствии классификатора создаётся временный по config (вызывающая сторона).""" + classifier = self._get_classifier() + if classifier is None: + raise ValueError("Классификатор не доступен для эмбеддингов") + cleaned = [self._normalize(t) for t in texts] + return classifier.get_embeddings(cleaned) diff --git a/commands_classifier/client.py b/app/client.py similarity index 100% rename from commands_classifier/client.py rename to app/client.py diff --git a/app/domain/__init__.py b/app/domain/__init__.py new file mode 100644 index 0000000..edd96f7 --- /dev/null +++ b/app/domain/__init__.py @@ -0,0 +1,14 @@ +"""Domain layer: entities, ports, text utilities. No external dependencies.""" + +from app.domain.entities import Example, PredictionResult, TrainingStatus +from app.domain.ports import IClassifier, IExampleRepository +from app.domain.text_utils import remove_punctuation + +__all__ = [ + "Example", + "PredictionResult", + "TrainingStatus", + "IClassifier", + "IExampleRepository", + "remove_punctuation", +] diff --git a/app/domain/entities.py b/app/domain/entities.py new file mode 100644 index 0000000..7e516a9 --- /dev/null +++ b/app/domain/entities.py @@ -0,0 +1,31 @@ +"""Доменные сущности. Без зависимостей от фреймворков.""" + +from dataclasses import dataclass +from enum import Enum + + +class TrainingStatus(str, Enum): + """Статусы обучения.""" + + IDLE = "idle" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +@dataclass(frozen=True) +class Example: + """Пример для обучения: текст команды и метка.""" + + id: int + text: str + command: str + is_trained: bool = False + + +@dataclass(frozen=True) +class PredictionResult: + """Результат предсказания: команда и уверенность.""" + + command: str + confidence: float diff --git a/app/domain/ports.py b/app/domain/ports.py new file mode 100644 index 0000000..91ac853 --- /dev/null +++ b/app/domain/ports.py @@ -0,0 +1,48 @@ +"""Порты (интерфейсы) для адаптеров. Без внешних зависимостей.""" + +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union + + +class IClassifier(Protocol): + """Порт классификатора команд (предсказание, эмбеддинги, обучение, сохранение/загрузка).""" + + def predict( + self, text: str, return_confidence: bool = False + ) -> Union[str, Tuple[str, float]]: ... + def predict_batch( + self, texts: List[str], return_confidence: bool = False + ) -> Union[List[str], Tuple[List[str], List[float]]]: ... + def get_embeddings(self, texts: List[str]) -> List[List[float]]: ... + def train( + self, + texts: List[str], + labels: List[str], + num_iterations: int = 20, + num_epochs: int = 1, + batch_size: int = 16, + learning_rate: float = 2e-5, + device: Optional[str] = None, + ) -> None: ... + def save(self, model_path: str) -> None: ... + def load(self, model_path: str, confidence_threshold: Optional[float] = None) -> None: ... + + +class IExampleRepository(Protocol): + """Порт репозитория примеров (хранение и выборка для обучения).""" + + def init(self, db_path: str, csv_path: Optional[str] = None) -> None: ... + def get_all(self, db_path: str) -> List[Tuple[int, str, str]]: ... + def get_by_id(self, db_path: str, example_id: int) -> Optional[Tuple[int, str, str]]: ... + def add(self, db_path: str, text: str, command: str) -> int: ... + def delete(self, db_path: str, example_id: int) -> bool: ... + def get_examples_for_training( + self, db_path: str + ) -> Tuple[List[str], List[str], List[int]]: ... + def get_trained_examples_by_labels( + self, db_path: str, labels: List[str], limit_per_label: int + ) -> Tuple[List[str], List[str], List[int]]: ... + def mark_as_trained(self, db_path: str, example_ids: List[int]) -> None: ... + def count(self, db_path: str) -> int: ... + def get_training_stats(self, db_path: str) -> Dict[str, Any]: ... + def reset_training_status(self, db_path: str) -> int: ... + def check_connection(self, db_path: str) -> bool: ... diff --git a/commands_classifier/api/utils.py b/app/domain/text_utils.py similarity index 60% rename from commands_classifier/api/utils.py rename to app/domain/text_utils.py index 921ea7d..17a9299 100644 --- a/commands_classifier/api/utils.py +++ b/app/domain/text_utils.py @@ -1,4 +1,4 @@ -"""Утилиты для API.""" +"""Утилиты для работы с текстом. Без внешних зависимостей.""" import re @@ -13,8 +13,6 @@ def remove_punctuation(text: str) -> str: Returns: Текст без знаков препинания """ - # Удаляем все знаки препинания, оставляя только буквы, цифры и пробелы text = re.sub(r"[^\w\s]", "", text) - # Удаляем множественные пробелы и обрезаем text = re.sub(r"\s+", " ", text).strip() return text diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..9bce222 --- /dev/null +++ b/app/main.py @@ -0,0 +1,62 @@ +"""Точка входа: запуск API сервера CVC.""" + +import argparse +import sys + + +def main(): + parser = argparse.ArgumentParser( + description="CVC API — сервер классификации голосовых команд", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Хост (по умолчанию: 0.0.0.0)" + ) + parser.add_argument( + "--port", type=int, default=20001, help="Порт (по умолчанию: 20001)" + ) + parser.add_argument( + "--config", + type=str, + default="config.yaml", + help="Путь к config.yaml (по умолчанию: config.yaml)", + ) + parser.add_argument( + "--reload", + action="store_true", + help="Автоперезагрузка при изменении кода (для разработки)", + ) + args = parser.parse_args() + + try: + import uvicorn + except ImportError: + print("Ошибка: uvicorn не установлен. Установите: pip install uvicorn", file=sys.stderr) + sys.exit(1) + + print(f"Запуск API на {args.host}:{args.port}") + print(f"Конфигурация: {args.config}") + print(f"Документация: http://{args.host}:{args.port}/docs") + + try: + if args.reload: + uvicorn.run( + "app.api.server:app", + host=args.host, + port=args.port, + reload=args.reload, + ) + else: + from app.api.server import app + + uvicorn.run(app, host=args.host, port=args.port, reload=False) + except Exception as e: + print(f"Ошибка при запуске: {e}", file=sys.stderr) + import traceback + + traceback.print_exc(file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/ci/train_via_api.py b/ci/train_via_api.py index 1e99ca1..5e1ec14 100755 --- a/ci/train_via_api.py +++ b/ci/train_via_api.py @@ -20,8 +20,8 @@ import yaml -from commands_classifier import db as db_module -from commands_classifier.model import CommandsClassifier +from app.adapters import persistence as db_module +from app.adapters.ml import CommandsClassifier def load_config(config_path: str = "config.yaml") -> dict: @@ -155,7 +155,7 @@ def main(): # Помечаем примеры как обученные print("\nОбновление статуса примеров в БД...") - from commands_classifier.db import mark_examples_as_trained + from app.adapters.persistence import mark_examples_as_trained mark_examples_as_trained(training_db_path, example_ids) print(f"✓ {len(example_ids)} примеров помечено как обученные") diff --git a/commands_classifier/api/routes/__init__.py b/commands_classifier/api/routes/__init__.py deleted file mode 100644 index ed6884e..0000000 --- a/commands_classifier/api/routes/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Routes для API сервера CVC.""" - -from commands_classifier.api.routes.command_feedback import router as command_feedback_router -from commands_classifier.api.routes.examples import router as examples_router -from commands_classifier.api.routes.health import router as health_router -from commands_classifier.api.routes.load_from_hf import router as load_from_hf_router -from commands_classifier.api.routes.predict import router as predict_router -from commands_classifier.api.routes.training import router as training_router - -__all__ = [ - "predict_router", - "training_router", - "examples_router", - "load_from_hf_router", - "health_router", - "command_feedback_router", -] diff --git a/commands_classifier/cli.py b/commands_classifier/cli.py deleted file mode 100644 index 4db5fc6..0000000 --- a/commands_classifier/cli.py +++ /dev/null @@ -1,82 +0,0 @@ -"""CLI для запуска API сервера.""" - -import argparse -import sys - - -def serve_command(args): - """Команда для запуска API сервера.""" - # Проверяем импорт uvicorn отдельно - try: - import uvicorn - except ImportError: - print("Ошибка: uvicorn не установлен. Установите его: pip install uvicorn", file=sys.stderr) - sys.exit(1) - - print(f"Запуск API сервера на {args.host}:{args.port}") - print(f"Конфигурация: {args.config}") - print(f"Документация API: http://{args.host}:{args.port}/docs") - - try: - # Для reload нужно передавать строку импорта, а не объект - if args.reload: - uvicorn.run( - "commands_classifier.api.server:app", - host=args.host, - port=args.port, - reload=args.reload, - ) - else: - # Без reload можно использовать объект напрямую - from commands_classifier.api.server import app - - uvicorn.run(app, host=args.host, port=args.port, reload=False) - except ImportError as e: - print(f"Ошибка импорта модуля: {e}", file=sys.stderr) - import traceback - - traceback.print_exc(file=sys.stderr) - sys.exit(1) - except Exception as e: - print(f"Ошибка при запуске сервера: {e}", file=sys.stderr) - import traceback - - traceback.print_exc(file=sys.stderr) - sys.exit(1) - - -def main(): - """Главная функция CLI для запуска сервера.""" - parser = argparse.ArgumentParser( - description="CVC API сервер - запуск сервера для классификации голосовых команд", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - # Игнорируем команду "serve" если она передана (для обратной совместимости) - if len(sys.argv) > 1 and sys.argv[1] == "serve": - sys.argv.pop(1) - - parser.add_argument( - "--host", type=str, default="0.0.0.0", help="Хост для сервера (по умолчанию: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=20001, help="Порт для сервера (по умолчанию: 20001)" - ) - parser.add_argument( - "--config", - type=str, - default="config.yaml", - help="Путь к конфигурационному файлу (по умолчанию: config.yaml)", - ) - parser.add_argument( - "--reload", - action="store_true", - help="Включить автоматическую перезагрузку при изменении кода (для разработки)", - ) - - args = parser.parse_args() - serve_command(args) - - -if __name__ == "__main__": - main() diff --git a/commands_classifier/db.py b/commands_classifier/db.py deleted file mode 100644 index 8492f48..0000000 --- a/commands_classifier/db.py +++ /dev/null @@ -1,457 +0,0 @@ -"""Утилиты для работы с базой данных SQLite для хранения обучающих данных.""" - -import re -import sqlite3 -from pathlib import Path -from typing import List, Optional, Tuple - -import pandas as pd - - -def remove_punctuation(text: str) -> str: - """ - Удаляет все знаки препинания из текста. - - Args: - text: Исходный текст - - Returns: - Текст без знаков препинания - """ - # Удаляем все знаки препинания, оставляя только буквы, цифры и пробелы - # Используем регулярное выражение для удаления всех знаков препинания - text = re.sub(r"[^\w\s]", "", text) - # Удаляем множественные пробелы и обрезаем - text = re.sub(r"\s+", " ", text).strip() - return text - - -def _normalize_db_path(db_path: str) -> str: - """ - Нормализует путь к базе данных. - Если путь указывает на директорию, пытается исправить это. - - Args: - db_path: Путь к файлу базы данных - - Returns: - Нормализованный путь к файлу базы данных - """ - path = Path(db_path) - - # Если путь указывает на директорию - if path.exists() and path.is_dir(): - # Пытаемся удалить директорию, если она пустая - try: - if not any(path.iterdir()): - path.rmdir() - # После удаления директории, путь свободен для создания файла - return db_path - else: - # Если директория не пустая, создаем файл внутри неё - return str(path / "training_data.db") - except OSError: - # Если не удалось удалить, создаем файл внутри - return str(path / "training_data.db") - - return db_path - - -def check_connection(db_path: str) -> bool: - """ - Проверяет доступность базы данных. - - Args: - db_path: Путь к файлу базы данных SQLite (или директории с training_data.db) - - Returns: - True если соединение установлено и запрос выполняется, False иначе. - """ - path = _normalize_db_path(db_path) - try: - conn = sqlite3.connect(path, timeout=2.0) - conn.execute("SELECT 1") - conn.close() - return True - except Exception: - return False - - -def _example_exists(cursor: sqlite3.Cursor, text: str, command: str) -> bool: - """ - Проверяет, существует ли пример с указанным text и command в БД. - - Args: - cursor: Курсор базы данных - text: Текст команды - command: Метка команды - - Returns: - True если пример существует, False иначе - """ - cursor.execute("SELECT COUNT(*) FROM examples WHERE text = ? AND command = ?", (text, command)) - return cursor.fetchone()[0] > 0 - - -def init_db(db_path: str, csv_path: Optional[str] = None) -> None: - """ - Инициализирует базу данных и создает таблицу examples. - При каждом запуске проверяет CSV файлы и добавляет отсутствующие строки в БД. - - Args: - db_path: Путь к файлу базы данных SQLite - csv_path: Опциональный путь к CSV файлу или директории с CSV файлами для миграции. - Если указана директория, загружаются все CSV файлы из неё. - При каждом запуске проверяются все CSV файлы и добавляются только новые строки. - """ - # Нормализуем путь к базе данных - db_path = _normalize_db_path(db_path) - path = Path(db_path) - - # Создаем родительскую директорию, если её нет - path.parent.mkdir(parents=True, exist_ok=True) - - # Убеждаемся, что файл может быть создан (проверяем права доступа) - conn = None - try: - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - # Создаем таблицу examples - cursor.execute(""" - CREATE TABLE IF NOT EXISTS examples ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - text TEXT NOT NULL, - command TEXT NOT NULL, - is_trained INTEGER DEFAULT 0 - ) - """) - - # Миграция: добавляем поле is_trained если его нет в существующей таблице - cursor.execute("PRAGMA table_info(examples)") - columns = [column[1] for column in cursor.fetchall()] - if "is_trained" not in columns: - cursor.execute("ALTER TABLE examples ADD COLUMN is_trained INTEGER DEFAULT 0") - # Устанавливаем is_trained = 0 для всех существующих записей - cursor.execute("UPDATE examples SET is_trained = 0 WHERE is_trained IS NULL") - - conn.commit() - - # Если указан CSV путь, выполняем синхронизацию (проверяем при каждом запуске) - if csv_path: - csv_path_obj = Path(csv_path) - if csv_path_obj.exists(): - csv_files = [] - - # Если это директория, находим все CSV файлы в ней - if csv_path_obj.is_dir(): - csv_files = list(csv_path_obj.glob("*.csv")) - if not csv_files: - print(f"В директории {csv_path} не найдено CSV файлов") - # Если это файл, используем его - elif csv_path_obj.is_file() and csv_path_obj.suffix.lower() == ".csv": - csv_files = [csv_path_obj] - else: - print(f"Путь {csv_path} не является директорией или CSV файлом") - - # Синхронизируем данные из всех найденных CSV файлов - total_added = 0 - total_skipped = 0 - for csv_file in csv_files: - try: - df = pd.read_csv(csv_file) - if "text" in df.columns and "command" in df.columns: - added_count = 0 - skipped_count = 0 - for _, row in df.iterrows(): - # Очищаем знаки препинания из текста перед сохранением - cleaned_text = remove_punctuation(str(row["text"])) - command = str(row["command"]) - - # Проверяем, существует ли уже такая строка - if not _example_exists(cursor, cleaned_text, command): - cursor.execute( - "INSERT INTO examples (text, command, is_trained) VALUES (?, ?, 0)", - (cleaned_text, command), - ) - added_count += 1 - else: - skipped_count += 1 - - conn.commit() - total_added += added_count - total_skipped += skipped_count - except Exception as e: - print(f"Ошибка при синхронизации {csv_file.name}: {e}") - - if total_added > 0: - print( - f"Синхронизация CSV: добавлено {total_added} новых примеров из {len(csv_files)} файл(ов)" - ) - except sqlite3.OperationalError as e: - error_msg = ( - f"Не удалось создать/открыть базу данных по пути: {db_path}\n" - f"Ошибка: {e}\n" - f"Возможные причины:\n" - f" 1. Нет прав на запись в директорию {path.parent}\n" - f" 2. Путь указывает на директорию вместо файла (проблема Docker volume)\n" - f" 3. Директория не существует и не может быть создана" - ) - raise RuntimeError(error_msg) from e - finally: - if conn: - conn.close() - - -def get_all_examples(db_path: str) -> List[Tuple[int, str, str]]: - """ - Получает все примеры из базы данных. - - Args: - db_path: Путь к файлу базы данных - - Returns: - Список кортежей (id, text, command) - """ - db_path = _normalize_db_path(db_path) - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute("SELECT id, text, command FROM examples ORDER BY id") - results = cursor.fetchall() - conn.close() - return results - - -def get_examples_for_training(db_path: str) -> Tuple[List[str], List[str], List[int]]: - """ - Получает необученные примеры в формате для обучения (только text и command). - - Args: - db_path: Путь к файлу базы данных - - Returns: - Кортеж (texts, labels, ids) - списки текстов, команд и ID строк - """ - db_path = _normalize_db_path(db_path) - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - # Получаем только необученные примеры (is_trained = 0) - cursor.execute("SELECT id, text, command FROM examples WHERE is_trained = 0 ORDER BY id") - examples = cursor.fetchall() - conn.close() - - texts = [ex[1] for ex in examples] - labels = [ex[2] for ex in examples] - ids = [ex[0] for ex in examples] - return texts, labels, ids - - -def get_trained_examples_by_labels( - db_path: str, labels: List[str], limit_per_label: int -) -> Tuple[List[str], List[str], List[int]]: - """ - Получает обученные примеры из указанных классов для дополнения недостающих примеров. - - Args: - db_path: Путь к файлу базы данных - labels: Список меток классов, для которых нужно получить примеры - limit_per_label: Максимальное количество примеров на класс - - Returns: - Кортеж (texts, labels, ids) - списки текстов, команд и ID строк - """ - if not labels: - return [], [], [] - - db_path = _normalize_db_path(db_path) - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - texts = [] - result_labels = [] - ids = [] - - for label in labels: - # Получаем обученные примеры из этого класса (is_trained = 1) - cursor.execute( - "SELECT id, text, command FROM examples WHERE command = ? AND is_trained = 1 ORDER BY id LIMIT ?", - (label, limit_per_label), - ) - examples = cursor.fetchall() - - for ex in examples: - ids.append(ex[0]) - texts.append(ex[1]) - result_labels.append(ex[2]) - - conn.close() - return texts, result_labels, ids - - -def add_example(db_path: str, text: str, command: str) -> int: - """ - Добавляет новый пример в базу данных. - - Args: - db_path: Путь к файлу базы данных - text: Текст команды - command: Метка команды - - Returns: - ID добавленного примера - """ - db_path = _normalize_db_path(db_path) - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute( - "INSERT INTO examples (text, command, is_trained) VALUES (?, ?, 0)", (text, command) - ) - example_id = cursor.lastrowid - conn.commit() - conn.close() - return example_id - - -def delete_example(db_path: str, example_id: int) -> bool: - """ - Удаляет пример по ID. - - Args: - db_path: Путь к файлу базы данных - example_id: ID примера для удаления - - Returns: - True если пример был удален, False если не найден - """ - db_path = _normalize_db_path(db_path) - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute("DELETE FROM examples WHERE id = ?", (example_id,)) - deleted = cursor.rowcount > 0 - conn.commit() - conn.close() - return deleted - - -def count_examples(db_path: str) -> int: - """ - Возвращает количество примеров в базе данных. - - Args: - db_path: Путь к файлу базы данных - - Returns: - Количество примеров - """ - db_path = _normalize_db_path(db_path) - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute("SELECT COUNT(*) FROM examples") - count = cursor.fetchone()[0] - conn.close() - return count - - -def get_example_by_id(db_path: str, example_id: int) -> Optional[Tuple[int, str, str]]: - """ - Получает пример по ID. - - Args: - db_path: Путь к файлу базы данных - example_id: ID примера - - Returns: - Кортеж (id, text, command) или None если не найден - """ - db_path = _normalize_db_path(db_path) - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute("SELECT id, text, command FROM examples WHERE id = ?", (example_id,)) - result = cursor.fetchone() - conn.close() - return result - - -def mark_examples_as_trained(db_path: str, example_ids: List[int]) -> None: - """ - Отмечает примеры как обученные (устанавливает is_trained = 1). - - Args: - db_path: Путь к файлу базы данных - example_ids: Список ID примеров для отметки - """ - if not example_ids: - return - - # Валидация: убеждаемся, что все ID являются целыми числами - try: - validated_ids = [int(ex_id) for ex_id in example_ids] - except (ValueError, TypeError) as e: - raise ValueError(f"Некорректные ID примеров: {e}") - - db_path = _normalize_db_path(db_path) - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - # Используем параметризованный запрос для безопасности - # Ограничиваем количество ID для предотвращения DoS - if len(validated_ids) > 10000: - raise ValueError("Слишком много ID для одной операции (максимум 10000)") - - placeholders = ",".join("?" * len(validated_ids)) - query = f"UPDATE examples SET is_trained = 1 WHERE id IN ({placeholders})" - cursor.execute(query, validated_ids) - - conn.commit() - conn.close() - - -def get_training_stats(db_path: str) -> dict: - """ - Получает статистику по обученным и необученным примерам. - - Args: - db_path: Путь к файлу базы данных - - Returns: - Словарь со статистикой: total, trained, untrained - """ - db_path = _normalize_db_path(db_path) - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - cursor.execute("SELECT COUNT(*) FROM examples") - total = cursor.fetchone()[0] - - cursor.execute("SELECT COUNT(*) FROM examples WHERE is_trained = 1") - trained = cursor.fetchone()[0] - - cursor.execute("SELECT COUNT(*) FROM examples WHERE is_trained = 0") - untrained = cursor.fetchone()[0] - - conn.close() - - return {"total": total, "trained": trained, "untrained": untrained} - - -def reset_training_status(db_path: str) -> int: - """ - Сбрасывает статус обучения для всех примеров (устанавливает is_trained = 0). - - Args: - db_path: Путь к файлу базы данных - - Returns: - Количество сброшенных записей - """ - db_path = _normalize_db_path(db_path) - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - cursor.execute("UPDATE examples SET is_trained = 0 WHERE is_trained = 1") - reset_count = cursor.rowcount - - conn.commit() - conn.close() - - return reset_count diff --git a/commands_classifier/model.py b/commands_classifier/model.py deleted file mode 100644 index 6cf4c2a..0000000 --- a/commands_classifier/model.py +++ /dev/null @@ -1,390 +0,0 @@ -"""Класс для работы с SetFit моделью классификации команд.""" - -import os -from pathlib import Path -from typing import List, Optional, Tuple - -import numpy as np - -# Патч для совместимости с PyTorch ROCm на Windows -# sentence-transformers использует torch.distributed.is_initialized(), -# которая может быть недоступна в ROCm версии PyTorch -try: - import torch - - if not hasattr(torch.distributed, "is_initialized"): - - def _is_initialized(): - return False - - torch.distributed.is_initialized = _is_initialized -except ImportError: - pass - -from datasets import Dataset -from setfit import SetFitModel, SetFitTrainer - -from commands_classifier.hf_retry import retry_hf - - -def _get_hf_token() -> Optional[str]: - """ - Получает токен Hugging Face из переменных окружения. - - Returns: - Токен Hugging Face или None, если не найден - """ - return os.getenv("HF_TOKEN") - - -class CommandsClassifier: - """Классификатор команд на основе SetFit для few-shot learning.""" - - def __init__( - self, model_name: str, confidence_threshold: float = 0.5, cache_dir: Optional[str] = None - ): - """ - Инициализирует классификатор. - - Args: - model_name: Имя предобученной модели - confidence_threshold: Порог уверенности для отбраковки (0.0-1.0). - Если уверенность ниже порога, возвращается "unknown" - cache_dir: Путь для кэширования базовой модели (опционально) - """ - self.model_name = model_name - self.model: Optional[SetFitModel] = None - self.is_trained = False - # Убеждаемся, что confidence_threshold всегда float - self.confidence_threshold = float(confidence_threshold) - self.cache_dir = cache_dir - - def train( - self, - texts: List[str], - labels: List[str], - num_iterations: int = 20, - num_epochs: int = 1, - batch_size: int = 16, - learning_rate: float = 2e-5, # Может быть float или str, будет преобразовано - device: Optional[str] = None, - ) -> None: - """ - Обучает модель на предоставленных данных. - - Args: - texts: Список текстов для обучения - labels: Список меток (команд) для каждого текста - num_iterations: Количество итераций контрастного обучения (используется как num_epochs для body) - num_epochs: Количество эпох для fine-tuning head - batch_size: Размер батча (больше = быстрее, но требует больше памяти) - learning_rate: Скорость обучения - device: Устройство для обучения ('cpu', 'cuda' или None - определяется автоматически) - """ - if len(texts) != len(labels): - raise ValueError( - f"Количество текстов ({len(texts)}) не совпадает с количеством меток ({len(labels)})" - ) - - # Определяем устройство - import torch - - device = device or ("cuda" if torch.cuda.is_available() else "cpu") - - # Создаем модель с кэшированием в указанную директорию - cache_dir_path = None - if self.cache_dir: - cache_dir_path = Path(self.cache_dir) - cache_dir_path.mkdir(parents=True, exist_ok=True) - cache_dir_path = str(cache_dir_path) - - # Получаем токен Hugging Face для доступа к gated моделям - hf_token = _get_hf_token() - - def _load_base_model(): - try: - return SetFitModel.from_pretrained( - self.model_name, - cache_dir=cache_dir_path, - use_safetensors=True, - token=hf_token, - ) - except Exception: - return SetFitModel.from_pretrained( - self.model_name, - cache_dir=cache_dir_path, - token=hf_token, - ) - - self.model = retry_hf(_load_base_model) - - # Перемещаем модель на устройство (SetFitModel автоматически обрабатывает это) - self.model = self.model.to(device) - - # Создаем датасет для обучения - train_dataset = Dataset.from_dict({"text": texts, "label": labels}) - - # Убеждаемся, что learning_rate это float (может прийти как str из config) - learning_rate_float = float(learning_rate) - - # Создаем тренер с параметрами напрямую - # Модель уже перемещена на нужное устройство выше - trainer = SetFitTrainer( - model=self.model, - train_dataset=train_dataset, - num_iterations=num_iterations, - num_epochs=num_epochs, - batch_size=batch_size, - learning_rate=learning_rate_float, - column_mapping={"text": "text", "label": "label"}, - ) - - # Обучаем модель - trainer.train() - - self.is_trained = True - - def predict(self, text: str, return_confidence: bool = False) -> str | Tuple[str, float]: - """ - Классифицирует один текст. - - Args: - text: Текст для классификации - return_confidence: Если True, возвращает (команда, уверенность) - - Returns: - Предсказанная команда или (команда, уверенность) если return_confidence=True - - Raises: - ValueError: Если модель не обучена - """ - if not self.is_trained or self.model is None: - raise ValueError("Модель не обучена. Сначала вызовите метод train().") - - # Получаем предсказания с вероятностями - predictions, confidences = self._predict_with_confidence([text]) - command = predictions[0] - confidence = confidences[0] - - # Применяем порог уверенности - if confidence < self.confidence_threshold: - command = "unknown" - - if return_confidence: - return command, confidence - return command - - def _predict_with_confidence(self, texts: List[str]) -> Tuple[List[str], List[float]]: - """ - Внутренний метод для получения предсказаний с уверенностью. - - Args: - texts: Список текстов для классификации - - Returns: - Кортеж (предсказания, уверенности) - """ - # Получаем вероятности для всех классов - probs = self.model.predict_proba(texts) - - # Получаем предсказания (классы) - preds = self.model.predict(texts) - - # Находим максимальную вероятность для каждого предсказания - predictions = [] - confidences = [] - - # Обрабатываем probs (может быть numpy array или список) - if hasattr(probs, "tolist"): - probs = probs.tolist() - - # Обрабатываем preds (может быть numpy array или список) - if hasattr(preds, "tolist"): - preds = preds.tolist() - else: - preds = list(preds) - - for i, prob in enumerate(probs): - # Находим индекс максимальной вероятности - if isinstance(prob, (list, np.ndarray)): - max_idx = np.argmax(prob) - max_prob = float(prob[max_idx]) - else: - max_prob = float(prob) - - # Используем предсказание модели - predictions.append(str(preds[i])) - # Убеждаемся, что confidence всегда float - confidences.append(float(max_prob)) - - return predictions, confidences - - def predict_batch( - self, texts: List[str], return_confidence: bool = False - ) -> List[str] | Tuple[List[str], List[float]]: - """ - Классифицирует список текстов. - - Args: - texts: Список текстов для классификации - return_confidence: Если True, возвращает (команды, уверенности) - - Returns: - Список предсказанных команд или (команды, уверенности) если return_confidence=True - - Raises: - ValueError: Если модель не обучена - """ - if not self.is_trained or self.model is None: - raise ValueError("Модель не обучена. Сначала вызовите метод train().") - - predictions, confidences = self._predict_with_confidence(texts) - - # Применяем порог уверенности - commands = [] - for pred, conf in zip(predictions, confidences): - if conf < self.confidence_threshold: - commands.append("unknown") - else: - commands.append(pred) - - if return_confidence: - return commands, confidences - return commands - - def save(self, model_path: str): - """ - Сохраняет обученную модель. - - Args: - model_path: Путь для сохранения модели - - Raises: - ValueError: Если модель не обучена - """ - if not self.is_trained or self.model is None: - raise ValueError("Модель не обучена. Нечего сохранять.") - - import shutil - import tempfile - - path = Path(model_path) - path.parent.mkdir(parents=True, exist_ok=True) - - # Сохраняем во временную директорию, чтобы избежать проблем с открытыми файлами - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) / path.name - self.model.save_pretrained(str(temp_path)) - - # Если целевая директория существует, удаляем её - if path.exists(): - shutil.rmtree(path) - - # Перемещаем из временной директории в целевую - shutil.move(str(temp_path), str(path)) - - def load(self, model_path: str, confidence_threshold: Optional[float] = None): - """ - Загружает сохраненную модель. - - Args: - model_path: Путь к сохраненной модели - confidence_threshold: Порог уверенности (если None, используется текущий) - """ - import warnings - - path = Path(model_path) - if not path.exists(): - raise FileNotFoundError(f"Модель не найдена: {model_path}") - - # Получаем токен Hugging Face (может понадобиться для загрузки базовой модели) - hf_token = _get_hf_token() - - # Подавляем предупреждение о токенизаторе Mistral (если модель была обучена на Mistral) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message=".*mistral.*regex.*", category=UserWarning) - # Пытаемся загрузить с safetensors, если доступно - try: - self.model = SetFitModel.from_pretrained( - str(path), use_safetensors=True, token=hf_token - ) - except Exception: - # Если safetensors не доступны, пробуем без них - self.model = SetFitModel.from_pretrained(str(path), token=hf_token) - - self.is_trained = True - - if confidence_threshold is not None: - # Убеждаемся, что confidence_threshold всегда float - self.confidence_threshold = float(confidence_threshold) - - def get_embeddings(self, texts: List[str]) -> List[List[float]]: - """ - Получает эмбеддинги для списка текстов. - Использует базовую модель эмбеддингов (без классификатора). - - Args: - texts: Список текстов для получения эмбеддингов - - Returns: - Список эмбеддингов (каждый эмбеддинг - список float) - - Raises: - ValueError: Если модель не инициализирована - """ - # Получаем токен Hugging Face для доступа к gated моделям - hf_token = _get_hf_token() - - # Если модель не загружена, загружаем базовую модель (с retry для HF) - if self.model is None: - - def _load_base(): - try: - return SetFitModel.from_pretrained( - self.model_name, - use_safetensors=True, - token=hf_token, - ) - except Exception: - return SetFitModel.from_pretrained( - self.model_name, - token=hf_token, - ) - - self.model = retry_hf(_load_base) - - # Получаем базовую модель эмбеддингов (sentence-transformers) - # SetFitModel имеет атрибут model_body для доступа к базовой модели - if hasattr(self.model, "model_body"): - embedding_model = self.model.model_body - elif hasattr(self.model, "model"): - # Альтернативный способ доступа - embedding_model = self.model.model - else: - # Если нет доступа к базовой модели, используем весь model - embedding_model = self.model - - # Получаем эмбеддинги через encode (стандартный метод sentence-transformers) - if hasattr(embedding_model, "encode"): - embeddings = embedding_model.encode(texts, convert_to_numpy=True) - else: - # Fallback: если encode недоступен, создаем новую базовую модель - from sentence_transformers import SentenceTransformer - - hf_token = _get_hf_token() - base_model = SentenceTransformer(self.model_name, token=hf_token) - embeddings = base_model.encode(texts, convert_to_numpy=True) - - # Преобразуем в список списков float - if hasattr(embeddings, "tolist"): - embeddings = embeddings.tolist() - - # Убеждаемся, что все элементы - списки float - result = [] - for emb in embeddings: - if isinstance(emb, (list, np.ndarray)): - result.append([float(x) for x in emb]) - else: - result.append([float(emb)]) - - return result diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml new file mode 100644 index 0000000..8c9e2bd --- /dev/null +++ b/docker-compose.prod.yml @@ -0,0 +1,17 @@ +# Запуск: docker compose -f docker-compose.yml -f docker-compose.prod.yml up -d + +services: + cvc-api: + image: ghcr.io/shiwarai/cvc-api:main + build: !reset null + volumes: + - ./models:/app/models + - ./checkpoints:/app/checkpoints + - ./cache/huggingface:/app/.cache/huggingface + - ./db:/app/db + - ./config.yaml:/app/config.yaml:ro + - ./data:/app/data:ro + - cvc_logs:/app/logs + +volumes: + cvc_logs: diff --git a/docker-compose.yml b/docker-compose.yml index 95dbf64..2170ed2 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,7 +10,7 @@ services: ports: - "20001:20001" volumes: - - ./commands_classifier:/app/commands_classifier:ro + - ./app:/app/app:ro - ./tests:/app/tests:ro - ./models:/app/models - ./checkpoints:/app/checkpoints diff --git a/tests/conftest.py b/tests/conftest.py index 2b3c198..af23d14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,8 +6,8 @@ import pytest from fastapi import FastAPI -from commands_classifier import db -from commands_classifier.api.routes import ( +from app.adapters import persistence as db +from app.api.routes import ( command_feedback_router, examples_router, health_router, @@ -15,7 +15,7 @@ predict_router, training_router, ) -from commands_classifier.api.state import ( +from app.api.state import ( set_classifier, set_config, set_default_device, diff --git a/tests/integration/test_e2e_docker.py b/tests/integration/test_e2e_docker.py index 95f9a02..74095bb 100644 --- a/tests/integration/test_e2e_docker.py +++ b/tests/integration/test_e2e_docker.py @@ -9,8 +9,8 @@ import pytest from fastapi import FastAPI -from commands_classifier import db -from commands_classifier.api.routes import ( +from app.adapters import persistence as db +from app.api.routes import ( command_feedback_router, examples_router, health_router, @@ -18,14 +18,14 @@ predict_router, training_router, ) -from commands_classifier.api.state import ( +from app.api.state import ( load_model, set_classifier, set_config, set_default_device, set_training_manager, ) -from commands_classifier.api.training import TrainingManager +from app.api.training import TrainingManager # Путь к фикстурам с тестовым датасетом (3 класса: lie_down, dismiss, unknown) FIXTURES_DIR = Path(__file__).resolve().parent.parent / "fixtures" diff --git a/tests/unit/test_dataset.py b/tests/unit/test_dataset.py index 69fdb10..12a885a 100644 --- a/tests/unit/test_dataset.py +++ b/tests/unit/test_dataset.py @@ -4,7 +4,7 @@ import pytest # noqa: F401 - используется для фикстур (tmp_path) -from commands_classifier.dataset import load_dataset +from app.adapters.data_loading import load_dataset def test_load_dataset_csv(tmp_path): diff --git a/tests/unit/test_db.py b/tests/unit/test_db.py index 745c63d..72c599f 100644 --- a/tests/unit/test_db.py +++ b/tests/unit/test_db.py @@ -1,6 +1,6 @@ """Unit-тесты для модуля БД (init_db, add_example, count_examples, get_all_examples).""" -from commands_classifier import db +from app.adapters import persistence as db def test_init_db_and_count_examples(temp_db_path): diff --git a/tests/unit/test_schemas.py b/tests/unit/test_schemas.py index 333ac26..cc5ea35 100644 --- a/tests/unit/test_schemas.py +++ b/tests/unit/test_schemas.py @@ -3,12 +3,12 @@ import pytest from pydantic import ValidationError -from commands_classifier.api.routes.predict import ( +from app.api.routes.predict import ( EmbedRequest, PredictBatchRequest, PredictRequest, ) -from commands_classifier.api.routes.training import TrainRequest +from app.api.routes.training import TrainRequest # --- EmbedRequest --- diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 50889dc..9ddf641 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,6 +1,6 @@ """Unit-тесты для утилит (remove_punctuation).""" -from commands_classifier.api.utils import remove_punctuation +from app.api.utils import remove_punctuation def test_remove_punctuation_empty_string():