diff --git a/django-backend/soroscan/ingest/decorators.py b/django-backend/soroscan/ingest/decorators.py new file mode 100644 index 00000000..1ef974a5 --- /dev/null +++ b/django-backend/soroscan/ingest/decorators.py @@ -0,0 +1,127 @@ +import hashlib +import hmac +import json +import logging +from functools import wraps + +from rest_framework import status +from rest_framework.response import Response + +from .models import WebhookSubscription + +logger = logging.getLogger(__name__) + + +def webhook_hmac_required(header_name="X-SoroScan-Signature"): + """ + DRF view decorator that validates an incoming webhook's HMAC signature. + + The decorator expects the JSON payload to contain a 'contract_id' field, + which is used to look up the corresponding WebhookSubscription(s) and their secrets. + The signature is verified against the serialized JSON body using hmac.compare_digest. + """ + + def decorator(view_func): + @wraps(view_func) + def _wrapped_view(request, *args, **kwargs): + signature_header = request.headers.get(header_name) + if not signature_header: + logger.warning("Missing webhook signature header: %s", header_name) + return Response( + {"detail": f"Missing signature header {header_name}"}, + status=status.HTTP_401_UNAUTHORIZED, + ) + + if "=" not in signature_header: + logger.warning("Invalid signature header format: %s", signature_header) + return Response( + {"detail": "Invalid signature header format"}, + status=status.HTTP_401_UNAUTHORIZED, + ) + + prefix, sig_hex = signature_header.split("=", 1) + prefix = prefix.lower() + + try: + # DRF parses the JSON body into request.data + payload = request.data + contract_id = payload.get("contract_id") + except Exception: + logger.warning("Failed to parse request data for HMAC validation") + return Response( + {"detail": "Invalid JSON payload"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + if not contract_id: + logger.warning("Missing contract_id in webhook payload") + return Response( + {"detail": "Missing contract_id in payload"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + # Find active subscriptions for this contract + subscriptions = WebhookSubscription.objects.filter( + contract__contract_id=contract_id, is_active=True + ) + + if not subscriptions.exists(): + logger.warning( + "No active webhook subscription found for contract: %s", contract_id + ) + return Response( + {"detail": "Subscription not found or inactive"}, + status=status.HTTP_401_UNAUTHORIZED, + ) + + # Re-serialize exactly as the dispatcher does (sort_keys=True, utf-8) + # This ensures we are testing against the same byte sequence used for signing. + try: + payload_bytes = json.dumps(payload, sort_keys=True).encode("utf-8") + except (TypeError, ValueError) as exc: + logger.error("Failed to re-serialize payload for HMAC check: %s", exc) + return Response( + {"detail": "Payload serialization failed"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + if prefix == "sha256": + digestmod = hashlib.sha256 + elif prefix == "sha1": + digestmod = hashlib.sha1 + else: + logger.warning("Unsupported signature algorithm: %s", prefix) + return Response( + {"detail": f"Unsupported signature algorithm: {prefix}"}, + status=status.HTTP_401_UNAUTHORIZED, + ) + + verified = False + for sub in subscriptions: + # Verify secret exists + if not sub.secret: + continue + + expected_sig = hmac.new( + sub.secret.encode("utf-8"), + msg=payload_bytes, + digestmod=digestmod, + ).hexdigest() + + if hmac.compare_digest(expected_sig, sig_hex): + verified = True + request.webhook_subscription = sub + break + + if not verified: + logger.warning("HMAC signature mismatch for contract: %s", contract_id) + return Response( + {"detail": "Invalid HMAC signature"}, + status=status.HTTP_401_UNAUTHORIZED, + ) + + return view_func(request, *args, **kwargs) + + return _wrapped_view + + return decorator diff --git a/django-backend/soroscan/ingest/tests/test_decorators.py b/django-backend/soroscan/ingest/tests/test_decorators.py new file mode 100644 index 00000000..8a211b28 --- /dev/null +++ b/django-backend/soroscan/ingest/tests/test_decorators.py @@ -0,0 +1,113 @@ +import hashlib +import hmac +import json + +import pytest +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from soroscan.ingest.tests.factories import ( + TrackedContractFactory, + WebhookSubscriptionFactory, +) + + +@pytest.fixture +def api_client(): + return APIClient() + + +@pytest.mark.django_db +class TestWebhookHMACDecorator: + def setup_method(self): + self.contract = TrackedContractFactory() + self.webhook = WebhookSubscriptionFactory( + contract=self.contract, + secret="test-secret-123", + signature_algorithm="sha256", + ) + self.url = reverse("webhook-receiver-example") + + def test_valid_sha256_signature(self, api_client): + payload = { + "contract_id": self.contract.contract_id, + "event_type": "transfer", + "payload": {"amount": 100}, + } + payload_bytes = json.dumps(payload, sort_keys=True).encode("utf-8") + sig = hmac.new(b"test-secret-123", payload_bytes, hashlib.sha256).hexdigest() + + headers = {"X-SoroScan-Signature": f"sha256={sig}"} + response = api_client.post(self.url, payload, format="json", headers=headers) + + assert response.status_code == status.HTTP_200_OK + assert response.data["status"] == "verified" + + def test_valid_sha1_signature(self, api_client): + self.webhook.signature_algorithm = "sha1" + self.webhook.save() + + payload = { + "contract_id": self.contract.contract_id, + "event_type": "transfer", + "payload": {"amount": 100}, + } + payload_bytes = json.dumps(payload, sort_keys=True).encode("utf-8") + sig = hmac.new(b"test-secret-123", payload_bytes, hashlib.sha1).hexdigest() + + headers = {"X-SoroScan-Signature": f"sha1={sig}"} + response = api_client.post(self.url, payload, format="json", headers=headers) + + assert response.status_code == status.HTTP_200_OK + assert response.data["status"] == "verified" + + def test_invalid_signature(self, api_client): + payload = {"contract_id": self.contract.contract_id, "event_type": "transfer"} + headers = {"X-SoroScan-Signature": "sha256=invalid-sig"} + response = api_client.post(self.url, payload, format="json", headers=headers) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert "Invalid HMAC signature" in response.data["detail"] + + def test_missing_signature_header(self, api_client): + payload = {"contract_id": self.contract.contract_id} + response = api_client.post(self.url, payload, format="json") + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert "Missing signature header" in response.data["detail"] + + def test_missing_contract_id(self, api_client): + payload = {"event_type": "transfer"} + headers = {"X-SoroScan-Signature": "sha256=some-sig"} + response = api_client.post(self.url, payload, format="json", headers=headers) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "Missing contract_id" in response.data["detail"] + + def test_inactive_subscription(self, api_client): + self.webhook.is_active = False + self.webhook.save() + + payload = {"contract_id": self.contract.contract_id} + headers = {"X-SoroScan-Signature": "sha256=some-sig"} + response = api_client.post(self.url, payload, format="json", headers=headers) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert "Subscription not found or inactive" in response.data["detail"] + + def test_multiple_subscriptions_one_matches(self, api_client): + # Create another webhook for same contract with different secret + WebhookSubscriptionFactory( + contract=self.contract, secret="other-secret", signature_algorithm="sha256" + ) + + payload = {"contract_id": self.contract.contract_id, "event_type": "transfer"} + payload_bytes = json.dumps(payload, sort_keys=True).encode("utf-8") + # Sign with the first secret + sig = hmac.new(b"test-secret-123", payload_bytes, hashlib.sha256).hexdigest() + + headers = {"X-SoroScan-Signature": f"sha256={sig}"} + response = api_client.post(self.url, payload, format="json", headers=headers) + + assert response.status_code == status.HTTP_200_OK diff --git a/django-backend/soroscan/ingest/urls.py b/django-backend/soroscan/ingest/urls.py index c3eb709d..a7bcdbae 100644 --- a/django-backend/soroscan/ingest/urls.py +++ b/django-backend/soroscan/ingest/urls.py @@ -1,6 +1,7 @@ """ URL patterns for SoroScan ingest API. """ + from django.urls import include, path from rest_framework.routers import DefaultRouter @@ -23,6 +24,7 @@ health_check, record_event_view, restore_archived_events, + webhook_receiver_example, transaction_events_view, vulnerability_impact_view, ) @@ -36,7 +38,11 @@ router.register(r"teams", TeamViewSet, basename="team") urlpatterns = [ - path("contracts//timeline/", contract_timeline_view, name="contract-timeline"), + path( + "contracts//timeline/", + contract_timeline_view, + name="contract-timeline", + ), path( "contracts//events/explorer/", contract_event_explorer_view, @@ -64,6 +70,11 @@ path("events/restore-archive/", restore_archived_events, name="restore-archive"), path("audit-trail/", audit_trail_view, name="audit-trail"), path("admin/ingest-errors/", admin_ingest_errors_view, name="admin-ingest-errors"), + path( + "webhooks/receiver-example/", + webhook_receiver_example, + name="webhook-receiver-example", + ), path( "admin/organization-costs/", organization_cost_breakdown_view, diff --git a/django-backend/soroscan/ingest/views.py b/django-backend/soroscan/ingest/views.py index 0d05c2fc..fb939ff6 100644 --- a/django-backend/soroscan/ingest/views.py +++ b/django-backend/soroscan/ingest/views.py @@ -1,6 +1,7 @@ """ API Views for SoroScan event ingestion. """ + import hashlib import hmac import json @@ -16,7 +17,12 @@ from django_filters.rest_framework import DjangoFilterBackend from drf_spectacular.utils import extend_schema, inline_serializer from rest_framework import serializers, status, viewsets -from rest_framework.decorators import action, api_view, permission_classes, throttle_classes +from rest_framework.decorators import ( + action, + api_view, + permission_classes, + throttle_classes, +) from rest_framework.filters import OrderingFilter, SearchFilter from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response @@ -26,7 +32,14 @@ from soroscan.throttles import IngestRateThrottle -from .cache_utils import cache_result, get_or_set_json, query_cache_ttl, stable_cache_key +from .decorators import webhook_hmac_required + +from .cache_utils import ( + cache_result, + get_or_set_json, + query_cache_ttl, + stable_cache_key, +) from .models import ( APIKey, AdminAction, @@ -142,7 +155,9 @@ def get_queryset(self): user = self.request.user if self.request.method in ["GET", "HEAD", "OPTIONS"]: if user.is_authenticated: - return qs.filter(Q(owner=user) | Q(team__memberships__user=user)).distinct() + return qs.filter( + Q(owner=user) | Q(team__memberships__user=user) + ).distinct() return qs return qs.filter(owner=self.request.user) @@ -323,7 +338,17 @@ def get_queryset(self): "payload_contains": serializers.CharField(required=False), "payload_field": serializers.CharField(required=False), "payload_op": serializers.ChoiceField( - choices=["eq", "neq", "gte", "lte", "gt", "lt", "contains", "startswith", "in"], + choices=[ + "eq", + "neq", + "gte", + "lte", + "gt", + "lt", + "contains", + "startswith", + "in", + ], required=False, ), "payload_value": serializers.CharField(required=False), @@ -367,6 +392,7 @@ def search(self, request): # The GIN index speeds up JSON containment (@>) queries; for plain text # search we rely on PostgreSQL's icontains on the cast. from django.db.models import TextField + qs = qs.annotate( _payload_text=Cast("payload", output_field=TextField()) ).filter(_payload_text__icontains=q) @@ -376,6 +402,7 @@ def search(self, request): if payload_contains: # Simple text containment inside the JSON; works with GIN index from django.db.models import TextField + if not q: # avoid double annotation qs = qs.annotate( _payload_text=Cast("payload", output_field=TextField()) @@ -471,7 +498,9 @@ def get_queryset(self): def get_serializer_context(self): """Add include_events flag from query params.""" context = super().get_serializer_context() - context["include_events"] = self.request.query_params.get("include_events") == "true" + context["include_events"] = ( + self.request.query_params.get("include_events") == "true" + ) return context def list(self, request, *args, **kwargs): @@ -504,7 +533,6 @@ def list(self, request, *args, **kwargs): return Response(serializer.data) - class WebhookSubscriptionViewSet(viewsets.ModelViewSet): """ ViewSet for managing webhook subscriptions. @@ -523,7 +551,7 @@ class WebhookSubscriptionViewSet(viewsets.ModelViewSet): def get_queryset(self): # Public read access, but filter by owner for write operations - if self.request.method in ['GET', 'HEAD', 'OPTIONS']: + if self.request.method in ["GET", "HEAD", "OPTIONS"]: return WebhookSubscription.objects.all() return WebhookSubscription.objects.filter(contract__owner=self.request.user) @@ -554,7 +582,9 @@ def test(self, request, pk=None): "timestamp": timezone.now().isoformat(), } payload_bytes = json.dumps(test_payload, sort_keys=True).encode("utf-8") - algorithm = (webhook.signature_algorithm or WebhookSubscription.SIGNATURE_SHA256).lower() + algorithm = ( + webhook.signature_algorithm or WebhookSubscription.SIGNATURE_SHA256 + ).lower() if algorithm == WebhookSubscription.SIGNATURE_SHA1: digestmod = hashlib.sha1 prefix = "sha1" @@ -673,7 +703,9 @@ def members(self, request, pk=None): try: new_user = User.objects.get(pk=ser.validated_data["user_id"]) except User.DoesNotExist: - return Response({"detail": "User not found."}, status=status.HTTP_404_NOT_FOUND) + return Response( + {"detail": "User not found."}, status=status.HTTP_404_NOT_FOUND + ) _, created = TeamMembership.objects.get_or_create( team=team, user=new_user, @@ -992,9 +1024,9 @@ def contract_event_explorer_view(request, contract_id: str): def contract_event_types_view(request, contract_id: str): """Get event types and their counts for a specific contract.""" contract = get_object_or_404(TrackedContract, contract_id=contract_id) - + cache_key = stable_cache_key("contract_event_types", {"contract_id": contract_id}) - + def _build(): return list( ContractEvent.objects.filter(contract=contract) @@ -1002,11 +1034,11 @@ def _build(): .annotate( count=Count("id"), first_seen=Min("timestamp"), - last_seen=Max("timestamp") + last_seen=Max("timestamp"), ) .order_by("-count") ) - + result = get_or_set_json(cache_key, 60, _build) return Response(result) @@ -1049,7 +1081,9 @@ def restore_archived_events(request): """ batch_id = request.query_params.get("batch_id") or request.data.get("batch_id") if not batch_id: - return Response({"detail": "batch_id is required."}, status=status.HTTP_400_BAD_REQUEST) + return Response( + {"detail": "batch_id is required."}, status=status.HTTP_400_BAD_REQUEST + ) batch = get_object_or_404(ArchivedEventBatch, id=batch_id) @@ -1086,7 +1120,9 @@ def restore_archived_events(request): restored_count = 0 for row in rows: try: - contract = TrackedContract.objects.get(contract_id=row["contract__contract_id"]) + contract = TrackedContract.objects.get( + contract_id=row["contract__contract_id"] + ) ContractEvent.objects.get_or_create( contract=contract, ledger=row["ledger"], @@ -1101,12 +1137,15 @@ def restore_archived_events(request): ) restored_count += 1 except Exception: - logger.warning("Skipped row during restore: %s", row.get("id"), exc_info=True) + logger.warning( + "Skipped row during restore: %s", row.get("id"), exc_info=True + ) batch.status = ArchivedEventBatch.STATUS_RESTORED batch.save(update_fields=["status"]) from .models import ArchivalAuditLog # noqa: PLC0415 + ArchivalAuditLog.objects.create( action=ArchivalAuditLog.ACTION_RESTORE, batch=batch, @@ -1179,11 +1218,13 @@ def audit_trail_view(request): def admin_ingest_errors_view(request): """Get recent ingest errors (admin only).""" if not request.user.is_staff: - return Response({"error": "Admin access required"}, status=status.HTTP_403_FORBIDDEN) - + return Response( + {"error": "Admin access required"}, status=status.HTTP_403_FORBIDDEN + ) + # Last 24 hours since = timezone.now() - timezone.timedelta(hours=24) - + # Group by error_type + contract_id and aggregate errors = ( IngestError.objects.filter(created_at__gte=since) @@ -1191,11 +1232,11 @@ def admin_ingest_errors_view(request): .annotate( count=Count("id"), last_occurrence=Max("created_at"), - sample_error=Max("sample_error") # Get one sample error message + sample_error=Max("sample_error"), # Get one sample error message ) .order_by("-count") ) - + return Response(list(errors)) @@ -1263,6 +1304,39 @@ def rate_limit_analytics_view(request): ) +@extend_schema( + request=inline_serializer( + name="WebhookReceiverRequest", + fields={ + "contract_id": serializers.CharField(), + "event_type": serializers.CharField(), + "payload": serializers.JSONField(), + }, + ), + responses={ + 200: inline_serializer( + name="WebhookReceiverResponse", fields={"status": serializers.CharField()} + ) + }, +) +@api_view(["POST"]) +@permission_classes([AllowAny]) # Secured by HMAC decorator +@webhook_hmac_required() +def webhook_receiver_example(request): + """ + Example webhook receiver that uses the HMAC validation decorator. + + This endpoint is public (AllowAny) but requires a valid X-SoroScan-Signature + header to proceed. The decorator verifies the signature using the secret + associated with the 'contract_id' in the payload. + """ + data = request.data + logger.info( + "Received verified webhook for contract %s, event %s", + data.get("contract_id"), + data.get("event_type"), + ) + return Response({"status": "verified"}) # --------------------------------------------------------------------------- # Issue #280: GDPR — deletion requests & compliance export # ---------------------------------------------------------------------------