diff --git a/src/kernelbot/api/api_utils.py b/src/kernelbot/api/api_utils.py index a90b0d39..97ff4009 100644 --- a/src/kernelbot/api/api_utils.py +++ b/src/kernelbot/api/api_utils.py @@ -18,6 +18,7 @@ SubmissionRequest, prepare_submission, ) +from libkernelbot.utils import KernelBotError async def _handle_discord_oauth(code: str, redirect_uri: str) -> tuple[str, str]: @@ -147,6 +148,8 @@ async def _run_submission( ): try: req = prepare_submission(submission, backend) + except KernelBotError as e: + raise HTTPException(status_code=e.http_code, detail=str(e)) from e except Exception as e: raise HTTPException(status_code=400, detail=str(e)) from e diff --git a/src/kernelbot/api/main.py b/src/kernelbot/api/main.py index d9d0ae9b..ecc23efd 100644 --- a/src/kernelbot/api/main.py +++ b/src/kernelbot/api/main.py @@ -36,6 +36,7 @@ app = FastAPI() + def json_serializer(obj): """JSON serializer for objects not serializable by default json code""" if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): @@ -255,10 +256,16 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends raise e except Exception as e: # Catch unexpected errors during OAuth handling - raise HTTPException(status_code=500, detail=f"Error during {auth_provider} OAuth flow: {e}") from e + raise HTTPException( + status_code=500, + detail=f"Error during {auth_provider} OAuth flow: {e}", + ) from e if not user_id or not user_name: - raise HTTPException(status_code=500,detail="Failed to retrieve user ID or username from provider.",) + raise HTTPException( + status_code=500, + detail="Failed to retrieve user ID or username from provider.", + ) try: with db_context as db: @@ -268,7 +275,10 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends db.create_user_from_cli(user_id, user_name, cli_id, auth_provider) except AttributeError as e: - raise HTTPException(status_code=500, detail=f"Database interface error during update: {e}") from e + raise HTTPException( + status_code=500, + detail=f"Database interface error during update: {e}", + ) from e except Exception as e: raise HTTPException(status_code=400, detail=f"Database update failed: {e}") from e @@ -280,6 +290,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends "is_reset": is_reset, } + async def _stream_submission_response( submission_request: SubmissionRequest, submission_mode_enum: SubmissionMode, @@ -298,18 +309,22 @@ async def _stream_submission_response( while not task.done(): elapsed_time = time.time() - start_time - yield f"event: status\ndata: {json.dumps({'status': 'processing', - 'elapsed_time': round(elapsed_time, 2)}, - default=json_serializer)}\n\n" + status_data = json.dumps( + {"status": "processing", "elapsed_time": round(elapsed_time, 2)}, + default=json_serializer, + ) + yield f"event: status\ndata: {status_data}\n\n" try: await asyncio.wait_for(asyncio.shield(task), timeout=15.0) except asyncio.TimeoutError: continue except asyncio.CancelledError: - yield f"event: error\ndata: {json.dumps( - {'status': 'error', 'detail': 'Submission cancelled'}, - default=json_serializer)}\n\n" + error_data = json.dumps( + {"status": "error", "detail": "Submission cancelled"}, + default=json_serializer, + ) + yield f"event: error\ndata: {error_data}\n\n" return result, reports = await task @@ -343,6 +358,7 @@ async def _stream_submission_response( except asyncio.CancelledError: pass + @app.post("/{leaderboard_name}/{gpu_type}/{submission_mode}") async def run_submission( # noqa: C901 leaderboard_name: str, @@ -381,13 +397,13 @@ async def run_submission( # noqa: C901 ) return StreamingResponse(generator, media_type="text/event-stream") + async def enqueue_background_job( req: ProcessedSubmissionRequest, mode: SubmissionMode, backend: KernelBackend, manager: BackgroundSubmissionManager, ): - # pre-create the submission for api returns with backend.db as db: sub_id = db.create_submission( @@ -401,7 +417,8 @@ async def enqueue_background_job( job_id = db.upsert_submission_job_status(sub_id, "initial", None) # put submission request in queue await manager.enqueue(req, mode, sub_id) - return sub_id,job_id + return sub_id, job_id + @app.post("/submission/{leaderboard_name}/{gpu_type}/{submission_mode}") async def run_submission_async( @@ -425,37 +442,49 @@ async def run_submission_async( Raises: HTTPException: If the kernelbot is not initialized, or header/input is invalid. Returns: - JSONResponse: A JSON response containing job_id and and submission_id for the client to poll for status. + JSONResponse: A JSON response containing job_id and submission_id. + The client can poll for status using these ids. """ try: - await simple_rate_limit() - logger.info(f"Received submission request for {leaderboard_name} {gpu_type} {submission_mode}") - + logger.info( + "Received submission request for %s %s %s", + leaderboard_name, + gpu_type, + submission_mode, + ) # throw error if submission request is invalid try: submission_request, submission_mode_enum = await to_submit_info( - user_info, submission_mode, file, leaderboard_name, gpu_type, db_context + user_info, submission_mode, file, leaderboard_name, gpu_type, db_context ) req = prepare_submission(submission_request, backend_instance) + except KernelBotError as e: + raise HTTPException(status_code=e.http_code, detail=str(e)) from e except Exception as e: - raise HTTPException(status_code=400, detail=f"failed to prepare submission request: {str(e)}") from e + raise HTTPException( + status_code=400, + detail=f"failed to prepare submission request: {str(e)}", + ) from e # prepare submission request before the submission is started if not req.gpus or len(req.gpus) != 1: raise HTTPException(status_code=400, detail="Invalid GPU type") # put submission request to background manager to run in background - sub_id,job_status_id = await enqueue_background_job( + sub_id, job_status_id = await enqueue_background_job( req, submission_mode_enum, backend_instance, background_submission_manager ) return JSONResponse( status_code=202, - content={"details":{"id": sub_id, "job_status_id": job_status_id}, "status": "accepted"}, + content={ + "details": {"id": sub_id, "job_status_id": job_status_id}, + "status": "accepted", + }, ) # Preserve FastAPI HTTPException as-is except HTTPException: @@ -470,6 +499,7 @@ async def run_submission_async( logger.error(f"Unexpected error in api submissoin: {e}") raise HTTPException(status_code=500, detail="Internal server error") from e + @app.get("/leaderboards") async def get_leaderboards(db_context=Depends(get_db)): """An endpoint that returns all leaderboards. diff --git a/src/kernelbot/cogs/admin_cog.py b/src/kernelbot/cogs/admin_cog.py index d2d01115..7b68961a 100644 --- a/src/kernelbot/cogs/admin_cog.py +++ b/src/kernelbot/cogs/admin_cog.py @@ -122,6 +122,16 @@ def __init__(self, bot: "ClusterBot"): name="set-forum-ids", description="Sets forum IDs" )(self.set_forum_ids) + self.set_submission_rate_limit = bot.admin_group.command( + name="set-submission-rate-limit", + description="Set default or per-user submission rate limit (submissions/minute).", + )(self.set_submission_rate_limit) + + self.get_submission_rate_limit = bot.admin_group.command( + name="get-submission-rate-limit", + description="Get default or per-user submission rate limit (submissions/minute).", + )(self.get_submission_rate_limit) + self._scheduled_cleanup_temp_users.start() # -------------------------------------------------------------------------- @@ -512,6 +522,153 @@ async def start(self, interaction: discord.Interaction): interaction, "Bot will accept submissions again!", ephemeral=True ) + def _parse_user_id_arg(self, user_id: str) -> str: + """Accepts a raw id or a discord mention and returns the id string.""" + s = (user_id or "").strip() + if s.startswith("<@") and s.endswith(">"): + s = s[2:-1].strip() + if s.startswith("!"): + s = s[1:].strip() + return s + + def _format_rate(self, rate: float | None) -> str: + if rate is None: + return "unlimited" + r = float(rate) + if r == 0: + return "blocked" + return f"{r:g}/min" + + @app_commands.describe( + rate_per_minute="Rate in submissions/minute. Use 'none' for unlimited; 'default' clears a user override.", # noqa: E501 + user_id="Optional user id or mention. If omitted, sets the default.", + ) + @with_error_handling + async def set_submission_rate_limit( + self, + interaction: discord.Interaction, + rate_per_minute: str, + user_id: Optional[str] = None, + ): + is_admin = await self.admin_check(interaction) + if not is_admin: + await send_discord_message( + interaction, + "You need to be Admin to use this command.", + ephemeral=True, + ) + return + + rate_s = (rate_per_minute or "").strip().lower() + if rate_s in {"none", "unlimited", "off"}: + parsed_rate: float | None = None + clear_override = False + elif rate_s == "default": + parsed_rate = None + clear_override = True + else: + try: + parsed_rate = float(rate_s) + except ValueError: + await send_discord_message( + interaction, + "Invalid rate. Use a number like `1` or `0.5`, or `none`.", + ephemeral=True, + ) + return + if parsed_rate < 0: + await send_discord_message( + interaction, + "Invalid rate. Must be >= 0 (or `none`).", + ephemeral=True, + ) + return + clear_override = False + + with self.bot.leaderboard_db as db: + if user_id is None or user_id.strip() == "": + if clear_override: + await send_discord_message( + interaction, + "For default limit, use a number or `none` (not `default`).", + ephemeral=True, + ) + return + db.set_default_submission_rate_limit(parsed_rate) + await send_discord_message( + interaction, + f"Default submission rate limit set to **{self._format_rate(parsed_rate)}**.", + ephemeral=True, + ) + return + + uid = self._parse_user_id_arg(user_id) + if uid == "": + await send_discord_message( + interaction, + "Invalid user id.", + ephemeral=True, + ) + return + + if clear_override: + db.clear_user_submission_rate_limit(uid) + await send_discord_message( + interaction, + f"Cleared per-user submission rate limit override for `{uid}` (default now applies).", # noqa: E501 + ephemeral=True, + ) + return + + db.set_user_submission_rate_limit(uid, parsed_rate) + await send_discord_message( + interaction, + f"Submission rate limit for `{uid}` set to **{self._format_rate(parsed_rate)}**.", + ephemeral=True, + ) + + @app_commands.describe( + user_id="Optional user id or mention. If omitted, shows the default.", + ) + @with_error_handling + async def get_submission_rate_limit( + self, + interaction: discord.Interaction, + user_id: Optional[str] = None, + ): + is_admin = await self.admin_check(interaction) + if not is_admin: + await send_discord_message( + interaction, + "You need to be Admin to use this command.", + ephemeral=True, + ) + return + + with self.bot.leaderboard_db as db: + if user_id is None or user_id.strip() == "": + default_rate, capacity = db.get_default_submission_rate_limit() + msg = ( + f"Default submission rate limit: **{self._format_rate(default_rate)}**\n" + f"Bucket capacity: **{capacity:g}**" + ) + await send_discord_message(interaction, msg, ephemeral=True) + return + + uid = self._parse_user_id_arg(user_id) + effective, has_override, user_rate, default_rate, capacity = ( + db.get_submission_rate_limits(uid) + ) + override_text = "default" if not has_override else self._format_rate(user_rate) + msg = ( + f"User `{uid}` submission rate limit:\n" + f"- Effective: **{self._format_rate(effective)}**\n" + f"- User override: **{override_text}**\n" + f"- Default: **{self._format_rate(default_rate)}**\n" + f"- Bucket capacity: **{capacity:g}**" + ) + await send_discord_message(interaction, msg, ephemeral=True) + @app_commands.describe( problem_set="Which problem set to load.", repository_name="Name of the repository to load problems from (in format: user/repo)", diff --git a/src/libkernelbot/leaderboard_db.py b/src/libkernelbot/leaderboard_db.py index 599a0d79..e07a3e7a 100644 --- a/src/libkernelbot/leaderboard_db.py +++ b/src/libkernelbot/leaderboard_db.py @@ -1,7 +1,7 @@ import dataclasses import datetime import json -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import psycopg2 @@ -267,6 +267,192 @@ def validate_identity( logger.exception("Error validating %s %s", human_label, identifier, exc_info=e) raise KernelBotError(f"Error validating {human_label}") from e + def set_default_submission_rate_limit(self, rate_per_minute: Optional[float]) -> None: + """Set the default submission rate limit in submissions/minute (None = unlimited).""" + try: + self.cursor.execute( + """ + INSERT INTO leaderboard.submission_rate_limit_settings AS s + (id, default_rate_per_minute, updated_at) + VALUES + (TRUE, %s, NOW()) + ON CONFLICT (id) DO UPDATE + SET + default_rate_per_minute = EXCLUDED.default_rate_per_minute, + updated_at = NOW(); + """, + (rate_per_minute,), + ) + self.connection.commit() + except psycopg2.Error as e: + self.connection.rollback() + logger.exception("Could not set default submission rate limit.", exc_info=e) + raise KernelBotError("Database error while setting submission rate limit") from e + + def get_default_submission_rate_limit(self) -> Tuple[Optional[float], float]: + """Return (default_rate_per_minute, capacity).""" + try: + self.cursor.execute( + """ + SELECT default_rate_per_minute, capacity + FROM leaderboard.submission_rate_limit_settings + WHERE id = TRUE + """, + ) + row = self.cursor.fetchone() + if not row: + return None, 1.0 + return row[0], float(row[1]) + except psycopg2.Error as e: + self.connection.rollback() + logger.exception("Could not read default submission rate limit.", exc_info=e) + raise KernelBotError("Database error while reading submission rate limit") from e + + def set_user_submission_rate_limit( + self, user_id: str, rate_per_minute: Optional[float] + ) -> None: + """Set a per-user submission rate limit in submissions/minute (None = unlimited).""" + try: + self.cursor.execute( + """ + INSERT INTO leaderboard.submission_rate_limit_user AS u + (user_id, rate_per_minute, updated_at) + VALUES + (%s, %s, NOW()) + ON CONFLICT (user_id) DO UPDATE + SET + rate_per_minute = EXCLUDED.rate_per_minute, + updated_at = NOW(); + """, + (str(user_id), rate_per_minute), + ) + self.connection.commit() + except psycopg2.Error as e: + self.connection.rollback() + logger.exception("Could not set user submission rate limit.", exc_info=e) + raise KernelBotError("Database error while setting submission rate limit") from e + + def clear_user_submission_rate_limit(self, user_id: str) -> None: + """Remove a per-user override so the default applies again.""" + try: + self.cursor.execute( + """ + DELETE FROM leaderboard.submission_rate_limit_user + WHERE user_id = %s; + """, + (str(user_id),), + ) + self.connection.commit() + except psycopg2.Error as e: + self.connection.rollback() + logger.exception("Could not clear user submission rate limit.", exc_info=e) + raise KernelBotError("Database error while clearing submission rate limit") from e + + def get_submission_rate_limits( + self, user_id: str + ) -> Tuple[Optional[float], bool, Optional[float], Optional[float], float]: + """Return (effective_rate, has_override, user_rate, default_rate, capacity).""" + default_rate, capacity = self.get_default_submission_rate_limit() + user_rate: Optional[float] = None + has_override = False + + try: + self.cursor.execute( + """ + SELECT rate_per_minute + FROM leaderboard.submission_rate_limit_user + WHERE user_id = %s + """, + (str(user_id),), + ) + row = self.cursor.fetchone() + if row is not None: + has_override = True + user_rate = row[0] + + effective = user_rate if has_override else default_rate + return effective, has_override, user_rate, default_rate, capacity + except psycopg2.Error as e: + self.connection.rollback() + logger.exception("Could not read submission rate limit config.", exc_info=e) + raise KernelBotError("Database error while reading submission rate limit config") from e + + def enforce_submission_rate_limit(self, user_id: str) -> None: + """Enforce per-user submission rate limiting (token bucket, submissions/minute).""" + user_id = str(user_id) + now = datetime.datetime.now(datetime.timezone.utc) + + effective_rate, _, _, _, capacity = self.get_submission_rate_limits(user_id) + if effective_rate is None: + return + + rate = float(effective_rate) + if rate <= 0: + raise KernelBotError( + "You are currently rate-limited from submitting. Please contact an admin.", + code=429, + ) + + try: + self.cursor.execute( + """ + INSERT INTO leaderboard.submission_rate_limit_state (user_id, tokens, last_refill) + VALUES (%s, %s, %s) + ON CONFLICT (user_id) DO NOTHING; + """, + (user_id, float(capacity), now), + ) + + self.cursor.execute( + """ + SELECT tokens, last_refill + FROM leaderboard.submission_rate_limit_state + WHERE user_id = %s + FOR UPDATE; + """, + (user_id,), + ) + tokens, last_refill = self.cursor.fetchone() + tokens = float(tokens) + + dt_seconds = max(0.0, (now - last_refill).total_seconds()) + refill = (dt_seconds / 60.0) * rate + tokens = min(float(capacity), tokens + refill) + + if tokens < 1.0: + self.cursor.execute( + """ + UPDATE leaderboard.submission_rate_limit_state + SET tokens = %s, last_refill = %s + WHERE user_id = %s; + """, + (tokens, now, user_id), + ) + self.connection.commit() + + wait_seconds = int(((1.0 - tokens) * 60.0) / rate) + 1 + if wait_seconds < 0: + wait_seconds = 0 + raise KernelBotError( + f"Rate limit exceeded. Please wait {wait_seconds}s before submitting again.", + code=429, + ) + + tokens -= 1.0 + self.cursor.execute( + """ + UPDATE leaderboard.submission_rate_limit_state + SET tokens = %s, last_refill = %s + WHERE user_id = %s; + """, + (tokens, now, user_id), + ) + self.connection.commit() + except psycopg2.Error as e: + self.connection.rollback() + logger.exception("Error enforcing submission rate limit.", exc_info=e) + raise KernelBotError("Database error while enforcing submission rate limit") from e + def create_submission( self, leaderboard: str, diff --git a/src/libkernelbot/submission.py b/src/libkernelbot/submission.py index 805f7435..9d87c863 100644 --- a/src/libkernelbot/submission.py +++ b/src/libkernelbot/submission.py @@ -62,6 +62,7 @@ def prepare_submission( with backend.db as db: leaderboard = db.get_leaderboard(req.leaderboard) + db.enforce_submission_rate_limit(str(req.user_id)) check_deadline(leaderboard) task_gpus = get_avail_gpus(req.leaderboard, backend.db) diff --git a/src/migrations/20251221_01_kRtL1-submission-rate-limit.py b/src/migrations/20251221_01_kRtL1-submission-rate-limit.py new file mode 100644 index 00000000..e5e880dc --- /dev/null +++ b/src/migrations/20251221_01_kRtL1-submission-rate-limit.py @@ -0,0 +1,45 @@ +""" +submission_rate_limit +""" + +from yoyo import step + +__depends__ = {"20251106_01_kOjGy-draft-code-editor"} + +steps = [ + step( + """ + CREATE TABLE IF NOT EXISTS leaderboard.submission_rate_limit_settings ( + id BOOLEAN PRIMARY KEY DEFAULT TRUE, + default_rate_per_minute DOUBLE PRECISION DEFAULT NULL, + capacity DOUBLE PRECISION NOT NULL DEFAULT 1.0, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + """ + ), + step( + """ + INSERT INTO leaderboard.submission_rate_limit_settings (id, default_rate_per_minute, capacity) + VALUES (TRUE, NULL, 1.0) + ON CONFLICT (id) DO NOTHING; + """ # noqa: E501 + ), + step( + """ + CREATE TABLE IF NOT EXISTS leaderboard.submission_rate_limit_user ( + user_id TEXT PRIMARY KEY, + rate_per_minute DOUBLE PRECISION DEFAULT NULL, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + """ + ), + step( + """ + CREATE TABLE IF NOT EXISTS leaderboard.submission_rate_limit_state ( + user_id TEXT PRIMARY KEY, + tokens DOUBLE PRECISION NOT NULL, + last_refill TIMESTAMPTZ NOT NULL + ); + """ + ), +]