diff --git a/.test.env b/.test.env
index def9d16..914120e 100644
--- a/.test.env
+++ b/.test.env
@@ -32,6 +32,8 @@ CHANNEL_SPOILER=2769521890099371011
CHANNEL_BOT_LOGS=1105517088266788925
# Roles
+ROLE_VERIFIED=1389344463884652554
+
ROLE_BIZCTF2022=7629466241011276950
ROLE_NOAH_GANG=6706800691011276950
ROLE_BUDDY_GANG=6706800681011276950
diff --git a/src/cmds/automation/auto_verify.py b/src/cmds/automation/auto_verify.py
index 606ad3e..98c538d 100644
--- a/src/cmds/automation/auto_verify.py
+++ b/src/cmds/automation/auto_verify.py
@@ -2,12 +2,8 @@
from discord import Member, Message, User
from discord.ext import commands
-from sqlalchemy import select
from src.bot import Bot
-from src.database.models import HtbDiscordLink
-from src.database.session import AsyncSessionLocal
-from src.helpers.verification import get_user_details, process_identification
logger = logging.getLogger(__name__)
@@ -19,31 +15,11 @@ def __init__(self, bot: Bot):
self.bot = bot
async def process_reverification(self, member: Member | User) -> None:
- """Re-verifation process for a member."""
- async with AsyncSessionLocal() as session:
- stmt = (
- select(HtbDiscordLink)
- .where(HtbDiscordLink.discord_user_id == member.id)
- .order_by(HtbDiscordLink.id)
- .limit(1)
- )
- result = await session.scalars(stmt)
- htb_discord_link: HtbDiscordLink = result.first()
-
- if not htb_discord_link:
- raise VerificationError(f"HTB Discord link for user {member.name} with ID {member}")
-
- member_token: str = htb_discord_link.account_identifier
-
- if member_token is None:
- raise VerificationError(f"HTB account identifier for user {member.name} with ID {member.id} not found")
-
- logger.debug(f"Processing re-verify of member {member.name} ({member.id}).")
- htb_details = await get_user_details(member_token)
- if htb_details is None:
- raise VerificationError(f"Retrieving user details for user {member.name} with ID {member.id} failed")
-
- await process_identification(htb_details, user=member, bot=self.bot)
+ """Re-verifation process for a member.
+
+ TODO: Reimplement once it's possible to fetch link state from the HTB Account.
+ """
+ raise VerificationError("Not implemented")
@commands.Cog.listener()
@commands.cooldown(1, 60, commands.BucketType.user)
@@ -74,4 +50,5 @@ class VerificationError(Exception):
def setup(bot: Bot) -> None:
"""Load the `MessageHandler` cog."""
- bot.add_cog(MessageHandler(bot))
+ # bot.add_cog(MessageHandler(bot))
+ pass
diff --git a/src/cmds/core/identify.py b/src/cmds/core/identify.py
index ef52ed7..34a378a 100644
--- a/src/cmds/core/identify.py
+++ b/src/cmds/core/identify.py
@@ -1,17 +1,12 @@
import logging
-from typing import Sequence
-import discord
from discord import ApplicationContext, Interaction, WebhookMessage, slash_command
from discord.ext import commands
from discord.ext.commands import cooldown
-from sqlalchemy import select
from src.bot import Bot
from src.core import settings
-from src.database.models import HtbDiscordLink
-from src.database.session import AsyncSessionLocal
-from src.helpers.verification import get_user_details, process_identification
+from src.helpers.verification import send_verification_instructions
logger = logging.getLogger(__name__)
@@ -25,121 +20,15 @@ def __init__(self, bot: Bot):
@slash_command(
guild_ids=settings.guild_ids,
description="Identify yourself on the HTB Discord server by linking your HTB account ID to your Discord user "
- "ID.", guild_only=False
+ "ID.",
+ guild_only=False,
)
@cooldown(1, 60, commands.BucketType.user)
- async def identify(self, ctx: ApplicationContext, account_identifier: str) -> Interaction | WebhookMessage:
- """Identify yourself on the HTB Discord server by linking your HTB account ID to your Discord user ID."""
- if len(account_identifier) != 60:
- return await ctx.respond(
- "This Account Identifier does not appear to be the right length (must be 60 characters long).",
- ephemeral=True
- )
-
- await ctx.respond("Identification initiated, please wait...", ephemeral=True)
- htb_user_details = await get_user_details(account_identifier)
- if htb_user_details is None:
- embed = discord.Embed(title="Error: Invalid account identifier.", color=0xFF0000)
- return await ctx.respond(embed=embed, ephemeral=True)
-
- json_htb_user_id = htb_user_details["user_id"]
-
- author = ctx.user
- member = await self.bot.get_or_fetch_user(author.id)
- if not member:
- return await ctx.respond(f"Error getting guild member with id: {author.id}.")
-
- # Step 1: Check if the Account Identifier has already been recorded and if they are the previous owner.
- # Scenario:
- # - I create a new Discord account.
- # - I reuse my previous Account Identifier.
- # - I now have an "alt account" with the same roles.
- async with AsyncSessionLocal() as session:
- stmt = (
- select(HtbDiscordLink)
- .filter(HtbDiscordLink.account_identifier == account_identifier)
- .order_by(HtbDiscordLink.id.desc())
- .limit(1)
- )
- result = await session.scalars(stmt)
- most_recent_rec: HtbDiscordLink = result.first()
-
- if most_recent_rec and most_recent_rec.discord_user_id_as_int != member.id:
- error_desc = (
- f"Verified user {member.mention} tried to identify as another identified user.\n"
- f"Current Discord UID: {member.id}\n"
- f"Other Discord UID: {most_recent_rec.discord_user_id}\n"
- f"Related HTB UID: {most_recent_rec.htb_user_id}"
- )
- embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429)
- await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed)
-
- return await ctx.respond(
- "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True
- )
-
- # Step 2: Given the htb_user_id from JSON, check if each discord_user_id are different from member.id.
- # Scenario:
- # - I have a Discord account that is linked already to a "Hacker" role.
- # - I create a new HTB account.
- # - I identify with the new account.
- # - `SELECT * FROM htb_discord_link WHERE htb_user_id = %s` will be empty,
- # because the new account has not been verified before. All is good.
- # - I am now "Noob" rank.
- async with AsyncSessionLocal() as session:
- stmt = select(HtbDiscordLink).filter(HtbDiscordLink.htb_user_id == json_htb_user_id)
- result = await session.scalars(stmt)
- user_links: Sequence[HtbDiscordLink] = result.all()
-
- discord_user_ids = {u_link.discord_user_id_as_int for u_link in user_links}
- if discord_user_ids and member.id not in discord_user_ids:
- orig_discord_ids = ", ".join([f"<@{id_}>" for id_ in discord_user_ids])
- error_desc = (f"The HTB account {json_htb_user_id} attempted to be identified by user <@{member.id}>, "
- f"but is tied to another Discord account.\n"
- f"Originally linked to Discord UID {orig_discord_ids}.")
- embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429)
- await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed)
-
- return await ctx.respond(
- "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True
- )
-
- # Step 3: Check if discord_user_id already linked to an htb_user_id, and if JSON/db HTB IDs are the same.
- # Scenario:
- # - I have a new, unlinked Discord account.
- # - Clubby generates a new token and gives it to me.
- # - `SELECT * FROM htb_discord_link WHERE discord_user_id = %s`
- # will be empty because I have not identified before.
- # - I am now Clubby.
- async with AsyncSessionLocal() as session:
- stmt = select(HtbDiscordLink).filter(HtbDiscordLink.discord_user_id == member.id)
- result = await session.scalars(stmt)
- user_links: Sequence[HtbDiscordLink] = result.all()
-
- user_htb_ids = {u_link.htb_user_id_as_int for u_link in user_links}
- if user_htb_ids and json_htb_user_id not in user_htb_ids:
- error_desc = (f"User {member.mention} ({member.id}) tried to identify with a new HTB account.\n"
- f"Original HTB UIDs: {', '.join([str(i) for i in user_htb_ids])}, new HTB UID: "
- f"{json_htb_user_id}.")
- embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429)
- await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed)
-
- return await ctx.respond(
- "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True
- )
-
- htb_discord_link = HtbDiscordLink(
- account_identifier=account_identifier, discord_user_id=member.id, htb_user_id=json_htb_user_id
- )
- async with AsyncSessionLocal() as session:
- session.add(htb_discord_link)
- await session.commit()
-
- await process_identification(htb_user_details, user=member, bot=self.bot)
-
- return await ctx.respond(
- f"Your Discord user has been successfully identified as HTB user {json_htb_user_id}.", ephemeral=True
- )
+ async def identify(
+ self, ctx: ApplicationContext, account_identifier: str
+ ) -> Interaction | WebhookMessage:
+ """Legacy command. Now sends instructions to identify with HTB account."""
+ await send_verification_instructions(ctx, ctx.author)
def setup(bot: Bot) -> None:
diff --git a/src/cmds/core/verify.py b/src/cmds/core/verify.py
index 23088c1..919f872 100644
--- a/src/cmds/core/verify.py
+++ b/src/cmds/core/verify.py
@@ -1,14 +1,12 @@
import logging
-import discord
from discord import ApplicationContext, Interaction, WebhookMessage, slash_command
-from discord.errors import Forbidden, HTTPException
from discord.ext import commands
from discord.ext.commands import cooldown
from src.bot import Bot
from src.core import settings
-from src.helpers.verification import process_certification
+from src.helpers.verification import process_certification, send_verification_instructions
logger = logging.getLogger(__name__)
@@ -47,57 +45,7 @@ async def verifycertification(self, ctx: ApplicationContext, certid: str, fullna
@cooldown(1, 60, commands.BucketType.user)
async def verify(self, ctx: ApplicationContext) -> Interaction | WebhookMessage:
"""Receive instructions in a DM on how to identify yourself with your HTB account."""
- member = ctx.user
-
- # Step one
- embed_step1 = discord.Embed(color=0x9ACC14)
- embed_step1.add_field(
- name="Step 1: Log in at Hack The Box",
- value="Go to the Hack The Box website at "
- " and navigate to **Login > HTB Labs**. Log in to your HTB Account."
- , inline=False, )
- embed_step1.set_image(
- url="https://media.discordapp.net/attachments/724587782755844098/839871275627315250/unknown.png"
- )
-
- # Step two
- embed_step2 = discord.Embed(color=0x9ACC14)
- embed_step2.add_field(
- name="Step 2: Locate the Account Identifier",
- value='Click on your profile name, then select **My Profile**. '
- 'In the Profile Settings tab, find the field labeled **Account Identifier**. () '
- "Click the green button to copy your secret identifier.", inline=False, )
- embed_step2.set_image(
- url="https://media.discordapp.net/attachments/724587782755844098/839871332963188766/unknown.png"
- )
-
- # Step three
- embed_step3 = discord.Embed(color=0x9ACC14)
- embed_step3.add_field(
- name="Step 3: Identification",
- value="Now type `/identify IDENTIFIER_HERE` in the bot-commands channel.\n\nYour roles will be "
- "applied automatically.", inline=False
- )
- embed_step3.set_image(
- url="https://media.discordapp.net/attachments/709907130102317093/904744444539076618/unknown.png"
- )
-
- try:
- await member.send(embed=embed_step1)
- await member.send(embed=embed_step2)
- await member.send(embed=embed_step3)
- except Forbidden as ex:
- logger.error("Exception during verify call", exc_info=ex)
- return await ctx.respond(
- "Whoops! I cannot DM you after all due to your privacy settings. Please allow DMs from other server "
- "members and try again in 1 minute."
- )
- except HTTPException as ex:
- logger.error("Exception during verify call.", exc_info=ex)
- return await ctx.respond(
- "An unexpected error happened (HTTP 400, bad request). Please contact an Administrator."
- )
- return await ctx.respond("Please check your DM for instructions.", ephemeral=True)
+ await send_verification_instructions(ctx, ctx.author)
def setup(bot: Bot) -> None:
diff --git a/src/core/config.py b/src/core/config.py
index 6b3025e..b91644d 100644
--- a/src/core/config.py
+++ b/src/core/config.py
@@ -91,6 +91,7 @@ class AcademyCertificates(BaseSettings):
class Roles(BaseSettings):
"""The roles settings."""
+ VERIFIED: int
# Moderation
COMMUNITY_MANAGER: int
diff --git a/src/helpers/ban.py b/src/helpers/ban.py
index 8d2ddd0..08a6895 100644
--- a/src/helpers/ban.py
+++ b/src/helpers/ban.py
@@ -2,6 +2,8 @@
import asyncio
import logging
from datetime import datetime, timezone
+from enum import Enum
+
import arrow
import discord
@@ -22,6 +24,12 @@
logger = logging.getLogger(__name__)
+class BanCodes(Enum):
+ SUCCESS = "SUCCESS"
+ ALREADY_EXISTS = "ALREADY_EXISTS"
+ FAILED = "FAILED"
+
+
async def _check_member(bot: Bot, guild: Guild, member: Member | User, author: Member = None) -> SimpleResponse | None:
if isinstance(member, Member):
if member_is_staff(member):
@@ -29,9 +37,9 @@ async def _check_member(bot: Bot, guild: Guild, member: Member | User, author: M
elif isinstance(member, User):
member = await bot.get_member_or_user(guild, member.id)
if member.bot:
- return SimpleResponse(message="You cannot ban a bot.", delete_after=None)
+ return SimpleResponse(message="You cannot ban a bot.", delete_after=None, code=BanCodes.FAILED)
if author and author.id == member.id:
- return SimpleResponse(message="You cannot ban yourself.", delete_after=None)
+ return SimpleResponse(message="You cannot ban yourself.", delete_after=None, code=BanCodes.FAILED)
async def _get_ban_or_create(member: Member, ban: Ban, infraction: Infraction) -> tuple[int, bool]:
@@ -50,6 +58,29 @@ async def _get_ban_or_create(member: Member, ban: Ban, infraction: Infraction) -
return ban_id, False
+async def _create_ban_response(member: Member | User, end_date: str, dm_banned_member: bool, needs_approval: bool) -> SimpleResponse:
+ """Create a SimpleResponse for ban operations."""
+ if needs_approval:
+ if member:
+ message = f"{member.display_name} ({member.id}) has been banned until {end_date} (UTC)."
+ else:
+ message = f"{member.id} has been banned until {end_date} (UTC)."
+ else:
+ if member:
+ message = f"Member {member.display_name} has been banned permanently."
+ else:
+ message = f"Member {member.id} has been banned permanently."
+
+ if not dm_banned_member:
+ message += "\n Could not DM banned member due to permission error."
+
+ return SimpleResponse(
+ message=message,
+ delete_after=0 if not needs_approval else None,
+ code=BanCodes.SUCCESS
+ )
+
+
async def ban_member(
bot: Bot, guild: Guild, member: Member | User, duration: str, reason: str, evidence: str, author: Member = None,
needs_approval: bool = True
@@ -93,7 +124,8 @@ async def ban_member(
if is_existing:
return SimpleResponse(
message=f"A ban with id: {ban_id} already exists for member {member}",
- delete_after=None
+ delete_after=None,
+ code=BanCodes.ALREADY_EXISTS
)
# DM member, before we ban, else we cannot dm since we do not share a guild
@@ -107,27 +139,20 @@ async def ban_member(
extra={"ban_requestor": author.name, "ban_receiver": member.id}
)
if author:
- return SimpleResponse(message="You do not have the proper permissions to ban.", delete_after=None)
+ return SimpleResponse(message="You do not have the proper permissions to ban.", delete_after=None, code=BanCodes.FAILED)
return
except HTTPException as ex:
logger.warning(f"HTTPException when trying to ban user with ID {member.id}", exc_info=ex)
if author:
return SimpleResponse(
message="Here's a 400 Bad Request for you. Just like when you tried to ask me out, last week.",
- delete_after=None
+ delete_after=None,
+ code=BanCodes.FAILED
)
return
# If approval is required, send a message to the moderator channel about the ban
if not needs_approval:
- if member:
- message = f"Member {member.display_name} has been banned permanently."
- else:
- message = f"Member {member.id} has been banned permanently."
-
- if not dm_banned_member:
- message += "\n Could not DM banned member due to permission error."
-
logger.info(
"Member has been banned permanently.",
extra={"ban_requestor": author.name, "ban_receiver": member.id, "dm_banned_member": dm_banned_member}
@@ -136,16 +161,7 @@ async def ban_member(
unban_task = schedule(unban_member(guild, member), run_at=datetime.fromtimestamp(ban.unban_time))
asyncio.create_task(unban_task)
logger.debug("Unbanned sceduled for ban", extra={"ban_id": ban_id, "unban_time": ban.unban_time})
- return SimpleResponse(message=message, delete_after=0)
else:
- if member:
- message = f"{member.display_name} ({member.id}) has been banned until {end_date} (UTC)."
- else:
- message = f"{member.id} has been banned until {end_date} (UTC)."
-
- if not dm_banned_member:
- message += " Could not DM banned member due to permission error."
-
member_name = f"{member.display_name} ({member.name})"
embed = discord.Embed(
title=f"Ban request #{ban_id}",
@@ -156,7 +172,8 @@ async def ban_member(
embed.set_thumbnail(url=f"{settings.HTB_URL}/images/logo600.png")
view = BanDecisionView(ban_id, bot, guild, member, end_date, reason)
await guild.get_channel(settings.channels.SR_MOD).send(embed=embed, view=view)
- return SimpleResponse(message=message)
+
+ return await _create_ban_response(member, end_date, dm_banned_member, needs_approval)
async def _dm_banned_member(end_date: str, guild: Guild, member: Member, reason: str) -> bool:
diff --git a/src/helpers/responses.py b/src/helpers/responses.py
index 6513bfd..91a8213 100644
--- a/src/helpers/responses.py
+++ b/src/helpers/responses.py
@@ -4,9 +4,10 @@
class SimpleResponse(object):
"""A simple response object."""
- def __init__(self, message: str, delete_after: int | None = None):
+ def __init__(self, message: str, delete_after: int | None = None, code: str | None = None):
self.message = message
self.delete_after = delete_after
+ self.code = code
def __str__(self):
return json.dumps(dict(self), ensure_ascii=False)
diff --git a/src/helpers/verification.py b/src/helpers/verification.py
index ddb4be9..ae02038 100644
--- a/src/helpers/verification.py
+++ b/src/helpers/verification.py
@@ -1,19 +1,86 @@
import logging
-from datetime import datetime
-from typing import Dict, List, Optional, cast
+from typing import Dict, List, Optional
import aiohttp
import discord
-from discord import Forbidden, Member, Role, User
+from discord import ApplicationContext, Forbidden, Member, Role, User, HTTPException
from discord.ext.commands import GuildNotFound, MemberNotFound
from src.bot import Bot
from src.core import settings
-from src.helpers.ban import ban_member
+from src.helpers.ban import BanCodes, ban_member
logger = logging.getLogger(__name__)
+async def send_verification_instructions(
+ ctx: ApplicationContext, member: Member
+) -> discord.Interaction | discord.WebhookMessage:
+ """Send instructions via DM on how to identify with HTB account.
+
+ Args:
+ ctx (ApplicationContext): The context of the command.
+ member (Member): The member to send the instructions to.
+
+ Returns:
+ discord.Interaction | discord.WebhookMessage: The response message.
+ """
+ member = ctx.user
+
+ # Create step-by-step instruction embeds
+ embed_step1 = discord.Embed(color=0x9ACC14)
+ embed_step1.add_field(
+ name="Step 1: Login to your HTB Account",
+ value="Go to and login.",
+ inline=False,
+ )
+ embed_step1.set_image(
+ url="https://media.discordapp.net/attachments/1102700815493378220/1384587341338902579/image.png"
+ )
+
+ embed_step2 = discord.Embed(color=0x9ACC14)
+ embed_step2.add_field(
+ name="Step 2: Navigate to your Security Settings",
+ value="In the navigation bar, click on **Security Settings** and scroll down to the **Discord Account** section. "
+ "()",
+ inline=False,
+ )
+ embed_step2.set_image(
+ url="https://media.discordapp.net/attachments/1102700815493378220/1384587813760270392/image.png"
+ )
+
+ embed_step3 = discord.Embed(color=0x9ACC14)
+ embed_step3.add_field(
+ name="Step 3: Link your Discord Account",
+ value="Click **Connect** and you will be redirected to login to your Discord account via oauth. "
+ "After logging in, you will be redirected back to the HTB Account page. "
+ "Your Discord account will now be linked. Discord may take a few minutes to update. "
+ "If you have any issues, please contact a Moderator.",
+ inline=False,
+ )
+ embed_step3.set_image(
+ url="https://media.discordapp.net/attachments/1102700815493378220/1384586811384402042/image.png"
+ )
+
+ try:
+ await member.send(embed=embed_step1)
+ await member.send(embed=embed_step2)
+ await member.send(embed=embed_step3)
+ except Forbidden as ex:
+ logger.error("Exception during verify call", exc_info=ex)
+ return await ctx.respond(
+ "Whoops! I cannot DM you after all due to your privacy settings. Please allow DMs from other server "
+ "members and try again in 1 minute."
+ )
+ except HTTPException as ex:
+ logger.error("Exception during verify call.", exc_info=ex)
+ return await ctx.respond(
+ "An unexpected error happened (HTTP 400, bad request). Please contact an Administrator."
+ )
+
+ return await ctx.respond("Please check your DM for instructions.", ephemeral=True)
+
+
async def get_user_details(account_identifier: str) -> Optional[Dict]:
"""Get user details from HTB."""
acc_id_url = f"{settings.API_URL}/discord/identifier/{account_identifier}?secret={settings.HTB_API_SECRET}"
@@ -23,10 +90,14 @@ async def get_user_details(account_identifier: str) -> Optional[Dict]:
if r.status == 200:
response = await r.json()
elif r.status == 404:
- logger.debug("Account identifier has been regenerated since last identification. Cannot re-verify.")
+ logger.debug(
+ "Account identifier has been regenerated since last identification. Cannot re-verify."
+ )
response = None
else:
- logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.")
+ logger.error(
+ f"Non-OK HTTP status code returned from identifier lookup: {r.status}."
+ )
response = None
return response
@@ -45,7 +116,9 @@ async def get_season_rank(htb_uid: int) -> str | None:
logger.error("Invalid Season ID.")
response = None
else:
- logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.")
+ logger.error(
+ f"Non-OK HTTP status code returned from identifier lookup: {r.status}."
+ )
response = None
if not response["data"]:
@@ -59,26 +132,19 @@ async def get_season_rank(htb_uid: int) -> str | None:
return rank
-async def _check_for_ban(uid: str) -> Optional[Dict]:
- async with aiohttp.ClientSession() as session:
- token_url = f"{settings.API_URL}/discord/{uid}/banned?secret={settings.HTB_API_SECRET}"
- async with session.get(token_url) as r:
- if r.status == 200:
- ban_details = await r.json()
- else:
- logger.error(
- f"Could not fetch ban details for uid {uid}: "
- f"non-OK status code returned ({r.status}). Body: {r.content}"
- )
- ban_details = None
-
- return ban_details
+async def _check_for_ban(member: Member) -> Optional[Dict]:
+ """Check if the member is banned."""
+ try:
+ member.guild.get_role(settings.roles.BANNED)
+ return True
+ except Forbidden:
+ return False
async def process_certification(certid: str, name: str):
"""Process certifications."""
cert_api_url = f"{settings.API_V4_URL}/certificate/lookup"
- params = {'id': certid, 'name': name}
+ params = {"id": certid, "name": name}
async with aiohttp.ClientSession() as session:
async with session.get(cert_api_url, params=params) as r:
if r.status == 200:
@@ -86,7 +152,9 @@ async def process_certification(certid: str, name: str):
elif r.status == 404:
return False
else:
- logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.")
+ logger.error(
+ f"Non-OK HTTP status code returned from identifier lookup: {r.status}."
+ )
response = None
try:
certRawName = response["certificates"][0]["name"]
@@ -107,7 +175,90 @@ async def process_certification(certid: str, name: str):
return cert
-async def process_identification(
+async def _handle_banned_user(member: Member, bot: Bot):
+ """Handle banned trait during account linking.
+
+ Args:
+ member (Member): The member to process.
+ bot (Bot): The bot instance.
+ """
+ resp = await ban_member(
+ bot,
+ member.guild,
+ member,
+ "1337w",
+ (
+ "Banned on the HTB Platform. Ban duration could not be determined. "
+ "Please login to confirm ban details and contact HTB Support to appeal."
+ ),
+ None,
+ needs_approval=False,
+ )
+ if resp.code == BanCodes.SUCCESS:
+ embed = discord.Embed(
+ title="Identification error",
+ description=f"User {member.mention} ({member.id}) was platform banned HTB and thus also here.",
+ color=0xFF2429,
+ )
+ await member.guild.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed)
+
+
+async def _set_nickname(member: Member, nickname: str) -> bool:
+ """Set the nickname of the member.
+
+ Args:
+ member (Member): The member to set the nickname for.
+ nickname (str): The nickname to set.
+
+ Returns:
+ bool: True if the nickname was set, False otherwise.
+ """
+ try:
+ await member.edit(nick=nickname)
+ return True
+ except Forbidden as e:
+ logger.error(f"Exception whe trying to edit the nick-name of the user: {e}")
+ return False
+
+
+async def process_account_identification(
+ member: Member, bot: Bot, traits: dict[str, str] | None = None
+) -> None:
+ """Process HTB account identification, to be called during account linking.
+
+ Args:
+ member (Member): The member to process.
+ bot (Bot): The bot instance.
+ traits (dict[str, str] | None): Optional user traits to process.
+ """
+ await member.add_roles(member.guild.get_role(settings.roles.VERIFIED), atomic=True)
+
+ nickname_changed = False
+
+ if traits.get("username") and traits.get("username") != member.name:
+ nickname_changed = await _set_nickname(member, traits.get("username"))
+
+ if not nickname_changed:
+ logger.warning(
+ f"No username provided for {member.name} with ID {member.id} during identification."
+ )
+
+ if traits.get("mp_user_id"):
+ htb_user_details = await get_user_details(traits.get("mp_user_id"))
+ await process_labs_identification(htb_user_details, member, bot)
+
+ if not nickname_changed:
+ logger.debug(
+ f"Falling back on HTB username to set nickname for {member.name} with ID {member.id}."
+ )
+ await _set_nickname(member, htb_user_details["username"])
+
+ if traits.get("banned", False) == True: # noqa: E712 - explicit bool only, no truthiness
+ await _handle_banned_user(member, bot)
+ return
+
+
+async def process_labs_identification(
htb_user_details: Dict[str, str], user: Optional[Member | User], bot: Bot
) -> Optional[List[Role]]:
"""Returns roles to assign if identification was successfully processed."""
@@ -123,54 +274,33 @@ async def process_identification(
raise MemberNotFound(str(user.id))
else:
raise GuildNotFound(f"Could not identify member {user} in guild.")
- season_rank = await get_season_rank(htb_uid)
- banned_details = await _check_for_ban(htb_uid)
-
- if banned_details is not None and banned_details["banned"]:
- # If user is banned, this field must be a string
- # Strip date e.g. from "2022-01-31T11:00:00.000000Z"
- banned_until: str = cast(str, banned_details["ends_at"])[:10]
- banned_until_dt: datetime = datetime.strptime(banned_until, "%Y-%m-%d")
- ban_duration: str = f"{(banned_until_dt - datetime.now()).days}d"
- reason = "Banned on the HTB Platform. Please contact HTB Support to appeal."
- logger.info(f"Discord user {member.name} ({member.id}) is platform banned. Banning from Discord...")
- await ban_member(bot, guild, member, ban_duration, reason, None, needs_approval=False)
-
- embed = discord.Embed(
- title="Identification error",
- description=f"User {member.mention} ({member.id}) was platform banned HTB and thus also here.",
- color=0xFF2429, )
-
- await guild.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed)
- return None
to_remove = []
-
for role in member.roles:
- if role.id in settings.role_groups.get("ALL_RANKS") + settings.role_groups.get("ALL_POSITIONS"):
+ if role.id in settings.role_groups.get("ALL_RANKS") + settings.role_groups.get(
+ "ALL_POSITIONS"
+ ):
to_remove.append(guild.get_role(role.id))
to_assign = []
- logger.debug(
- "Getting role 'rank':", extra={
- "role_id": settings.get_post_or_rank(htb_user_details["rank"]),
- "role_obj": guild.get_role(settings.get_post_or_rank(htb_user_details["rank"])),
- "htb_rank": htb_user_details["rank"],
- }, )
- if htb_user_details["rank"] not in ["Deleted", "Moderator", "Ambassador", "Admin", "Staff"]:
- to_assign.append(guild.get_role(settings.get_post_or_rank(htb_user_details["rank"])))
+ if htb_user_details["rank"] not in [
+ "Deleted",
+ "Moderator",
+ "Ambassador",
+ "Admin",
+ "Staff",
+ ]:
+ to_assign.append(
+ guild.get_role(settings.get_post_or_rank(htb_user_details["rank"]))
+ )
+
+ season_rank = await get_season_rank(htb_uid)
if season_rank:
to_assign.append(guild.get_role(settings.get_season(season_rank)))
+
if htb_user_details["vip"]:
- logger.debug(
- 'Getting role "VIP":', extra={"role_id": settings.roles.VIP, "role_obj": guild.get_role(settings.roles.VIP)}
- )
to_assign.append(guild.get_role(settings.roles.VIP))
if htb_user_details["dedivip"]:
- logger.debug(
- 'Getting role "VIP+":',
- extra={"role_id": settings.roles.VIP_PLUS, "role_obj": guild.get_role(settings.roles.VIP_PLUS)}
- )
to_assign.append(guild.get_role(settings.roles.VIP_PLUS))
if htb_user_details["hof_position"] != "unranked":
position = int(htb_user_details["hof_position"])
@@ -180,45 +310,17 @@ async def process_identification(
elif position <= 10:
pos_top = "10"
if pos_top:
- logger.debug(f"User is Hall of Fame rank {position}. Assigning role Top-{pos_top}...")
- logger.debug(
- 'Getting role "HoF role":', extra={
- "role_id": settings.get_post_or_rank(pos_top),
- "role_obj": guild.get_role(settings.get_post_or_rank(pos_top)), "hof_val": pos_top,
- }, )
to_assign.append(guild.get_role(settings.get_post_or_rank(pos_top)))
- else:
- logger.debug(f"User is position {position}. No Hall of Fame roles for them.")
if htb_user_details["machines"]:
- logger.debug(
- 'Getting role "BOX_CREATOR":',
- extra={"role_id": settings.roles.BOX_CREATOR, "role_obj": guild.get_role(settings.roles.BOX_CREATOR)}, )
to_assign.append(guild.get_role(settings.roles.BOX_CREATOR))
if htb_user_details["challenges"]:
- logger.debug(
- 'Getting role "CHALLENGE_CREATOR":', extra={
- "role_id": settings.roles.CHALLENGE_CREATOR,
- "role_obj": guild.get_role(settings.roles.CHALLENGE_CREATOR),
- }, )
to_assign.append(guild.get_role(settings.roles.CHALLENGE_CREATOR))
- if member.nick != htb_user_details["user_name"]:
- try:
- await member.edit(nick=htb_user_details["user_name"])
- except Forbidden as e:
- logger.error(f"Exception whe trying to edit the nick-name of the user: {e}")
-
- logger.debug("All roles to_assign:", extra={"to_assign": to_assign})
# We don't need to remove any roles that are going to be assigned again
to_remove = list(set(to_remove) - set(to_assign))
- logger.debug("All roles to_remove:", extra={"to_remove": to_remove})
if to_remove:
await member.remove_roles(*to_remove, atomic=True)
- else:
- logger.debug("No roles need to be removed")
if to_assign:
await member.add_roles(*to_assign, atomic=True)
- else:
- logger.debug("No roles need to be assigned")
return to_assign
diff --git a/src/webhooks/handlers/__init__.py b/src/webhooks/handlers/__init__.py
index 02e0edf..601943c 100644
--- a/src/webhooks/handlers/__init__.py
+++ b/src/webhooks/handlers/__init__.py
@@ -1,9 +1,9 @@
from discord import Bot
-from src.webhooks.handlers.academy import handler as academy_handler
+from src.webhooks.handlers.account import AccountHandler
from src.webhooks.types import Platform, WebhookBody
-handlers = {Platform.ACADEMY: academy_handler}
+handlers = {Platform.ACCOUNT: AccountHandler().handle}
def can_handle(platform: Platform) -> bool:
diff --git a/src/webhooks/handlers/academy.py b/src/webhooks/handlers/academy.py
deleted file mode 100644
index 5aa3568..0000000
--- a/src/webhooks/handlers/academy.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import logging
-
-from discord import Bot
-from discord.errors import NotFound
-from fastapi import HTTPException
-
-from src.core import settings
-from src.webhooks.types import WebhookBody, WebhookEvent
-
-logger = logging.getLogger(__name__)
-
-
-async def handler(body: WebhookBody, bot: Bot) -> dict:
- """
- Handles incoming webhook events and performs actions accordingly.
-
- This function processes different webhook events related to account linking,
- certificate awarding, and account unlinking. It updates the member's roles
- based on the received event.
-
- Args:
- body (WebhookBody): The data received from the webhook.
- bot (Bot): The instance of the Discord bot.
-
- Returns:
- dict: A dictionary with a "success" key indicating whether the operation was successful.
-
- Raises:
- HTTPException: If an error occurs while processing the webhook event.
- """
- # TODO: Change it here so we pass the guild instead of the bot # noqa: T000
- guild = await bot.fetch_guild(settings.guild_ids[0])
-
- try:
- discord_id = int(body.data["discord_id"])
- member = await guild.fetch_member(discord_id)
- except ValueError as exc:
- logger.debug("Invalid Discord ID", exc_info=exc)
- raise HTTPException(status_code=400, detail="Invalid Discord ID") from exc
- except NotFound as exc:
- logger.debug("User is not in the Discord server", exc_info=exc)
- raise HTTPException(status_code=400, detail="User is not in the Discord server") from exc
-
- if body.event == WebhookEvent.ACCOUNT_LINKED:
- roles_to_add = {settings.roles.ACADEMY_USER}
- roles_to_add.update(settings.get_academy_cert_role(cert["id"]) for cert in body.data["certifications"])
-
- # Filter out invalid role IDs
- role_ids_to_add = {role_id for role_id in roles_to_add if role_id is not None}
- roles_to_add = {guild.get_role(role_id) for role_id in role_ids_to_add}
-
- await member.add_roles(*roles_to_add, atomic=True)
- elif body.event == WebhookEvent.CERTIFICATE_AWARDED:
- cert_id = body.data["certification"]["id"]
-
- role = settings.get_academy_cert_role(cert_id)
- if not role:
- logger.debug(f"Role for certification: {cert_id} does not exist")
- raise HTTPException(status_code=400, detail=f"Role for certification: {cert_id} does not exist")
-
- await member.add_roles(role, atomic=True)
- elif body.event == WebhookEvent.ACCOUNT_UNLINKED:
- current_role_ids = {role.id for role in member.roles}
- cert_role_ids = {settings.get_academy_cert_role(cert_id) for _, cert_id in settings.academy_certificates}
-
- common_role_ids = current_role_ids.intersection(cert_role_ids)
-
- role_ids_to_remove = {settings.roles.ACADEMY_USER}.union(common_role_ids)
- roles_to_remove = {guild.get_role(role_id) for role_id in role_ids_to_remove}
-
- await member.remove_roles(*roles_to_remove, atomic=True)
- else:
- logger.debug(f"Event {body.event} not implemented")
- raise HTTPException(status_code=501, detail=f"Event {body.event} not implemented")
-
- return {"success": True}
diff --git a/src/webhooks/handlers/account.py b/src/webhooks/handlers/account.py
new file mode 100644
index 0000000..895cc9f
--- /dev/null
+++ b/src/webhooks/handlers/account.py
@@ -0,0 +1,50 @@
+import logging
+
+from discord import Bot
+
+from src.core import settings
+from src.helpers.verification import process_account_identification
+from src.webhooks.types import WebhookBody, WebhookEvent
+
+from src.webhooks.handlers.base import BaseHandler
+
+
+
+class AccountHandler(BaseHandler):
+ async def handle(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles incoming webhook events and performs actions accordingly.
+
+ This function processes different webhook events originating from the
+ HTB Account.
+ """
+ if body.event == WebhookEvent.ACCOUNT_LINKED:
+ await self.handle_account_linked(body, bot)
+ elif body.event == WebhookEvent.ACCOUNT_UNLINKED:
+ await self.handle_account_unlinked(body, bot)
+ elif body.event == WebhookEvent.ACCOUNT_DELETED:
+ await self.handle_account_deleted(body, bot)
+
+ async def handle_account_linked(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles the account linked event.
+ """
+ discord_id = self.validate_discord_id(body.properties.get("discord_id"))
+ account_id = self.validate_account_id(body.properties.get("account_id"))
+
+ member = await self.get_guild_member(discord_id, bot)
+ await process_account_identification(member, bot, traits=self.merge_properties_and_traits(body.properties, body.traits))
+ await bot.send_message(settings.channels.VERIFY_LOGS, f"Account linked: {account_id} -> ({member.mention} ({member.id})")
+
+ self.logger.info(f"Account {account_id} linked to {member.id}", extra={"account_id": account_id, "discord_id": discord_id})
+
+ async def handle_account_unlinked(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles the account unlinked event.
+ """
+ discord_id = self.validate_discord_id(body.properties.get("discord_id"))
+ account_id = self.validate_account_id(body.properties.get("account_id"))
+
+ member = await self.get_guild_member(discord_id, bot)
+
+ await member.remove_roles(settings.roles.VERIFIED, atomic=True)
diff --git a/src/webhooks/handlers/base.py b/src/webhooks/handlers/base.py
new file mode 100644
index 0000000..cfb63d9
--- /dev/null
+++ b/src/webhooks/handlers/base.py
@@ -0,0 +1,114 @@
+import logging
+
+from abc import ABC, abstractmethod
+from typing import TypeVar
+
+from discord import Bot, Member
+from discord.errors import NotFound
+from fastapi import HTTPException
+
+from src.core import settings
+from src.webhooks.types import WebhookBody
+
+T = TypeVar("T")
+
+
+class BaseHandler(ABC):
+ ACADEMY_USER_ID = "academy_user_id"
+ MP_USER_ID = "mp_user_id"
+ EP_USER_ID = "ep_user_id"
+ CTF_USER_ID = "ctf_user_id"
+ ACCOUNT_ID = "account_id"
+ DISCORD_ID = "discord_id"
+
+
+ def __init__(self):
+ self.logger = logging.getLogger(self.__class__.__name__)
+
+ @abstractmethod
+ async def handle(self, body: WebhookBody, bot: Bot) -> dict:
+ pass
+
+ async def get_guild_member(self, discord_id: int, bot: Bot) -> Member:
+ """
+ Fetches a guild member from the Discord server.
+
+ Args:
+ discord_id (int): The Discord ID of the user.
+ bot (Bot): The Discord bot instance.
+
+ Returns:
+ Member: The guild member.
+
+ Raises:
+ HTTPException: If the user is not in the Discord server (400)
+ """
+ try:
+ guild = await bot.fetch_guild(settings.guild_ids[0])
+ member = await guild.fetch_member(discord_id)
+ return member
+
+ except NotFound as exc:
+ self.logger.debug("User is not in the Discord server", exc_info=exc)
+ raise HTTPException(
+ status_code=400, detail="User is not in the Discord server"
+ ) from exc
+
+ def validate_property(self, property: T | None, name: str) -> T:
+ """
+ Validates a property is not None.
+
+ Args:
+ property (T | None): The property to validate.
+ name (str): The name of the property.
+
+ Returns:
+ T: The validated property.
+
+ Raises:
+ HTTPException: If the property is None (400)
+ """
+ if property is None:
+ msg = f"Invalid {name}"
+ self.logger.debug(msg)
+ raise HTTPException(status_code=400, detail=msg)
+
+ return property
+
+ def validate_discord_id(self, discord_id: str | int) -> int:
+ """
+ Validates the Discord ID. See validate_property function.
+ """
+ return self.validate_property(discord_id, "Discord ID")
+
+ def validate_account_id(self, account_id: str | int) -> int:
+ """
+ Validates the Account ID. See validate_property function.
+ """
+ return self.validate_property(account_id, "Account ID")
+
+ def get_property_or_trait(self, body: WebhookBody, name: str) -> int | None:
+ """
+ Gets a trait or property from the webhook body.
+ """
+ return body.properties.get(name) or body.traits.get(name)
+
+ def merge_properties_and_traits(self, properties: dict[str, int | None], traits: dict[str, int | None]) -> dict[str, int | None]:
+ """
+ Merges the properties and traits from the webhook body without duplicates.
+ If a property and trait have the same name but different values, the property value will be used.
+ """
+ return {**properties, **{k: v for k, v in traits.items() if k not in properties}}
+
+ def get_platform_properties(self, body: WebhookBody) -> dict[str, int | None]:
+ """
+ Gets the platform properties from the webhook body.
+ """
+ properties = {
+ self.ACCOUNT_ID: self.get_property_or_trait(body, self.ACCOUNT_ID),
+ self.MP_USER_ID: self.get_property_or_trait(body, self.MP_USER_ID),
+ self.EP_USER_ID: self.get_property_or_trait(body, self.EP_USER_ID),
+ self.CTF_USER_ID: self.get_property_or_trait(body, self.CTF_USER_ID),
+ self.ACADEMY_USER_ID: self.get_property_or_trait(body, self.ACADEMY_USER_ID),
+ }
+ return properties
\ No newline at end of file
diff --git a/src/webhooks/server.py b/src/webhooks/server.py
index 92222a7..06d02af 100644
--- a/src/webhooks/server.py
+++ b/src/webhooks/server.py
@@ -1,10 +1,13 @@
+import hashlib
import hmac
import logging
-from typing import Any, Dict, Union
+import json
+from typing import Any, Dict
-from fastapi import FastAPI, Header, HTTPException
+from fastapi import FastAPI, HTTPException, Request
from hypercorn.asyncio import serve as hypercorn_serve
from hypercorn.config import Config as HypercornConfig
+from pydantic import ValidationError
from src.bot import bot
from src.core import settings
@@ -17,20 +20,36 @@
app = FastAPI()
+def verify_signature(body: dict, signature: str, secret: str) -> bool:
+ """
+ HMAC SHA1 signature verification.
+
+ Args:
+ body (dict): The raw body of the webhook request.
+ signature (str): The X-Signature header of the webhook request.
+ secret (str): The webhook secret.
+
+ Returns:
+ bool: True if the signature is valid, False otherwise.
+ """
+ if not signature:
+ return False
+
+ digest = hmac.new(secret.encode(), body, hashlib.sha1).hexdigest()
+ return hmac.compare_digest(signature, digest)
+
+
@app.post("/webhook")
-async def webhook_handler(
- body: WebhookBody, authorization: Union[str, None] = Header(default=None)
-) -> Dict[str, Any]:
+async def webhook_handler(request: Request) -> Dict[str, Any]:
"""
Handles incoming webhook requests and forwards them to the appropriate handler.
- This function first checks the provided authorization token in the request header.
- If the token is valid, it checks if the platform can be handled and then forwards
+ This function first verifies the provided HMAC signature in the request header.
+ If the signature is valid, it checks if the platform can be handled and then forwards
the request to the corresponding handler.
Args:
- body (WebhookBody): The data received from the webhook.
- authorization (Union[str, None]): The authorization header containing the Bearer token.
+ request (Request): The incoming webhook request.
Returns:
Dict[str, Any]: The response from the corresponding handler. The dictionary contains
@@ -39,17 +58,21 @@ async def webhook_handler(
Raises:
HTTPException: If an error occurs while processing the webhook event or if unauthorized.
"""
- if authorization is None or not authorization.strip().startswith("Bearer"):
- logger.warning("Unauthorized webhook request")
- raise HTTPException(status_code=401, detail="Unauthorized")
+ body = await request.body()
+ signature = request.headers.get("X-Signature")
- token = authorization[6:].strip()
- if hmac.compare_digest(token, settings.WEBHOOK_TOKEN):
+ if not verify_signature(body, signature, settings.WEBHOOK_TOKEN):
logger.warning("Unauthorized webhook request")
raise HTTPException(status_code=401, detail="Unauthorized")
+ try:
+ body = WebhookBody.model_validate(json.loads(body))
+ except ValidationError as e:
+ logger.warning("Invalid webhook request: %s", e.errors())
+ raise HTTPException(status_code=400, detail="Invalid webhook request body")
+
if not handlers.can_handle(body.platform):
- logger.warning("Webhook request not handled by platform")
+ logger.warning("Webhook request not handled by platform: %s", body.platform)
raise HTTPException(status_code=501, detail="Platform not implemented")
return await handlers.handle(body, bot)
diff --git a/src/webhooks/types.py b/src/webhooks/types.py
index 3885dda..9812ccb 100644
--- a/src/webhooks/types.py
+++ b/src/webhooks/types.py
@@ -1,11 +1,12 @@
from enum import Enum
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict
class WebhookEvent(Enum):
- ACCOUNT_LINKED = "AccountLinked"
- ACCOUNT_UNLINKED = "AccountUnlinked"
+ ACCOUNT_LINKED = "DiscordAccountLinked"
+ ACCOUNT_UNLINKED = "DiscordAccountUnlinked"
+ ACCOUNT_DELETED = "UserAccountDeleted"
CERTIFICATE_AWARDED = "CertificateAwarded"
RANK_UP = "RankUp"
HOF_CHANGE = "HofChange"
@@ -19,9 +20,13 @@ class Platform(Enum):
ACADEMY = "academy"
CTF = "ctf"
ENTERPRISE = "enterprise"
+ ACCOUNT = "account"
class WebhookBody(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
platform: Platform
event: WebhookEvent
- data: dict
+ properties: dict | None
+ traits: dict | None
diff --git a/tests/src/webhooks/handlers/test_account.py b/tests/src/webhooks/handlers/test_account.py
new file mode 100644
index 0000000..91cf195
--- /dev/null
+++ b/tests/src/webhooks/handlers/test_account.py
@@ -0,0 +1,364 @@
+import logging
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from discord import Bot, Member
+from discord.errors import NotFound
+from fastapi import HTTPException
+
+from src.webhooks.handlers.account import AccountHandler
+from src.webhooks.types import WebhookBody, Platform, WebhookEvent
+from tests import helpers
+
+
+class TestAccountHandler:
+ """Test the `AccountHandler` class."""
+
+ def test_initialization(self):
+ """Test that AccountHandler initializes correctly."""
+ handler = AccountHandler()
+
+ assert isinstance(handler.logger, logging.Logger)
+ assert handler.logger.name == "AccountHandler"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_linked_event(self, bot):
+ """Test handle method routes ACCOUNT_LINKED event correctly."""
+ handler = AccountHandler()
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"discord_id": 123456789, "account_id": 987654321},
+ traits={},
+ )
+
+ with patch.object(
+ handler, "handle_account_linked", new_callable=AsyncMock
+ ) as mock_handle:
+ await handler.handle(body, bot)
+ mock_handle.assert_called_once_with(body, bot)
+
+ @pytest.mark.asyncio
+ async def test_handle_account_unlinked_event(self, bot):
+ """Test handle method routes ACCOUNT_UNLINKED event correctly."""
+ handler = AccountHandler()
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_UNLINKED,
+ properties={"discord_id": 123456789, "account_id": 987654321},
+ traits={},
+ )
+
+ with patch.object(
+ handler, "handle_account_unlinked", new_callable=AsyncMock
+ ) as mock_handle:
+ await handler.handle(body, bot)
+ mock_handle.assert_called_once_with(body, bot)
+
+ @pytest.mark.asyncio
+ async def test_handle_account_deleted_event(self, bot):
+ """Test handle method with ACCOUNT_DELETED event (method not implemented)."""
+ handler = AccountHandler()
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_DELETED,
+ properties={"discord_id": 123456789, "account_id": 987654321},
+ traits={},
+ )
+
+ # The handle_account_deleted method is not implemented, so this should raise AttributeError
+ with pytest.raises(AttributeError):
+ await handler.handle(body, bot)
+
+ @pytest.mark.asyncio
+ async def test_handle_unknown_event(self, bot):
+ """Test handle method with unknown event does nothing."""
+ handler = AccountHandler()
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.CERTIFICATE_AWARDED, # Not handled by AccountHandler
+ properties={"discord_id": 123456789, "account_id": 987654321},
+ traits={},
+ )
+
+ # Should not raise any exceptions, just do nothing
+ await handler.handle(body, bot)
+
+ @pytest.mark.asyncio
+ async def test_handle_account_linked_success(self, bot):
+ """Test successful account linking."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ mock_member = helpers.MockMember(id=discord_id, mention="@testuser")
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"discord_id": discord_id, "account_id": account_id},
+ traits={"htb_user_id": 555},
+ )
+
+ # Create a custom bot mock without spec_set restrictions for this test
+ custom_bot = MagicMock()
+ custom_bot.send_message = AsyncMock()
+
+ with (
+ patch.object(
+ handler, "validate_discord_id", return_value=discord_id
+ ) as mock_validate_discord,
+ patch.object(
+ handler, "validate_account_id", return_value=account_id
+ ) as mock_validate_account,
+ patch.object(
+ handler,
+ "get_guild_member",
+ new_callable=AsyncMock,
+ return_value=mock_member,
+ ) as mock_get_member,
+ patch.object(
+ handler,
+ "merge_properties_and_traits",
+ return_value={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "htb_user_id": 555,
+ },
+ ) as mock_merge,
+ patch(
+ "src.webhooks.handlers.account.process_account_identification",
+ new_callable=AsyncMock,
+ ) as mock_process,
+ patch("src.webhooks.handlers.account.settings") as mock_settings,
+ patch.object(handler.logger, "info") as mock_log,
+ ):
+ mock_settings.channels.VERIFY_LOGS = 12345
+
+ await handler.handle_account_linked(body, custom_bot)
+
+ # Verify all method calls
+ mock_validate_discord.assert_called_once_with(discord_id)
+ mock_validate_account.assert_called_once_with(account_id)
+ mock_get_member.assert_called_once_with(discord_id, custom_bot)
+ mock_merge.assert_called_once_with(body.properties, body.traits)
+ mock_process.assert_called_once_with(
+ mock_member,
+ custom_bot,
+ traits={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "htb_user_id": 555,
+ },
+ )
+ custom_bot.send_message.assert_called_once_with(
+ 12345, f"Account linked: {account_id} -> (@testuser ({discord_id})"
+ )
+ mock_log.assert_called_once_with(
+ f"Account {account_id} linked to {discord_id}",
+ extra={"account_id": account_id, "discord_id": discord_id},
+ )
+
+ @pytest.mark.asyncio
+ async def test_handle_account_linked_invalid_discord_id(self, bot):
+ """Test account linking with invalid Discord ID."""
+ handler = AccountHandler()
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"discord_id": None, "account_id": 987654321},
+ traits={},
+ )
+
+ with patch.object(
+ handler,
+ "validate_discord_id",
+ side_effect=HTTPException(status_code=400, detail="Invalid Discord ID"),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler.handle_account_linked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Discord ID"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_linked_invalid_account_id(self, bot):
+ """Test account linking with invalid Account ID."""
+ handler = AccountHandler()
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"discord_id": 123456789, "account_id": None},
+ traits={},
+ )
+
+ with (
+ patch.object(handler, "validate_discord_id", return_value=123456789),
+ patch.object(
+ handler,
+ "validate_account_id",
+ side_effect=HTTPException(status_code=400, detail="Invalid Account ID"),
+ ),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler.handle_account_linked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Account ID"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_linked_user_not_in_guild(self, bot):
+ """Test account linking when user is not in the Discord guild."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"discord_id": discord_id, "account_id": account_id},
+ traits={},
+ )
+
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(
+ handler,
+ "get_guild_member",
+ new_callable=AsyncMock,
+ side_effect=HTTPException(
+ status_code=400, detail="User is not in the Discord server"
+ ),
+ ),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler.handle_account_linked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "User is not in the Discord server"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_unlinked_success(self, bot):
+ """Test successful account unlinking."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ mock_member = helpers.MockMember(id=discord_id)
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_UNLINKED,
+ properties={"discord_id": discord_id, "account_id": account_id},
+ traits={},
+ )
+
+ with (
+ patch.object(
+ handler, "validate_discord_id", return_value=discord_id
+ ) as mock_validate_discord,
+ patch.object(
+ handler, "validate_account_id", return_value=account_id
+ ) as mock_validate_account,
+ patch.object(
+ handler,
+ "get_guild_member",
+ new_callable=AsyncMock,
+ return_value=mock_member,
+ ) as mock_get_member,
+ patch("src.webhooks.handlers.account.settings") as mock_settings,
+ ):
+ mock_settings.roles.VERIFIED = helpers.MockRole(id=99999, name="Verified")
+ mock_member.remove_roles = AsyncMock()
+
+ await handler.handle_account_unlinked(body, bot)
+
+ # Verify all method calls
+ mock_validate_discord.assert_called_once_with(discord_id)
+ mock_validate_account.assert_called_once_with(account_id)
+ mock_get_member.assert_called_once_with(discord_id, bot)
+ mock_member.remove_roles.assert_called_once_with(
+ mock_settings.roles.VERIFIED, atomic=True
+ )
+
+ @pytest.mark.asyncio
+ async def test_handle_account_unlinked_invalid_discord_id(self, bot):
+ """Test account unlinking with invalid Discord ID."""
+ handler = AccountHandler()
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_UNLINKED,
+ properties={"discord_id": None, "account_id": 987654321},
+ traits={},
+ )
+
+ with patch.object(
+ handler,
+ "validate_discord_id",
+ side_effect=HTTPException(status_code=400, detail="Invalid Discord ID"),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler.handle_account_unlinked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Discord ID"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_unlinked_invalid_account_id(self, bot):
+ """Test account unlinking with invalid Account ID."""
+ handler = AccountHandler()
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_UNLINKED,
+ properties={"discord_id": 123456789, "account_id": None},
+ traits={},
+ )
+
+ with (
+ patch.object(handler, "validate_discord_id", return_value=123456789),
+ patch.object(
+ handler,
+ "validate_account_id",
+ side_effect=HTTPException(status_code=400, detail="Invalid Account ID"),
+ ),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler.handle_account_unlinked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Account ID"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_unlinked_user_not_in_guild(self, bot):
+ """Test account unlinking when user is not in the Discord guild."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_UNLINKED,
+ properties={"discord_id": discord_id, "account_id": account_id},
+ traits={},
+ )
+
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(
+ handler,
+ "get_guild_member",
+ new_callable=AsyncMock,
+ side_effect=HTTPException(
+ status_code=400, detail="User is not in the Discord server"
+ ),
+ ),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler.handle_account_unlinked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "User is not in the Discord server"
diff --git a/tests/src/webhooks/handlers/test_base.py b/tests/src/webhooks/handlers/test_base.py
new file mode 100644
index 0000000..ba774cf
--- /dev/null
+++ b/tests/src/webhooks/handlers/test_base.py
@@ -0,0 +1,315 @@
+import logging
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from discord import Bot
+from discord.errors import NotFound
+from fastapi import HTTPException
+
+from src.webhooks.handlers.base import BaseHandler
+from src.webhooks.types import WebhookBody, Platform, WebhookEvent
+from tests import helpers
+
+
+class ConcreteHandler(BaseHandler):
+ """Concrete implementation of BaseHandler for testing purposes."""
+
+ async def handle(self, body: WebhookBody, bot: Bot) -> dict:
+ return {"status": "handled"}
+
+
+class TestBaseHandler:
+ """Test the `BaseHandler` class."""
+
+ def test_initialization(self):
+ """Test that BaseHandler initializes correctly."""
+ handler = ConcreteHandler()
+
+ assert isinstance(handler.logger, logging.Logger)
+ assert handler.logger.name == "ConcreteHandler"
+
+ def test_constants(self):
+ """Test that all required constants are defined."""
+ handler = ConcreteHandler()
+
+ assert handler.ACADEMY_USER_ID == "academy_user_id"
+ assert handler.MP_USER_ID == "mp_user_id"
+ assert handler.EP_USER_ID == "ep_user_id"
+ assert handler.CTF_USER_ID == "ctf_user_id"
+ assert handler.ACCOUNT_ID == "account_id"
+ assert handler.DISCORD_ID == "discord_id"
+
+ @pytest.mark.asyncio
+ async def test_get_guild_member_success(self, bot):
+ """Test successful guild member retrieval."""
+ handler = ConcreteHandler()
+ discord_id = 123456789
+ mock_guild = helpers.MockGuild(id=12345)
+ mock_member = helpers.MockMember(id=discord_id)
+
+ bot.fetch_guild = AsyncMock(return_value=mock_guild)
+ mock_guild.fetch_member = AsyncMock(return_value=mock_member)
+
+ with patch("src.webhooks.handlers.base.settings") as mock_settings:
+ mock_settings.guild_ids = [12345]
+
+ result = await handler.get_guild_member(discord_id, bot)
+
+ assert result == mock_member
+ bot.fetch_guild.assert_called_once_with(12345)
+ mock_guild.fetch_member.assert_called_once_with(discord_id)
+
+ @pytest.mark.asyncio
+ async def test_get_guild_member_not_found(self, bot):
+ """Test guild member retrieval when user is not in server."""
+ handler = ConcreteHandler()
+ discord_id = 123456789
+ mock_guild = helpers.MockGuild(id=12345)
+
+ bot.fetch_guild = AsyncMock(return_value=mock_guild)
+ mock_guild.fetch_member = AsyncMock(
+ side_effect=NotFound(MagicMock(), "User not found")
+ )
+
+ with patch("src.webhooks.handlers.base.settings") as mock_settings:
+ mock_settings.guild_ids = [12345]
+
+ with pytest.raises(HTTPException) as exc_info:
+ await handler.get_guild_member(discord_id, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "User is not in the Discord server"
+
+ def test_validate_property_success(self):
+ """Test successful property validation."""
+ handler = ConcreteHandler()
+
+ result = handler.validate_property("valid_value", "test_property")
+ assert result == "valid_value"
+
+ result = handler.validate_property(123, "test_number")
+ assert result == 123
+
+ def test_validate_property_none(self):
+ """Test property validation with None value."""
+ handler = ConcreteHandler()
+
+ with pytest.raises(HTTPException) as exc_info:
+ handler.validate_property(None, "test_property")
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid test_property"
+
+ def test_validate_discord_id_success(self):
+ """Test successful Discord ID validation."""
+ handler = ConcreteHandler()
+
+ result = handler.validate_discord_id(123456789)
+ assert result == 123456789
+
+ result = handler.validate_discord_id("987654321")
+ assert result == "987654321"
+
+ def test_validate_discord_id_none(self):
+ """Test Discord ID validation with None value."""
+ handler = ConcreteHandler()
+
+ with pytest.raises(HTTPException) as exc_info:
+ handler.validate_discord_id(None)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Discord ID"
+
+ def test_validate_account_id_success(self):
+ """Test successful Account ID validation."""
+ handler = ConcreteHandler()
+
+ result = handler.validate_account_id(123456789)
+ assert result == 123456789
+
+ result = handler.validate_account_id("987654321")
+ assert result == "987654321"
+
+ def test_validate_account_id_none(self):
+ """Test Account ID validation with None value."""
+ handler = ConcreteHandler()
+
+ with pytest.raises(HTTPException) as exc_info:
+ handler.validate_account_id(None)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Account ID"
+
+ def test_get_property_or_trait_from_properties(self):
+ """Test getting value from properties."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"test_key": 123},
+ traits={"test_key": 456, "other_key": 789},
+ )
+
+ result = handler.get_property_or_trait(body, "test_key")
+ assert result == 123 # Should prioritize properties over traits
+
+ def test_get_property_or_trait_from_traits(self):
+ """Test getting value from traits when not in properties."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={},
+ traits={"test_key": 456},
+ )
+
+ result = handler.get_property_or_trait(body, "test_key")
+ assert result == 456
+
+ def test_get_property_or_trait_not_found(self):
+ """Test getting value when key is not found."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={},
+ traits={},
+ )
+
+ result = handler.get_property_or_trait(body, "missing_key")
+ assert result is None
+
+ def test_merge_properties_and_traits_no_duplicates(self):
+ """Test merging properties and traits without duplicates."""
+ handler = ConcreteHandler()
+ properties = {"key1": 1, "key2": 2}
+ traits = {"key3": 3, "key4": 4}
+
+ result = handler.merge_properties_and_traits(properties, traits)
+
+ expected = {"key1": 1, "key2": 2, "key3": 3, "key4": 4}
+ assert result == expected
+
+ def test_merge_properties_and_traits_with_duplicates(self):
+ """Test merging properties and traits with duplicate keys."""
+ handler = ConcreteHandler()
+ properties = {"key1": 1, "key2": 2}
+ traits = {"key2": 99, "key3": 3} # key2 is duplicate
+
+ result = handler.merge_properties_and_traits(properties, traits)
+
+ expected = {"key1": 1, "key2": 2, "key3": 3} # Properties value should win
+ assert result == expected
+
+ def test_merge_properties_and_traits_empty_properties(self):
+ """Test merging when properties is empty."""
+ handler = ConcreteHandler()
+ properties = {}
+ traits = {"key1": 1, "key2": 2}
+
+ result = handler.merge_properties_and_traits(properties, traits)
+
+ assert result == traits
+
+ def test_merge_properties_and_traits_empty_traits(self):
+ """Test merging when traits is empty."""
+ handler = ConcreteHandler()
+ properties = {"key1": 1, "key2": 2}
+ traits = {}
+
+ result = handler.merge_properties_and_traits(properties, traits)
+
+ assert result == properties
+
+ def test_get_platform_properties_all_present(self):
+ """Test getting platform properties when all are present."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={
+ "account_id": 1,
+ "mp_user_id": 2,
+ "ep_user_id": 3,
+ "ctf_user_id": 4,
+ "academy_user_id": 5,
+ },
+ traits={},
+ )
+
+ result = handler.get_platform_properties(body)
+
+ expected = {
+ "account_id": 1,
+ "mp_user_id": 2,
+ "ep_user_id": 3,
+ "ctf_user_id": 4,
+ "academy_user_id": 5,
+ }
+ assert result == expected
+
+ def test_get_platform_properties_mixed_sources(self):
+ """Test getting platform properties from both properties and traits."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"account_id": 1, "mp_user_id": 2},
+ traits={"ep_user_id": 3, "ctf_user_id": 4, "academy_user_id": 5},
+ )
+
+ result = handler.get_platform_properties(body)
+
+ expected = {
+ "account_id": 1,
+ "mp_user_id": 2,
+ "ep_user_id": 3,
+ "ctf_user_id": 4,
+ "academy_user_id": 5,
+ }
+ assert result == expected
+
+ def test_get_platform_properties_missing_values(self):
+ """Test getting platform properties when some are missing."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"account_id": 1},
+ traits={"mp_user_id": 2},
+ )
+
+ result = handler.get_platform_properties(body)
+
+ expected = {
+ "account_id": 1,
+ "mp_user_id": 2,
+ "ep_user_id": None,
+ "ctf_user_id": None,
+ "academy_user_id": None,
+ }
+ assert result == expected
+
+ def test_get_platform_properties_properties_override_traits(self):
+ """Test that properties override traits for the same key."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"account_id": 1, "mp_user_id": 2},
+ traits={
+ "mp_user_id": 999, # Should be overridden
+ "ep_user_id": 3,
+ },
+ )
+
+ result = handler.get_platform_properties(body)
+
+ expected = {
+ "account_id": 1,
+ "mp_user_id": 2, # Properties value should win
+ "ep_user_id": 3,
+ "ctf_user_id": None,
+ "academy_user_id": None,
+ }
+ assert result == expected
diff --git a/tests/src/webhooks/test_handlers_init.py b/tests/src/webhooks/test_handlers_init.py
new file mode 100644
index 0000000..ce0b436
--- /dev/null
+++ b/tests/src/webhooks/test_handlers_init.py
@@ -0,0 +1,35 @@
+from unittest import mock
+from typing import Callable
+
+import pytest
+
+from src.webhooks.handlers import handlers, can_handle, handle
+from src.webhooks.types import Platform, WebhookBody, WebhookEvent
+from tests.conftest import bot
+
+class TestHandlersInit:
+ def test_handler_init(self):
+ assert handlers is not None
+ assert isinstance(handlers, dict)
+ assert len(handlers) > 0
+ assert all(isinstance(handler, Callable) for handler in handlers.values())
+
+ def test_can_handle_unknown_platform(self):
+ assert not can_handle("UNKNOWN")
+
+ def test_can_handle_success(self):
+ with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: True}):
+ assert can_handle(Platform.MAIN)
+
+ def test_handle_success(self):
+ with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: 1337}):
+ assert handle(WebhookBody(platform=Platform.MAIN, event=WebhookEvent.ACCOUNT_LINKED, properties={}, traits={}), bot) == 1337
+
+ def test_handle_unknown_platform(self):
+ with pytest.raises(ValueError):
+ handle(WebhookBody(platform="UNKNOWN", event=WebhookEvent.ACCOUNT_LINKED, properties={}, traits={}), bot)
+
+ def test_handle_unknown_event(self):
+ with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: 1337}):
+ with pytest.raises(ValueError):
+ handle(WebhookBody(platform=Platform.MAIN, event="UNKNOWN", properties={}, traits={}), bot)
\ No newline at end of file