Skip to content

Commit

Permalink
Let current task finish before shutdown
Browse files Browse the repository at this point in the history
The shutdown event blocks all incoming requests, therefore we cannot wait for a get request for the current task to be received first before shutting down the app. So, the alternative for now is to at least let the current task finish, then program something in the Airflow triggerer that checks in S3 whether a specific program has been transcribed already or not.
  • Loading branch information
greenw0lf committed Jan 21, 2025
1 parent 20a0a24 commit a3441df
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Optional
from uuid import uuid4
from fastapi import BackgroundTasks, FastAPI, HTTPException, status, Response
import asyncio
from contextlib import asynccontextmanager
from asr import run
from whisper import load_model
from enum import Enum
Expand All @@ -12,14 +14,33 @@
W_MODEL,
)


logger = logging.getLogger(__name__)
api = FastAPI()

logger.info(f"Loading model on device {W_DEVICE}")
shutdown_current_task_done = False


@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the Whisper model
global model
logger.info(f"Loading model on device {W_DEVICE}")
model = load_model(MODEL_BASE_DIR, W_MODEL, W_DEVICE)
# Run the API
yield
# Shutdown call
global current_task
# If there is a task still being processed at shutdown time
if current_task and current_task.status == Status.PROCESSING:
global shutdown_current_task_done
shutdown_current_task_done = False
# Wait until the task is complete
while not shutdown_current_task_done:
await asyncio.sleep(0.1)
yield
# Clean up
logger.info("Current task is done. Shutting down...")
del model

# load the model in memory on API startup
model = load_model(MODEL_BASE_DIR, W_MODEL, W_DEVICE)
api = FastAPI(lifespan=lifespan)


class Status(Enum):
Expand Down Expand Up @@ -84,6 +105,9 @@ def try_whisper(task: Task):
task.error_msg = str(e)
update_task(task)
logger.info(f"Task {task.id} has been updated")
# In case the application is shutting down
global shutdown_current_task_done
shutdown_current_task_done = True


@api.get("/tasks")
Expand Down

0 comments on commit a3441df

Please sign in to comment.