diff --git a/alembic/versions/402666fa2440_align_db.py b/alembic/versions/402666fa2440_align_db.py new file mode 100644 index 0000000..48f12d9 --- /dev/null +++ b/alembic/versions/402666fa2440_align_db.py @@ -0,0 +1,66 @@ +"""Align DB + +Revision ID: 402666fa2440 +Revises: 4fc1c39216c9 +Create Date: 2026-04-03 17:49:17.151261 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision = '402666fa2440' +down_revision = '4fc1c39216c9' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('ban', 'unban_time', + existing_type=mysql.BIGINT(display_width=18), + type_=mysql.BIGINT(display_width=11, unsigned=True), + existing_nullable=False) + op.drop_index(op.f('ix_ctf_id'), table_name='ctf') + op.alter_column('htb_discord_link', 'account_identifier', + existing_type=mysql.VARCHAR(length=255), + nullable=False) + op.alter_column('htb_discord_link', 'discord_user_id', + existing_type=mysql.VARCHAR(length=42), + type_=mysql.BIGINT(display_width=18), + nullable=False) + op.alter_column('htb_discord_link', 'htb_user_id', + existing_type=mysql.VARCHAR(length=255), + type_=mysql.BIGINT(), + nullable=False) + op.alter_column('mute', 'unmute_time', + existing_type=mysql.BIGINT(display_width=18), + type_=mysql.BIGINT(display_width=11, unsigned=True), + existing_nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('mute', 'unmute_time', + existing_type=mysql.BIGINT(display_width=11, unsigned=True), + type_=mysql.BIGINT(display_width=18), + existing_nullable=False) + op.alter_column('htb_discord_link', 'htb_user_id', + existing_type=mysql.BIGINT(), + type_=mysql.VARCHAR(length=255), + nullable=True) + op.alter_column('htb_discord_link', 'discord_user_id', + existing_type=mysql.BIGINT(display_width=18), + type_=mysql.VARCHAR(length=42), + nullable=True) + op.alter_column('htb_discord_link', 'account_identifier', + existing_type=mysql.VARCHAR(length=255), + nullable=True) + op.create_index(op.f('ix_ctf_id'), 'ctf', ['id'], unique=False) + op.alter_column('ban', 'unban_time', + existing_type=mysql.BIGINT(display_width=11, unsigned=True), + type_=mysql.BIGINT(display_width=18), + existing_nullable=False) + # ### end Alembic commands ### diff --git a/alembic/versions/9aa21aede2ec_add_dynamic_role_table.py b/alembic/versions/9aa21aede2ec_add_dynamic_role_table.py new file mode 100644 index 0000000..e8bdfaa --- /dev/null +++ b/alembic/versions/9aa21aede2ec_add_dynamic_role_table.py @@ -0,0 +1,39 @@ +"""Add dynamic_role table + +Revision ID: 9aa21aede2ec +Revises: 402666fa2440 +Create Date: 2026-04-03 17:52:08.412419 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision = '9aa21aede2ec' +down_revision = '402666fa2440' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dynamic_role', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('key', sa.String(length=64), nullable=False), + sa.Column('discord_role_id', mysql.BIGINT(unsigned=True), nullable=False), + sa.Column('category', sa.Enum('RANK', 'SEASON', 'SUBSCRIPTION_LABS', 'SUBSCRIPTION_ACADEMY', 'CREATOR', 'POSITION', 'ACADEMY_CERT', 'JOINABLE', name='rolecategory'), nullable=False), + sa.Column('display_name', sa.String(length=128), nullable=False), + sa.Column('description', sa.String(length=256), nullable=True), + sa.Column('cert_full_name', sa.String(length=128), nullable=True), + sa.Column('cert_integer_id', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('key', 'category', name='uq_dynamic_role_key_category') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('dynamic_role') + # ### end Alembic commands ### diff --git a/docker-compose.yml b/docker-compose.yml index c966863..cf04c71 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,8 +1,35 @@ -version: '3.8' - services: + db: + image: mariadb:10.11.2 + environment: + MYSQL_ROOT_PASSWORD: rootpassword + MYSQL_DATABASE: hackster + MYSQL_USER: hackster + MYSQL_PASSWORD: hackster + ports: + - "3306:3306" + volumes: + - db_data:/var/lib/mysql + healthcheck: + test: ["CMD-SHELL", "mariadb-admin -uroot -prootpassword ping || exit 1"] + start_period: 30s + interval: 10s + timeout: 5s + retries: 3 bot: build: . env_file: - .env + depends_on: + db: + condition: service_healthy + environment: + MYSQL_HOST: db + MYSQL_PORT: 3306 + MYSQL_DATABASE: hackster + MYSQL_USER: hackster + MYSQL_PASSWORD: hackster + +volumes: + db_data: diff --git a/scripts/seed_dynamic_roles.py b/scripts/seed_dynamic_roles.py new file mode 100644 index 0000000..7f5d525 --- /dev/null +++ b/scripts/seed_dynamic_roles.py @@ -0,0 +1,176 @@ +""" +One-time seed script to populate the dynamic_role table from environment variables. + +Usage: + ENV_PATH=.env python -m scripts.seed_dynamic_roles + # or for test env: + python -m scripts.seed_dynamic_roles (defaults to .test.env) +""" + +import asyncio +import logging +import os +import sys + +from dotenv import dotenv_values +from sqlalchemy.dialects.mysql import insert + +# Add project root to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.database.models.dynamic_role import DynamicRole, RoleCategory # noqa: E402 +from src.database.session import AsyncSessionLocal # noqa: E402 + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# (env_var_suffix, category, key, display_name, extra_kwargs) +SEED_DATA = [ + # Ranks + ("OMNISCIENT", RoleCategory.RANK, "Omniscient", "Omniscient", {}), + ("GURU", RoleCategory.RANK, "Guru", "Guru", {}), + ("ELITE_HACKER", RoleCategory.RANK, "Elite Hacker", "Elite Hacker", {}), + ("PRO_HACKER", RoleCategory.RANK, "Pro Hacker", "Pro Hacker", {}), + ("HACKER", RoleCategory.RANK, "Hacker", "Hacker", {}), + ("SCRIPT_KIDDIE", RoleCategory.RANK, "Script Kiddie", "Script Kiddie", {}), + ("NOOB", RoleCategory.RANK, "Noob", "Noob", {}), + # Subscriptions - Labs + ("VIP", RoleCategory.SUBSCRIPTION_LABS, "vip", "VIP", {}), + ("VIP_PLUS", RoleCategory.SUBSCRIPTION_LABS, "dedivip", "VIP+", {}), + # Subscriptions - Academy + ("SILVER_ANNUAL", RoleCategory.SUBSCRIPTION_ACADEMY, "Silver Annual", "Silver Annual", {}), + ("GOLD_ANNUAL", RoleCategory.SUBSCRIPTION_ACADEMY, "Gold Annual", "Gold Annual", {}), + # Creators + ("BOX_CREATOR", RoleCategory.CREATOR, "Box Creator", "Box Creator", {}), + ("CHALLENGE_CREATOR", RoleCategory.CREATOR, "Challenge Creator", "Challenge Creator", {}), + ("SHERLOCK_CREATOR", RoleCategory.CREATOR, "Sherlock Creator", "Sherlock Creator", {}), + # Positions + ("RANK_ONE", RoleCategory.POSITION, "1", "Top 1", {}), + ("RANK_TEN", RoleCategory.POSITION, "10", "Top 10", {}), + # Seasons + ("SEASON_HOLO", RoleCategory.SEASON, "Holo", "Holo", {}), + ("SEASON_PLATINUM", RoleCategory.SEASON, "Platinum", "Platinum", {}), + ("SEASON_RUBY", RoleCategory.SEASON, "Ruby", "Ruby", {}), + ("SEASON_SILVER", RoleCategory.SEASON, "Silver", "Silver", {}), + ("SEASON_BRONZE", RoleCategory.SEASON, "Bronze", "Bronze", {}), + # Academy Certs (with cert_full_name and cert_integer_id) + ("ACADEMY_CWES", RoleCategory.ACADEMY_CERT, "CWES", "Certified Web Exploitation Specialist", { + "cert_full_name": "HTB Certified Web Exploitation Specialist", + "cert_integer_id": 2, + }), + ("ACADEMY_CPTS", RoleCategory.ACADEMY_CERT, "CPTS", "Certified Penetration Testing Specialist", { + "cert_full_name": "HTB Certified Penetration Testing Specialist", + "cert_integer_id": 3, + }), + ("ACADEMY_CDSA", RoleCategory.ACADEMY_CERT, "CDSA", "Certified Defensive Security Analyst", { + "cert_full_name": "HTB Certified Defensive Security Analyst", + "cert_integer_id": 4, + }), + ("ACADEMY_CWEE", RoleCategory.ACADEMY_CERT, "CWEE", "Certified Web Exploitation Expert", { + "cert_full_name": "HTB Certified Web Exploitation Expert", + "cert_integer_id": 5, + }), + ("ACADEMY_CAPE", RoleCategory.ACADEMY_CERT, "CAPE", "Certified Active Directory Pentesting Expert", { + "cert_full_name": "HTB Certified Active Directory Pentesting Expert", + "cert_integer_id": 6, + }), + ("ACADEMY_CJCA", RoleCategory.ACADEMY_CERT, "CJCA", "Certified Junior Cybersecurity Associate", { + "cert_full_name": "HTB Certified Junior Cybersecurity Associate", + "cert_integer_id": 7, + }), + ("ACADEMY_CWPE", RoleCategory.ACADEMY_CERT, "CWPE", "Certified Wi-Fi Pentesting Expert", { + "cert_full_name": "HTB Certified Wi-Fi Pentesting Expert", + "cert_integer_id": 8, + }), +] + +# Joinable roles: multiple display names can share the same env var / discord_role_id. +# Format: (env_var_suffix, key, display_name, description) +JOINABLE_SEED_DATA = [ + ("UNICTF2022", "Cyber Apocalypse", "Cyber Apocalypse", "Pinged for CTF Announcements"), + ("UNICTF2022", "Business CTF", "Business CTF", "Pinged for CTF Announcements"), + ("UNICTF2022", "University CTF", "University CTF", "Pinged for CTF Announcements"), + ("NOAH_GANG", "Noah Gang", "Noah Gang", "Get pinged when Fugl posts pictures of his cute bird"), + ("BUDDY_GANG", "Buddy Gang", "Buddy Gang", "Get pinged when Legacyy posts pictures of his cute dog"), + ("RED_TEAM", "Red Team", "Red Team", "Red team fans. Also gives access to the Red and Blue team channels"), + ("BLUE_TEAM", "Blue Team", "Blue Team", "Blue team fans. Also gives access to the Red and Blue team channels"), +] + + +async def _upsert_role(session, values: dict) -> None: + """Insert or update a dynamic role using MariaDB upsert.""" + stmt = insert(DynamicRole).values(values) + # On duplicate key, update all fields except the unique key (key, category) + stmt = stmt.on_duplicate_key_update( + discord_role_id=stmt.inserted.discord_role_id, + display_name=stmt.inserted.display_name, + description=stmt.inserted.description, + cert_full_name=stmt.inserted.cert_full_name, + cert_integer_id=stmt.inserted.cert_integer_id, + ) + await session.execute(stmt) + + +async def seed(env_file: str) -> None: + env_values = dotenv_values(env_file) + logger.info(f"Loaded env from {env_file} ({len(env_values)} values)") + + upserted = 0 + skipped = 0 + + async with AsyncSessionLocal() as session: + # Seed standard dynamic roles + for env_suffix, category, key, display_name, extra in SEED_DATA: + env_var = f"ROLE_{env_suffix}" + role_id_str = env_values.get(env_var) + if not role_id_str: + logger.warning(f"Skipping {env_var}: not found in {env_file}") + skipped += 1 + continue + + values = { + "key": key, + "discord_role_id": int(role_id_str), + "category": category, + "display_name": display_name, + "description": None, + "cert_full_name": extra.get("cert_full_name"), + "cert_integer_id": extra.get("cert_integer_id"), + } + + await _upsert_role(session, values) + upserted += 1 + logger.info(f" {category.value}/{key} = {role_id_str}") + + # Seed joinable roles + for env_suffix, key, display_name, description in JOINABLE_SEED_DATA: + env_var = f"ROLE_{env_suffix}" + role_id_str = env_values.get(env_var) + if not role_id_str: + logger.warning(f"Skipping joinable {env_var}/{key}: not found in {env_file}") + skipped += 1 + continue + + values = { + "key": key, + "discord_role_id": int(role_id_str), + "category": RoleCategory.JOINABLE, + "display_name": display_name, + "description": description, + "cert_full_name": None, + "cert_integer_id": None, + } + + await _upsert_role(session, values) + upserted += 1 + logger.info(f" joinable/{key} = {role_id_str}") + + await session.commit() + + logger.info(f"Seeding complete: {upserted} upserted, {skipped} skipped") + + +if __name__ == "__main__": + env_file = os.environ.get("ENV_PATH", ".test.env") + asyncio.run(seed(env_file)) diff --git a/src/__main__.py b/src/__main__.py index 83f73eb..96e27b4 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,12 +1,19 @@ +import asyncio import logging from src.bot import bot from src.core import settings +from src.services.role_manager import RoleManager from src.utils.extensions import walk_extensions from src.webhooks.server import serve logger = logging.getLogger(__name__) +# Load dynamic roles from DB BEFORE starting servers (eliminates race condition). +role_manager = RoleManager(fallback_roles=settings.roles) +asyncio.get_event_loop().run_until_complete(role_manager.load()) +bot.role_manager = role_manager + # Load all cogs extensions. for ext in walk_extensions(): if ext == "src.cmds.automation.scheduled_tasks": diff --git a/src/bot.py b/src/bot.py index c54d329..6f004f5 100644 --- a/src/bot.py +++ b/src/bot.py @@ -53,6 +53,7 @@ def __init__(self, mock: bool = False, **kwargs): mock (bool): Whether to mock the client or not. """ super().__init__(**kwargs) + self.role_manager = None if not mock: logger.debug("HTTP session will be initialized in an asynchronous context") self.http_session = None @@ -75,6 +76,13 @@ async def on_ready(self) -> None: self.send_log(devlog_msg, colour=constants.colours.bright_green) ) + # Refresh dynamic roles cache on (re)connect + if self.role_manager: + try: + await self.role_manager.reload() + except Exception: + logger.warning("Failed to reload dynamic roles on reconnect, keeping previous cache", exc_info=True) + logger.info(f"Started bot as {name}") print("Loading ScheduledTasks cog...") try: diff --git a/src/cmds/core/admin.py b/src/cmds/core/admin.py new file mode 100644 index 0000000..63dc32b --- /dev/null +++ b/src/cmds/core/admin.py @@ -0,0 +1,39 @@ +"""Admin command group for bot administration commands.""" + +import logging + +import discord +from discord.ext import commands + +from src.bot import Bot +from src.core import settings + +logger = logging.getLogger(__name__) + +# Top-level admin command group +# Subcommands (like 'role') are registered by other cogs using admin.create_subgroup() +admin = discord.SlashCommandGroup( + "admin", + "Bot administration commands", + guild_ids=settings.guild_ids, +) + + +class AdminCog(commands.Cog): + """Admin commands placeholder cog. + + This cog doesn't define any commands directly - it just ensures the + /admin group is registered. Subcommands are added by other cogs + via admin.create_subgroup(). + """ + + def __init__(self, bot: Bot): + self.bot = bot + + +def setup(bot: Bot) -> None: + """Load the AdminCog and register the admin command group.""" + cog = AdminCog(bot) + bot.add_cog(cog) + # Register the admin group at module level so other cogs can add subgroups + bot.add_application_command(admin) diff --git a/src/cmds/core/role_admin.py b/src/cmds/core/role_admin.py new file mode 100644 index 0000000..bbe53a2 --- /dev/null +++ b/src/cmds/core/role_admin.py @@ -0,0 +1,149 @@ +import logging + +import discord +from discord import ApplicationContext, Interaction, Option, WebhookMessage, slash_command +from discord.ext import commands +from discord.ext.commands import has_any_role + +from src.bot import Bot +from src.cmds.core.admin import admin # Import the admin group from admin.py +from src.core import settings +from src.database.models.dynamic_role import RoleCategory + +logger = logging.getLogger(__name__) + +CATEGORY_CHOICES = [c.value for c in RoleCategory] + +# Create the role subgroup under the admin group +role = admin.create_subgroup( + "role", + "Manage dynamic Discord roles", +) + + +class RoleAdminCog(commands.Cog): + """Admin commands for managing dynamic Discord roles.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @role.command(description="Add a new dynamic role.") + @has_any_role(*settings.role_groups.get("ALL_ADMINS")) + async def add( + self, + ctx: ApplicationContext, + category: Option(str, "Role category", choices=CATEGORY_CHOICES), + key: Option(str, "Lookup key (e.g. 'Omniscient', 'CWPE')"), + role: Option(discord.Role, "The Discord role"), + display_name: Option(str, "Human-readable name"), + description: Option(str, "Description (for joinable roles)", required=False), + cert_full_name: Option(str, "Full cert name from HTB platform (academy_cert only)", required=False), + cert_integer_id: Option(int, "Platform cert ID (academy_cert only)", required=False), + ) -> Interaction | WebhookMessage: + """Add a new dynamic role to the database.""" + try: + cat = RoleCategory(category) + except ValueError: + return await ctx.respond(f"Invalid category: {category}", ephemeral=True) + + await self.bot.role_manager.add_role( + key=key, + category=cat, + discord_role_id=role.id, + display_name=display_name, + description=description, + cert_full_name=cert_full_name, + cert_integer_id=cert_integer_id, + ) + return await ctx.respond( + f"Added dynamic role: `{category}/{key}` -> {role.mention}", + ephemeral=True, + ) + + @role.command(description="Remove a dynamic role.") + @has_any_role(*settings.role_groups.get("ALL_ADMINS")) + async def remove( + self, + ctx: ApplicationContext, + category: Option(str, "Role category", choices=CATEGORY_CHOICES), + key: Option(str, "Lookup key to remove"), + ) -> Interaction | WebhookMessage: + """Remove a dynamic role from the database.""" + try: + cat = RoleCategory(category) + except ValueError: + return await ctx.respond(f"Invalid category: {category}", ephemeral=True) + + deleted = await self.bot.role_manager.remove_role(cat, key) + if deleted: + return await ctx.respond(f"Removed dynamic role: `{category}/{key}`", ephemeral=True) + return await ctx.respond(f"No role found for `{category}/{key}`", ephemeral=True) + + @role.command(description="Update a dynamic role's Discord role.") + @has_any_role(*settings.role_groups.get("ALL_ADMINS")) + async def update( + self, + ctx: ApplicationContext, + category: Option(str, "Role category", choices=CATEGORY_CHOICES), + key: Option(str, "Lookup key to update"), + role: Option(discord.Role, "The new Discord role"), + ) -> Interaction | WebhookMessage: + """Update a dynamic role's Discord ID.""" + try: + cat = RoleCategory(category) + except ValueError: + return await ctx.respond(f"Invalid category: {category}", ephemeral=True) + + updated = await self.bot.role_manager.update_role(cat, key, role.id) + if updated: + return await ctx.respond( + f"Updated `{category}/{key}` -> {role.mention}", + ephemeral=True, + ) + return await ctx.respond(f"No role found for `{category}/{key}`", ephemeral=True) + + @role.command(description="List configured dynamic roles.") + @has_any_role(*settings.role_groups.get("ALL_ADMINS"), *settings.role_groups.get("ALL_MODS")) + async def list( + self, + ctx: ApplicationContext, + category: Option(str, "Filter by category", choices=CATEGORY_CHOICES, required=False), + ) -> Interaction | WebhookMessage: + """List all configured dynamic roles.""" + cat = RoleCategory(category) if category else None + roles = await self.bot.role_manager.list_roles(cat) + + if not roles: + return await ctx.respond("No dynamic roles configured.", ephemeral=True) + + # Group by category for display + grouped: dict[str, list[str]] = {} + for r in roles: + cat_name = r.category.value + if cat_name not in grouped: + grouped[cat_name] = [] + guild_role = ctx.guild.get_role(r.discord_role_id) + role_mention = guild_role.mention if guild_role else f"`{r.discord_role_id}`" + grouped[cat_name].append(f"`{r.key}` -> {role_mention} ({r.display_name})") + + embed = discord.Embed(title="Dynamic Roles", color=0x9ACC14) + for cat_name, entries in grouped.items(): + embed.add_field( + name=cat_name, + value="\n".join(entries[:10]) + (f"\n... and {len(entries) - 10} more" if len(entries) > 10 else ""), + inline=False, + ) + + return await ctx.respond(embed=embed, ephemeral=True) + + @role.command(description="Force reload dynamic roles from database.") + @has_any_role(*settings.role_groups.get("ALL_ADMINS")) + async def reload(self, ctx: ApplicationContext) -> Interaction | WebhookMessage: + """Force reload the role manager cache from the database.""" + await self.bot.role_manager.reload() + return await ctx.respond("Dynamic roles reloaded from database.", ephemeral=True) + + +def setup(bot: Bot) -> None: + """Load the `RoleAdminCog` cog.""" + bot.add_cog(RoleAdminCog(bot)) diff --git a/src/cmds/core/user.py b/src/cmds/core/user.py index 4337212..bd5218c 100644 --- a/src/cmds/core/user.py +++ b/src/cmds/core/user.py @@ -21,6 +21,16 @@ logger = logging.getLogger(__name__) +async def joinable_role_autocomplete(ctx: discord.AutocompleteContext): + """Autocomplete for joinable roles, fetched from DB via role manager.""" + role_manager = ctx.bot.role_manager + if not role_manager: + return [] + joinable = role_manager.get_joinable_roles() + current = (ctx.value or "").lower() + return [name for name in joinable.keys() if current in name.lower()][:25] + + class UserCog(commands.Cog): """Ban related commands.""" @@ -98,10 +108,10 @@ async def kick(self, ctx: ApplicationContext, user: Member, reason: str, evidenc await add_infraction(ctx.guild, member, 0, infraction_reason, ctx.user) return await ctx.respond(f"{member.name} got the boot!") - @staticmethod - def _match_role(role_name: str) -> Tuple[Union[int, None], Union[str, None]]: - role_name = role_name.lower() - matches = [v[0] for k, v in settings.roles_to_join.items() if role_name in k.lower()] + def _match_role(self, role_name: str) -> Tuple[Union[int, None], Union[str, None]]: + joinable = self.bot.role_manager.get_joinable_roles() if self.bot.role_manager else {} + role_name_lower = role_name.lower() + matches = [v[0] for k, v in joinable.items() if role_name_lower in k.lower()] if not matches: return None, "I don't know what role that is. Did you spell it right?" if len(matches) > 1: @@ -117,15 +127,16 @@ def _match_role(role_name: str) -> Tuple[Union[int, None], Union[str, None]]: async def join( self, ctx: ApplicationContext, - role_name: Option(str, "Choose the role!", choices=settings.roles_to_join.keys()), + role_name: Option(str, "Choose the role!", autocomplete=joinable_role_autocomplete), ) -> Interaction | WebhookMessage: """Join a vanity role if such is specified, otherwise list the vanity roles available to join.""" # No role or empty role name passed if not role_name or role_name.isspace(): + joinable = self.bot.role_manager.get_joinable_roles() if self.bot.role_manager else {} embed = discord.Embed(title=" ", color=0x3D85C6) embed.set_author(name="Join-able Roles") embed.set_footer(text="Use the command /join to join a role.") - for role, value in settings.roles_to_join.items(): + for role, value in joinable.items(): embed.add_field(name=role, value=value[1], inline=True) return await ctx.respond(embed=embed) @@ -143,7 +154,7 @@ async def join( async def leave( self, ctx: ApplicationContext, - role_name: Option(str, "Choose the role!", choices=settings.roles_to_join.keys()), + role_name: Option(str, "Choose the role!", autocomplete=joinable_role_autocomplete), ) -> Interaction | WebhookMessage: """Removes the vanity role from your user.""" role_id, exc = self._match_role(role_name) diff --git a/src/cmds/core/verify.py b/src/cmds/core/verify.py index 919f872..3c78d22 100644 --- a/src/cmds/core/verify.py +++ b/src/cmds/core/verify.py @@ -30,9 +30,15 @@ async def verifycertification(self, ctx: ApplicationContext, certid: str, fullna if not certid.startswith("HTBCERT-"): await ctx.respond("CertID must start with HTBCERT-", ephemeral=True) return - cert = await process_certification(certid, fullname) + cert = await process_certification(certid, fullname, self.bot.role_manager) if cert: - to_add = settings.get_cert(cert) + to_add = self.bot.role_manager.get_cert_role_id(cert) + if not to_add: + await ctx.respond( + "This certification role is not yet configured. Please contact an admin.", + ephemeral=True, + ) + return await ctx.author.add_roles(ctx.guild.get_role(to_add)) await ctx.respond(f"Added {cert}!", ephemeral=True) else: diff --git a/src/core/config.py b/src/core/config.py index 43674c4..4a93a83 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -6,6 +6,8 @@ import toml from pydantic import BaseSettings, validator +# AcademyCertificates is removed; cert mappings are now in the dynamic_role DB table. + class Bot(BaseSettings): """The API settings.""" @@ -86,21 +88,14 @@ class Config: env_prefix = "CHANNEL_" -class AcademyCertificates(BaseSettings): - CERTIFIED_WEB_EXPLOITATION_SPECIALIST = 2 - CERTIFIED_PENETRATION_TESTING_SPECIALIST = 3 - CERTIFIED_DEFENSIVE_SECURITY_ANALYST = 4 - CERTIFIED_WEB_EXPLOITATION_EXPERT = 5 - CERTIFIED_ACTIVEDIRECTORY_PENTESTING_EXPERT = 6 - CERTIFIED_JUNIOR_CYBERSECURITY_ASSOCIATE = 7 - CERTIFIED_WIFI_PENTESTING_EXPERT=8 - - class Roles(BaseSettings): - """The roles settings.""" - VERIFIED: int + """The roles settings. - # Moderation + Core roles (required): used in decorators at import time for permission checks. + Dynamic roles (optional): managed via DB, kept here as fallback during transition. + """ + # ── Core roles (required, used in decorators) ──────────────────── + VERIFIED: int COMMUNITY_MANAGER: int COMMUNITY_TEAM: int ADMINISTRATOR: int @@ -110,54 +105,54 @@ class Roles(BaseSettings): HTB_STAFF: int HTB_SUPPORT: int MUTED: int + ACADEMY_USER: int + # ── Dynamic roles (optional, DB-backed, env var fallback) ──────── # Ranks - OMNISCIENT: int - GURU: int - ELITE_HACKER: int - PRO_HACKER: int - HACKER: int - SCRIPT_KIDDIE: int - NOOB: int - + OMNISCIENT: Optional[int] = None + GURU: Optional[int] = None + ELITE_HACKER: Optional[int] = None + PRO_HACKER: Optional[int] = None + HACKER: Optional[int] = None + SCRIPT_KIDDIE: Optional[int] = None + NOOB: Optional[int] = None # Subscriptions - VIP: int - VIP_PLUS: int - SILVER_ANNUAL: int - GOLD_ANNUAL: int - + VIP: Optional[int] = None + VIP_PLUS: Optional[int] = None + SILVER_ANNUAL: Optional[int] = None + GOLD_ANNUAL: Optional[int] = None # Content Creation - CHALLENGE_CREATOR: int - BOX_CREATOR: int - SHERLOCK_CREATOR: int - + CHALLENGE_CREATOR: Optional[int] = None + BOX_CREATOR: Optional[int] = None + SHERLOCK_CREATOR: Optional[int] = None # Positions - RANK_ONE: int - RANK_TEN: int - - # Season Tiers: - SEASON_HOLO: int - SEASON_PLATINUM: int - SEASON_RUBY: int - SEASON_SILVER: int - SEASON_BRONZE: int - # Academy - ACADEMY_USER: int - ACADEMY_CWES: int - ACADEMY_CPTS: int - ACADEMY_CDSA: int - ACADEMY_CWEE: int - ACADEMY_CAPE: int - ACADEMY_CJCA: int - ACADEMY_CWPE: int + RANK_ONE: Optional[int] = None + RANK_TEN: Optional[int] = None + # Season Tiers + SEASON_HOLO: Optional[int] = None + SEASON_PLATINUM: Optional[int] = None + SEASON_RUBY: Optional[int] = None + SEASON_SILVER: Optional[int] = None + SEASON_BRONZE: Optional[int] = None + # Academy Certs + ACADEMY_CWES: Optional[int] = None + ACADEMY_CPTS: Optional[int] = None + ACADEMY_CDSA: Optional[int] = None + ACADEMY_CWEE: Optional[int] = None + ACADEMY_CAPE: Optional[int] = None + ACADEMY_CJCA: Optional[int] = None + ACADEMY_CWPE: Optional[int] = None # Joinable roles - UNICTF2022: int - BIZCTF2022: int - NOAH_GANG: int - BUDDY_GANG: int - RED_TEAM: int - BLUE_TEAM: int - @validator("*", pre=True, each_item=True) + UNICTF2022: Optional[int] = None + BIZCTF2022: Optional[int] = None + NOAH_GANG: Optional[int] = None + BUDDY_GANG: Optional[int] = None + RED_TEAM: Optional[int] = None + BLUE_TEAM: Optional[int] = None + + @validator("VERIFIED", "COMMUNITY_MANAGER", "COMMUNITY_TEAM", "ADMINISTRATOR", + "SR_MODERATOR", "MODERATOR", "JR_MODERATOR", "HTB_STAFF", "HTB_SUPPORT", + "MUTED", "ACADEMY_USER", pre=True, each_item=True) def check_length(cls, value: str | int) -> str | int: value_str = str(value) if not 17 <= len(value_str) <= 20: @@ -180,10 +175,6 @@ class Global(BaseSettings): roles: Roles = None HTB_API_KEY: str - # Collections are defined using lowercase - academy_certificates: AcademyCertificates = AcademyCertificates() - - roles_to_join: dict[str, tuple[int | str, str]] = {} role_groups: dict[str, list[int | str]] = {} guild_ids: list[int] @@ -234,57 +225,8 @@ def check_ids_format(cls, v: list[int]) -> list[int]: assert len(str(discord_id)) > 17, "Discord ids must have a length of 19." return v - def get_academy_cert_role(self, certificate: int) -> int: - return { - self.academy_certificates.CERTIFIED_WEB_EXPLOITATION_SPECIALIST: self.roles.ACADEMY_CWES, - self.academy_certificates.CERTIFIED_PENETRATION_TESTING_SPECIALIST: self.roles.ACADEMY_CPTS, - self.academy_certificates.CERTIFIED_DEFENSIVE_SECURITY_ANALYST: self.roles.ACADEMY_CDSA, - self.academy_certificates.CERTIFIED_WEB_EXPLOITATION_EXPERT: self.roles.ACADEMY_CWEE, - self.academy_certificates.CERTIFIED_ACTIVEDIRECTORY_PENTESTING_EXPERT: self.roles.ACADEMY_CAPE, - self.academy_certificates.CERTIFIED_JUNIOR_CYBERSECURITY_ASSOCIATE: self.roles.ACADEMY_CJCA, - self.academy_certificates.CERTIFIED_WIFI_PENTESTING_EXPERT: self.roles.ACADEMY_CWPE - }.get(certificate) - - def get_post_or_rank(self, what: str) -> Optional[int]: - return { - "1": self.roles.RANK_ONE, - "10": self.roles.RANK_TEN, - "Omniscient": self.roles.OMNISCIENT, - "Guru": self.roles.GURU, - "Elite Hacker": self.roles.ELITE_HACKER, - "Pro Hacker": self.roles.PRO_HACKER, - "Hacker": self.roles.HACKER, - "Script Kiddie": self.roles.SCRIPT_KIDDIE, - "Noob": self.roles.NOOB, - "vip": self.roles.VIP, - "dedivip": self.roles.VIP_PLUS, - "Challenge Creator": self.roles.CHALLENGE_CREATOR, - "Box Creator": self.roles.BOX_CREATOR, - "Sherlock Creator": self.roles.SHERLOCK_CREATOR, - "Silver Annual": self.roles.SILVER_ANNUAL, - "Gold Annual": self.roles.GOLD_ANNUAL, - }.get(what) - - def get_season(self, what: str): - return { - "Holo": self.roles.SEASON_HOLO, - "Platinum": self.roles.SEASON_PLATINUM, - "Ruby": self.roles.SEASON_RUBY, - "Silver": self.roles.SEASON_SILVER, - "Bronze": self.roles.SEASON_BRONZE, - }.get(what) - - def get_cert(self, what: str): - return { - "CPTS": self.roles.ACADEMY_CPTS, - "CWES": self.roles.ACADEMY_CWES, - "CBBH": self.roles.ACADEMY_CWES, # For legacy reasons - "CDSA": self.roles.ACADEMY_CDSA, - "CWEE": self.roles.ACADEMY_CWEE, - "CAPE": self.roles.ACADEMY_CAPE, - "CJCA": self.roles.ACADEMY_CJCA, - "CWPE": self.roles.ACADEMY_CWPE - }.get(what) + # Helper methods (get_post_or_rank, get_season, get_cert, get_academy_cert_role) + # have been moved to RoleManager (src/services/role_manager.py). class Config: """The Pydantic settings configuration.""" @@ -299,37 +241,9 @@ def load_settings(env_file: str | None = None): global_settings.channels = Channels(_env_file=env_file) global_settings.roles = Roles(_env_file=env_file) - global_settings.roles_to_join = { - "Cyber Apocalypse": ( - global_settings.roles.UNICTF2022, - "Pinged for CTF Announcements", - ), - "Business CTF": ( - global_settings.roles.UNICTF2022, - "Pinged for CTF Announcements", - ), - "University CTF": ( - global_settings.roles.UNICTF2022, - "Pinged for CTF Announcements", - ), - "Noah Gang": ( - global_settings.roles.NOAH_GANG, - "Get pinged when Fugl posts pictures of his cute bird", - ), - "Buddy Gang": ( - global_settings.roles.BUDDY_GANG, - "Get pinged when Legacyy posts pictures of his cute dog", - ), - "Red Team": ( - global_settings.roles.RED_TEAM, - "Red team fans. Also gives access to the Red and Blue team channels", - ), - "Blue Team": ( - global_settings.roles.BLUE_TEAM, - "Blue team fans. Also gives access to the Red and Blue team channels", - ), - } - + # Core role groups (used in decorators at import time for permission checks). + # Dynamic role groups (ALL_RANKS, ALL_SEASON_RANKS, etc.) are now served by + # RoleManager.get_group_ids() from the database. global_settings.role_groups = { "ALL_ADMINS": [ global_settings.roles.ADMINISTRATOR, @@ -343,39 +257,6 @@ def load_settings(env_file: str | None = None): ], "ALL_HTB_STAFF": [global_settings.roles.HTB_STAFF], "ALL_HTB_SUPPORT": [global_settings.roles.HTB_SUPPORT], - "ALL_RANKS": [ - global_settings.roles.OMNISCIENT, - global_settings.roles.GURU, - global_settings.roles.ELITE_HACKER, - global_settings.roles.PRO_HACKER, - global_settings.roles.HACKER, - global_settings.roles.SCRIPT_KIDDIE, - global_settings.roles.NOOB, - ], - "ALL_SEASON_RANKS": [ - global_settings.roles.SEASON_HOLO, - global_settings.roles.SEASON_PLATINUM, - global_settings.roles.SEASON_RUBY, - global_settings.roles.SEASON_SILVER, - global_settings.roles.SEASON_BRONZE, - ], - "ALL_CREATORS": [ - global_settings.roles.BOX_CREATOR, - global_settings.roles.CHALLENGE_CREATOR, - global_settings.roles.SHERLOCK_CREATOR, - ], - "ALL_POSITIONS": [ - global_settings.roles.RANK_ONE, - global_settings.roles.RANK_TEN, - ], - "ALL_LABS_SUBSCRIPTIONS": [ - global_settings.roles.VIP, - global_settings.roles.VIP_PLUS, - ], - "ALL_ACADEMY_SUBSCRIPTIONS": [ - global_settings.roles.SILVER_ANNUAL, - global_settings.roles.GOLD_ANNUAL, - ] } return global_settings diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 216cf22..15ab602 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -3,6 +3,7 @@ from .ban import Ban from .ctf import Ctf +from .dynamic_role import DynamicRole, RoleCategory from .htb_discord_link import HtbDiscordLink from .infraction import Infraction from .macro import Macro diff --git a/src/database/models/dynamic_role.py b/src/database/models/dynamic_role.py new file mode 100644 index 0000000..bf1341d --- /dev/null +++ b/src/database/models/dynamic_role.py @@ -0,0 +1,48 @@ +# flake8: noqa: D101 +import enum + +from sqlalchemy import Enum, Integer, String, UniqueConstraint +from sqlalchemy.dialects.mysql import BIGINT +from sqlalchemy.orm import Mapped, mapped_column + +from . import Base + + +class RoleCategory(str, enum.Enum): + """Categories for dynamic Discord roles.""" + RANK = "rank" + SEASON = "season" + SUBSCRIPTION_LABS = "subscription_labs" + SUBSCRIPTION_ACADEMY = "subscription_academy" + CREATOR = "creator" + POSITION = "position" + ACADEMY_CERT = "academy_cert" + JOINABLE = "joinable" + + +class DynamicRole(Base): + """ + A dynamically configured Discord role, managed via DB instead of env vars. + + Attributes: + id: Primary key. + key: Lookup key (e.g. "Omniscient", "CWPE", "vip"). + discord_role_id: The Discord snowflake ID for this role. + category: Which group this role belongs to. + display_name: Human-readable name shown in commands and embeds. + description: Optional description (used for joinable roles). + cert_full_name: Full certificate name from HTB platform (academy_cert only). + cert_integer_id: Platform certificate ID (academy_cert only). + """ + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + key: Mapped[str] = mapped_column(String(64), nullable=False) + discord_role_id: Mapped[int] = mapped_column(BIGINT(unsigned=True), nullable=False) + category: Mapped[RoleCategory] = mapped_column(Enum(RoleCategory), nullable=False) + display_name: Mapped[str] = mapped_column(String(128), nullable=False) + description: Mapped[str] = mapped_column(String(256), nullable=True, default=None) + cert_full_name: Mapped[str] = mapped_column(String(128), nullable=True, default=None) + cert_integer_id: Mapped[int] = mapped_column(Integer, nullable=True, default=None) + + __table_args__ = ( + UniqueConstraint("key", "category", name="uq_dynamic_role_key_category"), + ) diff --git a/src/helpers/verification.py b/src/helpers/verification.py index fa1efdc..b3cf7d7 100644 --- a/src/helpers/verification.py +++ b/src/helpers/verification.py @@ -158,8 +158,17 @@ async def get_season_rank(htb_uid: int) -> str | None: return rank -async def process_certification(certid: str, name: str): - """Process certifications.""" +async def process_certification(certid: str, name: str, role_manager): + """Process certifications. + + Args: + certid: The certificate ID string (e.g. "HTBCERT-xxxx"). + name: The full name on the certificate. + role_manager: RoleManager instance for cert name lookups. + + Returns: + str | False: The cert abbreviation (e.g. "CPTS") or False if not found. + """ cert_api_url = f"{settings.API_V4_URL}/certificate/lookup" params = {"id": certid, "name": name} async with aiohttp.ClientSession() as session: @@ -174,27 +183,19 @@ async def process_certification(certid: str, name: str): ) response = {} try: - certRawName = response["certificates"][0]["name"] - except IndexError: + cert_raw_name = response["certificates"][0]["name"] + except (IndexError, KeyError): return False - if certRawName == "HTB Certified Bug Bounty Hunter": - cert = "CBBH" - elif certRawName == "HTB Certified Web Exploitation Specialist": - cert = "CBBH" - elif certRawName == "HTB Certified Penetration Testing Specialist": - cert = "CPTS" - elif certRawName == "HTB Certified Defensive Security Analyst": - cert = "CDSA" - elif certRawName == "HTB Certified Web Exploitation Expert": - cert = "CWEE" - elif certRawName == "HTB Certified Active Directory Pentesting Expert": - cert = "CAPE" - elif certRawName == "HTB Certified Junior Cybersecurity Associate": - cert = "CJCA" - elif certRawName == "HTB Certified Wi-Fi Pentesting Expert": - cert = "CWPE" - else: - cert = False + + # Look up cert abbreviation from the DB via full platform name + cert = role_manager.get_cert_abbrev_by_full_name(cert_raw_name) + if not cert: + # Handle legacy "Bug Bounty Hunter" name (renamed to CWES on the platform) + if "Bug Bounty Hunter" in cert_raw_name: + cert = "CBBH" + else: + logger.warning(f"Unknown certificate name: {cert_raw_name}") + return False return cert @@ -306,20 +307,21 @@ async def process_labs_identification( htb_user_details: dict, user: Optional[Member | User], bot: Bot ) -> Optional[List[Role]]: """Returns roles to assign if identification was successfully processed.""" - + # Resolve member and guild member, guild = await _resolve_member_and_guild(user, bot) - + role_manager = bot.role_manager + # Get roles to remove and assign - to_remove = _get_roles_to_remove(member, guild) - to_assign = await _process_role_assignments(htb_user_details, guild) - + to_remove = _get_roles_to_remove(member, guild, role_manager) + to_assign = await _process_role_assignments(htb_user_details, guild, role_manager) + # Remove roles that will be reassigned to_remove = list(set(to_remove) - set(to_assign)) - + # Apply role changes await _apply_role_changes(member, to_remove, to_assign) - + return to_assign @@ -340,12 +342,12 @@ async def _resolve_member_and_guild( raise GuildNotFound(f"Could not identify member {user} in guild.") -def _get_roles_to_remove(member: Member, guild: Guild) -> list[Role]: +def _get_roles_to_remove(member: Member, guild: Guild, role_manager) -> list[Role]: """Get existing roles that should be removed.""" to_remove = [] try: - all_ranks = settings.role_groups.get("ALL_RANKS", []) - all_positions = settings.role_groups.get("ALL_POSITIONS", []) + all_ranks = role_manager.get_group_ids("rank") if role_manager else [] + all_positions = role_manager.get_group_ids("position") if role_manager else [] removable_role_ids = all_ranks + all_positions for role in member.roles: @@ -359,35 +361,35 @@ def _get_roles_to_remove(member: Member, guild: Guild) -> list[Role]: async def _process_role_assignments( - htb_user_details: dict, guild: Guild + htb_user_details: dict, guild: Guild, role_manager ) -> list[Role]: """Process role assignments based on HTB user details.""" to_assign = [] - + # Process rank roles - to_assign.extend(_process_rank_roles(htb_user_details.get("rank", ""), guild)) - + to_assign.extend(_process_rank_roles(htb_user_details.get("rank", ""), guild, role_manager)) + # Process season rank roles - to_assign.extend(await _process_season_rank_roles(htb_user_details.get("id", ""), guild)) - + to_assign.extend(await _process_season_rank_roles(htb_user_details.get("id", ""), guild, role_manager)) + # Process VIP roles - to_assign.extend(_process_vip_roles(htb_user_details, guild)) - + to_assign.extend(_process_vip_roles(htb_user_details, guild, role_manager)) + # Process HOF position roles - to_assign.extend(_process_hof_position_roles(htb_user_details.get("ranking", "unranked"), guild)) - + to_assign.extend(_process_hof_position_roles(htb_user_details.get("ranking", "unranked"), guild, role_manager)) + # Process creator roles - to_assign.extend(_process_creator_roles(htb_user_details.get("content", {}), guild)) - + to_assign.extend(_process_creator_roles(htb_user_details.get("content", {}), guild, role_manager)) + return to_assign -def _process_rank_roles(rank: str, guild: Guild) -> list[Role]: +def _process_rank_roles(rank: str, guild: Guild, role_manager) -> list[Role]: """Process rank-based role assignments.""" roles = [] - + if rank and rank not in ["Deleted", "Moderator", "Ambassador", "Admin", "Staff"]: - role_id = settings.get_post_or_rank(rank) + role_id = role_manager.get_post_or_rank(rank) if role_manager else None if role_id: role = guild.get_role(role_id) if role: @@ -396,13 +398,13 @@ def _process_rank_roles(rank: str, guild: Guild) -> list[Role]: return roles -async def _process_season_rank_roles(mp_user_id: int, guild: Guild) -> list[Role]: +async def _process_season_rank_roles(mp_user_id: int, guild: Guild, role_manager) -> list[Role]: """Process season rank role assignments.""" roles = [] try: season_rank = await get_season_rank(mp_user_id) if isinstance(season_rank, str): - season_role_id = settings.get_season(season_rank) + season_role_id = role_manager.get_season_role_id(season_rank) if role_manager else None if season_role_id: season_role = guild.get_role(season_role_id) if season_role: @@ -412,17 +414,19 @@ async def _process_season_rank_roles(mp_user_id: int, guild: Guild) -> list[Role return roles -def _process_vip_roles(htb_user_details: dict, guild: Guild) -> list[Role]: +def _process_vip_roles(htb_user_details: dict, guild: Guild, role_manager) -> list[Role]: """Process VIP role assignments.""" roles = [] try: if htb_user_details.get("isVip", False): - vip_role = guild.get_role(settings.roles.VIP) + vip_id = role_manager.get_role_id("subscription_labs", "vip") if role_manager else None + vip_role = guild.get_role(vip_id) if vip_id else None if vip_role: roles.append(vip_role) - + if htb_user_details.get("isDedicatedVip", False): - vip_plus_role = guild.get_role(settings.roles.VIP_PLUS) + vip_plus_id = role_manager.get_role_id("subscription_labs", "dedivip") if role_manager else None + vip_plus_role = guild.get_role(vip_plus_id) if vip_plus_id else None if vip_plus_role: roles.append(vip_plus_role) except Exception as e: @@ -430,7 +434,7 @@ def _process_vip_roles(htb_user_details: dict, guild: Guild) -> list[Role]: return roles -def _process_hof_position_roles(htb_user_ranking: str | int, guild: Guild) -> list[Role]: +def _process_hof_position_roles(htb_user_ranking: str | int, guild: Guild, role_manager) -> list[Role]: """Process Hall of Fame position role assignments.""" roles = [] try: @@ -439,9 +443,9 @@ def _process_hof_position_roles(htb_user_ranking: str | int, guild: Guild) -> li if hof_position != "unranked": position = int(hof_position) pos_top = _get_position_tier(position) - + if pos_top: - pos_role_id = settings.get_post_or_rank(pos_top) + pos_role_id = role_manager.get_post_or_rank(pos_top) if role_manager else None if pos_role_id: pos_role = guild.get_role(pos_role_id) if pos_role: @@ -460,24 +464,27 @@ def _get_position_tier(position: int) -> Optional[str]: return None -def _process_creator_roles(htb_user_content: dict, guild: Guild) -> list[Role]: +def _process_creator_roles(htb_user_content: dict, guild: Guild, role_manager) -> list[Role]: """Process creator role assignments.""" roles = [] try: if htb_user_content.get("machines"): - box_creator_role = guild.get_role(settings.roles.BOX_CREATOR) + box_id = role_manager.get_role_id("creator", "Box Creator") if role_manager else None + box_creator_role = guild.get_role(box_id) if box_id else None if box_creator_role: logger.debug("Adding box creator role to user.") roles.append(box_creator_role) - + if htb_user_content.get("challenges"): - challenge_creator_role = guild.get_role(settings.roles.CHALLENGE_CREATOR) + challenge_id = role_manager.get_role_id("creator", "Challenge Creator") if role_manager else None + challenge_creator_role = guild.get_role(challenge_id) if challenge_id else None if challenge_creator_role: logger.debug("Adding challenge creator role to user.") roles.append(challenge_creator_role) - + if htb_user_content.get("sherlocks"): - sherlock_creator_role = guild.get_role(settings.roles.SHERLOCK_CREATOR) + sherlock_id = role_manager.get_role_id("creator", "Sherlock Creator") if role_manager else None + sherlock_creator_role = guild.get_role(sherlock_id) if sherlock_id else None if sherlock_creator_role: logger.debug("Adding sherlock creator role to user.") roles.append(sherlock_creator_role) diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/services/role_manager.py b/src/services/role_manager.py new file mode 100644 index 0000000..3167d8f --- /dev/null +++ b/src/services/role_manager.py @@ -0,0 +1,267 @@ +import logging +from typing import Optional + +from sqlalchemy import delete, select + +from src.database.models.dynamic_role import DynamicRole, RoleCategory +from src.database.session import AsyncSessionLocal + +logger = logging.getLogger(__name__) + +# Legacy alias: CBBH was renamed to CWES +_CERT_ALIASES = {"CBBH": "CWES"} + +# Explicit search order for get_post_or_rank to avoid ambiguity on key collisions. +_POST_OR_RANK_SEARCH_ORDER = [ + RoleCategory.POSITION, + RoleCategory.RANK, + RoleCategory.SUBSCRIPTION_LABS, + RoleCategory.SUBSCRIPTION_ACADEMY, + RoleCategory.CREATOR, +] + +# Mapping from env var field names to (category, key) for dual-read fallback. +# Only used during transition period; will be removed in follow-up PR. +_ENV_FALLBACK_MAP: dict[str, tuple[RoleCategory, str]] = { + "OMNISCIENT": (RoleCategory.RANK, "Omniscient"), + "GURU": (RoleCategory.RANK, "Guru"), + "ELITE_HACKER": (RoleCategory.RANK, "Elite Hacker"), + "PRO_HACKER": (RoleCategory.RANK, "Pro Hacker"), + "HACKER": (RoleCategory.RANK, "Hacker"), + "SCRIPT_KIDDIE": (RoleCategory.RANK, "Script Kiddie"), + "NOOB": (RoleCategory.RANK, "Noob"), + "VIP": (RoleCategory.SUBSCRIPTION_LABS, "vip"), + "VIP_PLUS": (RoleCategory.SUBSCRIPTION_LABS, "dedivip"), + "SILVER_ANNUAL": (RoleCategory.SUBSCRIPTION_ACADEMY, "Silver Annual"), + "GOLD_ANNUAL": (RoleCategory.SUBSCRIPTION_ACADEMY, "Gold Annual"), + "BOX_CREATOR": (RoleCategory.CREATOR, "Box Creator"), + "CHALLENGE_CREATOR": (RoleCategory.CREATOR, "Challenge Creator"), + "SHERLOCK_CREATOR": (RoleCategory.CREATOR, "Sherlock Creator"), + "RANK_ONE": (RoleCategory.POSITION, "1"), + "RANK_TEN": (RoleCategory.POSITION, "10"), + "SEASON_HOLO": (RoleCategory.SEASON, "Holo"), + "SEASON_PLATINUM": (RoleCategory.SEASON, "Platinum"), + "SEASON_RUBY": (RoleCategory.SEASON, "Ruby"), + "SEASON_SILVER": (RoleCategory.SEASON, "Silver"), + "SEASON_BRONZE": (RoleCategory.SEASON, "Bronze"), + "ACADEMY_CWES": (RoleCategory.ACADEMY_CERT, "CWES"), + "ACADEMY_CPTS": (RoleCategory.ACADEMY_CERT, "CPTS"), + "ACADEMY_CDSA": (RoleCategory.ACADEMY_CERT, "CDSA"), + "ACADEMY_CWEE": (RoleCategory.ACADEMY_CERT, "CWEE"), + "ACADEMY_CAPE": (RoleCategory.ACADEMY_CERT, "CAPE"), + "ACADEMY_CJCA": (RoleCategory.ACADEMY_CERT, "CJCA"), + "ACADEMY_CWPE": (RoleCategory.ACADEMY_CERT, "CWPE"), +} + + +class RoleManager: + """Manages dynamic Discord roles backed by the database with an in-memory cache.""" + + def __init__(self, fallback_roles=None): + self._roles: dict[str, dict[str, DynamicRole]] = {} + self._cert_by_integer_id: dict[int, DynamicRole] = {} + self._cert_by_full_name: dict[str, DynamicRole] = {} + self._fallback_roles = fallback_roles + self._loaded = False + + async def load(self) -> None: + """Load all dynamic roles from DB into memory. + + On failure: if a previous cache exists, keeps it and logs a warning. + If no previous cache (first load), raises the exception. + """ + try: + async with AsyncSessionLocal() as session: + result = await session.scalars(select(DynamicRole)) + roles = result.all() + + new_roles: dict[str, dict[str, DynamicRole]] = {} + new_cert_by_id: dict[int, DynamicRole] = {} + new_cert_by_name: dict[str, DynamicRole] = {} + + for role in roles: + cat = role.category.value + if cat not in new_roles: + new_roles[cat] = {} + new_roles[cat][role.key] = role + + if role.category == RoleCategory.ACADEMY_CERT: + if role.cert_integer_id is not None: + new_cert_by_id[role.cert_integer_id] = role + if role.cert_full_name: + new_cert_by_name[role.cert_full_name] = role + + # Apply env var fallbacks for any missing roles (transition period) + if self._fallback_roles: + self._apply_fallbacks(new_roles) + + self._roles = new_roles + self._cert_by_integer_id = new_cert_by_id + self._cert_by_full_name = new_cert_by_name + self._loaded = True + + total = sum(len(v) for v in self._roles.values()) + logger.info(f"Loaded {total} dynamic roles from database") + + except Exception: + if self._loaded: + logger.warning("Failed to reload dynamic roles from DB, keeping previous cache", exc_info=True) + else: + logger.error("Failed to load dynamic roles from DB on first attempt", exc_info=True) + raise + + def _apply_fallbacks(self, roles: dict[str, dict[str, DynamicRole]]) -> None: + """Fill in missing roles from env var fallback values (transition period only).""" + for field_name, (category, key) in _ENV_FALLBACK_MAP.items(): + cat = category.value + if cat in roles and key in roles[cat]: + continue + + value = getattr(self._fallback_roles, field_name, None) + if value is None: + continue + + if cat not in roles: + roles[cat] = {} + + # Create a lightweight stand-in (not persisted to DB) + fallback = DynamicRole() + fallback.key = key + fallback.discord_role_id = value + fallback.category = category + fallback.display_name = key + roles[cat][key] = fallback + logger.debug(f"Using env var fallback for dynamic role: {category.value}/{key}") + + async def reload(self) -> None: + """Reload roles from DB. Alias for load(), used after CRUD operations.""" + await self.load() + + # ── Single role lookups ────────────────────────────────────────── + + def get_role_id(self, category: str, key: str) -> Optional[int]: + """Get a single role's Discord ID by category and key. Returns None if not configured.""" + role = self._roles.get(category, {}).get(key) + return role.discord_role_id if role else None + + def get_cert_role_id(self, cert_abbrev: str) -> Optional[int]: + """Get academy cert role by abbreviation (CPTS, CWES, etc.). Handles CBBH legacy alias.""" + resolved = _CERT_ALIASES.get(cert_abbrev, cert_abbrev) + return self.get_role_id(RoleCategory.ACADEMY_CERT.value, resolved) + + def get_academy_cert_role(self, certificate_id: int) -> Optional[int]: + """Get academy cert role by platform integer ID (2, 3, 4, etc.).""" + role = self._cert_by_integer_id.get(certificate_id) + return role.discord_role_id if role else None + + def get_cert_abbrev_by_full_name(self, full_name: str) -> Optional[str]: + """Get cert abbreviation by full platform name. Replaces the hardcoded process_certification() mapping.""" + role = self._cert_by_full_name.get(full_name) + return role.key if role else None + + def get_rank_role_id(self, rank_name: str) -> Optional[int]: + """Get rank role by name (Omniscient, Guru, etc.).""" + return self.get_role_id(RoleCategory.RANK.value, rank_name) + + def get_season_role_id(self, tier: str) -> Optional[int]: + """Get season tier role (Holo, Platinum, etc.).""" + return self.get_role_id(RoleCategory.SEASON.value, tier) + + # ── Cross-category lookup ──────────────────────────────────────── + + def get_post_or_rank(self, what: str) -> Optional[int]: + """Replaces settings.get_post_or_rank(). Searches position, rank, subscriptions, creator.""" + for category in _POST_OR_RANK_SEARCH_ORDER: + role_id = self.get_role_id(category.value, what) + if role_id is not None: + return role_id + return None + + # ── Group lookups ──────────────────────────────────────────────── + + def get_group_ids(self, category: str) -> list[int]: + """Get all Discord role IDs for a category. Returns [] if none configured.""" + return [r.discord_role_id for r in self._roles.get(category, {}).values()] + + # ── Joinable roles ─────────────────────────────────────────────── + + def get_joinable_roles(self) -> dict[str, tuple[int, str]]: + """Get joinable roles as {display_name: (discord_role_id, description)}.""" + joinable = self._roles.get(RoleCategory.JOINABLE.value, {}) + return { + r.display_name: (r.discord_role_id, r.description or "") + for r in joinable.values() + } + + # ── CRUD operations ────────────────────────────────────────────── + + async def add_role( + self, + key: str, + category: RoleCategory, + discord_role_id: int, + display_name: str, + description: str | None = None, + cert_full_name: str | None = None, + cert_integer_id: int | None = None, + ) -> DynamicRole: + """Add a new dynamic role to the database and reload cache.""" + role = DynamicRole() + role.key = key + role.category = category + role.discord_role_id = discord_role_id + role.display_name = display_name + role.description = description + role.cert_full_name = cert_full_name + role.cert_integer_id = cert_integer_id + + async with AsyncSessionLocal() as session: + session.add(role) + await session.commit() + await session.refresh(role) + + await self.reload() + return role + + async def remove_role(self, category: RoleCategory, key: str) -> bool: + """Remove a dynamic role from the database and reload cache. Returns True if found.""" + async with AsyncSessionLocal() as session: + stmt = delete(DynamicRole).where( + DynamicRole.key == key, + DynamicRole.category == category, + ) + result = await session.execute(stmt) + await session.commit() + deleted = result.rowcount > 0 + + if deleted: + await self.reload() + return deleted + + async def update_role(self, category: RoleCategory, key: str, discord_role_id: int) -> Optional[DynamicRole]: + """Update a role's Discord ID. Returns updated role or None if not found.""" + async with AsyncSessionLocal() as session: + stmt = select(DynamicRole).where( + DynamicRole.key == key, + DynamicRole.category == category, + ) + result = await session.scalars(stmt) + role = result.first() + if not role: + return None + role.discord_role_id = discord_role_id + await session.commit() + await session.refresh(role) + + await self.reload() + return role + + async def list_roles(self, category: RoleCategory | None = None) -> list[DynamicRole]: + """List all dynamic roles, optionally filtered by category.""" + async with AsyncSessionLocal() as session: + stmt = select(DynamicRole) + if category: + stmt = stmt.where(DynamicRole.category == category) + stmt = stmt.order_by(DynamicRole.category, DynamicRole.key) + result = await session.scalars(stmt) + return list(result.all()) diff --git a/src/webhooks/handlers/academy.py b/src/webhooks/handlers/academy.py index 73b2a89..46acb0b 100644 --- a/src/webhooks/handlers/academy.py +++ b/src/webhooks/handlers/academy.py @@ -33,24 +33,21 @@ async def _handle_certificate_awarded(self, body: WebhookBody, bot: Bot) -> dict self.logger.info(f"Handling certificate awarded event for {discord_id} with certificate {certificate_id}") member = await self.get_guild_member(discord_id, bot) - certificate_role_id = settings.get_academy_cert_role(int(certificate_id)) + certificate_role_id = bot.role_manager.get_academy_cert_role(int(certificate_id)) if not certificate_role_id: self.logger.warning(f"No certificate role found for certificate {certificate_id}") return self.fail() - if certificate_role_id: - self.logger.info(f"Adding certificate role {certificate_role_id} to member {member.id}") - try: - await member.add_roles( - bot.guilds[0].get_role(certificate_role_id), atomic=True # type: ignore - ) # type: ignore - - except Exception as e: - self.logger.error(f"Error adding certificate role {certificate_role_id} to member {member.id}: {e}") - raise e + self.logger.info(f"Adding certificate role {certificate_role_id} to member {member.id}") + try: + await member.add_roles( + bot.guilds[0].get_role(certificate_role_id), atomic=True # type: ignore + ) # type: ignore + except Exception as e: + self.logger.error(f"Error adding certificate role {certificate_role_id} to member {member.id}: {e}") + raise e - # Safely attempt to send verification log only after role addition try: verify_channel = bot.guilds[0].get_channel(settings.channels.VERIFY_LOGS) if verify_channel: @@ -67,7 +64,6 @@ async def _handle_certificate_awarded(self, body: WebhookBody, bot: Bot) -> dict self.logger.error(f"Failed to send certificate verification log: {e}") raise e - return self.success() async def _handle_subscription_change(self, body: WebhookBody, bot: Bot) -> dict: @@ -80,13 +76,12 @@ async def _handle_subscription_change(self, body: WebhookBody, bot: Bot) -> dict self.logger.info(f"Handling subscription change event for {discord_id} with plan {plan}") member = await self.get_guild_member(discord_id, bot) - subscription_role_id = settings.get_post_or_rank(plan) + subscription_role_id = bot.role_manager.get_post_or_rank(plan) if not subscription_role_id: self.logger.warning(f"No subscription role found for plan {plan}") return self.fail() - # Use the base handler's role swapping method - role_group = [int(r) for r in settings.role_groups["ALL_ACADEMY_SUBSCRIPTIONS"]] + role_group = bot.role_manager.get_group_ids("subscription_academy") await self.swap_role_in_group(member, subscription_role_id, role_group, bot) - return self.success() \ No newline at end of file + return self.success() diff --git a/src/webhooks/handlers/mp.py b/src/webhooks/handlers/mp.py index bc8a842..17d2de5 100644 --- a/src/webhooks/handlers/mp.py +++ b/src/webhooks/handlers/mp.py @@ -2,7 +2,6 @@ from typing import Literal -from src.core import settings from src.webhooks.handlers.base import BaseHandler from src.webhooks.types import WebhookBody, WebhookEvent @@ -37,12 +36,11 @@ async def _handle_subscription_change(self, body: WebhookBody, bot: Bot) -> dict member = await self.get_guild_member(discord_id, bot) - subscription_id = settings.get_post_or_rank(subscription_name) + subscription_id = bot.role_manager.get_post_or_rank(subscription_name) if not subscription_id: raise ValueError(f"Invalid subscription name: {subscription_name}") - # Use the base handler's role swapping method - role_group = [int(r) for r in settings.role_groups["ALL_LABS_SUBSCRIPTIONS"]] + role_group = bot.role_manager.get_group_ids("subscription_labs") await self.swap_role_in_group(member, subscription_id, role_group, bot) return self.success() @@ -57,9 +55,12 @@ async def _handle_hof_change(self, body: WebhookBody, bot: Bot) -> dict: self.get_property_or_trait(body, "hof_tier"), "hof_tier", # type: ignore ) + + rank_one_id = bot.role_manager.get_role_id("position", "1") + rank_ten_id = bot.role_manager.get_role_id("position", "10") hof_roles = { - "1": bot.guilds[0].get_role(settings.roles.RANK_ONE), - "10": bot.guilds[0].get_role(settings.roles.RANK_TEN), + "1": bot.guilds[0].get_role(rank_one_id) if rank_one_id else None, + "10": bot.guilds[0].get_role(rank_ten_id) if rank_ten_id else None, } member = await self.get_guild_member(discord_id, bot) @@ -124,7 +125,7 @@ async def _handle_rank_up(self, body: WebhookBody, bot: Bot) -> dict: member = await self.get_guild_member(discord_id, bot) - rank_id = settings.get_post_or_rank(rank) + rank_id = bot.role_manager.get_post_or_rank(rank) if not rank_id: err = ValueError(f"Cannot find role for '{rank}'") self.logger.error( @@ -137,10 +138,9 @@ async def _handle_rank_up(self, body: WebhookBody, bot: Bot) -> dict: ) raise err - # Use the base handler's role swapping method - role_group = [int(r) for r in settings.role_groups["ALL_RANKS"]] + role_group = bot.role_manager.get_group_ids("rank") changes_made = await self.swap_role_in_group(member, rank_id, role_group, bot) - + if not changes_made: return self.success() # No changes needed @@ -155,7 +155,7 @@ async def _handle_season_rank(self, body: WebhookBody, bot: Bot) -> dict: self.get_property_or_trait(body, "season_rank"), "season_rank" ) - season_role_id = settings.get_season(season_rank) + season_role_id = bot.role_manager.get_season_role_id(season_rank) if not season_role_id: err = ValueError(f"Cannot find role for '{season_rank}'") self.logger.error( @@ -170,8 +170,7 @@ async def _handle_season_rank(self, body: WebhookBody, bot: Bot) -> dict: member = await self.get_guild_member(discord_id, bot) - # Use the base handler's role swapping method - role_group = [int(r) for r in settings.role_groups["ALL_SEASON_RANKS"]] + role_group = bot.role_manager.get_group_ids("season") await self.swap_role_in_group(member, season_role_id, role_group, bot) return self.success() \ No newline at end of file diff --git a/tests/helpers.py b/tests/helpers.py index b64c709..62cb583 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -166,7 +166,7 @@ def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None # Create a Role instance to get a realistic Mock of `discord.Role` -role_data = {'name': 'role', 'id': 1} +role_data = {'name': 'role', 'id': 1, 'colors': {'primary_color': 0}} # noinspection PyTypeChecker role_instance = discord.Role(guild=guild_instance, state=mock.MagicMock(), data=role_data) @@ -277,6 +277,58 @@ def mock_create_task(coroutine): return loop +class MockRoleManager: + """A mock RoleManager that returns sensible defaults for all lookups.""" + + def __init__(self): + self._loaded = True + + def get_role_id(self, category, key): + return None + + def get_cert_role_id(self, cert_abbrev): + return None + + def get_academy_cert_role(self, certificate_id): + return None + + def get_cert_abbrev_by_full_name(self, full_name): + return None + + def get_rank_role_id(self, rank_name): + return None + + def get_season_role_id(self, tier): + return None + + def get_post_or_rank(self, what): + return None + + def get_group_ids(self, category): + return [] + + def get_joinable_roles(self): + return {} + + async def reload(self): + pass + + async def load(self): + pass + + async def add_role(self, **kwargs): + return None + + async def remove_role(self, category, key): + return False + + async def update_role(self, category, key, discord_role_id): + return None + + async def list_roles(self, category=None): + return [] + + class MockBot(CustomMockMixin, mock.MagicMock): """ A MagicMock subclass to mock Bot objects. @@ -292,6 +344,7 @@ def __init__(self, **kwargs) -> None: self.loop = _get_mock_loop() self.http_session = mock.create_autospec(spec=ClientSession, spec_set=True) + self.role_manager = MockRoleManager() # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` diff --git a/tests/src/cmds/core/test_role_admin.py b/tests/src/cmds/core/test_role_admin.py new file mode 100644 index 0000000..0014648 --- /dev/null +++ b/tests/src/cmds/core/test_role_admin.py @@ -0,0 +1,227 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from src.cmds.core import role_admin +from src.cmds.core.role_admin import RoleAdminCog +from src.database.models.dynamic_role import RoleCategory +from tests import helpers + + +class TestRoleAdminCog: + """Test the `RoleAdminCog` cog.""" + + # ── add command ────────────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_add_success(self, ctx, bot): + """Test adding a dynamic role with valid parameters.""" + bot.role_manager.add_role = AsyncMock() + discord_role = helpers.MockRole(id=9999, name="TestRole") + + cog = RoleAdminCog(bot) + await cog.add.callback( + cog, + ctx, + category="rank", + key="Omniscient", + role=discord_role, + display_name="Omniscient Rank", + description=None, + cert_full_name=None, + cert_integer_id=None, + ) + + bot.role_manager.add_role.assert_awaited_once_with( + key="Omniscient", + category=RoleCategory.RANK, + discord_role_id=discord_role.id, + display_name="Omniscient Rank", + description=None, + cert_full_name=None, + cert_integer_id=None, + ) + ctx.respond.assert_awaited_once() + assert "Added dynamic role" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_add_invalid_category(self, ctx, bot): + """Test adding a role with an invalid category returns an error.""" + discord_role = helpers.MockRole(id=9999, name="TestRole") + + cog = RoleAdminCog(bot) + await cog.add.callback( + cog, + ctx, + category="not_a_real_category", + key="Key", + role=discord_role, + display_name="Display", + description=None, + cert_full_name=None, + cert_integer_id=None, + ) + + ctx.respond.assert_awaited_once() + assert "Invalid category" in ctx.respond.call_args[0][0] + + # ── remove command ─────────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_remove_success(self, ctx, bot): + """Test removing an existing dynamic role.""" + bot.role_manager.remove_role = AsyncMock(return_value=True) + + cog = RoleAdminCog(bot) + await cog.remove.callback(cog, ctx, category="rank", key="Omniscient") + + bot.role_manager.remove_role.assert_awaited_once_with(RoleCategory.RANK, "Omniscient") + ctx.respond.assert_awaited_once() + assert "Removed dynamic role" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_remove_not_found(self, ctx, bot): + """Test removing a non-existent dynamic role.""" + bot.role_manager.remove_role = AsyncMock(return_value=False) + + cog = RoleAdminCog(bot) + await cog.remove.callback(cog, ctx, category="rank", key="NonExistent") + + ctx.respond.assert_awaited_once() + assert "No role found" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_remove_invalid_category(self, ctx, bot): + """Test removing with an invalid category returns an error.""" + cog = RoleAdminCog(bot) + await cog.remove.callback(cog, ctx, category="bad_cat", key="Key") + + ctx.respond.assert_awaited_once() + assert "Invalid category" in ctx.respond.call_args[0][0] + + # ── update command ─────────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_update_success(self, ctx, bot): + """Test updating an existing dynamic role.""" + bot.role_manager.update_role = AsyncMock(return_value=True) + new_role = helpers.MockRole(id=7777, name="NewRole") + + cog = RoleAdminCog(bot) + await cog.update.callback(cog, ctx, category="rank", key="Omniscient", role=new_role) + + bot.role_manager.update_role.assert_awaited_once_with(RoleCategory.RANK, "Omniscient", new_role.id) + ctx.respond.assert_awaited_once() + assert "Updated" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_update_not_found(self, ctx, bot): + """Test updating a non-existent dynamic role.""" + bot.role_manager.update_role = AsyncMock(return_value=None) + new_role = helpers.MockRole(id=7777, name="NewRole") + + cog = RoleAdminCog(bot) + await cog.update.callback(cog, ctx, category="rank", key="NonExistent", role=new_role) + + ctx.respond.assert_awaited_once() + assert "No role found" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_update_invalid_category(self, ctx, bot): + """Test updating with an invalid category returns an error.""" + new_role = helpers.MockRole(id=7777, name="NewRole") + + cog = RoleAdminCog(bot) + await cog.update.callback(cog, ctx, category="bad_cat", key="Key", role=new_role) + + ctx.respond.assert_awaited_once() + assert "Invalid category" in ctx.respond.call_args[0][0] + + # ── list command ───────────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_list_no_roles(self, ctx, bot): + """Test listing when no dynamic roles are configured.""" + bot.role_manager.list_roles = AsyncMock(return_value=[]) + + cog = RoleAdminCog(bot) + await cog.list.callback(cog, ctx, category=None) + + ctx.respond.assert_awaited_once() + assert "No dynamic roles configured" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_list_with_roles(self, ctx, bot): + """Test listing dynamic roles grouped by category.""" + mock_role_entry = MagicMock() + mock_role_entry.category = RoleCategory.RANK + mock_role_entry.key = "Omniscient" + mock_role_entry.discord_role_id = 1234 + mock_role_entry.display_name = "Omniscient Rank" + + guild_role = helpers.MockRole(id=1234, name="Omniscient") + ctx.guild.get_role = MagicMock(return_value=guild_role) + + bot.role_manager.list_roles = AsyncMock(return_value=[mock_role_entry]) + + cog = RoleAdminCog(bot) + await cog.list.callback(cog, ctx, category=None) + + ctx.respond.assert_awaited_once() + call_kwargs = ctx.respond.call_args[1] + embed = call_kwargs["embed"] + assert embed.title == "Dynamic Roles" + assert len(embed.fields) == 1 + assert embed.fields[0].name == "rank" + + @pytest.mark.asyncio + async def test_list_with_category_filter(self, ctx, bot): + """Test listing dynamic roles filtered by a specific category.""" + bot.role_manager.list_roles = AsyncMock(return_value=[]) + + cog = RoleAdminCog(bot) + await cog.list.callback(cog, ctx, category="rank") + + bot.role_manager.list_roles.assert_awaited_once_with(RoleCategory.RANK) + + @pytest.mark.asyncio + async def test_list_role_not_in_guild(self, ctx, bot): + """Test listing when a role ID does not exist in the guild (shows raw ID).""" + mock_role_entry = MagicMock() + mock_role_entry.category = RoleCategory.RANK + mock_role_entry.key = "Ghost" + mock_role_entry.discord_role_id = 99999 + mock_role_entry.display_name = "Ghost Rank" + + ctx.guild.get_role = MagicMock(return_value=None) + bot.role_manager.list_roles = AsyncMock(return_value=[mock_role_entry]) + + cog = RoleAdminCog(bot) + await cog.list.callback(cog, ctx, category=None) + + ctx.respond.assert_awaited_once() + call_kwargs = ctx.respond.call_args[1] + embed = call_kwargs["embed"] + # When guild role is not found, the raw ID should be shown + assert "99999" in embed.fields[0].value + + # ── reload command ─────────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_reload(self, ctx, bot): + """Test the reload command calls role_manager.reload.""" + bot.role_manager.reload = AsyncMock() + + cog = RoleAdminCog(bot) + await cog.reload.callback(cog, ctx) + + bot.role_manager.reload.assert_awaited_once() + ctx.respond.assert_awaited_once() + assert "reloaded" in ctx.respond.call_args[0][0].lower() + + # ── setup ──────────────────────────────────────────────────────── + + def test_setup(self, bot): + """Test the setup function registers the cog.""" + role_admin.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/src/cmds/core/test_user.py b/tests/src/cmds/core/test_user.py index bf3fb18..350de82 100644 --- a/tests/src/cmds/core/test_user.py +++ b/tests/src/cmds/core/test_user.py @@ -90,6 +90,177 @@ async def test_user_stats(self, ctx, bot): assert fields["Bots"] == "1" assert "66.67%" in fields["Verified Members"] + # ── _match_role tests ────────────────────────────────────────────── + + def test_match_role_no_role_manager(self, bot): + """Test _match_role when role_manager is None.""" + bot.role_manager = None + cog = user.UserCog(bot) + role_id, err = cog._match_role("anything") + assert role_id is None + assert "don't know what role" in err + + def test_match_role_no_match(self, bot): + """Test _match_role when no joinable role matches.""" + bot.role_manager.get_joinable_roles = lambda: {"Alpha": (111, "Alpha desc")} + cog = user.UserCog(bot) + role_id, err = cog._match_role("zzz_nonexistent") + assert role_id is None + assert "don't know what role" in err + + def test_match_role_single_match(self, bot): + """Test _match_role with exactly one match.""" + bot.role_manager.get_joinable_roles = lambda: {"PenTesting": (222, "Pen desc")} + cog = user.UserCog(bot) + role_id, err = cog._match_role("pentest") + assert role_id == 222 + assert err is None + + def test_match_role_multiple_matches(self, bot): + """Test _match_role when multiple roles match.""" + bot.role_manager.get_joinable_roles = lambda: { + "Red Team": (111, "desc"), + "Red Alert": (222, "desc"), + } + cog = user.UserCog(bot) + role_id, err = cog._match_role("Red") + assert role_id is None + assert "multiple roles" in err.lower() + + # ── join command tests ─────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_join_empty_role_name_shows_list(self, ctx, bot): + """Test /join with empty role name shows the joinable roles embed.""" + bot.role_manager.get_joinable_roles = lambda: {"Alpha": (111, "Alpha desc")} + cog = user.UserCog(bot) + await cog.join.callback(cog, ctx, role_name="") + + ctx.respond.assert_awaited_once() + call_kwargs = ctx.respond.call_args[1] + embed = call_kwargs["embed"] + assert embed.author.name == "Join-able Roles" + + @pytest.mark.asyncio + async def test_join_whitespace_role_name_shows_list(self, ctx, bot): + """Test /join with whitespace-only role name shows the joinable roles embed.""" + bot.role_manager.get_joinable_roles = lambda: {} + cog = user.UserCog(bot) + await cog.join.callback(cog, ctx, role_name=" ") + + ctx.respond.assert_awaited_once() + call_kwargs = ctx.respond.call_args[1] + assert "embed" in call_kwargs + + @pytest.mark.asyncio + async def test_join_unknown_role(self, ctx, bot): + """Test /join with an unknown role name returns an error.""" + bot.role_manager.get_joinable_roles = lambda: {"Alpha": (111, "Alpha desc")} + cog = user.UserCog(bot) + await cog.join.callback(cog, ctx, role_name="zzz_no_match") + + ctx.respond.assert_awaited_once() + assert "don't know what role" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_join_success(self, ctx, bot): + """Test /join successfully adds the role.""" + guild_role = helpers.MockRole(id=111, name="Alpha") + ctx.guild.get_role = lambda rid: guild_role if rid == 111 else None + ctx.user = helpers.MockMember() + ctx.user.add_roles = AsyncMock() + bot.role_manager.get_joinable_roles = lambda: {"Alpha": (111, "Alpha desc")} + + cog = user.UserCog(bot) + await cog.join.callback(cog, ctx, role_name="Alpha") + + ctx.user.add_roles.assert_awaited_once_with(guild_role) + ctx.respond.assert_awaited_once() + assert "Welcome to Alpha" in ctx.respond.call_args[0][0] + + # ── leave command tests ────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_leave_unknown_role(self, ctx, bot): + """Test /leave with an unknown role name returns an error.""" + bot.role_manager.get_joinable_roles = lambda: {"Alpha": (111, "Alpha desc")} + cog = user.UserCog(bot) + await cog.leave.callback(cog, ctx, role_name="zzz_no_match") + + ctx.respond.assert_awaited_once() + assert "don't know what role" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_leave_success(self, ctx, bot): + """Test /leave successfully removes the role.""" + guild_role = helpers.MockRole(id=111, name="Alpha") + ctx.guild.get_role = lambda rid: guild_role if rid == 111 else None + ctx.user = helpers.MockMember() + ctx.user.remove_roles = AsyncMock() + bot.role_manager.get_joinable_roles = lambda: {"Alpha": (111, "Alpha desc")} + + cog = user.UserCog(bot) + await cog.leave.callback(cog, ctx, role_name="Alpha") + + ctx.user.remove_roles.assert_awaited_once_with(guild_role) + ctx.respond.assert_awaited_once() + assert "left Alpha" in ctx.respond.call_args[0][0] + + # ── joinable_role_autocomplete tests ───────────────────────────── + + @pytest.mark.asyncio + async def test_joinable_role_autocomplete_no_role_manager(self): + """Test autocomplete returns empty list when role_manager is None.""" + ac_ctx = helpers.MockContext() + ac_ctx.bot = helpers.MockBot() + ac_ctx.bot.role_manager = None + ac_ctx.value = "test" + + result = await user.joinable_role_autocomplete(ac_ctx) + assert result == [] + + @pytest.mark.asyncio + async def test_joinable_role_autocomplete_filters(self): + """Test autocomplete filters results based on typed value.""" + ac_ctx = helpers.MockContext() + ac_ctx.bot = helpers.MockBot() + ac_ctx.bot.role_manager.get_joinable_roles = lambda: { + "Red Team": (1, "desc"), + "Blue Team": (2, "desc"), + "Red Alert": (3, "desc"), + } + ac_ctx.value = "red" + + result = await user.joinable_role_autocomplete(ac_ctx) + assert "Red Team" in result + assert "Red Alert" in result + assert "Blue Team" not in result + + @pytest.mark.asyncio + async def test_joinable_role_autocomplete_empty_value(self): + """Test autocomplete returns all roles when value is empty.""" + ac_ctx = helpers.MockContext() + ac_ctx.bot = helpers.MockBot() + ac_ctx.bot.role_manager.get_joinable_roles = lambda: { + "Alpha": (1, "desc"), + "Beta": (2, "desc"), + } + ac_ctx.value = "" + + result = await user.joinable_role_autocomplete(ac_ctx) + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_joinable_role_autocomplete_none_value(self): + """Test autocomplete handles None value gracefully.""" + ac_ctx = helpers.MockContext() + ac_ctx.bot = helpers.MockBot() + ac_ctx.bot.role_manager.get_joinable_roles = lambda: {"Alpha": (1, "desc")} + ac_ctx.value = None + + result = await user.joinable_role_autocomplete(ac_ctx) + assert len(result) == 1 + def test_setup(self, bot): """Test the setup method of the cog.""" # Invoke the command diff --git a/tests/src/cmds/core/test_verify.py b/tests/src/cmds/core/test_verify.py index b771723..5fae603 100644 --- a/tests/src/cmds/core/test_verify.py +++ b/tests/src/cmds/core/test_verify.py @@ -1,9 +1,83 @@ +from unittest.mock import AsyncMock, patch + +import pytest + from src.cmds.core import verify +from src.cmds.core.verify import VerifyCog +from tests import helpers class TestVerifyCog: """Test the `Verify` cog.""" + @pytest.mark.asyncio + async def test_verifycertification_success(self, ctx, bot): + """Test successful certification verification adds the role.""" + bot.role_manager.get_cert_role_id = lambda abbrev: 5555 + guild_role = helpers.MockRole(id=5555, name="CPTS") + ctx.guild.get_role = lambda rid: guild_role if rid == 5555 else None + ctx.author = helpers.MockMember() + ctx.author.add_roles = AsyncMock() + + with patch( + "src.cmds.core.verify.process_certification", + new_callable=AsyncMock, + return_value="CPTS", + ): + cog = VerifyCog(bot) + await cog.verifycertification.callback(cog, ctx, certid="HTBCERT-123", fullname="John Doe") + + ctx.author.add_roles.assert_awaited_once_with(guild_role) + ctx.respond.assert_awaited_once() + assert "Added CPTS" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_verifycertification_cert_role_not_configured(self, ctx, bot): + """Test that when get_cert_role_id returns None, user is told the role is not configured.""" + # Default MockRoleManager.get_cert_role_id returns None + with patch( + "src.cmds.core.verify.process_certification", + new_callable=AsyncMock, + return_value="NEWCERT", + ): + cog = VerifyCog(bot) + await cog.verifycertification.callback(cog, ctx, certid="HTBCERT-456", fullname="Jane Doe") + + ctx.respond.assert_awaited_once() + assert "not yet configured" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_verifycertification_cert_not_found(self, ctx, bot): + """Test that when process_certification returns False, user is told cert was not found.""" + with patch( + "src.cmds.core.verify.process_certification", + new_callable=AsyncMock, + return_value=False, + ): + cog = VerifyCog(bot) + await cog.verifycertification.callback(cog, ctx, certid="HTBCERT-789", fullname="Unknown") + + ctx.respond.assert_awaited_once() + assert "Unable to find certification" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_verifycertification_empty_certid(self, ctx, bot): + """Test that empty certid returns an error.""" + cog = VerifyCog(bot) + await cog.verifycertification.callback(cog, ctx, certid="", fullname="John") + + ctx.respond.assert_awaited_once() + assert "must supply a cert id" in ctx.respond.call_args[0][0] + + @pytest.mark.asyncio + async def test_verifycertification_bad_prefix(self, ctx, bot): + """Test that certid without HTBCERT- prefix returns an error.""" + cog = VerifyCog(bot) + await cog.verifycertification.callback(cog, ctx, certid="BADPREFIX-123", fullname="John") + + ctx.respond.assert_awaited_once() + assert "must start with HTBCERT-" in ctx.respond.call_args[0][0] + def test_setup(self, bot): """Test the setup method of the cog.""" # Invoke the command diff --git a/tests/src/core/test_config.py b/tests/src/core/test_config.py index 567e0e4..284d7cb 100644 --- a/tests/src/core/test_config.py +++ b/tests/src/core/test_config.py @@ -4,37 +4,32 @@ class TestConfig(unittest.TestCase): - def test_sherlock_creator_role_in_all_creators(self): - """Test that SHERLOCK_CREATOR role is included in ALL_CREATORS group.""" - all_creators = settings.role_groups.get("ALL_CREATORS", []) - self.assertIn(settings.roles.SHERLOCK_CREATOR, all_creators) - self.assertIn(settings.roles.CHALLENGE_CREATOR, all_creators) - self.assertIn(settings.roles.BOX_CREATOR, all_creators) - - def test_get_post_or_rank_sherlock_creator(self): - """Test that get_post_or_rank returns correct role for Sherlock Creator.""" - result = settings.get_post_or_rank("Sherlock Creator") - self.assertEqual(result, settings.roles.SHERLOCK_CREATOR) - - def test_get_post_or_rank_other_creators(self): - """Test that get_post_or_rank works for all creator types.""" - test_cases = [ - ("Challenge Creator", settings.roles.CHALLENGE_CREATOR), - ("Box Creator", settings.roles.BOX_CREATOR), - ("Sherlock Creator", settings.roles.SHERLOCK_CREATOR), - ] - - for role_name, expected_role in test_cases: - with self.subTest(role_name=role_name): - result = settings.get_post_or_rank(role_name) - self.assertEqual(result, expected_role) - - def test_get_post_or_rank_invalid_role(self): - """Test that get_post_or_rank returns None for invalid role.""" - result = settings.get_post_or_rank("Invalid Role") - self.assertIsNone(result) - - def test_sherlock_creator_role_configured(self): - """Test that SHERLOCK_CREATOR role is properly configured.""" - self.assertIsNotNone(settings.roles.SHERLOCK_CREATOR) - self.assertIsInstance(settings.roles.SHERLOCK_CREATOR, int) + def test_core_roles_required(self): + """Test that core roles are still loaded from env vars.""" + self.assertIsNotNone(settings.roles.VERIFIED) + self.assertIsInstance(settings.roles.VERIFIED, int) + self.assertIsNotNone(settings.roles.ADMINISTRATOR) + self.assertIsInstance(settings.roles.ADMINISTRATOR, int) + + def test_core_role_groups_present(self): + """Test that core role groups are populated.""" + self.assertIn("ALL_ADMINS", settings.role_groups) + self.assertIn("ALL_MODS", settings.role_groups) + self.assertIn("ALL_HTB_STAFF", settings.role_groups) + self.assertIn("ALL_SR_MODS", settings.role_groups) + self.assertIn("ALL_HTB_SUPPORT", settings.role_groups) + + def test_dynamic_role_groups_removed(self): + """Test that dynamic role groups are no longer in settings.""" + self.assertNotIn("ALL_RANKS", settings.role_groups) + self.assertNotIn("ALL_SEASON_RANKS", settings.role_groups) + self.assertNotIn("ALL_CREATORS", settings.role_groups) + self.assertNotIn("ALL_POSITIONS", settings.role_groups) + + def test_dynamic_roles_are_optional(self): + """Test that dynamic role fields default to None when env vars are missing.""" + # These should be Optional[int] = None if not in env + # In test env they may still be set via .test.env, so just check the field exists + self.assertTrue(hasattr(settings.roles, "OMNISCIENT")) + self.assertTrue(hasattr(settings.roles, "VIP")) + self.assertTrue(hasattr(settings.roles, "BOX_CREATOR")) diff --git a/tests/src/helpers/test_verification.py b/tests/src/helpers/test_verification.py index 225bd2d..bb5810c 100644 --- a/tests/src/helpers/test_verification.py +++ b/tests/src/helpers/test_verification.py @@ -7,6 +7,7 @@ from src.core import settings from src.helpers.verification import get_user_details, process_labs_identification +from tests.helpers import MockRoleManager class TestGetUserDetails(unittest.IsolatedAsyncioTestCase): @@ -75,6 +76,32 @@ async def test_get_user_details_other_status(self): self.assertEqual(result, {"content": {}}) +# Test role IDs used by the role manager mock +_SHERLOCK_ROLE_ID = 100 +_CHALLENGE_ROLE_ID = 200 +_BOX_ROLE_ID = 300 +_HACKER_ROLE_ID = 400 + + +def _make_test_role_manager(): + """Create a role manager that returns test role IDs for creator and rank lookups.""" + rm = MockRoleManager() + rm.get_role_id = lambda cat, key: { + ("creator", "Sherlock Creator"): _SHERLOCK_ROLE_ID, + ("creator", "Challenge Creator"): _CHALLENGE_ROLE_ID, + ("creator", "Box Creator"): _BOX_ROLE_ID, + }.get((cat, key)) + rm.get_post_or_rank = lambda what: { + "Hacker": _HACKER_ROLE_ID, + }.get(what) + rm.get_group_ids = lambda cat: { + "rank": [_HACKER_ROLE_ID], + "position": [], + }.get(cat, []) + rm.get_season_role_id = lambda tier: None + return rm + + class TestProcessLabsIdentification(unittest.IsolatedAsyncioTestCase): def setUp(self): """Set up test fixtures.""" @@ -90,14 +117,17 @@ def setUp(self): self.box_role = MagicMock(spec=discord.Role) self.rank_role = MagicMock(spec=discord.Role) - # Set up guild.get_role to return appropriate roles + # Set up guild.get_role to return appropriate roles by test IDs self.guild.get_role.side_effect = lambda role_id: { - settings.roles.SHERLOCK_CREATOR: self.sherlock_role, - settings.roles.CHALLENGE_CREATOR: self.challenge_role, - settings.roles.BOX_CREATOR: self.box_role, - settings.roles.HACKER: self.rank_role, + _SHERLOCK_ROLE_ID: self.sherlock_role, + _CHALLENGE_ROLE_ID: self.challenge_role, + _BOX_ROLE_ID: self.box_role, + _HACKER_ROLE_ID: self.rank_role, }.get(role_id) + # Set up role manager on bot + self.bot.role_manager = _make_test_role_manager() + @pytest.mark.asyncio async def test_process_identification_with_sherlocks(self): """Test that Sherlock creator role is assigned when user has sherlocks.""" @@ -115,15 +145,13 @@ async def test_process_identification_with_sherlocks(self): "ranking": "unranked", } - # Mock the member edit method self.member.edit = AsyncMock() - self.member.nick = "test_user" # Same as username, so no edit needed + self.member.nick = "test_user" result = await process_labs_identification(htb_user_details, self.member, self.bot) - # Verify that the Sherlock creator role is in the result self.assertIn(self.sherlock_role, result) - self.guild.get_role.assert_any_call(settings.roles.SHERLOCK_CREATOR) + self.guild.get_role.assert_any_call(_SHERLOCK_ROLE_ID) @pytest.mark.asyncio async def test_process_identification_with_challenges_and_sherlocks(self): @@ -142,17 +170,15 @@ async def test_process_identification_with_challenges_and_sherlocks(self): "ranking": "unranked", } - # Mock the member edit method self.member.edit = AsyncMock() self.member.nick = "test_user" result = await process_labs_identification(htb_user_details, self.member, self.bot) - # Verify that both roles are in the result self.assertIn(self.sherlock_role, result) self.assertIn(self.challenge_role, result) - self.guild.get_role.assert_any_call(settings.roles.SHERLOCK_CREATOR) - self.guild.get_role.assert_any_call(settings.roles.CHALLENGE_CREATOR) + self.guild.get_role.assert_any_call(_SHERLOCK_ROLE_ID) + self.guild.get_role.assert_any_call(_CHALLENGE_ROLE_ID) @pytest.mark.asyncio async def test_process_identification_without_sherlocks(self): @@ -171,17 +197,14 @@ async def test_process_identification_without_sherlocks(self): "ranking": "unranked", } - # Mock the member edit method self.member.edit = AsyncMock() self.member.nick = "test_user" result = await process_labs_identification(htb_user_details, self.member, self.bot) - # Verify that the Sherlock creator role is not in the result self.assertNotIn(self.sherlock_role, result) - # Verify get_role was not called for Sherlock creator calls = [call[0][0] for call in self.guild.get_role.call_args_list] - self.assertNotIn(settings.roles.SHERLOCK_CREATOR, calls) + self.assertNotIn(_SHERLOCK_ROLE_ID, calls) @pytest.mark.asyncio async def test_process_identification_all_creator_roles(self): @@ -200,18 +223,15 @@ async def test_process_identification_all_creator_roles(self): "ranking": "unranked", } - # Mock the member edit method self.member.edit = AsyncMock() self.member.nick = "test_user" result = await process_labs_identification(htb_user_details, self.member, self.bot) - # Verify that all three creator roles are in the result self.assertIn(self.sherlock_role, result) self.assertIn(self.challenge_role, result) self.assertIn(self.box_role, result) - # Verify all get_role calls were made - self.guild.get_role.assert_any_call(settings.roles.SHERLOCK_CREATOR) - self.guild.get_role.assert_any_call(settings.roles.CHALLENGE_CREATOR) - self.guild.get_role.assert_any_call(settings.roles.BOX_CREATOR) + self.guild.get_role.assert_any_call(_SHERLOCK_ROLE_ID) + self.guild.get_role.assert_any_call(_CHALLENGE_ROLE_ID) + self.guild.get_role.assert_any_call(_BOX_ROLE_ID) diff --git a/tests/src/webhooks/handlers/test_academy.py b/tests/src/webhooks/handlers/test_academy.py index 013ef39..de6f150 100644 --- a/tests/src/webhooks/handlers/test_academy.py +++ b/tests/src/webhooks/handlers/test_academy.py @@ -1,10 +1,19 @@ import pytest -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch from src.webhooks.handlers.academy import AcademyHandler from src.webhooks.types import WebhookBody, Platform, WebhookEvent from tests import helpers + +def _make_role_manager(**overrides): + """Create a MockRoleManager with optional method overrides.""" + rm = helpers.MockRoleManager() + for attr, val in overrides.items(): + setattr(rm, attr, val if callable(val) else lambda *a, v=val, **kw: v) + return rm + + class TestAcademyHandler: @pytest.mark.asyncio async def test_handle_certificate_awarded_success(self, bot): @@ -25,24 +34,25 @@ async def test_handle_certificate_awarded_success(self, bot): }, traits={}, ) - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", side_effect=[certificate_id]), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.academy.settings") as mock_settings, - patch.object(handler.logger, "info") as mock_log, - ): - mock_settings.get_academy_cert_role.return_value = 555 + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=certificate_id) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager(get_academy_cert_role=lambda cid: 555) + + with patch("src.webhooks.handlers.academy.settings") as mock_settings: mock_settings.channels.VERIFY_LOGS = 777 mock_guild = helpers.MockGuild(id=1) mock_guild.get_role.return_value = MagicMock() mock_channel = AsyncMock() mock_guild.get_channel.return_value = mock_channel bot.guilds = [mock_guild] + result = await handler._handle_certificate_awarded(body, bot) mock_member.add_roles.assert_awaited() - mock_channel.send.assert_awaited_once_with("Certification linked: Test Certificate with Certificate ID: 42 -> @123456789 (123456789)") - mock_log.assert_called() + mock_channel.send.assert_awaited_once_with( + "Certification linked: Test Certificate with Certificate ID: 42 -> @123456789 (123456789)" + ) assert result == handler.success() @pytest.mark.asyncio @@ -60,28 +70,28 @@ async def test_handle_certificate_awarded_success_no_name(self, bot): "discord_id": discord_id, "account_id": account_id, "certificate_id": certificate_id, - # certificate_name missing }, traits={}, ) - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", side_effect=[certificate_id]), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.academy.settings") as mock_settings, - patch.object(handler.logger, "info") as mock_log, - ): - mock_settings.get_academy_cert_role.return_value = 555 + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=certificate_id) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager(get_academy_cert_role=lambda cid: 555) + + with patch("src.webhooks.handlers.academy.settings") as mock_settings: mock_settings.channels.VERIFY_LOGS = 777 mock_guild = helpers.MockGuild(id=1) mock_guild.get_role.return_value = MagicMock() mock_channel = AsyncMock() mock_guild.get_channel.return_value = mock_channel bot.guilds = [mock_guild] + result = await handler._handle_certificate_awarded(body, bot) mock_member.add_roles.assert_awaited() - mock_channel.send.assert_awaited_once_with("Certification linked: Certificate ID: 42 -> @123456789 (123456789)") - mock_log.assert_called() + mock_channel.send.assert_awaited_once_with( + "Certification linked: Certificate ID: 42 -> @123456789 (123456789)" + ) assert result == handler.success() @pytest.mark.asyncio @@ -102,17 +112,14 @@ async def test_handle_certificate_awarded_no_role(self, bot): }, traits={}, ) - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", side_effect=[certificate_id]), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.academy.settings") as mock_settings, - patch.object(handler.logger, "warning") as mock_log, - ): - mock_settings.get_academy_cert_role.return_value = None - result = await handler._handle_certificate_awarded(body, bot) - mock_log.assert_called() - assert result == handler.fail() + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=certificate_id) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager(get_academy_cert_role=lambda cid: None) + + result = await handler._handle_certificate_awarded(body, bot) + assert result == handler.fail() @pytest.mark.asyncio async def test_handle_certificate_awarded_add_roles_error(self, bot): @@ -133,24 +140,22 @@ async def test_handle_certificate_awarded_add_roles_error(self, bot): }, traits={}, ) - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", side_effect=[certificate_id]), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.academy.settings") as mock_settings, - patch.object(handler.logger, "error") as mock_log, - ): - mock_settings.get_academy_cert_role.return_value = 555 + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=certificate_id) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager(get_academy_cert_role=lambda cid: 555) + + with patch("src.webhooks.handlers.academy.settings") as mock_settings: mock_settings.channels.VERIFY_LOGS = 777 mock_guild = helpers.MockGuild(id=1) mock_guild.get_role.return_value = MagicMock() mock_channel = AsyncMock() mock_guild.get_channel.return_value = mock_channel bot.guilds = [mock_guild] + with pytest.raises(Exception, match="add_roles error"): await handler._handle_certificate_awarded(body, bot) - mock_log.assert_called() - mock_channel.send.assert_not_called() @pytest.mark.asyncio async def test_handle_certificate_awarded_no_verify_channel(self, bot): @@ -159,6 +164,7 @@ async def test_handle_certificate_awarded_no_verify_channel(self, bot): account_id = 987654321 certificate_id = 42 mock_member = helpers.MockMember(id=discord_id, name="123456789") + mock_member.add_roles = AsyncMock() body = WebhookBody( platform=Platform.ACADEMY, event=WebhookEvent.CERTIFICATE_AWARDED, @@ -170,23 +176,24 @@ async def test_handle_certificate_awarded_no_verify_channel(self, bot): }, traits={}, ) + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=certificate_id) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager(get_academy_cert_role=lambda cid: 555) + with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", side_effect=[certificate_id]), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), patch("src.webhooks.handlers.academy.settings") as mock_settings, patch.object(handler.logger, "warning") as mock_log, ): - mock_settings.get_academy_cert_role.return_value = 555 mock_settings.channels.VERIFY_LOGS = 777 mock_guild = helpers.MockGuild(id=1) mock_guild.get_role.return_value = MagicMock() - # Return None for get_channel to trigger the missing branch mock_guild.get_channel.return_value = None bot.guilds = [mock_guild] - + result = await handler._handle_certificate_awarded(body, bot) - + mock_member.add_roles.assert_awaited() mock_log.assert_called_with("Verify logs channel 777 not found") assert result == handler.success() @@ -213,22 +220,21 @@ async def test_handle_subscription_change_success(self, bot): ) mock_role = MagicMock() mock_role.id = 555 - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", return_value=plan), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.academy.settings") as mock_settings, - patch.object(handler.logger, "info") as mock_log, - ): - mock_settings.get_post_or_rank.return_value = 555 - mock_settings.role_groups = {"ALL_ACADEMY_SUBSCRIPTIONS": [555, 666, 777]} - mock_guild = helpers.MockGuild(id=1) - mock_guild.get_role.return_value = mock_role - bot.guilds = [mock_guild] - result = await handler._handle_subscription_change(body, bot) - mock_member.add_roles.assert_awaited_once() - mock_log.assert_called() - assert result == handler.success() + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=plan) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager( + get_post_or_rank=lambda what: 555, + get_group_ids=lambda cat: [555, 666, 777], + ) + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.return_value = mock_role + bot.guilds = [mock_guild] + + result = await handler._handle_subscription_change(body, bot) + mock_member.add_roles.assert_awaited_once() + assert result == handler.success() @pytest.mark.asyncio async def test_handle_subscription_change_no_role(self, bot): @@ -247,17 +253,14 @@ async def test_handle_subscription_change_no_role(self, bot): }, traits={}, ) - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", return_value=plan), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.academy.settings") as mock_settings, - patch.object(handler.logger, "warning") as mock_log, - ): - mock_settings.get_post_or_rank.return_value = None - result = await handler._handle_subscription_change(body, bot) - mock_log.assert_called() - assert result == handler.fail() + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=plan) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager(get_post_or_rank=lambda what: None) + + result = await handler._handle_subscription_change(body, bot) + assert result == handler.fail() @pytest.mark.asyncio async def test_handle_subscription_change_role_swap(self, bot): @@ -266,15 +269,14 @@ async def test_handle_subscription_change_role_swap(self, bot): discord_id = 123456789 account_id = 987654321 plan = "professional" - - # Mock member with an existing academy subscription role + old_role = MagicMock() old_role.id = 666 mock_member = helpers.MockMember(id=discord_id, name="123456789") mock_member.roles = [old_role] mock_member.add_roles = AsyncMock() mock_member.remove_roles = AsyncMock() - + body = WebhookBody( platform=Platform.ACADEMY, event=WebhookEvent.SUBSCRIPTION_CHANGE, @@ -285,36 +287,34 @@ async def test_handle_subscription_change_role_swap(self, bot): }, traits={}, ) - + new_role = MagicMock() new_role.id = 555 - - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", return_value=plan), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.academy.settings") as mock_settings, - ): - mock_settings.get_post_or_rank.return_value = 555 - mock_settings.role_groups = {"ALL_ACADEMY_SUBSCRIPTIONS": [555, 666, 777]} - mock_guild = helpers.MockGuild(id=1) - - def get_role_mock(role_id): - if role_id == 555: - return new_role - elif role_id == 666: - return old_role - return None - - mock_guild.get_role.side_effect = get_role_mock - bot.guilds = [mock_guild] - - result = await handler._handle_subscription_change(body, bot) - - # Verify old role was removed and new role was added - mock_member.remove_roles.assert_awaited_once_with(old_role, atomic=True) - mock_member.add_roles.assert_awaited_once_with(new_role, atomic=True) - assert result == handler.success() + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=plan) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager( + get_post_or_rank=lambda what: 555, + get_group_ids=lambda cat: [555, 666, 777], + ) + mock_guild = helpers.MockGuild(id=1) + + def get_role_mock(role_id): + if role_id == 555: + return new_role + elif role_id == 666: + return old_role + return None + + mock_guild.get_role.side_effect = get_role_mock + bot.guilds = [mock_guild] + + result = await handler._handle_subscription_change(body, bot) + + mock_member.remove_roles.assert_awaited_once_with(old_role, atomic=True) + mock_member.add_roles.assert_awaited_once_with(new_role, atomic=True) + assert result == handler.success() @pytest.mark.asyncio async def test_handle_invalid_event(self, bot): @@ -326,4 +326,4 @@ async def test_handle_invalid_event(self, bot): traits={}, ) with pytest.raises(ValueError, match="Invalid event"): - await handler.handle(body, bot) \ No newline at end of file + await handler.handle(body, bot) diff --git a/tests/src/webhooks/handlers/test_mp.py b/tests/src/webhooks/handlers/test_mp.py index 2dad297..37d291c 100644 --- a/tests/src/webhooks/handlers/test_mp.py +++ b/tests/src/webhooks/handlers/test_mp.py @@ -1,11 +1,19 @@ import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from fastapi import HTTPException +from unittest.mock import AsyncMock, MagicMock from src.webhooks.handlers.mp import MPHandler from src.webhooks.types import WebhookBody, Platform, WebhookEvent from tests import helpers + +def _make_role_manager(**overrides): + """Create a MockRoleManager with optional method overrides.""" + rm = helpers.MockRoleManager() + for attr, val in overrides.items(): + setattr(rm, attr, val if callable(val) else lambda *a, v=val, **kw: v) + return rm + + class TestMPHandler: @pytest.mark.asyncio async def test_handle_invalid_event(self, bot): @@ -37,22 +45,23 @@ async def test_handle_subscription_change_success(self, bot): }, traits={}, ) - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", return_value=subscription_name), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.mp.settings") as mock_settings, - ): - mock_settings.get_post_or_rank.return_value = 555 - mock_settings.role_groups = {"ALL_LABS_SUBSCRIPTIONS": [555, 666, 777]} - mock_role = MagicMock() - mock_role.id = 555 - mock_guild = helpers.MockGuild(id=1) - mock_guild.get_role.return_value = mock_role - bot.guilds = [mock_guild] - result = await handler._handle_subscription_change(body, bot) - mock_member.add_roles.assert_awaited_once() - assert result == handler.success() + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=subscription_name) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager( + get_post_or_rank=lambda what: 555, + get_group_ids=lambda cat: [555, 666, 777], + ) + mock_role = MagicMock() + mock_role.id = 555 + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.return_value = mock_role + bot.guilds = [mock_guild] + + result = await handler._handle_subscription_change(body, bot) + mock_member.add_roles.assert_awaited_once() + assert result == handler.success() @pytest.mark.asyncio async def test_handle_subscription_change_invalid_role(self, bot): @@ -71,15 +80,13 @@ async def test_handle_subscription_change_invalid_role(self, bot): }, traits={}, ) - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", return_value=subscription_name), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.mp.settings") as mock_settings, - ): - mock_settings.get_post_or_rank.return_value = None - with pytest.raises(ValueError, match="Invalid subscription name"): - await handler._handle_subscription_change(body, bot) + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=subscription_name) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager(get_post_or_rank=lambda what: None) + with pytest.raises(ValueError, match="Invalid subscription name"): + await handler._handle_subscription_change(body, bot) @pytest.mark.asyncio async def test_handle_hof_change_success_top1(self, bot): @@ -103,21 +110,21 @@ async def test_handle_hof_change_success_top1(self, bot): ) mock_role_1 = MagicMock() mock_role_10 = MagicMock() - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", return_value=hof_tier), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.mp.settings") as mock_settings, - patch.object(handler, "_find_user_with_role", new_callable=AsyncMock, return_value=None), - ): - mock_settings.roles.RANK_ONE = 1 - mock_settings.roles.RANK_TEN = 10 - mock_guild = helpers.MockGuild(id=1) - mock_guild.get_role.side_effect = lambda rid: mock_role_1 if rid == 1 else mock_role_10 - bot.guilds = [mock_guild] - result = await handler._handle_hof_change(body, bot) - mock_member.add_roles.assert_awaited_with(mock_role_1, atomic=True) - assert result == handler.success() + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=hof_tier) + handler.get_guild_member = AsyncMock(return_value=mock_member) + handler._find_user_with_role = AsyncMock(return_value=None) + + bot.role_manager = _make_role_manager( + get_role_id=lambda cat, key: 1 if key == "1" else (10 if key == "10" else None), + ) + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.side_effect = lambda rid: mock_role_1 if rid == 1 else mock_role_10 + bot.guilds = [mock_guild] + + result = await handler._handle_hof_change(body, bot) + mock_member.add_roles.assert_awaited_with(mock_role_1, atomic=True) + assert result == handler.success() @pytest.mark.asyncio async def test_handle_hof_change_invalid_tier(self, bot): @@ -137,19 +144,19 @@ async def test_handle_hof_change_invalid_tier(self, bot): }, traits={}, ) - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", return_value=hof_tier), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.mp.settings") as mock_settings, - ): - mock_settings.roles.RANK_ONE = 1 - mock_settings.roles.RANK_TEN = 10 - mock_guild = helpers.MockGuild(id=1) - mock_guild.get_role.side_effect = lambda rid: None - bot.guilds = [mock_guild] - with pytest.raises(ValueError, match="Invalid HOF tier"): - await handler._handle_hof_change(body, bot) + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=hof_tier) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager( + get_role_id=lambda cat, key: None, + ) + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.side_effect = lambda rid: None + bot.guilds = [mock_guild] + + with pytest.raises(ValueError, match="Invalid HOF tier"): + await handler._handle_hof_change(body, bot) @pytest.mark.asyncio async def test_handle_rank_up_success(self, bot): @@ -173,20 +180,21 @@ async def test_handle_rank_up_success(self, bot): ) mock_role = MagicMock() mock_role.id = 555 - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", return_value=rank), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.mp.settings") as mock_settings, - ): - mock_settings.get_post_or_rank.return_value = 555 - mock_settings.role_groups = {"ALL_RANKS": [555, 666, 777]} - mock_guild = helpers.MockGuild(id=1) - mock_guild.get_role.return_value = mock_role - bot.guilds = [mock_guild] - result = await handler._handle_rank_up(body, bot) - mock_member.add_roles.assert_awaited_with(mock_role, atomic=True) - assert result == handler.success() + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=rank) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager( + get_post_or_rank=lambda what: 555, + get_group_ids=lambda cat: [555, 666, 777], + ) + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.return_value = mock_role + bot.guilds = [mock_guild] + + result = await handler._handle_rank_up(body, bot) + mock_member.add_roles.assert_awaited_with(mock_role, atomic=True) + assert result == handler.success() @pytest.mark.asyncio async def test_handle_rank_up_invalid_role(self, bot): @@ -208,15 +216,14 @@ async def test_handle_rank_up_invalid_role(self, bot): }, traits={}, ) - with ( - patch.object(handler, "validate_common_properties", return_value=(discord_id, account_id)), - patch.object(handler, "validate_property", return_value=rank), - patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), - patch("src.webhooks.handlers.mp.settings") as mock_settings, - ): - mock_settings.get_post_or_rank.return_value = None - mock_guild = helpers.MockGuild(id=1) - mock_guild.get_role.return_value = None - bot.guilds = [mock_guild] - with pytest.raises(ValueError, match="Cannot find role for"): - await handler._handle_rank_up(body, bot) \ No newline at end of file + handler.validate_common_properties = MagicMock(return_value=(discord_id, account_id)) + handler.validate_property = MagicMock(return_value=rank) + handler.get_guild_member = AsyncMock(return_value=mock_member) + + bot.role_manager = _make_role_manager(get_post_or_rank=lambda what: None) + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.return_value = None + bot.guilds = [mock_guild] + + with pytest.raises(ValueError, match="Cannot find role for"): + await handler._handle_rank_up(body, bot)