diff --git a/.test.env b/.test.env index def9d16..914120e 100644 --- a/.test.env +++ b/.test.env @@ -32,6 +32,8 @@ CHANNEL_SPOILER=2769521890099371011 CHANNEL_BOT_LOGS=1105517088266788925 # Roles +ROLE_VERIFIED=1389344463884652554 + ROLE_BIZCTF2022=7629466241011276950 ROLE_NOAH_GANG=6706800691011276950 ROLE_BUDDY_GANG=6706800681011276950 diff --git a/src/cmds/automation/auto_verify.py b/src/cmds/automation/auto_verify.py index 606ad3e..98c538d 100644 --- a/src/cmds/automation/auto_verify.py +++ b/src/cmds/automation/auto_verify.py @@ -2,12 +2,8 @@ from discord import Member, Message, User from discord.ext import commands -from sqlalchemy import select from src.bot import Bot -from src.database.models import HtbDiscordLink -from src.database.session import AsyncSessionLocal -from src.helpers.verification import get_user_details, process_identification logger = logging.getLogger(__name__) @@ -19,31 +15,11 @@ def __init__(self, bot: Bot): self.bot = bot async def process_reverification(self, member: Member | User) -> None: - """Re-verifation process for a member.""" - async with AsyncSessionLocal() as session: - stmt = ( - select(HtbDiscordLink) - .where(HtbDiscordLink.discord_user_id == member.id) - .order_by(HtbDiscordLink.id) - .limit(1) - ) - result = await session.scalars(stmt) - htb_discord_link: HtbDiscordLink = result.first() - - if not htb_discord_link: - raise VerificationError(f"HTB Discord link for user {member.name} with ID {member}") - - member_token: str = htb_discord_link.account_identifier - - if member_token is None: - raise VerificationError(f"HTB account identifier for user {member.name} with ID {member.id} not found") - - logger.debug(f"Processing re-verify of member {member.name} ({member.id}).") - htb_details = await get_user_details(member_token) - if htb_details is None: - raise VerificationError(f"Retrieving user details for user {member.name} with ID {member.id} failed") - - await process_identification(htb_details, user=member, bot=self.bot) + """Re-verifation process for a member. + + TODO: Reimplement once it's possible to fetch link state from the HTB Account. + """ + raise VerificationError("Not implemented") @commands.Cog.listener() @commands.cooldown(1, 60, commands.BucketType.user) @@ -74,4 +50,5 @@ class VerificationError(Exception): def setup(bot: Bot) -> None: """Load the `MessageHandler` cog.""" - bot.add_cog(MessageHandler(bot)) + # bot.add_cog(MessageHandler(bot)) + pass diff --git a/src/cmds/core/identify.py b/src/cmds/core/identify.py index ef52ed7..34a378a 100644 --- a/src/cmds/core/identify.py +++ b/src/cmds/core/identify.py @@ -1,17 +1,12 @@ import logging -from typing import Sequence -import discord from discord import ApplicationContext, Interaction, WebhookMessage, slash_command from discord.ext import commands from discord.ext.commands import cooldown -from sqlalchemy import select from src.bot import Bot from src.core import settings -from src.database.models import HtbDiscordLink -from src.database.session import AsyncSessionLocal -from src.helpers.verification import get_user_details, process_identification +from src.helpers.verification import send_verification_instructions logger = logging.getLogger(__name__) @@ -25,121 +20,15 @@ def __init__(self, bot: Bot): @slash_command( guild_ids=settings.guild_ids, description="Identify yourself on the HTB Discord server by linking your HTB account ID to your Discord user " - "ID.", guild_only=False + "ID.", + guild_only=False, ) @cooldown(1, 60, commands.BucketType.user) - async def identify(self, ctx: ApplicationContext, account_identifier: str) -> Interaction | WebhookMessage: - """Identify yourself on the HTB Discord server by linking your HTB account ID to your Discord user ID.""" - if len(account_identifier) != 60: - return await ctx.respond( - "This Account Identifier does not appear to be the right length (must be 60 characters long).", - ephemeral=True - ) - - await ctx.respond("Identification initiated, please wait...", ephemeral=True) - htb_user_details = await get_user_details(account_identifier) - if htb_user_details is None: - embed = discord.Embed(title="Error: Invalid account identifier.", color=0xFF0000) - return await ctx.respond(embed=embed, ephemeral=True) - - json_htb_user_id = htb_user_details["user_id"] - - author = ctx.user - member = await self.bot.get_or_fetch_user(author.id) - if not member: - return await ctx.respond(f"Error getting guild member with id: {author.id}.") - - # Step 1: Check if the Account Identifier has already been recorded and if they are the previous owner. - # Scenario: - # - I create a new Discord account. - # - I reuse my previous Account Identifier. - # - I now have an "alt account" with the same roles. - async with AsyncSessionLocal() as session: - stmt = ( - select(HtbDiscordLink) - .filter(HtbDiscordLink.account_identifier == account_identifier) - .order_by(HtbDiscordLink.id.desc()) - .limit(1) - ) - result = await session.scalars(stmt) - most_recent_rec: HtbDiscordLink = result.first() - - if most_recent_rec and most_recent_rec.discord_user_id_as_int != member.id: - error_desc = ( - f"Verified user {member.mention} tried to identify as another identified user.\n" - f"Current Discord UID: {member.id}\n" - f"Other Discord UID: {most_recent_rec.discord_user_id}\n" - f"Related HTB UID: {most_recent_rec.htb_user_id}" - ) - embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429) - await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed) - - return await ctx.respond( - "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True - ) - - # Step 2: Given the htb_user_id from JSON, check if each discord_user_id are different from member.id. - # Scenario: - # - I have a Discord account that is linked already to a "Hacker" role. - # - I create a new HTB account. - # - I identify with the new account. - # - `SELECT * FROM htb_discord_link WHERE htb_user_id = %s` will be empty, - # because the new account has not been verified before. All is good. - # - I am now "Noob" rank. - async with AsyncSessionLocal() as session: - stmt = select(HtbDiscordLink).filter(HtbDiscordLink.htb_user_id == json_htb_user_id) - result = await session.scalars(stmt) - user_links: Sequence[HtbDiscordLink] = result.all() - - discord_user_ids = {u_link.discord_user_id_as_int for u_link in user_links} - if discord_user_ids and member.id not in discord_user_ids: - orig_discord_ids = ", ".join([f"<@{id_}>" for id_ in discord_user_ids]) - error_desc = (f"The HTB account {json_htb_user_id} attempted to be identified by user <@{member.id}>, " - f"but is tied to another Discord account.\n" - f"Originally linked to Discord UID {orig_discord_ids}.") - embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429) - await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed) - - return await ctx.respond( - "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True - ) - - # Step 3: Check if discord_user_id already linked to an htb_user_id, and if JSON/db HTB IDs are the same. - # Scenario: - # - I have a new, unlinked Discord account. - # - Clubby generates a new token and gives it to me. - # - `SELECT * FROM htb_discord_link WHERE discord_user_id = %s` - # will be empty because I have not identified before. - # - I am now Clubby. - async with AsyncSessionLocal() as session: - stmt = select(HtbDiscordLink).filter(HtbDiscordLink.discord_user_id == member.id) - result = await session.scalars(stmt) - user_links: Sequence[HtbDiscordLink] = result.all() - - user_htb_ids = {u_link.htb_user_id_as_int for u_link in user_links} - if user_htb_ids and json_htb_user_id not in user_htb_ids: - error_desc = (f"User {member.mention} ({member.id}) tried to identify with a new HTB account.\n" - f"Original HTB UIDs: {', '.join([str(i) for i in user_htb_ids])}, new HTB UID: " - f"{json_htb_user_id}.") - embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429) - await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed) - - return await ctx.respond( - "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True - ) - - htb_discord_link = HtbDiscordLink( - account_identifier=account_identifier, discord_user_id=member.id, htb_user_id=json_htb_user_id - ) - async with AsyncSessionLocal() as session: - session.add(htb_discord_link) - await session.commit() - - await process_identification(htb_user_details, user=member, bot=self.bot) - - return await ctx.respond( - f"Your Discord user has been successfully identified as HTB user {json_htb_user_id}.", ephemeral=True - ) + async def identify( + self, ctx: ApplicationContext, account_identifier: str + ) -> Interaction | WebhookMessage: + """Legacy command. Now sends instructions to identify with HTB account.""" + await send_verification_instructions(ctx, ctx.author) def setup(bot: Bot) -> None: diff --git a/src/cmds/core/verify.py b/src/cmds/core/verify.py index 23088c1..919f872 100644 --- a/src/cmds/core/verify.py +++ b/src/cmds/core/verify.py @@ -1,14 +1,12 @@ import logging -import discord from discord import ApplicationContext, Interaction, WebhookMessage, slash_command -from discord.errors import Forbidden, HTTPException from discord.ext import commands from discord.ext.commands import cooldown from src.bot import Bot from src.core import settings -from src.helpers.verification import process_certification +from src.helpers.verification import process_certification, send_verification_instructions logger = logging.getLogger(__name__) @@ -47,57 +45,7 @@ async def verifycertification(self, ctx: ApplicationContext, certid: str, fullna @cooldown(1, 60, commands.BucketType.user) async def verify(self, ctx: ApplicationContext) -> Interaction | WebhookMessage: """Receive instructions in a DM on how to identify yourself with your HTB account.""" - member = ctx.user - - # Step one - embed_step1 = discord.Embed(color=0x9ACC14) - embed_step1.add_field( - name="Step 1: Log in at Hack The Box", - value="Go to the Hack The Box website at " - " and navigate to **Login > HTB Labs**. Log in to your HTB Account." - , inline=False, ) - embed_step1.set_image( - url="https://media.discordapp.net/attachments/724587782755844098/839871275627315250/unknown.png" - ) - - # Step two - embed_step2 = discord.Embed(color=0x9ACC14) - embed_step2.add_field( - name="Step 2: Locate the Account Identifier", - value='Click on your profile name, then select **My Profile**. ' - 'In the Profile Settings tab, find the field labeled **Account Identifier**. () ' - "Click the green button to copy your secret identifier.", inline=False, ) - embed_step2.set_image( - url="https://media.discordapp.net/attachments/724587782755844098/839871332963188766/unknown.png" - ) - - # Step three - embed_step3 = discord.Embed(color=0x9ACC14) - embed_step3.add_field( - name="Step 3: Identification", - value="Now type `/identify IDENTIFIER_HERE` in the bot-commands channel.\n\nYour roles will be " - "applied automatically.", inline=False - ) - embed_step3.set_image( - url="https://media.discordapp.net/attachments/709907130102317093/904744444539076618/unknown.png" - ) - - try: - await member.send(embed=embed_step1) - await member.send(embed=embed_step2) - await member.send(embed=embed_step3) - except Forbidden as ex: - logger.error("Exception during verify call", exc_info=ex) - return await ctx.respond( - "Whoops! I cannot DM you after all due to your privacy settings. Please allow DMs from other server " - "members and try again in 1 minute." - ) - except HTTPException as ex: - logger.error("Exception during verify call.", exc_info=ex) - return await ctx.respond( - "An unexpected error happened (HTTP 400, bad request). Please contact an Administrator." - ) - return await ctx.respond("Please check your DM for instructions.", ephemeral=True) + await send_verification_instructions(ctx, ctx.author) def setup(bot: Bot) -> None: diff --git a/src/core/config.py b/src/core/config.py index 6b3025e..b91644d 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -91,6 +91,7 @@ class AcademyCertificates(BaseSettings): class Roles(BaseSettings): """The roles settings.""" + VERIFIED: int # Moderation COMMUNITY_MANAGER: int diff --git a/src/helpers/ban.py b/src/helpers/ban.py index 8d2ddd0..08a6895 100644 --- a/src/helpers/ban.py +++ b/src/helpers/ban.py @@ -2,6 +2,8 @@ import asyncio import logging from datetime import datetime, timezone +from enum import Enum + import arrow import discord @@ -22,6 +24,12 @@ logger = logging.getLogger(__name__) +class BanCodes(Enum): + SUCCESS = "SUCCESS" + ALREADY_EXISTS = "ALREADY_EXISTS" + FAILED = "FAILED" + + async def _check_member(bot: Bot, guild: Guild, member: Member | User, author: Member = None) -> SimpleResponse | None: if isinstance(member, Member): if member_is_staff(member): @@ -29,9 +37,9 @@ async def _check_member(bot: Bot, guild: Guild, member: Member | User, author: M elif isinstance(member, User): member = await bot.get_member_or_user(guild, member.id) if member.bot: - return SimpleResponse(message="You cannot ban a bot.", delete_after=None) + return SimpleResponse(message="You cannot ban a bot.", delete_after=None, code=BanCodes.FAILED) if author and author.id == member.id: - return SimpleResponse(message="You cannot ban yourself.", delete_after=None) + return SimpleResponse(message="You cannot ban yourself.", delete_after=None, code=BanCodes.FAILED) async def _get_ban_or_create(member: Member, ban: Ban, infraction: Infraction) -> tuple[int, bool]: @@ -50,6 +58,29 @@ async def _get_ban_or_create(member: Member, ban: Ban, infraction: Infraction) - return ban_id, False +async def _create_ban_response(member: Member | User, end_date: str, dm_banned_member: bool, needs_approval: bool) -> SimpleResponse: + """Create a SimpleResponse for ban operations.""" + if needs_approval: + if member: + message = f"{member.display_name} ({member.id}) has been banned until {end_date} (UTC)." + else: + message = f"{member.id} has been banned until {end_date} (UTC)." + else: + if member: + message = f"Member {member.display_name} has been banned permanently." + else: + message = f"Member {member.id} has been banned permanently." + + if not dm_banned_member: + message += "\n Could not DM banned member due to permission error." + + return SimpleResponse( + message=message, + delete_after=0 if not needs_approval else None, + code=BanCodes.SUCCESS + ) + + async def ban_member( bot: Bot, guild: Guild, member: Member | User, duration: str, reason: str, evidence: str, author: Member = None, needs_approval: bool = True @@ -93,7 +124,8 @@ async def ban_member( if is_existing: return SimpleResponse( message=f"A ban with id: {ban_id} already exists for member {member}", - delete_after=None + delete_after=None, + code=BanCodes.ALREADY_EXISTS ) # DM member, before we ban, else we cannot dm since we do not share a guild @@ -107,27 +139,20 @@ async def ban_member( extra={"ban_requestor": author.name, "ban_receiver": member.id} ) if author: - return SimpleResponse(message="You do not have the proper permissions to ban.", delete_after=None) + return SimpleResponse(message="You do not have the proper permissions to ban.", delete_after=None, code=BanCodes.FAILED) return except HTTPException as ex: logger.warning(f"HTTPException when trying to ban user with ID {member.id}", exc_info=ex) if author: return SimpleResponse( message="Here's a 400 Bad Request for you. Just like when you tried to ask me out, last week.", - delete_after=None + delete_after=None, + code=BanCodes.FAILED ) return # If approval is required, send a message to the moderator channel about the ban if not needs_approval: - if member: - message = f"Member {member.display_name} has been banned permanently." - else: - message = f"Member {member.id} has been banned permanently." - - if not dm_banned_member: - message += "\n Could not DM banned member due to permission error." - logger.info( "Member has been banned permanently.", extra={"ban_requestor": author.name, "ban_receiver": member.id, "dm_banned_member": dm_banned_member} @@ -136,16 +161,7 @@ async def ban_member( unban_task = schedule(unban_member(guild, member), run_at=datetime.fromtimestamp(ban.unban_time)) asyncio.create_task(unban_task) logger.debug("Unbanned sceduled for ban", extra={"ban_id": ban_id, "unban_time": ban.unban_time}) - return SimpleResponse(message=message, delete_after=0) else: - if member: - message = f"{member.display_name} ({member.id}) has been banned until {end_date} (UTC)." - else: - message = f"{member.id} has been banned until {end_date} (UTC)." - - if not dm_banned_member: - message += " Could not DM banned member due to permission error." - member_name = f"{member.display_name} ({member.name})" embed = discord.Embed( title=f"Ban request #{ban_id}", @@ -156,7 +172,8 @@ async def ban_member( embed.set_thumbnail(url=f"{settings.HTB_URL}/images/logo600.png") view = BanDecisionView(ban_id, bot, guild, member, end_date, reason) await guild.get_channel(settings.channels.SR_MOD).send(embed=embed, view=view) - return SimpleResponse(message=message) + + return await _create_ban_response(member, end_date, dm_banned_member, needs_approval) async def _dm_banned_member(end_date: str, guild: Guild, member: Member, reason: str) -> bool: diff --git a/src/helpers/responses.py b/src/helpers/responses.py index 6513bfd..91a8213 100644 --- a/src/helpers/responses.py +++ b/src/helpers/responses.py @@ -4,9 +4,10 @@ class SimpleResponse(object): """A simple response object.""" - def __init__(self, message: str, delete_after: int | None = None): + def __init__(self, message: str, delete_after: int | None = None, code: str | None = None): self.message = message self.delete_after = delete_after + self.code = code def __str__(self): return json.dumps(dict(self), ensure_ascii=False) diff --git a/src/helpers/verification.py b/src/helpers/verification.py index ddb4be9..ae02038 100644 --- a/src/helpers/verification.py +++ b/src/helpers/verification.py @@ -1,19 +1,86 @@ import logging -from datetime import datetime -from typing import Dict, List, Optional, cast +from typing import Dict, List, Optional import aiohttp import discord -from discord import Forbidden, Member, Role, User +from discord import ApplicationContext, Forbidden, Member, Role, User, HTTPException from discord.ext.commands import GuildNotFound, MemberNotFound from src.bot import Bot from src.core import settings -from src.helpers.ban import ban_member +from src.helpers.ban import BanCodes, ban_member logger = logging.getLogger(__name__) +async def send_verification_instructions( + ctx: ApplicationContext, member: Member +) -> discord.Interaction | discord.WebhookMessage: + """Send instructions via DM on how to identify with HTB account. + + Args: + ctx (ApplicationContext): The context of the command. + member (Member): The member to send the instructions to. + + Returns: + discord.Interaction | discord.WebhookMessage: The response message. + """ + member = ctx.user + + # Create step-by-step instruction embeds + embed_step1 = discord.Embed(color=0x9ACC14) + embed_step1.add_field( + name="Step 1: Login to your HTB Account", + value="Go to and login.", + inline=False, + ) + embed_step1.set_image( + url="https://media.discordapp.net/attachments/1102700815493378220/1384587341338902579/image.png" + ) + + embed_step2 = discord.Embed(color=0x9ACC14) + embed_step2.add_field( + name="Step 2: Navigate to your Security Settings", + value="In the navigation bar, click on **Security Settings** and scroll down to the **Discord Account** section. " + "()", + inline=False, + ) + embed_step2.set_image( + url="https://media.discordapp.net/attachments/1102700815493378220/1384587813760270392/image.png" + ) + + embed_step3 = discord.Embed(color=0x9ACC14) + embed_step3.add_field( + name="Step 3: Link your Discord Account", + value="Click **Connect** and you will be redirected to login to your Discord account via oauth. " + "After logging in, you will be redirected back to the HTB Account page. " + "Your Discord account will now be linked. Discord may take a few minutes to update. " + "If you have any issues, please contact a Moderator.", + inline=False, + ) + embed_step3.set_image( + url="https://media.discordapp.net/attachments/1102700815493378220/1384586811384402042/image.png" + ) + + try: + await member.send(embed=embed_step1) + await member.send(embed=embed_step2) + await member.send(embed=embed_step3) + except Forbidden as ex: + logger.error("Exception during verify call", exc_info=ex) + return await ctx.respond( + "Whoops! I cannot DM you after all due to your privacy settings. Please allow DMs from other server " + "members and try again in 1 minute." + ) + except HTTPException as ex: + logger.error("Exception during verify call.", exc_info=ex) + return await ctx.respond( + "An unexpected error happened (HTTP 400, bad request). Please contact an Administrator." + ) + + return await ctx.respond("Please check your DM for instructions.", ephemeral=True) + + async def get_user_details(account_identifier: str) -> Optional[Dict]: """Get user details from HTB.""" acc_id_url = f"{settings.API_URL}/discord/identifier/{account_identifier}?secret={settings.HTB_API_SECRET}" @@ -23,10 +90,14 @@ async def get_user_details(account_identifier: str) -> Optional[Dict]: if r.status == 200: response = await r.json() elif r.status == 404: - logger.debug("Account identifier has been regenerated since last identification. Cannot re-verify.") + logger.debug( + "Account identifier has been regenerated since last identification. Cannot re-verify." + ) response = None else: - logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.") + logger.error( + f"Non-OK HTTP status code returned from identifier lookup: {r.status}." + ) response = None return response @@ -45,7 +116,9 @@ async def get_season_rank(htb_uid: int) -> str | None: logger.error("Invalid Season ID.") response = None else: - logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.") + logger.error( + f"Non-OK HTTP status code returned from identifier lookup: {r.status}." + ) response = None if not response["data"]: @@ -59,26 +132,19 @@ async def get_season_rank(htb_uid: int) -> str | None: return rank -async def _check_for_ban(uid: str) -> Optional[Dict]: - async with aiohttp.ClientSession() as session: - token_url = f"{settings.API_URL}/discord/{uid}/banned?secret={settings.HTB_API_SECRET}" - async with session.get(token_url) as r: - if r.status == 200: - ban_details = await r.json() - else: - logger.error( - f"Could not fetch ban details for uid {uid}: " - f"non-OK status code returned ({r.status}). Body: {r.content}" - ) - ban_details = None - - return ban_details +async def _check_for_ban(member: Member) -> Optional[Dict]: + """Check if the member is banned.""" + try: + member.guild.get_role(settings.roles.BANNED) + return True + except Forbidden: + return False async def process_certification(certid: str, name: str): """Process certifications.""" cert_api_url = f"{settings.API_V4_URL}/certificate/lookup" - params = {'id': certid, 'name': name} + params = {"id": certid, "name": name} async with aiohttp.ClientSession() as session: async with session.get(cert_api_url, params=params) as r: if r.status == 200: @@ -86,7 +152,9 @@ async def process_certification(certid: str, name: str): elif r.status == 404: return False else: - logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.") + logger.error( + f"Non-OK HTTP status code returned from identifier lookup: {r.status}." + ) response = None try: certRawName = response["certificates"][0]["name"] @@ -107,7 +175,90 @@ async def process_certification(certid: str, name: str): return cert -async def process_identification( +async def _handle_banned_user(member: Member, bot: Bot): + """Handle banned trait during account linking. + + Args: + member (Member): The member to process. + bot (Bot): The bot instance. + """ + resp = await ban_member( + bot, + member.guild, + member, + "1337w", + ( + "Banned on the HTB Platform. Ban duration could not be determined. " + "Please login to confirm ban details and contact HTB Support to appeal." + ), + None, + needs_approval=False, + ) + if resp.code == BanCodes.SUCCESS: + embed = discord.Embed( + title="Identification error", + description=f"User {member.mention} ({member.id}) was platform banned HTB and thus also here.", + color=0xFF2429, + ) + await member.guild.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed) + + +async def _set_nickname(member: Member, nickname: str) -> bool: + """Set the nickname of the member. + + Args: + member (Member): The member to set the nickname for. + nickname (str): The nickname to set. + + Returns: + bool: True if the nickname was set, False otherwise. + """ + try: + await member.edit(nick=nickname) + return True + except Forbidden as e: + logger.error(f"Exception whe trying to edit the nick-name of the user: {e}") + return False + + +async def process_account_identification( + member: Member, bot: Bot, traits: dict[str, str] | None = None +) -> None: + """Process HTB account identification, to be called during account linking. + + Args: + member (Member): The member to process. + bot (Bot): The bot instance. + traits (dict[str, str] | None): Optional user traits to process. + """ + await member.add_roles(member.guild.get_role(settings.roles.VERIFIED), atomic=True) + + nickname_changed = False + + if traits.get("username") and traits.get("username") != member.name: + nickname_changed = await _set_nickname(member, traits.get("username")) + + if not nickname_changed: + logger.warning( + f"No username provided for {member.name} with ID {member.id} during identification." + ) + + if traits.get("mp_user_id"): + htb_user_details = await get_user_details(traits.get("mp_user_id")) + await process_labs_identification(htb_user_details, member, bot) + + if not nickname_changed: + logger.debug( + f"Falling back on HTB username to set nickname for {member.name} with ID {member.id}." + ) + await _set_nickname(member, htb_user_details["username"]) + + if traits.get("banned", False) == True: # noqa: E712 - explicit bool only, no truthiness + await _handle_banned_user(member, bot) + return + + +async def process_labs_identification( htb_user_details: Dict[str, str], user: Optional[Member | User], bot: Bot ) -> Optional[List[Role]]: """Returns roles to assign if identification was successfully processed.""" @@ -123,54 +274,33 @@ async def process_identification( raise MemberNotFound(str(user.id)) else: raise GuildNotFound(f"Could not identify member {user} in guild.") - season_rank = await get_season_rank(htb_uid) - banned_details = await _check_for_ban(htb_uid) - - if banned_details is not None and banned_details["banned"]: - # If user is banned, this field must be a string - # Strip date e.g. from "2022-01-31T11:00:00.000000Z" - banned_until: str = cast(str, banned_details["ends_at"])[:10] - banned_until_dt: datetime = datetime.strptime(banned_until, "%Y-%m-%d") - ban_duration: str = f"{(banned_until_dt - datetime.now()).days}d" - reason = "Banned on the HTB Platform. Please contact HTB Support to appeal." - logger.info(f"Discord user {member.name} ({member.id}) is platform banned. Banning from Discord...") - await ban_member(bot, guild, member, ban_duration, reason, None, needs_approval=False) - - embed = discord.Embed( - title="Identification error", - description=f"User {member.mention} ({member.id}) was platform banned HTB and thus also here.", - color=0xFF2429, ) - - await guild.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed) - return None to_remove = [] - for role in member.roles: - if role.id in settings.role_groups.get("ALL_RANKS") + settings.role_groups.get("ALL_POSITIONS"): + if role.id in settings.role_groups.get("ALL_RANKS") + settings.role_groups.get( + "ALL_POSITIONS" + ): to_remove.append(guild.get_role(role.id)) to_assign = [] - logger.debug( - "Getting role 'rank':", extra={ - "role_id": settings.get_post_or_rank(htb_user_details["rank"]), - "role_obj": guild.get_role(settings.get_post_or_rank(htb_user_details["rank"])), - "htb_rank": htb_user_details["rank"], - }, ) - if htb_user_details["rank"] not in ["Deleted", "Moderator", "Ambassador", "Admin", "Staff"]: - to_assign.append(guild.get_role(settings.get_post_or_rank(htb_user_details["rank"]))) + if htb_user_details["rank"] not in [ + "Deleted", + "Moderator", + "Ambassador", + "Admin", + "Staff", + ]: + to_assign.append( + guild.get_role(settings.get_post_or_rank(htb_user_details["rank"])) + ) + + season_rank = await get_season_rank(htb_uid) if season_rank: to_assign.append(guild.get_role(settings.get_season(season_rank))) + if htb_user_details["vip"]: - logger.debug( - 'Getting role "VIP":', extra={"role_id": settings.roles.VIP, "role_obj": guild.get_role(settings.roles.VIP)} - ) to_assign.append(guild.get_role(settings.roles.VIP)) if htb_user_details["dedivip"]: - logger.debug( - 'Getting role "VIP+":', - extra={"role_id": settings.roles.VIP_PLUS, "role_obj": guild.get_role(settings.roles.VIP_PLUS)} - ) to_assign.append(guild.get_role(settings.roles.VIP_PLUS)) if htb_user_details["hof_position"] != "unranked": position = int(htb_user_details["hof_position"]) @@ -180,45 +310,17 @@ async def process_identification( elif position <= 10: pos_top = "10" if pos_top: - logger.debug(f"User is Hall of Fame rank {position}. Assigning role Top-{pos_top}...") - logger.debug( - 'Getting role "HoF role":', extra={ - "role_id": settings.get_post_or_rank(pos_top), - "role_obj": guild.get_role(settings.get_post_or_rank(pos_top)), "hof_val": pos_top, - }, ) to_assign.append(guild.get_role(settings.get_post_or_rank(pos_top))) - else: - logger.debug(f"User is position {position}. No Hall of Fame roles for them.") if htb_user_details["machines"]: - logger.debug( - 'Getting role "BOX_CREATOR":', - extra={"role_id": settings.roles.BOX_CREATOR, "role_obj": guild.get_role(settings.roles.BOX_CREATOR)}, ) to_assign.append(guild.get_role(settings.roles.BOX_CREATOR)) if htb_user_details["challenges"]: - logger.debug( - 'Getting role "CHALLENGE_CREATOR":', extra={ - "role_id": settings.roles.CHALLENGE_CREATOR, - "role_obj": guild.get_role(settings.roles.CHALLENGE_CREATOR), - }, ) to_assign.append(guild.get_role(settings.roles.CHALLENGE_CREATOR)) - if member.nick != htb_user_details["user_name"]: - try: - await member.edit(nick=htb_user_details["user_name"]) - except Forbidden as e: - logger.error(f"Exception whe trying to edit the nick-name of the user: {e}") - - logger.debug("All roles to_assign:", extra={"to_assign": to_assign}) # We don't need to remove any roles that are going to be assigned again to_remove = list(set(to_remove) - set(to_assign)) - logger.debug("All roles to_remove:", extra={"to_remove": to_remove}) if to_remove: await member.remove_roles(*to_remove, atomic=True) - else: - logger.debug("No roles need to be removed") if to_assign: await member.add_roles(*to_assign, atomic=True) - else: - logger.debug("No roles need to be assigned") return to_assign diff --git a/src/webhooks/handlers/__init__.py b/src/webhooks/handlers/__init__.py index 02e0edf..601943c 100644 --- a/src/webhooks/handlers/__init__.py +++ b/src/webhooks/handlers/__init__.py @@ -1,9 +1,9 @@ from discord import Bot -from src.webhooks.handlers.academy import handler as academy_handler +from src.webhooks.handlers.account import AccountHandler from src.webhooks.types import Platform, WebhookBody -handlers = {Platform.ACADEMY: academy_handler} +handlers = {Platform.ACCOUNT: AccountHandler().handle} def can_handle(platform: Platform) -> bool: diff --git a/src/webhooks/handlers/academy.py b/src/webhooks/handlers/academy.py deleted file mode 100644 index 5aa3568..0000000 --- a/src/webhooks/handlers/academy.py +++ /dev/null @@ -1,76 +0,0 @@ -import logging - -from discord import Bot -from discord.errors import NotFound -from fastapi import HTTPException - -from src.core import settings -from src.webhooks.types import WebhookBody, WebhookEvent - -logger = logging.getLogger(__name__) - - -async def handler(body: WebhookBody, bot: Bot) -> dict: - """ - Handles incoming webhook events and performs actions accordingly. - - This function processes different webhook events related to account linking, - certificate awarding, and account unlinking. It updates the member's roles - based on the received event. - - Args: - body (WebhookBody): The data received from the webhook. - bot (Bot): The instance of the Discord bot. - - Returns: - dict: A dictionary with a "success" key indicating whether the operation was successful. - - Raises: - HTTPException: If an error occurs while processing the webhook event. - """ - # TODO: Change it here so we pass the guild instead of the bot # noqa: T000 - guild = await bot.fetch_guild(settings.guild_ids[0]) - - try: - discord_id = int(body.data["discord_id"]) - member = await guild.fetch_member(discord_id) - except ValueError as exc: - logger.debug("Invalid Discord ID", exc_info=exc) - raise HTTPException(status_code=400, detail="Invalid Discord ID") from exc - except NotFound as exc: - logger.debug("User is not in the Discord server", exc_info=exc) - raise HTTPException(status_code=400, detail="User is not in the Discord server") from exc - - if body.event == WebhookEvent.ACCOUNT_LINKED: - roles_to_add = {settings.roles.ACADEMY_USER} - roles_to_add.update(settings.get_academy_cert_role(cert["id"]) for cert in body.data["certifications"]) - - # Filter out invalid role IDs - role_ids_to_add = {role_id for role_id in roles_to_add if role_id is not None} - roles_to_add = {guild.get_role(role_id) for role_id in role_ids_to_add} - - await member.add_roles(*roles_to_add, atomic=True) - elif body.event == WebhookEvent.CERTIFICATE_AWARDED: - cert_id = body.data["certification"]["id"] - - role = settings.get_academy_cert_role(cert_id) - if not role: - logger.debug(f"Role for certification: {cert_id} does not exist") - raise HTTPException(status_code=400, detail=f"Role for certification: {cert_id} does not exist") - - await member.add_roles(role, atomic=True) - elif body.event == WebhookEvent.ACCOUNT_UNLINKED: - current_role_ids = {role.id for role in member.roles} - cert_role_ids = {settings.get_academy_cert_role(cert_id) for _, cert_id in settings.academy_certificates} - - common_role_ids = current_role_ids.intersection(cert_role_ids) - - role_ids_to_remove = {settings.roles.ACADEMY_USER}.union(common_role_ids) - roles_to_remove = {guild.get_role(role_id) for role_id in role_ids_to_remove} - - await member.remove_roles(*roles_to_remove, atomic=True) - else: - logger.debug(f"Event {body.event} not implemented") - raise HTTPException(status_code=501, detail=f"Event {body.event} not implemented") - - return {"success": True} diff --git a/src/webhooks/handlers/account.py b/src/webhooks/handlers/account.py new file mode 100644 index 0000000..895cc9f --- /dev/null +++ b/src/webhooks/handlers/account.py @@ -0,0 +1,50 @@ +import logging + +from discord import Bot + +from src.core import settings +from src.helpers.verification import process_account_identification +from src.webhooks.types import WebhookBody, WebhookEvent + +from src.webhooks.handlers.base import BaseHandler + + + +class AccountHandler(BaseHandler): + async def handle(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles incoming webhook events and performs actions accordingly. + + This function processes different webhook events originating from the + HTB Account. + """ + if body.event == WebhookEvent.ACCOUNT_LINKED: + await self.handle_account_linked(body, bot) + elif body.event == WebhookEvent.ACCOUNT_UNLINKED: + await self.handle_account_unlinked(body, bot) + elif body.event == WebhookEvent.ACCOUNT_DELETED: + await self.handle_account_deleted(body, bot) + + async def handle_account_linked(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles the account linked event. + """ + discord_id = self.validate_discord_id(body.properties.get("discord_id")) + account_id = self.validate_account_id(body.properties.get("account_id")) + + member = await self.get_guild_member(discord_id, bot) + await process_account_identification(member, bot, traits=self.merge_properties_and_traits(body.properties, body.traits)) + await bot.send_message(settings.channels.VERIFY_LOGS, f"Account linked: {account_id} -> ({member.mention} ({member.id})") + + self.logger.info(f"Account {account_id} linked to {member.id}", extra={"account_id": account_id, "discord_id": discord_id}) + + async def handle_account_unlinked(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles the account unlinked event. + """ + discord_id = self.validate_discord_id(body.properties.get("discord_id")) + account_id = self.validate_account_id(body.properties.get("account_id")) + + member = await self.get_guild_member(discord_id, bot) + + await member.remove_roles(settings.roles.VERIFIED, atomic=True) diff --git a/src/webhooks/handlers/base.py b/src/webhooks/handlers/base.py new file mode 100644 index 0000000..cfb63d9 --- /dev/null +++ b/src/webhooks/handlers/base.py @@ -0,0 +1,114 @@ +import logging + +from abc import ABC, abstractmethod +from typing import TypeVar + +from discord import Bot, Member +from discord.errors import NotFound +from fastapi import HTTPException + +from src.core import settings +from src.webhooks.types import WebhookBody + +T = TypeVar("T") + + +class BaseHandler(ABC): + ACADEMY_USER_ID = "academy_user_id" + MP_USER_ID = "mp_user_id" + EP_USER_ID = "ep_user_id" + CTF_USER_ID = "ctf_user_id" + ACCOUNT_ID = "account_id" + DISCORD_ID = "discord_id" + + + def __init__(self): + self.logger = logging.getLogger(self.__class__.__name__) + + @abstractmethod + async def handle(self, body: WebhookBody, bot: Bot) -> dict: + pass + + async def get_guild_member(self, discord_id: int, bot: Bot) -> Member: + """ + Fetches a guild member from the Discord server. + + Args: + discord_id (int): The Discord ID of the user. + bot (Bot): The Discord bot instance. + + Returns: + Member: The guild member. + + Raises: + HTTPException: If the user is not in the Discord server (400) + """ + try: + guild = await bot.fetch_guild(settings.guild_ids[0]) + member = await guild.fetch_member(discord_id) + return member + + except NotFound as exc: + self.logger.debug("User is not in the Discord server", exc_info=exc) + raise HTTPException( + status_code=400, detail="User is not in the Discord server" + ) from exc + + def validate_property(self, property: T | None, name: str) -> T: + """ + Validates a property is not None. + + Args: + property (T | None): The property to validate. + name (str): The name of the property. + + Returns: + T: The validated property. + + Raises: + HTTPException: If the property is None (400) + """ + if property is None: + msg = f"Invalid {name}" + self.logger.debug(msg) + raise HTTPException(status_code=400, detail=msg) + + return property + + def validate_discord_id(self, discord_id: str | int) -> int: + """ + Validates the Discord ID. See validate_property function. + """ + return self.validate_property(discord_id, "Discord ID") + + def validate_account_id(self, account_id: str | int) -> int: + """ + Validates the Account ID. See validate_property function. + """ + return self.validate_property(account_id, "Account ID") + + def get_property_or_trait(self, body: WebhookBody, name: str) -> int | None: + """ + Gets a trait or property from the webhook body. + """ + return body.properties.get(name) or body.traits.get(name) + + def merge_properties_and_traits(self, properties: dict[str, int | None], traits: dict[str, int | None]) -> dict[str, int | None]: + """ + Merges the properties and traits from the webhook body without duplicates. + If a property and trait have the same name but different values, the property value will be used. + """ + return {**properties, **{k: v for k, v in traits.items() if k not in properties}} + + def get_platform_properties(self, body: WebhookBody) -> dict[str, int | None]: + """ + Gets the platform properties from the webhook body. + """ + properties = { + self.ACCOUNT_ID: self.get_property_or_trait(body, self.ACCOUNT_ID), + self.MP_USER_ID: self.get_property_or_trait(body, self.MP_USER_ID), + self.EP_USER_ID: self.get_property_or_trait(body, self.EP_USER_ID), + self.CTF_USER_ID: self.get_property_or_trait(body, self.CTF_USER_ID), + self.ACADEMY_USER_ID: self.get_property_or_trait(body, self.ACADEMY_USER_ID), + } + return properties \ No newline at end of file diff --git a/src/webhooks/server.py b/src/webhooks/server.py index 92222a7..06d02af 100644 --- a/src/webhooks/server.py +++ b/src/webhooks/server.py @@ -1,10 +1,13 @@ +import hashlib import hmac import logging -from typing import Any, Dict, Union +import json +from typing import Any, Dict -from fastapi import FastAPI, Header, HTTPException +from fastapi import FastAPI, HTTPException, Request from hypercorn.asyncio import serve as hypercorn_serve from hypercorn.config import Config as HypercornConfig +from pydantic import ValidationError from src.bot import bot from src.core import settings @@ -17,20 +20,36 @@ app = FastAPI() +def verify_signature(body: dict, signature: str, secret: str) -> bool: + """ + HMAC SHA1 signature verification. + + Args: + body (dict): The raw body of the webhook request. + signature (str): The X-Signature header of the webhook request. + secret (str): The webhook secret. + + Returns: + bool: True if the signature is valid, False otherwise. + """ + if not signature: + return False + + digest = hmac.new(secret.encode(), body, hashlib.sha1).hexdigest() + return hmac.compare_digest(signature, digest) + + @app.post("/webhook") -async def webhook_handler( - body: WebhookBody, authorization: Union[str, None] = Header(default=None) -) -> Dict[str, Any]: +async def webhook_handler(request: Request) -> Dict[str, Any]: """ Handles incoming webhook requests and forwards them to the appropriate handler. - This function first checks the provided authorization token in the request header. - If the token is valid, it checks if the platform can be handled and then forwards + This function first verifies the provided HMAC signature in the request header. + If the signature is valid, it checks if the platform can be handled and then forwards the request to the corresponding handler. Args: - body (WebhookBody): The data received from the webhook. - authorization (Union[str, None]): The authorization header containing the Bearer token. + request (Request): The incoming webhook request. Returns: Dict[str, Any]: The response from the corresponding handler. The dictionary contains @@ -39,17 +58,21 @@ async def webhook_handler( Raises: HTTPException: If an error occurs while processing the webhook event or if unauthorized. """ - if authorization is None or not authorization.strip().startswith("Bearer"): - logger.warning("Unauthorized webhook request") - raise HTTPException(status_code=401, detail="Unauthorized") + body = await request.body() + signature = request.headers.get("X-Signature") - token = authorization[6:].strip() - if hmac.compare_digest(token, settings.WEBHOOK_TOKEN): + if not verify_signature(body, signature, settings.WEBHOOK_TOKEN): logger.warning("Unauthorized webhook request") raise HTTPException(status_code=401, detail="Unauthorized") + try: + body = WebhookBody.model_validate(json.loads(body)) + except ValidationError as e: + logger.warning("Invalid webhook request: %s", e.errors()) + raise HTTPException(status_code=400, detail="Invalid webhook request body") + if not handlers.can_handle(body.platform): - logger.warning("Webhook request not handled by platform") + logger.warning("Webhook request not handled by platform: %s", body.platform) raise HTTPException(status_code=501, detail="Platform not implemented") return await handlers.handle(body, bot) diff --git a/src/webhooks/types.py b/src/webhooks/types.py index 3885dda..9812ccb 100644 --- a/src/webhooks/types.py +++ b/src/webhooks/types.py @@ -1,11 +1,12 @@ from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class WebhookEvent(Enum): - ACCOUNT_LINKED = "AccountLinked" - ACCOUNT_UNLINKED = "AccountUnlinked" + ACCOUNT_LINKED = "DiscordAccountLinked" + ACCOUNT_UNLINKED = "DiscordAccountUnlinked" + ACCOUNT_DELETED = "UserAccountDeleted" CERTIFICATE_AWARDED = "CertificateAwarded" RANK_UP = "RankUp" HOF_CHANGE = "HofChange" @@ -19,9 +20,13 @@ class Platform(Enum): ACADEMY = "academy" CTF = "ctf" ENTERPRISE = "enterprise" + ACCOUNT = "account" class WebhookBody(BaseModel): + model_config = ConfigDict(extra="allow") + platform: Platform event: WebhookEvent - data: dict + properties: dict | None + traits: dict | None diff --git a/tests/src/webhooks/handlers/test_account.py b/tests/src/webhooks/handlers/test_account.py new file mode 100644 index 0000000..91cf195 --- /dev/null +++ b/tests/src/webhooks/handlers/test_account.py @@ -0,0 +1,364 @@ +import logging +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from discord import Bot, Member +from discord.errors import NotFound +from fastapi import HTTPException + +from src.webhooks.handlers.account import AccountHandler +from src.webhooks.types import WebhookBody, Platform, WebhookEvent +from tests import helpers + + +class TestAccountHandler: + """Test the `AccountHandler` class.""" + + def test_initialization(self): + """Test that AccountHandler initializes correctly.""" + handler = AccountHandler() + + assert isinstance(handler.logger, logging.Logger) + assert handler.logger.name == "AccountHandler" + + @pytest.mark.asyncio + async def test_handle_account_linked_event(self, bot): + """Test handle method routes ACCOUNT_LINKED event correctly.""" + handler = AccountHandler() + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"discord_id": 123456789, "account_id": 987654321}, + traits={}, + ) + + with patch.object( + handler, "handle_account_linked", new_callable=AsyncMock + ) as mock_handle: + await handler.handle(body, bot) + mock_handle.assert_called_once_with(body, bot) + + @pytest.mark.asyncio + async def test_handle_account_unlinked_event(self, bot): + """Test handle method routes ACCOUNT_UNLINKED event correctly.""" + handler = AccountHandler() + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_UNLINKED, + properties={"discord_id": 123456789, "account_id": 987654321}, + traits={}, + ) + + with patch.object( + handler, "handle_account_unlinked", new_callable=AsyncMock + ) as mock_handle: + await handler.handle(body, bot) + mock_handle.assert_called_once_with(body, bot) + + @pytest.mark.asyncio + async def test_handle_account_deleted_event(self, bot): + """Test handle method with ACCOUNT_DELETED event (method not implemented).""" + handler = AccountHandler() + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_DELETED, + properties={"discord_id": 123456789, "account_id": 987654321}, + traits={}, + ) + + # The handle_account_deleted method is not implemented, so this should raise AttributeError + with pytest.raises(AttributeError): + await handler.handle(body, bot) + + @pytest.mark.asyncio + async def test_handle_unknown_event(self, bot): + """Test handle method with unknown event does nothing.""" + handler = AccountHandler() + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.CERTIFICATE_AWARDED, # Not handled by AccountHandler + properties={"discord_id": 123456789, "account_id": 987654321}, + traits={}, + ) + + # Should not raise any exceptions, just do nothing + await handler.handle(body, bot) + + @pytest.mark.asyncio + async def test_handle_account_linked_success(self, bot): + """Test successful account linking.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + mock_member = helpers.MockMember(id=discord_id, mention="@testuser") + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"discord_id": discord_id, "account_id": account_id}, + traits={"htb_user_id": 555}, + ) + + # Create a custom bot mock without spec_set restrictions for this test + custom_bot = MagicMock() + custom_bot.send_message = AsyncMock() + + with ( + patch.object( + handler, "validate_discord_id", return_value=discord_id + ) as mock_validate_discord, + patch.object( + handler, "validate_account_id", return_value=account_id + ) as mock_validate_account, + patch.object( + handler, + "get_guild_member", + new_callable=AsyncMock, + return_value=mock_member, + ) as mock_get_member, + patch.object( + handler, + "merge_properties_and_traits", + return_value={ + "discord_id": discord_id, + "account_id": account_id, + "htb_user_id": 555, + }, + ) as mock_merge, + patch( + "src.webhooks.handlers.account.process_account_identification", + new_callable=AsyncMock, + ) as mock_process, + patch("src.webhooks.handlers.account.settings") as mock_settings, + patch.object(handler.logger, "info") as mock_log, + ): + mock_settings.channels.VERIFY_LOGS = 12345 + + await handler.handle_account_linked(body, custom_bot) + + # Verify all method calls + mock_validate_discord.assert_called_once_with(discord_id) + mock_validate_account.assert_called_once_with(account_id) + mock_get_member.assert_called_once_with(discord_id, custom_bot) + mock_merge.assert_called_once_with(body.properties, body.traits) + mock_process.assert_called_once_with( + mock_member, + custom_bot, + traits={ + "discord_id": discord_id, + "account_id": account_id, + "htb_user_id": 555, + }, + ) + custom_bot.send_message.assert_called_once_with( + 12345, f"Account linked: {account_id} -> (@testuser ({discord_id})" + ) + mock_log.assert_called_once_with( + f"Account {account_id} linked to {discord_id}", + extra={"account_id": account_id, "discord_id": discord_id}, + ) + + @pytest.mark.asyncio + async def test_handle_account_linked_invalid_discord_id(self, bot): + """Test account linking with invalid Discord ID.""" + handler = AccountHandler() + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"discord_id": None, "account_id": 987654321}, + traits={}, + ) + + with patch.object( + handler, + "validate_discord_id", + side_effect=HTTPException(status_code=400, detail="Invalid Discord ID"), + ): + with pytest.raises(HTTPException) as exc_info: + await handler.handle_account_linked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Discord ID" + + @pytest.mark.asyncio + async def test_handle_account_linked_invalid_account_id(self, bot): + """Test account linking with invalid Account ID.""" + handler = AccountHandler() + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"discord_id": 123456789, "account_id": None}, + traits={}, + ) + + with ( + patch.object(handler, "validate_discord_id", return_value=123456789), + patch.object( + handler, + "validate_account_id", + side_effect=HTTPException(status_code=400, detail="Invalid Account ID"), + ), + ): + with pytest.raises(HTTPException) as exc_info: + await handler.handle_account_linked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Account ID" + + @pytest.mark.asyncio + async def test_handle_account_linked_user_not_in_guild(self, bot): + """Test account linking when user is not in the Discord guild.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"discord_id": discord_id, "account_id": account_id}, + traits={}, + ) + + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object( + handler, + "get_guild_member", + new_callable=AsyncMock, + side_effect=HTTPException( + status_code=400, detail="User is not in the Discord server" + ), + ), + ): + with pytest.raises(HTTPException) as exc_info: + await handler.handle_account_linked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "User is not in the Discord server" + + @pytest.mark.asyncio + async def test_handle_account_unlinked_success(self, bot): + """Test successful account unlinking.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + mock_member = helpers.MockMember(id=discord_id) + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_UNLINKED, + properties={"discord_id": discord_id, "account_id": account_id}, + traits={}, + ) + + with ( + patch.object( + handler, "validate_discord_id", return_value=discord_id + ) as mock_validate_discord, + patch.object( + handler, "validate_account_id", return_value=account_id + ) as mock_validate_account, + patch.object( + handler, + "get_guild_member", + new_callable=AsyncMock, + return_value=mock_member, + ) as mock_get_member, + patch("src.webhooks.handlers.account.settings") as mock_settings, + ): + mock_settings.roles.VERIFIED = helpers.MockRole(id=99999, name="Verified") + mock_member.remove_roles = AsyncMock() + + await handler.handle_account_unlinked(body, bot) + + # Verify all method calls + mock_validate_discord.assert_called_once_with(discord_id) + mock_validate_account.assert_called_once_with(account_id) + mock_get_member.assert_called_once_with(discord_id, bot) + mock_member.remove_roles.assert_called_once_with( + mock_settings.roles.VERIFIED, atomic=True + ) + + @pytest.mark.asyncio + async def test_handle_account_unlinked_invalid_discord_id(self, bot): + """Test account unlinking with invalid Discord ID.""" + handler = AccountHandler() + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_UNLINKED, + properties={"discord_id": None, "account_id": 987654321}, + traits={}, + ) + + with patch.object( + handler, + "validate_discord_id", + side_effect=HTTPException(status_code=400, detail="Invalid Discord ID"), + ): + with pytest.raises(HTTPException) as exc_info: + await handler.handle_account_unlinked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Discord ID" + + @pytest.mark.asyncio + async def test_handle_account_unlinked_invalid_account_id(self, bot): + """Test account unlinking with invalid Account ID.""" + handler = AccountHandler() + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_UNLINKED, + properties={"discord_id": 123456789, "account_id": None}, + traits={}, + ) + + with ( + patch.object(handler, "validate_discord_id", return_value=123456789), + patch.object( + handler, + "validate_account_id", + side_effect=HTTPException(status_code=400, detail="Invalid Account ID"), + ), + ): + with pytest.raises(HTTPException) as exc_info: + await handler.handle_account_unlinked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Account ID" + + @pytest.mark.asyncio + async def test_handle_account_unlinked_user_not_in_guild(self, bot): + """Test account unlinking when user is not in the Discord guild.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_UNLINKED, + properties={"discord_id": discord_id, "account_id": account_id}, + traits={}, + ) + + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object( + handler, + "get_guild_member", + new_callable=AsyncMock, + side_effect=HTTPException( + status_code=400, detail="User is not in the Discord server" + ), + ), + ): + with pytest.raises(HTTPException) as exc_info: + await handler.handle_account_unlinked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "User is not in the Discord server" diff --git a/tests/src/webhooks/handlers/test_base.py b/tests/src/webhooks/handlers/test_base.py new file mode 100644 index 0000000..ba774cf --- /dev/null +++ b/tests/src/webhooks/handlers/test_base.py @@ -0,0 +1,315 @@ +import logging +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from discord import Bot +from discord.errors import NotFound +from fastapi import HTTPException + +from src.webhooks.handlers.base import BaseHandler +from src.webhooks.types import WebhookBody, Platform, WebhookEvent +from tests import helpers + + +class ConcreteHandler(BaseHandler): + """Concrete implementation of BaseHandler for testing purposes.""" + + async def handle(self, body: WebhookBody, bot: Bot) -> dict: + return {"status": "handled"} + + +class TestBaseHandler: + """Test the `BaseHandler` class.""" + + def test_initialization(self): + """Test that BaseHandler initializes correctly.""" + handler = ConcreteHandler() + + assert isinstance(handler.logger, logging.Logger) + assert handler.logger.name == "ConcreteHandler" + + def test_constants(self): + """Test that all required constants are defined.""" + handler = ConcreteHandler() + + assert handler.ACADEMY_USER_ID == "academy_user_id" + assert handler.MP_USER_ID == "mp_user_id" + assert handler.EP_USER_ID == "ep_user_id" + assert handler.CTF_USER_ID == "ctf_user_id" + assert handler.ACCOUNT_ID == "account_id" + assert handler.DISCORD_ID == "discord_id" + + @pytest.mark.asyncio + async def test_get_guild_member_success(self, bot): + """Test successful guild member retrieval.""" + handler = ConcreteHandler() + discord_id = 123456789 + mock_guild = helpers.MockGuild(id=12345) + mock_member = helpers.MockMember(id=discord_id) + + bot.fetch_guild = AsyncMock(return_value=mock_guild) + mock_guild.fetch_member = AsyncMock(return_value=mock_member) + + with patch("src.webhooks.handlers.base.settings") as mock_settings: + mock_settings.guild_ids = [12345] + + result = await handler.get_guild_member(discord_id, bot) + + assert result == mock_member + bot.fetch_guild.assert_called_once_with(12345) + mock_guild.fetch_member.assert_called_once_with(discord_id) + + @pytest.mark.asyncio + async def test_get_guild_member_not_found(self, bot): + """Test guild member retrieval when user is not in server.""" + handler = ConcreteHandler() + discord_id = 123456789 + mock_guild = helpers.MockGuild(id=12345) + + bot.fetch_guild = AsyncMock(return_value=mock_guild) + mock_guild.fetch_member = AsyncMock( + side_effect=NotFound(MagicMock(), "User not found") + ) + + with patch("src.webhooks.handlers.base.settings") as mock_settings: + mock_settings.guild_ids = [12345] + + with pytest.raises(HTTPException) as exc_info: + await handler.get_guild_member(discord_id, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "User is not in the Discord server" + + def test_validate_property_success(self): + """Test successful property validation.""" + handler = ConcreteHandler() + + result = handler.validate_property("valid_value", "test_property") + assert result == "valid_value" + + result = handler.validate_property(123, "test_number") + assert result == 123 + + def test_validate_property_none(self): + """Test property validation with None value.""" + handler = ConcreteHandler() + + with pytest.raises(HTTPException) as exc_info: + handler.validate_property(None, "test_property") + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid test_property" + + def test_validate_discord_id_success(self): + """Test successful Discord ID validation.""" + handler = ConcreteHandler() + + result = handler.validate_discord_id(123456789) + assert result == 123456789 + + result = handler.validate_discord_id("987654321") + assert result == "987654321" + + def test_validate_discord_id_none(self): + """Test Discord ID validation with None value.""" + handler = ConcreteHandler() + + with pytest.raises(HTTPException) as exc_info: + handler.validate_discord_id(None) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Discord ID" + + def test_validate_account_id_success(self): + """Test successful Account ID validation.""" + handler = ConcreteHandler() + + result = handler.validate_account_id(123456789) + assert result == 123456789 + + result = handler.validate_account_id("987654321") + assert result == "987654321" + + def test_validate_account_id_none(self): + """Test Account ID validation with None value.""" + handler = ConcreteHandler() + + with pytest.raises(HTTPException) as exc_info: + handler.validate_account_id(None) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Account ID" + + def test_get_property_or_trait_from_properties(self): + """Test getting value from properties.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"test_key": 123}, + traits={"test_key": 456, "other_key": 789}, + ) + + result = handler.get_property_or_trait(body, "test_key") + assert result == 123 # Should prioritize properties over traits + + def test_get_property_or_trait_from_traits(self): + """Test getting value from traits when not in properties.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={}, + traits={"test_key": 456}, + ) + + result = handler.get_property_or_trait(body, "test_key") + assert result == 456 + + def test_get_property_or_trait_not_found(self): + """Test getting value when key is not found.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={}, + traits={}, + ) + + result = handler.get_property_or_trait(body, "missing_key") + assert result is None + + def test_merge_properties_and_traits_no_duplicates(self): + """Test merging properties and traits without duplicates.""" + handler = ConcreteHandler() + properties = {"key1": 1, "key2": 2} + traits = {"key3": 3, "key4": 4} + + result = handler.merge_properties_and_traits(properties, traits) + + expected = {"key1": 1, "key2": 2, "key3": 3, "key4": 4} + assert result == expected + + def test_merge_properties_and_traits_with_duplicates(self): + """Test merging properties and traits with duplicate keys.""" + handler = ConcreteHandler() + properties = {"key1": 1, "key2": 2} + traits = {"key2": 99, "key3": 3} # key2 is duplicate + + result = handler.merge_properties_and_traits(properties, traits) + + expected = {"key1": 1, "key2": 2, "key3": 3} # Properties value should win + assert result == expected + + def test_merge_properties_and_traits_empty_properties(self): + """Test merging when properties is empty.""" + handler = ConcreteHandler() + properties = {} + traits = {"key1": 1, "key2": 2} + + result = handler.merge_properties_and_traits(properties, traits) + + assert result == traits + + def test_merge_properties_and_traits_empty_traits(self): + """Test merging when traits is empty.""" + handler = ConcreteHandler() + properties = {"key1": 1, "key2": 2} + traits = {} + + result = handler.merge_properties_and_traits(properties, traits) + + assert result == properties + + def test_get_platform_properties_all_present(self): + """Test getting platform properties when all are present.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={ + "account_id": 1, + "mp_user_id": 2, + "ep_user_id": 3, + "ctf_user_id": 4, + "academy_user_id": 5, + }, + traits={}, + ) + + result = handler.get_platform_properties(body) + + expected = { + "account_id": 1, + "mp_user_id": 2, + "ep_user_id": 3, + "ctf_user_id": 4, + "academy_user_id": 5, + } + assert result == expected + + def test_get_platform_properties_mixed_sources(self): + """Test getting platform properties from both properties and traits.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"account_id": 1, "mp_user_id": 2}, + traits={"ep_user_id": 3, "ctf_user_id": 4, "academy_user_id": 5}, + ) + + result = handler.get_platform_properties(body) + + expected = { + "account_id": 1, + "mp_user_id": 2, + "ep_user_id": 3, + "ctf_user_id": 4, + "academy_user_id": 5, + } + assert result == expected + + def test_get_platform_properties_missing_values(self): + """Test getting platform properties when some are missing.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"account_id": 1}, + traits={"mp_user_id": 2}, + ) + + result = handler.get_platform_properties(body) + + expected = { + "account_id": 1, + "mp_user_id": 2, + "ep_user_id": None, + "ctf_user_id": None, + "academy_user_id": None, + } + assert result == expected + + def test_get_platform_properties_properties_override_traits(self): + """Test that properties override traits for the same key.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"account_id": 1, "mp_user_id": 2}, + traits={ + "mp_user_id": 999, # Should be overridden + "ep_user_id": 3, + }, + ) + + result = handler.get_platform_properties(body) + + expected = { + "account_id": 1, + "mp_user_id": 2, # Properties value should win + "ep_user_id": 3, + "ctf_user_id": None, + "academy_user_id": None, + } + assert result == expected diff --git a/tests/src/webhooks/test_handlers_init.py b/tests/src/webhooks/test_handlers_init.py new file mode 100644 index 0000000..ce0b436 --- /dev/null +++ b/tests/src/webhooks/test_handlers_init.py @@ -0,0 +1,35 @@ +from unittest import mock +from typing import Callable + +import pytest + +from src.webhooks.handlers import handlers, can_handle, handle +from src.webhooks.types import Platform, WebhookBody, WebhookEvent +from tests.conftest import bot + +class TestHandlersInit: + def test_handler_init(self): + assert handlers is not None + assert isinstance(handlers, dict) + assert len(handlers) > 0 + assert all(isinstance(handler, Callable) for handler in handlers.values()) + + def test_can_handle_unknown_platform(self): + assert not can_handle("UNKNOWN") + + def test_can_handle_success(self): + with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: True}): + assert can_handle(Platform.MAIN) + + def test_handle_success(self): + with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: 1337}): + assert handle(WebhookBody(platform=Platform.MAIN, event=WebhookEvent.ACCOUNT_LINKED, properties={}, traits={}), bot) == 1337 + + def test_handle_unknown_platform(self): + with pytest.raises(ValueError): + handle(WebhookBody(platform="UNKNOWN", event=WebhookEvent.ACCOUNT_LINKED, properties={}, traits={}), bot) + + def test_handle_unknown_event(self): + with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: 1337}): + with pytest.raises(ValueError): + handle(WebhookBody(platform=Platform.MAIN, event="UNKNOWN", properties={}, traits={}), bot) \ No newline at end of file