Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add asynchronous support to ip restriction middleware #3831

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

sentry-autofix[bot]
Copy link
Contributor

@sentry-autofix sentry-autofix bot commented Mar 6, 2025

User description

👋 Hi there! This PR was automatically generated by Autofix 🤖

This fix was triggered by DB

Fixes BLT-DJANGO-1EK

  • Implemented asynchronous versions of key functions using asgiref.sync.sync_to_async for compatibility with asynchronous views.
  • Added an asynchronous __acall__ method to handle requests asynchronously.
  • Created helper functions to record IP information and increment block counts.
  • Modified the middleware to support both synchronous and asynchronous request handling.

If you have any questions or feedback for the Sentry team about this fix, please email [email protected] with the Run ID: 8427.


PR Type

Enhancement


Description

  • Added asynchronous support to IP restriction middleware.

  • Implemented __acall__ method for async request handling.

  • Created async helper functions for IP recording and block count increment.

  • Refactored synchronous logic for better compatibility and maintainability.


Changes walkthrough 📝

Relevant files
Enhancement
ip_restrict.py
Added async support and refactored middleware logic           

blt/middleware/ip_restrict.py

  • Introduced asynchronous methods for IP recording and block count
    increment.
  • Added __acall__ method for handling async requests.
  • Refactored synchronous logic into helper methods for better reuse.
  • Integrated sync_to_async for compatibility with async views.
  • +99/-24 

    Need help?
  • Type /help how to ... in the comments thread for any questions about PR-Agent usage.
  • Check out the documentation for more information.
  • @DonnieBLT DonnieBLT marked this pull request as ready for review March 10, 2025 02:28
    Copy link
    Contributor

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    ⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
    🧪 No relevant tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review

    Potential Performance Concern

    The use of sync_to_async in multiple places, especially in critical middleware logic, could introduce performance bottlenecks. It is important to validate whether these calls are necessary or if native asynchronous implementations can be used instead.

    async def increment_block_count_async(self, ip=None, network=None, user_agent=None):
        """
        Asynchronous version of increment_block_count
        """
        await sync_to_async(self.increment_block_count)(ip, network, user_agent)
    
    def increment_block_count(self, ip=None, network=None, user_agent=None):
        """
        Increment the block count for a specific IP, network, or user agent in the Blocked model.
        """
        with transaction.atomic():
            if ip:
                blocked_entry = Blocked.objects.select_for_update().filter(address=ip).first()
            elif network:
                blocked_entry = Blocked.objects.select_for_update().filter(ip_network=network).first()
            elif user_agent:
                # Correct lookup: find if any user_agent_string is a substring of the user_agent
                blocked_entry = (
                    Blocked.objects.select_for_update()
                    .filter(
                        user_agent_string__in=[
                            agent
                            for agent in Blocked.objects.values_list("user_agent_string", flat=True)
                            if agent is not None and user_agent is not None and agent.lower() in user_agent.lower()
                        ]
                    )
                    .first()
                )
            else:
                return  # Nothing to increment
    
            if blocked_entry:
                blocked_entry.count = models.F("count") + 1
                blocked_entry.save(update_fields=["count"])
    
    async def record_ip_async(self, ip, agent, path):
        """
        Asynchronous version of IP record creation/update logic
        """
        if not ip:
            return
    
        await sync_to_async(self._record_ip)(ip, agent, path)
    
    def _record_ip(self, ip, agent, path):
        """
        Helper method to record IP information
        """
        with transaction.atomic():
            # create unique entry for every unique (ip,path) tuple
            # if this tuple already exists, we just increment the count.
            ip_records = IP.objects.select_for_update().filter(address=ip, path=path)
            if ip_records.exists():
                ip_record = ip_records.first()
    
                # Calculate the new count and ensure it doesn't exceed the MAX_COUNT
                new_count = ip_record.count + 1
                if new_count > MAX_COUNT:
                    new_count = MAX_COUNT
    
                ip_record.agent = agent
                ip_record.count = new_count
                if ip_record.pk:
                    ip_record.save(update_fields=["agent", "count"])
    
                # Check if a transaction is already active before starting a new one
                if not transaction.get_autocommit():
                    ip_records.exclude(pk=ip_record.pk).delete()
            else:
                # If no record exists, create a new one
                IP.objects.create(address=ip, agent=agent, count=1, path=path)
    
    def __call__(self, request):
        return self.process_request_sync(request)
    
    async def __acall__(self, request):
        """
        Asynchronous version of the middleware call method
        """
        # Get client information
        ip = request.META.get("HTTP_X_FORWARDED_FOR", "").split(",")[0].strip() or request.META.get("REMOTE_ADDR", "")
        agent = request.META.get("HTTP_USER_AGENT", "").strip()
    
        # Check cache for blocked items
        blocked_ips = await sync_to_async(self.blocked_ips)()
        blocked_ip_network = await sync_to_async(self.blocked_ip_network)()
        blocked_agents = await sync_to_async(self.blocked_agents)()
    
        # Check if IP is blocked directly
        if await sync_to_async(self.ip_in_ips)(ip, blocked_ips):
            await self.increment_block_count_async(ip=ip)
            return HttpResponseForbidden()
    
        # Check if IP is in a blocked network
        if await sync_to_async(self.ip_in_range)(ip, blocked_ip_network):
            # Find the specific network that caused the block and increment its count
            for network in blocked_ip_network:
                if ipaddress.ip_address(ip) in network:
                    await self.increment_block_count_async(network=str(network))
                    break
            return HttpResponseForbidden()
    
        # Check if user agent is blocked
        if await sync_to_async(self.is_user_agent_blocked)(agent, blocked_agents):
            await self.increment_block_count_async(user_agent=agent)
            return HttpResponseForbidden()
    
        # Record IP information
        await self.record_ip_async(ip, agent, request.path)
    
        # Continue with the request
        response = await self.get_response(request)
        return response
    Transaction Handling Issue

    The _record_ip method uses transaction.atomic() but also checks transaction.get_autocommit(). This logic should be reviewed to ensure that transactions are handled correctly and do not cause unexpected behavior in concurrent scenarios.

    def _record_ip(self, ip, agent, path):
        """
        Helper method to record IP information
        """
        with transaction.atomic():
            # create unique entry for every unique (ip,path) tuple
            # if this tuple already exists, we just increment the count.
            ip_records = IP.objects.select_for_update().filter(address=ip, path=path)
            if ip_records.exists():
                ip_record = ip_records.first()
    
                # Calculate the new count and ensure it doesn't exceed the MAX_COUNT
                new_count = ip_record.count + 1
                if new_count > MAX_COUNT:
                    new_count = MAX_COUNT
    
                ip_record.agent = agent
                ip_record.count = new_count
                if ip_record.pk:
                    ip_record.save(update_fields=["agent", "count"])
    
                # Check if a transaction is already active before starting a new one
                if not transaction.get_autocommit():
                    ip_records.exclude(pk=ip_record.pk).delete()
            else:
                # If no record exists, create a new one
                IP.objects.create(address=ip, agent=agent, count=1, path=path)
    Code Duplication

    There is significant duplication between the synchronous and asynchronous request handling logic (__call__ and __acall__). Consider refactoring to reduce redundancy and improve maintainability.

    def __call__(self, request):
        return self.process_request_sync(request)
    
    async def __acall__(self, request):
        """
        Asynchronous version of the middleware call method
        """
        # Get client information
        ip = request.META.get("HTTP_X_FORWARDED_FOR", "").split(",")[0].strip() or request.META.get("REMOTE_ADDR", "")
        agent = request.META.get("HTTP_USER_AGENT", "").strip()
    
        # Check cache for blocked items
        blocked_ips = await sync_to_async(self.blocked_ips)()
        blocked_ip_network = await sync_to_async(self.blocked_ip_network)()
        blocked_agents = await sync_to_async(self.blocked_agents)()
    
        # Check if IP is blocked directly
        if await sync_to_async(self.ip_in_ips)(ip, blocked_ips):
            await self.increment_block_count_async(ip=ip)
            return HttpResponseForbidden()
    
        # Check if IP is in a blocked network
        if await sync_to_async(self.ip_in_range)(ip, blocked_ip_network):
            # Find the specific network that caused the block and increment its count
            for network in blocked_ip_network:
                if ipaddress.ip_address(ip) in network:
                    await self.increment_block_count_async(network=str(network))
                    break
            return HttpResponseForbidden()
    
        # Check if user agent is blocked
        if await sync_to_async(self.is_user_agent_blocked)(agent, blocked_agents):
            await self.increment_block_count_async(user_agent=agent)
            return HttpResponseForbidden()
    
        # Record IP information
        await self.record_ip_async(ip, agent, request.path)
    
        # Continue with the request
        response = await self.get_response(request)
        return response
    
    def process_request_sync(self, request):
        """
        Synchronous version of the middleware logic
        """
        # Get client information
        ip = request.META.get("HTTP_X_FORWARDED_FOR", "").split(",")[0].strip() or request.META.get("REMOTE_ADDR", "")
        agent = request.META.get("HTTP_USER_AGENT", "").strip()
    
        # Check cache for blocked items
        blocked_ips = self.blocked_ips()
        blocked_ip_network = self.blocked_ip_network()
        blocked_agents = self.blocked_agents()
    
        # Check if IP is blocked directly
        if self.ip_in_ips(ip, blocked_ips):
            self.increment_block_count(ip=ip)
            return HttpResponseForbidden()
    
        # Check if IP is in a blocked network
        if self.ip_in_range(ip, blocked_ip_network):
            # Find the specific network that caused the block and increment its count
            for network in blocked_ip_network:
                if ipaddress.ip_address(ip) in network:
                    self.increment_block_count(network=str(network))
                    break
            return HttpResponseForbidden()
    
        # Check if user agent is blocked
        if self.is_user_agent_blocked(agent, blocked_agents):
            self.increment_block_count(user_agent=agent)
            return HttpResponseForbidden()
    
        # Record IP information
        if ip:
            self._record_ip(ip, agent, request.path)
    
        # Continue with the request
        return self.get_response(request)

    Copy link
    Contributor

    PR Code Suggestions ✨

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant