${_("Pending Tasks")}
diff --git a/openedx/core/djangoapps/agreements/api.py b/openedx/core/djangoapps/agreements/api.py
index 2489cefb8533..11ad23dddd25 100644
--- a/openedx/core/djangoapps/agreements/api.py
+++ b/openedx/core/djangoapps/agreements/api.py
@@ -3,17 +3,15 @@
"""
import logging
+from datetime import datetime
+from typing import Iterable, Optional
from django.contrib.auth import get_user_model
from django.core.exceptions import ObjectDoesNotExist
from opaque_keys.edx.keys import CourseKey
-from openedx.core.djangoapps.agreements.models import IntegritySignature
-from openedx.core.djangoapps.agreements.models import LTIPIITool
-from openedx.core.djangoapps.agreements.models import LTIPIISignature
-
-from .data import LTIToolsReceivingPIIData
-from .data import LTIPIISignatureData
+from .data import LTIPIISignatureData, LTIToolsReceivingPIIData, UserAgreementRecordData
+from .models import IntegritySignature, LTIPIISignature, LTIPIITool, UserAgreementRecord
log = logging.getLogger(__name__)
User = get_user_model()
@@ -240,3 +238,48 @@ def _user_signature_out_of_date(username, course_id):
return False
else:
return user_lti_pii_signature_hash != course_lti_pii_tools_hash
+
+
+def get_user_agreements(user: User) -> Iterable[UserAgreementRecordData]:
+ """
+ Retrieves all the agreements that the specified user has acknowledged.
+ """
+ for agreement_record in UserAgreementRecord.objects.filter(user=user):
+ yield UserAgreementRecordData.from_model(agreement_record)
+
+
+def get_latest_user_agreement_record(
+ user: User,
+ agreement_type: str,
+ agreed_after: datetime = None,
+) -> Optional[UserAgreementRecordData]:
+ """
+ Retrieve the user agreement record for the specified user and agreement type.
+
+ An agreement update timestamp can be provided to return a record only if it
+ was signed after that timestamp.
+ """
+ try:
+ record_query = UserAgreementRecord.objects.filter(
+ user=user,
+ agreement_type=agreement_type,
+ )
+ if agreed_after:
+ record_query = record_query.filter(timestamp__gte=agreed_after)
+ record = record_query.latest("timestamp")
+ return UserAgreementRecordData.from_model(record)
+ except UserAgreementRecord.DoesNotExist:
+ return None
+
+
+def create_user_agreement_record(user: User, agreement_type: str) -> UserAgreementRecordData:
+ """
+ Creates a user agreement record if one doesn't already exist, or updates existing
+ record to current timestamp.
+ """
+ record = UserAgreementRecord.objects.create(
+ user=user,
+ agreement_type=agreement_type,
+ timestamp=datetime.now(),
+ )
+ return UserAgreementRecordData.from_model(record)
diff --git a/openedx/core/djangoapps/agreements/data.py b/openedx/core/djangoapps/agreements/data.py
index 9d843c73cb04..01d83665c009 100644
--- a/openedx/core/djangoapps/agreements/data.py
+++ b/openedx/core/djangoapps/agreements/data.py
@@ -1,8 +1,13 @@
"""
Public data structures for this app.
"""
+from dataclasses import dataclass
+from datetime import datetime
+
import attr
+from .models import UserAgreementRecord
+
@attr.s(frozen=True, auto_attribs=True)
class LTIToolsReceivingPIIData:
@@ -21,3 +26,21 @@ class LTIPIISignatureData:
course_id: str
lti_tools: str
lti_tools_hash: str
+
+
+@dataclass
+class UserAgreementRecordData:
+ """
+ Data for a single user agreement record.
+ """
+ username: str
+ agreement_type: str
+ accepted_at: datetime
+
+ @classmethod
+ def from_model(cls, model: UserAgreementRecord):
+ return UserAgreementRecordData(
+ username=model.user.username,
+ agreement_type=model.agreement_type,
+ accepted_at=model.timestamp,
+ )
diff --git a/openedx/core/djangoapps/agreements/migrations/0006_useragreementrecord.py b/openedx/core/djangoapps/agreements/migrations/0006_useragreementrecord.py
new file mode 100644
index 000000000000..2e0985adb6de
--- /dev/null
+++ b/openedx/core/djangoapps/agreements/migrations/0006_useragreementrecord.py
@@ -0,0 +1,25 @@
+# Generated by Django 4.2.16 on 2024-12-06 11:34
+
+from django.conf import settings
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ migrations.swappable_dependency(settings.AUTH_USER_MODEL),
+ ('agreements', '0005_timestampedmodels'),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name='UserAgreementRecord',
+ fields=[
+ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
+ ('agreement_type', models.CharField(max_length=255)),
+ ('timestamp', models.DateTimeField(auto_now_add=True)),
+ ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
+ ],
+ ),
+ ]
diff --git a/openedx/core/djangoapps/agreements/models.py b/openedx/core/djangoapps/agreements/models.py
index 2672a4f47b24..2ceeeb98109f 100644
--- a/openedx/core/djangoapps/agreements/models.py
+++ b/openedx/core/djangoapps/agreements/models.py
@@ -70,3 +70,20 @@ class ProctoringPIISignature(TimeStampedModel):
class Meta:
app_label = 'agreements'
+
+
+class UserAgreementRecord(models.Model):
+ """
+ This model stores the agreements a user has accepted or acknowledged.
+
+ Each record here represents a user agreeing to the agreement type represented
+ by `agreement_type` at a particular time.
+
+ .. no_pii:
+ """
+ user = models.ForeignKey(User, db_index=True, on_delete=models.CASCADE)
+ agreement_type = models.CharField(max_length=255)
+ timestamp = models.DateTimeField(auto_now_add=True)
+
+ class Meta:
+ app_label = 'agreements'
diff --git a/openedx/core/djangoapps/agreements/serializers.py b/openedx/core/djangoapps/agreements/serializers.py
index 11e9d57f4054..0ecade7afd97 100644
--- a/openedx/core/djangoapps/agreements/serializers.py
+++ b/openedx/core/djangoapps/agreements/serializers.py
@@ -3,9 +3,10 @@
"""
from rest_framework import serializers
-from openedx.core.djangoapps.agreements.models import IntegritySignature, LTIPIISignature
from openedx.core.lib.api.serializers import CourseKeyField
+from .models import IntegritySignature, LTIPIISignature
+
class IntegritySignatureSerializer(serializers.ModelSerializer):
"""
@@ -31,3 +32,12 @@ class LTIPIISignatureSerializer(serializers.ModelSerializer):
class Meta:
model = LTIPIISignature
fields = ('username', 'course_id', 'lti_tools', 'created_at')
+
+
+class UserAgreementsSerializer(serializers.Serializer):
+ """
+ Serializer for UserAgreementRecord model
+ """
+ username = serializers.CharField(read_only=True)
+ agreement_type = serializers.CharField(read_only=True)
+ accepted_at = serializers.DateTimeField()
diff --git a/openedx/core/djangoapps/agreements/tests/test_api.py b/openedx/core/djangoapps/agreements/tests/test_api.py
index c66065789939..eb1e02956dc5 100644
--- a/openedx/core/djangoapps/agreements/tests/test_api.py
+++ b/openedx/core/djangoapps/agreements/tests/test_api.py
@@ -2,25 +2,29 @@
Tests for the Agreements API
"""
import logging
+from datetime import datetime, timedelta
+from django.test import TestCase
+from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture
from common.djangoapps.student.tests.factories import UserFactory
-from openedx.core.djangoapps.agreements.api import (
+from openedx.core.djangolib.testing.utils import skip_unless_lms
+from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
+from xmodule.modulestore.tests.factories import CourseFactory
+
+from ..api import (
create_integrity_signature,
+ create_lti_pii_signature,
+ create_user_agreement_record,
get_integrity_signature,
get_integrity_signatures_for_course,
+ get_lti_pii_signature,
get_pii_receiving_lti_tools,
- create_lti_pii_signature,
- get_lti_pii_signature
+ get_latest_user_agreement_record,
+ get_user_agreements
)
-from openedx.core.djangolib.testing.utils import skip_unless_lms
-from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase # lint-amnesty, pylint: disable=wrong-import-order
-from xmodule.modulestore.tests.factories import CourseFactory # lint-amnesty, pylint: disable=wrong-import-order
-from ..models import (
- LTIPIITool,
-)
-from opaque_keys.edx.keys import CourseKey
+from ..models import LTIPIITool
LOGGER_NAME = "openedx.core.djangoapps.agreements.api"
@@ -186,3 +190,37 @@ def _assert_ltitools(self, lti_list):
Helper function to assert the returned list has the correct tools
"""
self.assertEqual(self.lti_tools, lti_list)
+
+
+@skip_unless_lms
+class UserAgreementsTests(TestCase):
+ """
+ Tests for the python APIs related to user agreements.
+ """
+ def setUp(self):
+ self.user = UserFactory()
+
+ def test_get_user_agreements(self, ):
+ result = list(get_user_agreements(self.user))
+ assert len(result) == 0
+
+ record = create_user_agreement_record(self.user, 'test_type')
+ result = list(get_user_agreements(self.user))
+
+ assert len(result) == 1
+ assert result[0].agreement_type == 'test_type'
+ assert result[0].username == self.user.username
+ assert result[0].accepted_at == record.accepted_at
+
+ def test_get_user_agreement_record(self):
+ record = create_user_agreement_record(self.user, 'test_type')
+ result = get_latest_user_agreement_record(self.user, 'test_type')
+
+ assert result == record
+
+ result = get_latest_user_agreement_record(self.user, 'test_type', datetime.now() + timedelta(days=1))
+
+ assert result is None
+
+ def tearDown(self):
+ self.user.delete()
diff --git a/openedx/core/djangoapps/agreements/tests/test_views.py b/openedx/core/djangoapps/agreements/tests/test_views.py
index 4c52e5853f05..61cc8661fb43 100644
--- a/openedx/core/djangoapps/agreements/tests/test_views.py
+++ b/openedx/core/djangoapps/agreements/tests/test_views.py
@@ -2,26 +2,28 @@
Tests for agreements views
"""
+import json
from datetime import datetime, timedelta
from unittest.mock import patch
from django.conf import settings
from django.urls import reverse
-from rest_framework.test import APITestCase
-from rest_framework import status
from freezegun import freeze_time
-import json
+from rest_framework import status
+from rest_framework.test import APITestCase
-from common.djangoapps.student.tests.factories import UserFactory, AdminFactory
from common.djangoapps.student.roles import CourseStaffRole
-from openedx.core.djangoapps.agreements.api import (
+from common.djangoapps.student.tests.factories import AdminFactory, UserFactory
+from openedx.core.djangolib.testing.utils import skip_unless_lms
+from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
+from xmodule.modulestore.tests.factories import CourseFactory
+
+from ..api import (
create_integrity_signature,
+ create_user_agreement_record,
get_integrity_signatures_for_course,
get_lti_pii_signature
)
-from openedx.core.djangolib.testing.utils import skip_unless_lms
-from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase # lint-amnesty, pylint: disable=wrong-import-order
-from xmodule.modulestore.tests.factories import CourseFactory # lint-amnesty, pylint: disable=wrong-import-order
@skip_unless_lms
@@ -289,3 +291,54 @@ def test_post_lti_pii_signature(self):
signature = get_lti_pii_signature(self.user.username, self.course_id)
self.assertEqual(signature.user.username, self.user.username)
self.assertEqual(signature.lti_tools, self.lti_tools)
+
+
+@skip_unless_lms
+class UserAgreementsViewTests(APITestCase):
+ """
+ Tests for the UserAgreementsView
+ """
+
+ def setUp(self):
+ self.user = UserFactory(username="testuser", password="password")
+ self.url = reverse('user_agreements', kwargs={'agreement_type': 'sample_agreement'})
+ self.login()
+
+ def login(self):
+ self.client.login(username="testuser", password="password")
+
+ def test_get_user_agreement_record_no_data(self):
+ response = self.client.get(self.url)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+
+ def test_get_user_agreement_record_invalid_date(self):
+ response = self.client.get(self.url, {'after': 'invalid_date'})
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+
+ def test_get_user_agreement_record(self):
+ create_user_agreement_record(self.user, 'sample_agreement')
+ response = self.client.get(self.url)
+ assert response.status_code == status.HTTP_200_OK
+ assert 'accepted_at' in response.data
+
+ response = self.client.get(self.url, {"after": str(datetime.now() + timedelta(days=1))})
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+
+ def test_post_user_agreement(self):
+ with freeze_time("2024-11-21 12:00:00"):
+ response = self.client.post(self.url)
+ assert response.status_code == status.HTTP_201_CREATED
+
+ self.login()
+
+ response = self.client.get(self.url)
+ assert response.status_code == status.HTTP_200_OK
+
+ response = self.client.get(self.url, {"after": "2024-11-21T13:00:00Z"})
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+
+ response = self.client.post(self.url)
+ assert response.status_code == status.HTTP_201_CREATED
+
+ response = self.client.get(self.url, {"after": "2024-11-21T13:00:00Z"})
+ assert response.status_code == status.HTTP_200_OK
diff --git a/openedx/core/djangoapps/agreements/urls.py b/openedx/core/djangoapps/agreements/urls.py
index d9d009d65ac1..902f477a7087 100644
--- a/openedx/core/djangoapps/agreements/urls.py
+++ b/openedx/core/djangoapps/agreements/urls.py
@@ -3,9 +3,9 @@
"""
from django.conf import settings
-from django.urls import re_path
+from django.urls import path, re_path
-from .views import IntegritySignatureView, LTIPIISignatureView
+from .views import IntegritySignatureView, LTIPIISignatureView, UserAgreementsView
urlpatterns = [
re_path(r'^integrity_signature/{course_id}$'.format(
@@ -14,4 +14,5 @@
re_path(r'^lti_pii_signature/{course_id}$'.format(
course_id=settings.COURSE_ID_PATTERN
), LTIPIISignatureView.as_view(), name='lti_pii_signature'),
+ path("agreement/", UserAgreementsView.as_view(), name="user_agreements"),
]
diff --git a/openedx/core/djangoapps/agreements/views.py b/openedx/core/djangoapps/agreements/views.py
index cc928669ffdd..daf2ce09428c 100644
--- a/openedx/core/djangoapps/agreements/views.py
+++ b/openedx/core/djangoapps/agreements/views.py
@@ -2,21 +2,28 @@
Views served by the Agreements app
"""
+import edx_api_doc_tools as apidocs
+from django import forms
from django.conf import settings
+from drf_yasg import openapi
+from opaque_keys.edx.keys import CourseKey
from rest_framework import status
-from rest_framework.views import APIView
-from rest_framework.response import Response
from rest_framework.permissions import IsAuthenticated
-from opaque_keys.edx.keys import CourseKey
+from rest_framework.response import Response
+from rest_framework.views import APIView
from common.djangoapps.student import auth
from common.djangoapps.student.roles import CourseStaffRole
-from openedx.core.djangoapps.agreements.api import (
+
+from .api import (
create_integrity_signature,
create_lti_pii_signature,
+ create_user_agreement_record,
get_integrity_signature,
+ get_latest_user_agreement_record
)
-from openedx.core.djangoapps.agreements.serializers import IntegritySignatureSerializer, LTIPIISignatureSerializer
+from .serializers import IntegritySignatureSerializer, LTIPIISignatureSerializer, UserAgreementsSerializer
+from ...lib.api.view_utils import view_auth_classes
def is_user_course_or_global_staff(user, course_id):
@@ -159,3 +166,72 @@ def post(self, request, course_id):
else:
statusStr = status.HTTP_500_INTERNAL_SERVER_ERROR
return Response(data=serializer.data, status=statusStr)
+
+
+@view_auth_classes(is_authenticated=True)
+class UserAgreementsView(APIView):
+ """
+ Endpoint for the user agreements API.
+ """
+
+ class QueryFilterForm(forms.Form):
+ """
+ Query parameters for the GET method.
+ """
+ after = forms.DateTimeField(required=False)
+
+ @apidocs.schema(
+ parameters=[
+ apidocs.string_parameter(
+ 'agreement_type',
+ apidocs.ParameterLocation.PATH,
+ description="Agreement ID/Type",
+ ),
+ openapi.Parameter(
+ 'after',
+ apidocs.ParameterLocation.QUERY,
+ required=False,
+ type=openapi.TYPE_STRING,
+ format=openapi.FORMAT_DATETIME,
+ description="Return records after this date/time",
+ ),
+ ],
+ responses={
+ 200: UserAgreementsSerializer,
+ 400: "Bad Request",
+ 404: "Not Found",
+ },
+ )
+ def get(self, request, agreement_type):
+ """
+ Get a user's acknowledgement record for this agreement type.
+ """
+ params = UserAgreementsView.QueryFilterForm(request.query_params)
+ if not params.is_valid():
+ return Response(status=status.HTTP_400_BAD_REQUEST)
+ record = get_latest_user_agreement_record(request.user, agreement_type, params.cleaned_data.get('after'))
+ if record is None:
+ return Response(status=status.HTTP_404_NOT_FOUND)
+ serializer = UserAgreementsSerializer(record)
+ return Response(serializer.data)
+
+ @apidocs.schema(
+ parameters=[
+ apidocs.string_parameter(
+ 'agreement_type',
+ apidocs.ParameterLocation.PATH,
+ description="Agreement ID/Type",
+ ),
+ ],
+ responses={
+ 200: UserAgreementsSerializer,
+ 400: "Bad Request",
+ },
+ )
+ def post(self, request, agreement_type):
+ """
+ Marks a user's acknowledgement of this agreement type.
+ """
+ record = create_user_agreement_record(request.user, agreement_type)
+ serializer = UserAgreementsSerializer(record)
+ return Response(serializer.data, status=status.HTTP_201_CREATED)
diff --git a/openedx/core/djangoapps/courseware_api/views.py b/openedx/core/djangoapps/courseware_api/views.py
index 576f0472b770..8ad3d102f6fd 100644
--- a/openedx/core/djangoapps/courseware_api/views.py
+++ b/openedx/core/djangoapps/courseware_api/views.py
@@ -579,6 +579,7 @@ class SequenceMetadata(DeveloperErrorViewMixin, APIView):
authentication_classes = (
JwtAuthentication,
+ BearerAuthenticationAllowInactiveUser,
SessionAuthenticationAllowInactiveUser,
)
diff --git a/openedx/core/djangoapps/discussions/serializers.py b/openedx/core/djangoapps/discussions/serializers.py
index 88648a499598..6cc18605df9f 100644
--- a/openedx/core/djangoapps/discussions/serializers.py
+++ b/openedx/core/djangoapps/discussions/serializers.py
@@ -354,6 +354,12 @@ def _update_course_configuration(
key not in LegacySettingsSerializer.Meta.fields_cohorts
)
}
+ # toogle discussion tab is_hidden
+ for tab in course.tabs:
+ if tab.tab_id == 'discussion' and tab.is_hidden == instance.enabled:
+ tab.is_hidden = not instance.enabled
+ save = True
+ break
if save:
modulestore().update_item(course, self.context['user_id'])
return instance
diff --git a/openedx/core/djangoapps/discussions/tasks.py b/openedx/core/djangoapps/discussions/tasks.py
index fea20dc59bd4..4b08112f8266 100644
--- a/openedx/core/djangoapps/discussions/tasks.py
+++ b/openedx/core/djangoapps/discussions/tasks.py
@@ -196,7 +196,14 @@ def update_unit_discussion_state_from_discussion_blocks(course_key: CourseKey, u
"""
store = modulestore()
course = store.get_course(course_key)
- provider = course.discussions_settings.get('provider', None)
+ # The provider information has been written to both `provider_type` and `provider`.
+ # Both of these serve the same purpose and this is an accident of early development.
+ # The `provider_type` key is now treated as read-only to allow existing values
+ # to be respected while moving to the `provider` key in the future.
+ provider = course.discussions_settings.get(
+ 'provider_type',
+ course.discussions_settings.get('provider', None),
+ )
# Only migrate to the new discussion provider if the current provider is the legacy provider.
log.info(f"Current provider for {course_key} is {provider}")
if provider is not None and provider != Provider.LEGACY and not force:
diff --git a/openedx/core/djangoapps/django_comment_common/comment_client/user.py b/openedx/core/djangoapps/django_comment_common/comment_client/user.py
index eaac6b408659..ad259c937630 100644
--- a/openedx/core/djangoapps/django_comment_common/comment_client/user.py
+++ b/openedx/core/djangoapps/django_comment_common/comment_client/user.py
@@ -238,7 +238,7 @@ def _retrieve(self, *args, **kwargs):
if course_id:
course_id = str(course_id)
retrieve_params['course_id'] = course_id
- course_key = utils.get_course_key(course_id)
+ course_key = utils.get_course_key(course_id) or utils.get_course_key(kwargs.get("course_key"))
if is_forum_v2_enabled(course_key):
group_ids = [retrieve_params['group_id']] if 'group_id' in retrieve_params else []
@@ -251,6 +251,7 @@ def _retrieve(self, *args, **kwargs):
complete=is_complete
)
except ForumV2RequestError as e:
+ course_id = str(course_key)
self.save({"course_id": course_id})
response = forum_api.get_user(
self.attributes["id"],
diff --git a/openedx/core/djangoapps/schedules/management/commands/__init__.py b/openedx/core/djangoapps/schedules/management/commands/__init__.py
index bd0082f5331e..0b7255976563 100644
--- a/openedx/core/djangoapps/schedules/management/commands/__init__.py
+++ b/openedx/core/djangoapps/schedules/management/commands/__init__.py
@@ -29,12 +29,28 @@ def add_arguments(self, parser):
'--override-recipient-email',
help='Send all emails to this address instead of the actual recipient'
)
- parser.add_argument('site_domain_name')
+ parser.add_argument(
+ 'site_domain_name',
+ nargs='?',
+ default=None,
+ help=(
+ 'Domain name for the site to use. '
+ 'Do not provide a domain if you wish to run this for all sites'
+ )
+ )
parser.add_argument(
'--weeks',
type=int,
help='Number of weekly emails to be sent',
)
+ parser.add_argument(
+ '--override-middlewares',
+ action='append',
+ help=(
+ 'Use this middleware when emulating http requests. '
+ 'To use multiple middlewares, provide this argument multiple times'
+ )
+ )
def handle(self, *args, **options):
self.log_debug('Args = %r', options)
@@ -49,19 +65,26 @@ def handle(self, *args, **options):
tzinfo=pytz.UTC
)
self.log_debug('Current date = %s', current_date.isoformat())
+ override_recipient_email = options.get('override_recipient_email')
+ override_middlewares = options.get('override_middlewares')
- site = Site.objects.get(domain__iexact=options['site_domain_name'])
- self.log_debug('Running for site %s', site.domain)
+ site_domain_name = options['site_domain_name']
+ sites = Site.objects.filter(domain__iexact=site_domain_name) if site_domain_name else Site.objects.all()
- override_recipient_email = options.get('override_recipient_email')
- self.send_emails(site, current_date, override_recipient_email)
+ if sites:
+ for site in sites:
+ self.log_debug('Running for site %s', site.domain)
+ self.send_emails(site, current_date, override_recipient_email, override_middlewares)
+ else:
+ self.log_info("No matching site found")
- def enqueue(self, day_offset, site, current_date, override_recipient_email=None):
+ def enqueue(self, day_offset, site, current_date, override_recipient_email=None, override_middlewares=None):
self.async_send_task.enqueue(
site,
current_date,
day_offset,
override_recipient_email,
+ override_middlewares,
)
def send_emails(self, *args, **kwargs):
diff --git a/openedx/core/djangoapps/schedules/management/commands/tests/test_send_email_base_command.py b/openedx/core/djangoapps/schedules/management/commands/tests/test_send_email_base_command.py
index 47ba67cc0999..11b33d585195 100644
--- a/openedx/core/djangoapps/schedules/management/commands/tests/test_send_email_base_command.py
+++ b/openedx/core/djangoapps/schedules/management/commands/tests/test_send_email_base_command.py
@@ -10,6 +10,7 @@
import ddt
import pytz
from django.conf import settings
+from django.contrib.sites.models import Site
from openedx.core.djangoapps.schedules.management.commands import SendEmailBaseCommand
from openedx.core.djangoapps.site_configuration.tests.factories import SiteConfigurationFactory, SiteFactory
@@ -33,9 +34,23 @@ def test_handle(self):
send_emails.assert_called_once_with(
self.site,
datetime.datetime(2017, 9, 29, tzinfo=pytz.UTC),
+ None,
None
)
+ def test_handle_all_sites(self):
+ with patch.object(self.command, 'send_emails') as send_emails:
+ self.command.handle(site_domain_name=None, date='2017-09-29')
+ expected_sites = Site.objects.all()
+ for expected_site in expected_sites:
+ send_emails.assert_any_call(
+ expected_site,
+ datetime.datetime(2017, 9, 29, tzinfo=pytz.UTC),
+ None,
+ None
+ )
+ assert send_emails.call_count == len(expected_sites)
+
def test_weeks_option(self):
with patch.object(self.command, 'enqueue') as enqueue:
self.command.handle(site_domain_name=self.site.domain, date='2017-09-29', weeks=12)
diff --git a/openedx/core/djangoapps/schedules/tasks.py b/openedx/core/djangoapps/schedules/tasks.py
index 6a999d3dd8ed..bdb510a3348c 100644
--- a/openedx/core/djangoapps/schedules/tasks.py
+++ b/openedx/core/djangoapps/schedules/tasks.py
@@ -20,6 +20,7 @@
set_custom_attribute
)
from eventtracking import tracker
+from importlib import import_module
from opaque_keys.edx.keys import CourseKey
from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
@@ -103,7 +104,7 @@ class BinnedScheduleMessageBaseTask(ScheduleMessageBaseTask):
task_instance = None
@classmethod
- def enqueue(cls, site, current_date, day_offset, override_recipient_email=None): # lint-amnesty, pylint: disable=missing-function-docstring
+ def enqueue(cls, site, current_date, day_offset, override_recipient_email=None, override_middlewares=None): # lint-amnesty, pylint: disable=missing-function-docstring
set_code_owner_attribute_from_module(__name__)
current_date = resolvers._get_datetime_beginning_of_day(current_date) # lint-amnesty, pylint: disable=protected-access
@@ -120,6 +121,7 @@ def enqueue(cls, site, current_date, day_offset, override_recipient_email=None):
day_offset,
bin,
override_recipient_email,
+ override_middlewares,
)
cls.log_info('Launching task with args = %r', task_args)
cls.task_instance.apply_async(
@@ -128,16 +130,17 @@ def enqueue(cls, site, current_date, day_offset, override_recipient_email=None):
)
def run( # lint-amnesty, pylint: disable=arguments-differ
- self, site_id, target_day_str, day_offset, bin_num, override_recipient_email=None,
+ self, site_id, target_day_str, day_offset, bin_num, override_recipient_email=None, override_middlewares=None,
):
set_code_owner_attribute_from_module(__name__)
site = Site.objects.select_related('configuration').get(id=site_id)
- with emulate_http_request(site=site):
+ middlewares = [self.class_from_classpath(cls) for cls in override_middlewares] if override_middlewares else None
+ with emulate_http_request(site=site, middleware_classes=middlewares) as request:
msg_type = self.make_message_type(day_offset)
- _annotate_for_monitoring(msg_type, site, bin_num, target_day_str, day_offset)
+ _annotate_for_monitoring(msg_type, request.site, bin_num, target_day_str, day_offset)
return self.resolver( # lint-amnesty, pylint: disable=not-callable
self.async_send_task,
- site,
+ request.site,
deserialize(target_day_str),
day_offset,
bin_num,
@@ -147,6 +150,11 @@ def run( # lint-amnesty, pylint: disable=arguments-differ
def make_message_type(self, day_offset):
raise NotImplementedError
+ def class_from_classpath(self, class_path):
+ module_name, klass = class_path.rsplit('.', 1)
+ module = import_module(module_name)
+ return getattr(module, klass)
+
@shared_task(base=LoggedTask, ignore_result=True)
@set_code_owner_attribute
diff --git a/openedx/core/djangoapps/user_authn/views/login.py b/openedx/core/djangoapps/user_authn/views/login.py
index 266dec73b755..a2141104505a 100644
--- a/openedx/core/djangoapps/user_authn/views/login.py
+++ b/openedx/core/djangoapps/user_authn/views/login.py
@@ -81,7 +81,7 @@ def _do_third_party_auth(request):
try:
return pipeline.get_authenticated_user(requested_provider, username, third_party_uid)
- except USER_MODEL.DoesNotExist:
+ except USER_MODEL.DoesNotExist as err:
AUDIT_LOG.info(
"Login failed - user with username {username} has no social auth "
"with backend_name {backend_name}".format(username=username, backend_name=backend_name)
@@ -102,7 +102,13 @@ def _do_third_party_auth(request):
register_label_strong=HTML("{register_text}").format(register_text=_("Register")),
)
- raise AuthFailedError(message, error_code="third-party-auth-with-no-linked-account") # lint-amnesty, pylint: disable=raise-missing-from
+ redirect_url = configuration_helpers.get_value('OC_REDIRECT_ON_TPA_UNLINKED_ACCOUNT', None)
+
+ raise AuthFailedError(
+ message,
+ error_code='third-party-auth-with-no-linked-account',
+ redirect_url=redirect_url
+ ) from err
def _get_user_by_email(email):
diff --git a/openedx/core/djangoapps/user_authn/views/logout.py b/openedx/core/djangoapps/user_authn/views/logout.py
index 616b792b9f22..083e8d4ccb0d 100644
--- a/openedx/core/djangoapps/user_authn/views/logout.py
+++ b/openedx/core/djangoapps/user_authn/views/logout.py
@@ -8,7 +8,6 @@
import nh3
from django.conf import settings
from django.contrib.auth import logout
-from django.shortcuts import redirect
from django.utils.http import urlencode
from django.views.generic import TemplateView
from oauth2_provider.models import Application
@@ -47,7 +46,13 @@ def target(self):
If a redirect_url is specified in the querystring for this request, and the value is a safe
url for redirect, the view will redirect to this page after rendering the template.
If it is not specified, we will use the default target url.
+ Redirect to tpa_logout_url if TPA_AUTOMATIC_LOGOUT_ENABLED is set to True and if
+ tpa_logout_url is configured.
"""
+
+ if getattr(settings, 'TPA_AUTOMATIC_LOGOUT_ENABLED', False) and self.tpa_logout_url:
+ return self.tpa_logout_url
+
target_url = self.request.GET.get('redirect_url') or self.request.GET.get('next')
# Some third party apps do not build URLs correctly and send next query param without URL-encoding, resulting
@@ -85,16 +90,6 @@ def dispatch(self, request, *args, **kwargs):
mark_user_change_as_expected(None)
- # Redirect to tpa_logout_url if TPA_AUTOMATIC_LOGOUT_ENABLED is set to True and if
- # tpa_logout_url is configured.
- #
- # NOTE: This step skips rendering logout.html, which is used to log the user out from the
- # different IDAs. To ensure the user is logged out of all the IDAs be sure to redirect
- # back to /logout after logging out of the TPA.
- if getattr(settings, 'TPA_AUTOMATIC_LOGOUT_ENABLED', False):
- if self.tpa_logout_url:
- return redirect(self.tpa_logout_url)
-
return response
def _build_logout_url(self, url):
diff --git a/openedx/core/djangoapps/user_authn/views/tests/test_logout.py b/openedx/core/djangoapps/user_authn/views/tests/test_logout.py
index c59969c2d00d..77c21c86e1b1 100644
--- a/openedx/core/djangoapps/user_authn/views/tests/test_logout.py
+++ b/openedx/core/djangoapps/user_authn/views/tests/test_logout.py
@@ -211,8 +211,10 @@ def test_automatic_tpa_logout_url_redirect(self):
mock_idp_logout_url.return_value = idp_logout_url
self._authenticate_with_oauth(client)
response = self.client.get(reverse('logout'))
- assert response.status_code == 302
- assert response.url == idp_logout_url
+ expected = {
+ 'target': idp_logout_url,
+ }
+ self.assertDictContainsSubset(expected, response.context_data)
@mock.patch('django.conf.settings.TPA_AUTOMATIC_LOGOUT_ENABLED', True)
def test_no_automatic_tpa_logout_without_logout_url(self):
diff --git a/openedx/core/lib/celery/task_utils.py b/openedx/core/lib/celery/task_utils.py
index 9a54f1b3a550..10a3809fa651 100644
--- a/openedx/core/lib/celery/task_utils.py
+++ b/openedx/core/lib/celery/task_utils.py
@@ -45,7 +45,7 @@ def emulate_http_request(site=None, user=None, middleware_classes=None):
_run_method_if_implemented(middleware, 'process_request', request)
try:
- yield
+ yield request
except Exception as exc:
for middleware in reversed(middleware_instances):
_run_method_if_implemented(middleware, 'process_exception', request, exc)