diff --git a/RELEASE.rst b/RELEASE.rst index 32757f41e6..5245bddbcd 100644 --- a/RELEASE.rst +++ b/RELEASE.rst @@ -1,6 +1,25 @@ Release Notes ============= +Version 0.136.1 +--------------- + +- Remove social auth (#3127) +- REmoved banner and footer (#3120) +- Handle blank signatories in migrate_edx_data (#3125) +- add a management command backfill_certificate_page_revision (#3115) +- fix: Fix LTI based duplicate email users (#3098) +- fix: sync keycloak users to hubspot on create (#3119) +- fix(deps): update dependency social-auth-app-django to v5.6.0 [security] (#3005) +- Add multiple items to cart (#3117) +- Upgrade django to 5.1.15 (#3108) +- Removed option related to unused pgbouncer (#3118) +- Fix close button on add to cart confirm modal (#3114) +- Update retrieve user api to be versioned (#3102) +- Fix a 500 error when previewing CertificatePage from CMS (#3116) +- Hubspot course name fix (#3100) +- Fix b2b management commands (#3113) + Version 0.135.6 (Released December 01, 2025) --------------- diff --git a/authentication/backends/apisix_remote_user_org.py b/authentication/backends/apisix_remote_user_org.py index e3b711ccae..4a097f4e6d 100644 --- a/authentication/backends/apisix_remote_user_org.py +++ b/authentication/backends/apisix_remote_user_org.py @@ -10,6 +10,7 @@ from mitol.apigateway.backends import ApisixRemoteUserBackend from b2b.api import reconcile_user_orgs +from hubspot_sync.task_helpers import sync_hubspot_user log = logging.getLogger(__name__) @@ -42,4 +43,12 @@ def configure_user(self, request, user, *args, created=True): # Task should check to see if it needs to run or not reconcile_user_orgs(user, org_uuids) + if created: + log.info( + "New user created via APISIX/Keycloak, syncing to HubSpot: user_id=%s, email=%s", + user.id, + user.email, + ) + sync_hubspot_user(user) + return user diff --git a/authentication/backends/ol_open_id_connect.py b/authentication/backends/ol_open_id_connect.py deleted file mode 100644 index 440a3b6845..0000000000 --- a/authentication/backends/ol_open_id_connect.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Keycloak Authentication Configuration""" - -import logging - -from social_core.backends.open_id_connect import OpenIdConnectAuth - -log = logging.getLogger(__name__) - - -class OlOpenIdConnectAuth(OpenIdConnectAuth): - """ - Custom wrapper class for adding additional functionality to the - OpenIdConnectAuth child class. - """ - - name = "ol-oidc" - REQUIRES_EMAIL_VALIDATION = False - - def get_user_details(self, response): - """Get the user details from the API response""" - details = super().get_user_details(response) - - return { - **details, - "global_id": response.get("sub", None), - "name": response.get("name", None), - "is_active": True, - "profile": { - "name": response.get("name", ""), - "email_optin": bool(int(response["email_optin"])) - if "email_optin" in response - else None, - }, - } - - def __str__(self): - return "OL OpenID Connect (ol-oidc)" diff --git a/authentication/exceptions.py b/authentication/exceptions.py deleted file mode 100644 index 32ae7712c7..0000000000 --- a/authentication/exceptions.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Authentication exceptions""" - -from social_core.exceptions import AuthException - - -class RequireProviderException(AuthException): - """The user is required to authenticate via a specific provider/backend""" - - def __init__(self, backend, social_auth): - """ - Args: - social_auth (social_django.models.UserSocialAuth): A social auth objects - """ - self.social_auth = social_auth - super().__init__(backend) - - -class PartialException(AuthException): - """Partial pipeline exception""" - - def __init__(self, backend, partial, errors=None, field_errors=None): - self.partial = partial - self.errors = errors - self.field_errors = field_errors - super().__init__(backend) - - -class InvalidPasswordException(PartialException): - """Provided password was invalid""" - - def __str__(self): - return "Unable to login with that email and password combination" - - -class RequireEmailException(PartialException): - """Authentication requires an email""" - - def __str__(self): - return "Email is required to login" - - -class RequireRegistrationException(PartialException): - """Authentication requires registration""" - - def __str__(self): - return "There is no account with that email" - - -class RequirePasswordException(PartialException): - """Authentication requires a password""" - - def __str__(self): - return "Password is required to login" - - -class EmailBlockedException(PartialException): - """Raised if a user's email is marked blocked""" - - def __str__(self): - return "Email address is marked blocked" - - -class RequirePasswordAndPersonalInfoException(PartialException): - """Authentication requires a password and address""" - - def __str__(self): - return "Password and address need to be filled out" - - -class RequireProfileException(PartialException): - """Authentication requires a profile""" - - def __str__(self): - return "Profile needs to be filled out" - - -class UnexpectedExistingUserException(PartialException): - """Raised if a user already exists but shouldn't in the given pipeline step""" - - -class UserCreationFailedException(PartialException): - """Raised if user creation with a generated username failed""" - - -class UserExportBlockedException(AuthException): - """The user is blocked for export reasons from continuing to sign up""" - - def __init__(self, backend, reason_code): - super().__init__(backend) - self.reason_code = reason_code - - -class UserTryAgainLaterException(AuthException): - """The user should try to register again later""" - - -class UserMissingSocialAuthException(Exception): # noqa: N818 - """Raised if the user doesn't have a social auth""" diff --git a/authentication/middleware.py b/authentication/middleware.py deleted file mode 100644 index 0bd31ea87a..0000000000 --- a/authentication/middleware.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Authentication middleware""" - -from urllib.parse import quote - -from django.shortcuts import redirect -from social_core.exceptions import SocialAuthBaseException -from social_django.middleware import SocialAuthExceptionMiddleware - - -class SocialAuthExceptionRedirectMiddleware(SocialAuthExceptionMiddleware): - """ - This middleware subclasses SocialAuthExceptionMiddleware and overrides - process_exception to provide an implementation that does not use - django.contrib.messages and instead only issues a redirect - """ - - def process_exception(self, request, exception): - """ - Note: this is a subset of the SocialAuthExceptionMiddleware implementation - """ - strategy = getattr(request, "social_strategy", None) - if strategy is None or self.raise_exception(request, exception): - return None - - if isinstance(exception, SocialAuthBaseException): # noqa: RET503 - backend = getattr(request, "backend", None) - backend_name = getattr(backend, "name", "unknown-backend") - - message = self.get_message(request, exception) - url = self.get_redirect_uri(request, exception) - - if url: # noqa: RET503 - url += ( - ("?" in url and "&") or "?" - ) + f"message={quote(message)}&backend={backend_name}" - return redirect(url) diff --git a/authentication/middleware_test.py b/authentication/middleware_test.py deleted file mode 100644 index ea1f3bf930..0000000000 --- a/authentication/middleware_test.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Tests for auth middleware""" - -from urllib.parse import quote - -from django.contrib.sessions.middleware import SessionMiddleware -from django.shortcuts import reverse -from rest_framework import status -from social_core.exceptions import AuthAlreadyAssociated -from social_django.utils import load_backend, load_strategy - -from authentication.middleware import SocialAuthExceptionRedirectMiddleware - - -def test_process_exception_no_strategy(mocker, rf, settings): - """Tests that if the request has no strategy it does nothing""" - settings.DEBUG = False - get_response = mocker.MagicMock() - request = rf.get(reverse("social:complete", args=("email",))) - middleware = SocialAuthExceptionRedirectMiddleware(get_response) - assert middleware.process_exception(request, None) is None - - -def test_process_exception(mocker, rf, settings): - """Tests that a process_exception handles auth exceptions correctly""" - settings.DEBUG = False - request = rf.get(reverse("social:complete", args=("email",))) - # social_django depends on request.sesssion, so use the middleware to set that - get_response = mocker.MagicMock() - SessionMiddleware(get_response).process_request(request) - strategy = load_strategy(request) - backend = load_backend(strategy, "email", None) - request.social_strategy = strategy - request.backend = backend - - middleware = SocialAuthExceptionRedirectMiddleware(get_response) - error = AuthAlreadyAssociated(backend) - result = middleware.process_exception(request, error) - assert result.status_code == status.HTTP_302_FOUND - assert ( - result.url - == f"{reverse(settings.LOGIN_URL)}?message={quote(error.__str__())}&backend={backend.name}" - ) - - -def test_process_exception_non_auth_error(mocker, rf, settings): - """Tests that a process_exception handles non-auth exceptions correctly""" - settings.DEBUG = False - request = rf.get(reverse("social:complete", args=("email",))) - # social_django depends on request.sesssion, so use the middleware to set that - get_response = mocker.MagicMock() - SessionMiddleware(get_response).process_request(request) - strategy = load_strategy(request) - backend = load_backend(strategy, "email", None) - request.social_strategy = strategy - request.backend = backend - - middleware = SocialAuthExceptionRedirectMiddleware(get_response) - assert ( - middleware.process_exception(request, Exception("something bad happened")) - is None - ) diff --git a/authentication/pipeline/__init__.py b/authentication/pipeline/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/authentication/pipeline/user.py b/authentication/pipeline/user.py deleted file mode 100644 index f2655136e6..0000000000 --- a/authentication/pipeline/user.py +++ /dev/null @@ -1,288 +0,0 @@ -"""Auth pipline functions for user authentication""" - -import logging - -from django.db import IntegrityError -from mitol.common.utils import dict_without_keys -from social_core.backends.email import EmailAuth -from social_core.exceptions import AuthException -from social_core.pipeline.partial import partial -from social_core.pipeline.user import create_user - -from authentication.backends.ol_open_id_connect import OlOpenIdConnectAuth -from authentication.exceptions import ( - EmailBlockedException, - InvalidPasswordException, - RequireEmailException, - RequirePasswordAndPersonalInfoException, - RequirePasswordException, - RequireRegistrationException, - UnexpectedExistingUserException, - UserCreationFailedException, -) -from authentication.utils import SocialAuthState, is_user_email_blocked -from openedx import api as openedx_api -from openedx import tasks as openedx_tasks -from users.serializers import UserSerializer - -log = logging.getLogger() - -CREATE_OPENEDX_USER_RETRY_DELAY = 60 -NAME_MIN_LENGTH = 2 - -# pylint: disable=keyword-arg-before-vararg - - -def forbid_hijack(strategy, backend, **kwargs): # pylint: disable=unused-argument # noqa: ARG001 - """ - Forbid an admin user from trying to login/register while hijacking another user - - Args: - strategy (social_django.strategy.DjangoStrategy): the strategy used to authenticate - backend (social_core.backends.base.BaseAuth): the backend being used to authenticate - """ - # As first step in pipeline, stop a hijacking admin from going any further - if bool(strategy.session_get("hijack_history")): - raise AuthException("You are hijacking another user, don't try to login again") # noqa: EM101 - return {} - - -# Pipeline steps for OIDC logins - - -def create_ol_oidc_user(strategy, details, backend, user=None, *args, **kwargs): - """ - Create the user if we're using the ol-oidc backend. - - This also does a blocked user check and makes sure there's an email address. - If the created user is new, we make sure they're set active. (If the user is - inactive, they'll get knocked out of the pipeline elsewhere.) - """ - - if backend.name != OlOpenIdConnectAuth.name: - return {} - - if "email" not in details: - raise RequireEmailException(backend, None) - - if "email" in details and is_user_email_blocked(details["email"]): - raise EmailBlockedException(backend, None) - - retval = create_user(strategy, details, backend, user, *args, **kwargs) - - # When we have deprecated direct login, remove this and default the is_active - # flag to True in the User model. - if retval.get("is_new"): - retval["user"].is_active = True - retval["user"].save() - - return retval - - -# Pipeline steps for email logins - - -def validate_email_auth_request(strategy, backend, user=None, *args, **kwargs): # pylint: disable=unused-argument # noqa: ARG001 - """ - Validates an auth request for email - - Args: - strategy (social_django.strategy.DjangoStrategy): the strategy used to authenticate - backend (social_core.backends.base.BaseAuth): the backend being used to authenticate - user (User): the current user - """ - if backend.name != EmailAuth.name: - return {} - - # if there's a user, force this to be a login - if user is not None: - return {"flow": SocialAuthState.FLOW_LOGIN} - - return {} - - -def get_username(strategy, backend, user=None, details=None, *args, **kwargs): # pylint: disable=unused-argument # noqa: ARG001 - """ - Gets the username for a user - - Args: - strategy (social_django.strategy.DjangoStrategy): the strategy used to authenticate - backend (social_core.backends.base.BaseAuth): the backend being used to authenticate - user (User): the current user - """ - - if backend and backend.name == OlOpenIdConnectAuth.name: - return {"username": details["username"] if not user else user.edx_username} - - return {"username": None if not user else strategy.storage.user.get_username(user)} - - -@partial -def create_user_via_email( - strategy, - backend, - user=None, - flow=None, - current_partial=None, - *args, # noqa: ARG001 - **kwargs, -): # pylint: disable=too-many-arguments,unused-argument - """ - Creates a new user if needed and sets the password and name. - Args: - strategy (social_django.strategy.DjangoStrategy): the strategy used to authenticate - backend (social_core.backends.base.BaseAuth): the backend being used to authenticate - user (User): the current user - details (dict): Dict of user details - flow (str): the type of flow (login or register) - current_partial (Partial): the partial for the step in the pipeline - - Raises: - RequirePasswordAndPersonalInfoException: if the user hasn't set password or name - """ - if backend.name != EmailAuth.name or flow != SocialAuthState.FLOW_REGISTER: - return {} - - if user is not None: - raise UnexpectedExistingUserException(backend, current_partial) - - context = {} - data = strategy.request_data().copy() - expected_data_fields = {"name", "password", "username"} - if any(field for field in expected_data_fields if field not in data): - raise RequirePasswordAndPersonalInfoException(backend, current_partial) - if len(data.get("name", 0)) < NAME_MIN_LENGTH: - raise RequirePasswordAndPersonalInfoException( - backend, - current_partial, - errors=["Full name must be at least 2 characters long."], - ) - - data["email"] = kwargs.get("email", kwargs.get("details", {}).get("email")) - data["is_active"] = True - serializer = UserSerializer(data=data, context=context) - - if not serializer.is_valid(): - e = RequirePasswordAndPersonalInfoException( - backend, - current_partial, - errors=serializer.errors.get("non_field_errors"), - field_errors=dict_without_keys(serializer.errors, "non_field_errors"), - ) - - raise e - - try: - created_user = serializer.save() - except IntegrityError: - # 'email' and 'username' are the only unique fields that can be supplied by the user at this point, and a user - # cannot reach this point of the auth flow without a unique email, so we know that the IntegrityError is caused - # by the username not being unique. - username = data["username"] - raise RequirePasswordAndPersonalInfoException( # noqa: B904 - backend, - current_partial, - field_errors={ - "username": f"The username '{username}' is already taken. Please try a different username." - }, - ) - except Exception as exc: - raise UserCreationFailedException(backend, current_partial) from exc - - return {"is_new": True, "user": created_user, "username": created_user.edx_username} - - -@partial -def validate_email( - strategy, - backend, - user=None, # noqa: ARG001 - flow=None, # noqa: ARG001 - current_partial=None, - *args, # noqa: ARG001 - **kwargs, # noqa: ARG001 -): # pylint: disable=unused-argument - """ - Validates a user's email for register - - Args: - strategy (social_django.strategy.DjangoStrategy): the strategy used to authenticate - backend (social_core.backends.base.BaseAuth): the backend being used to authenticate - user (User): the current user - flow (str): the type of flow (login or register) - current_partial (Partial): the partial for the step in the pipeline - - Raises: - EmailBlockedException: if the user email is blocked - """ - data = strategy.request_data() - authentication_flow = data.get("flow") - if authentication_flow == SocialAuthState.FLOW_REGISTER and "email" in data: # noqa: SIM102 - if is_user_email_blocked(data["email"]): - raise EmailBlockedException(backend, current_partial) - return {} - - -@partial -def validate_password( - strategy, - backend, - user=None, - flow=None, - current_partial=None, - *args, # noqa: ARG001 - **kwargs, # noqa: ARG001 -): # pylint: disable=unused-argument - """ - Validates a user's password for login - - Args: - strategy (social_django.strategy.DjangoStrategy): the strategy used to authenticate - backend (social_core.backends.base.BaseAuth): the backend being used to authenticate - user (User): the current user - flow (str): the type of flow (login or register) - current_partial (Partial): the partial for the step in the pipeline - - Raises: - RequirePasswordException: if the user password is not provided - InvalidPasswordException: if the password does not match the user, or the user is not active. - """ - if backend.name != EmailAuth.name or flow != SocialAuthState.FLOW_LOGIN: - return {} - - data = strategy.request_data() - if user is None: - raise RequireRegistrationException(backend, current_partial) - - if "password" not in data: - raise RequirePasswordException(backend, current_partial) - - password = data["password"] - - if not user or not user.check_password(password) or not user.is_active: - raise InvalidPasswordException(backend, current_partial) - - return {} - - -def create_openedx_user(strategy, backend, user=None, is_new=False, **kwargs): # pylint: disable=unused-argument # noqa: FBT002, ARG001 - """ - Create a user in the openedx, deferring a retry via celery if it fails - - Args: - user (users.models.User): the user that was just created - is_new (bool): True if the user was just created - """ - if not is_new or not user.is_active: - return {} - - try: - openedx_api.create_user(user) - except Exception: # pylint: disable=broad-except - log.exception("Error creating openedx user records on User create") - # try again later - openedx_tasks.create_user_from_id.apply_async( - (user.id,), countdown=CREATE_OPENEDX_USER_RETRY_DELAY - ) - - return {} diff --git a/authentication/pipeline/user_test.py b/authentication/pipeline/user_test.py deleted file mode 100644 index ad6335e229..0000000000 --- a/authentication/pipeline/user_test.py +++ /dev/null @@ -1,630 +0,0 @@ -"""Tests of user pipeline actions""" -# pylint: disable=redefined-outer-name - -import faker -import pytest -import responses -from django.contrib.auth import get_user_model -from django.contrib.sessions.middleware import SessionMiddleware -from django.core.exceptions import ObjectDoesNotExist -from django.db import IntegrityError -from rest_framework import status -from social_core.backends.email import EmailAuth -from social_django.utils import load_backend, load_strategy - -from authentication.exceptions import ( - EmailBlockedException, - InvalidPasswordException, - RequirePasswordAndPersonalInfoException, - RequirePasswordException, - RequireRegistrationException, - UnexpectedExistingUserException, - UserCreationFailedException, -) -from authentication.pipeline import user as user_actions -from authentication.utils import SocialAuthState -from openedx.api import OPENEDX_REGISTRATION_VALIDATION_PATH -from users.factories import UserFactory - -User = get_user_model() -FAKE = faker.Faker() - - -@pytest.fixture -def backend_settings(settings): - """A dictionary of settings for the backend""" - return {"USER_FIELDS": settings.SOCIAL_AUTH_EMAIL_USER_FIELDS} - - -@pytest.fixture -def mock_email_backend(mocker, backend_settings): - """Fixture that returns a fake EmailAuth backend object""" - backend = mocker.Mock() - backend.name = "email" - backend.setting.side_effect = lambda key, default, **kwargs: backend_settings.get( # noqa: ARG005 - key, default - ) - return backend - - -@pytest.fixture -def mock_ol_oidc_backend(mocker, backend_settings): - """Fixture that returns a fake OlOpenIdConnectAuth backend object""" - backend = mocker.Mock() - backend.name = "ol-oidc" - backend.setting.side_effect = lambda key, default, **kwargs: backend_settings.get( # noqa: ARG005 - key, default - ) - return backend - - -@pytest.fixture -def mock_create_user_strategy(mocker): - """Fixture that returns a valid strategy for create_user_via_email""" - strategy = mocker.Mock() - strategy.request = mocker.Mock() - strategy.request_data.return_value = { - "name": "Jane Doe", - "password": "password1", - "username": "custom-username", - "legal_address": { - "country": "US", - "state": "US-MA", - }, - } - return strategy - - -@pytest.fixture -def application(settings): - """Test data and settings needed for create_edx_user tests""" - settings.OPENEDX_API_BASE_URL = "http://example.com" - - -def validate_email_auth_request_not_email_backend(mocker): - """Tests that validate_email_auth_request return if not using the email backend""" - mock_strategy = mocker.Mock() - mock_backend = mocker.Mock() - mock_backend.name = "notemail" - assert user_actions.validate_email_auth_request(mock_strategy, mock_backend) == {} - - -@pytest.mark.parametrize( - "has_user,expected", # noqa: PT006 - [(True, {"flow": SocialAuthState.FLOW_LOGIN}), (False, {})], -) -@pytest.mark.django_db -def test_validate_email_auth_request(mocker, rf, has_user, expected): - """Test that validate_email_auth_request returns correctly given the input""" - request = rf.post("/complete/email") - get_response = mocker.MagicMock() - middleware = SessionMiddleware(get_response) - middleware.process_request(request) - request.session.save() - strategy = load_strategy(request) - backend = load_backend(strategy, "email", None) - - user = UserFactory.create() if has_user else None - - assert ( - user_actions.validate_email_auth_request( - strategy, backend, pipeline_index=0, user=user - ) - == expected - ) - - -def test_get_username(mocker, user): - """Tests that we get a username for a new user""" - mock_strategy = mocker.Mock() - mock_strategy.storage.user.get_username.return_value = user.edx_username - assert user_actions.get_username(mock_strategy, None, user) == { - "username": user.edx_username - } - mock_strategy.storage.user.get_username.assert_called_once_with(user) - - -def test_get_username_no_user(mocker): - """Tests that get_username returns None if there is no User""" - mock_strategy = mocker.Mock() - assert user_actions.get_username(mock_strategy, None, None)["username"] is None - mock_strategy.storage.user.get_username.assert_not_called() - - -def test_user_password_not_email_backend(mocker): - """Tests that user_password return if not using the email backend""" - mock_strategy = mocker.MagicMock() - mock_user = mocker.Mock() - mock_backend = mocker.Mock() - mock_backend.name = "notemail" - assert ( - user_actions.validate_password( - mock_strategy, - mock_backend, - pipeline_index=0, - user=mock_user, - flow=SocialAuthState.FLOW_LOGIN, - ) - == {} - ) - # make sure we didn't update or check the password - mock_user.set_password.assert_not_called() - mock_user.save.assert_not_called() - mock_user.check_password.assert_not_called() - - -@pytest.mark.parametrize("user_password", ["abc123", "def456"]) -def test_user_password_login(mocker, rf, user, user_password): - """Tests that user_password works for login case""" - request_password = "abc123" # noqa: S105 - user.set_password(user_password) - user.save() - request = rf.post( - "/complete/email", {"password": request_password, "email": user.email} - ) - get_response = mocker.MagicMock() - middleware = SessionMiddleware(get_response) - middleware.process_request(request) - request.session.save() - strategy = load_strategy(request) - backend = load_backend(strategy, "email", None) - - if request_password == user_password: - assert ( - user_actions.validate_password( - strategy, - backend, - pipeline_index=0, - user=user, - flow=SocialAuthState.FLOW_LOGIN, - ) - == {} - ) - else: - with pytest.raises(InvalidPasswordException): - user_actions.validate_password( - strategy, - backend, - pipeline_index=0, - user=user, - flow=SocialAuthState.FLOW_LOGIN, - ) - - -def test_user_password_not_login(mocker, rf, user): - """ - Tests that user_password performs denies authentication - for an existing user if password not provided regardless of auth_type - """ - user.set_password("abc123") - user.save() - request = rf.post("/complete/email", {"email": user.email}) - get_response = mocker.MagicMock() - middleware = SessionMiddleware(get_response) - middleware.process_request(request) - request.session.save() - strategy = load_strategy(request) - backend = load_backend(strategy, "email", None) - - with pytest.raises(RequirePasswordException): - user_actions.validate_password( - strategy, - backend, - pipeline_index=0, - user=user, - flow=SocialAuthState.FLOW_LOGIN, - ) - - -def test_user_password_not_exists(mocker, rf): - """Tests that user_password raises auth error for nonexistent user""" - request = rf.post( - "/complete/email", {"password": "abc123", "email": "doesntexist@localhost"} - ) - get_response = mocker.MagicMock() - middleware = SessionMiddleware(get_response) - middleware.process_request(request) - request.session.save() - strategy = load_strategy(request) - backend = load_backend(strategy, "email", None) - - with pytest.raises(RequireRegistrationException): - user_actions.validate_password( - strategy, - backend, - pipeline_index=0, - user=None, - flow=SocialAuthState.FLOW_LOGIN, - ) - - -def test_user_not_active(mocker, rf, user): - """Tests that an inactive user raises auth error, InvalidPasswordException""" - user.set_password("abc123") - user.is_active = False - user.save() - request = rf.post("/complete/email", {"password": "abc123", "email": user.email}) - get_response = mocker.MagicMock() - middleware = SessionMiddleware(get_response) - middleware.process_request(request) - request.session.save() - strategy = load_strategy(request) - backend = load_backend(strategy, "email", None) - - with pytest.raises(InvalidPasswordException): - user_actions.validate_password( - strategy, - backend, - pipeline_index=0, - user=user, - flow=SocialAuthState.FLOW_LOGIN, - ) - - -@pytest.mark.parametrize( - "backend_name,flow", # noqa: PT006 - [ - ("notemail", None), - ("notemail", SocialAuthState.FLOW_REGISTER), - ("notemail", SocialAuthState.FLOW_LOGIN), - (EmailAuth.name, None), - (EmailAuth.name, SocialAuthState.FLOW_LOGIN), - ], -) -def test_create_user_via_email_exit(mocker, backend_name, flow): - """ - Tests that create_user_via_email returns if not using the email backend and attempting the - 'register' step of the auth flow - """ - mock_strategy = mocker.Mock() - mock_backend = mocker.Mock() - mock_backend.name = backend_name - assert ( - user_actions.create_user_via_email( - mock_strategy, mock_backend, pipeline_index=0, flow=flow - ) - == {} - ) - - mock_strategy.request_data.assert_not_called() - - -@responses.activate -@pytest.mark.django_db -def test_create_user_via_email( - mocker, mock_email_backend, mock_create_user_strategy, settings -): - """ - Tests that create_user_via_email creates a user via social_core.pipeline.user.create_user_via_email - and sets a name and password - """ - responses.add( - responses.POST, - settings.OPENEDX_API_BASE_URL + OPENEDX_REGISTRATION_VALIDATION_PATH, - json={"validation_decisions": {"username": "", "email": ""}}, - status=status.HTTP_200_OK, - ) - email = "user@example.com" - response = user_actions.create_user_via_email( - mock_create_user_strategy, - mock_email_backend, - details=dict(email=email), # noqa: C408 - pipeline_index=0, - flow=SocialAuthState.FLOW_REGISTER, - ) - assert isinstance(response["user"], User) is True - assert response["user"].edx_username == "custom-username" - assert response["user"].is_active is True - assert response["username"] == "custom-username" - assert response["is_new"] is True - - -@pytest.mark.django_db -def test_create_user_via_email_no_data(mocker, mock_email_backend): - """Tests that create_user_via_email raises an error if no data for name and password provided""" - mock_strategy = mocker.Mock() - mock_strategy.request_data.return_value = {} - with pytest.raises(RequirePasswordAndPersonalInfoException): - user_actions.create_user_via_email( - mock_strategy, - mock_email_backend, - pipeline_index=0, - flow=SocialAuthState.FLOW_REGISTER, - ) - - -@pytest.mark.django_db -def test_create_user_via_email_with_shorter_name(mocker, mock_email_backend): - """Tests that create_user_via_email raises an error if name field is shorter than 2 characters""" - mock_strategy = mocker.Mock() - mock_strategy.request_data.return_value = { - "name": "a", - "password": "password1", - "username": "custom-username", - "legal_address": { - "country": "US", - }, - } - - with pytest.raises(RequirePasswordAndPersonalInfoException) as exc: - user_actions.create_user_via_email( - mock_strategy, - mock_email_backend, - details=dict(email="test@example.com"), # noqa: C408 - pipeline_index=0, - flow=SocialAuthState.FLOW_REGISTER, - ) - - assert exc.value.errors == ["Full name must be at least 2 characters long."] - - -@pytest.mark.django_db -def test_create_user_via_email_existing_user_raises( - user, mock_email_backend, mock_create_user_strategy -): - """Tests that create_user_via_email raises an error if a user already exists in the pipeline""" - with pytest.raises(UnexpectedExistingUserException): - user_actions.create_user_via_email( - mock_create_user_strategy, - mock_email_backend, - user=user, - pipeline_index=0, - flow=SocialAuthState.FLOW_REGISTER, - ) - - -def test_create_user_via_email_create_fail( - mocker, mock_email_backend, mock_create_user_strategy -): - """Tests that create_user_via_email raises an error if user creation fails""" - mock_serializer_obj = mocker.Mock() - mock_serializer_obj.is_valid = mocker.Mock(return_value=True) - mock_serializer_obj.save = mocker.Mock(side_effect=ValueError) - patched_user_serializer = mocker.patch( - "authentication.pipeline.user.UserSerializer", return_value=mock_serializer_obj - ) - with pytest.raises(UserCreationFailedException): - user_actions.create_user_via_email( - mock_create_user_strategy, - mock_email_backend, - details=dict(email="someuser@example.com"), # noqa: C408 - pipeline_index=0, - flow=SocialAuthState.FLOW_REGISTER, - ) - patched_user_serializer.assert_called_once() - - -def test_create_user_via_email_validation( - mocker, mock_email_backend, mock_create_user_strategy -): - """Tests that create_user_via_email raises an exception if serializer validation fails""" - mock_serializer_obj = mocker.Mock() - mock_serializer_obj.is_valid = mocker.Mock(return_value=False) - mock_serializer_obj.errors = { - "non_field_errors": ["non field error"], - "username": "Invalid username", - } - patched_user_serializer = mocker.patch( - "authentication.pipeline.user.UserSerializer", return_value=mock_serializer_obj - ) - with pytest.raises(RequirePasswordAndPersonalInfoException) as exc: - user_actions.create_user_via_email( - mock_create_user_strategy, - mock_email_backend, - details=dict(email="someuser@example.com"), # noqa: C408 - pipeline_index=0, - flow=SocialAuthState.FLOW_REGISTER, - ) - patched_user_serializer.assert_called_once() - assert exc.value.errors == mock_serializer_obj.errors["non_field_errors"] - assert exc.value.field_errors == {"username": "Invalid username"} - - -@pytest.mark.django_db -def test_create_user_via_email_unique( - mocker, mock_email_backend, mock_create_user_strategy -): - """Tests that create_user_via_email raises an exception the given username is not unique""" - email = "user@example.com" - username = mock_create_user_strategy.request_data.return_value["username"] - mock_serializer_obj = mocker.Mock() - mock_serializer_obj.is_valid = mocker.Mock(return_value=True) - mock_serializer_obj.save = mocker.Mock(side_effect=IntegrityError) - patched_user_serializer = mocker.patch( - "authentication.pipeline.user.UserSerializer", return_value=mock_serializer_obj - ) - with pytest.raises(RequirePasswordAndPersonalInfoException) as exc: - user_actions.create_user_via_email( - mock_create_user_strategy, - mock_email_backend, - details=dict(email=email), # noqa: C408 - pipeline_index=0, - flow=SocialAuthState.FLOW_REGISTER, - ) - patched_user_serializer.assert_called_once() - assert exc.value.field_errors == { - "username": f"The username '{username}' is already taken. Please try a different username." - } - - -@pytest.mark.parametrize("hijacked", [True, False]) -def test_forbid_hijack(mocker, hijacked): - """ - Tests that forbid_hijack action raises an exception if a user is hijacked - """ - mock_strategy = mocker.Mock() - mock_strategy.session_get.return_value = hijacked - - mock_backend = mocker.Mock(name="email") - - args = [mock_strategy, mock_backend] - kwargs = {"flow": SocialAuthState.FLOW_LOGIN} - - if hijacked: - with pytest.raises(ValueError): # noqa: PT011 - user_actions.forbid_hijack(*args, **kwargs) - else: - assert user_actions.forbid_hijack(*args, **kwargs) == {} - - -@pytest.mark.parametrize("raises_error", [True, False]) -@pytest.mark.parametrize( - "is_active, is_new, creates_records", # noqa: PT006 - [ - [True, True, True], # noqa: PT007 - [True, False, False], # noqa: PT007 - [False, True, False], # noqa: PT007 - [False, False, False], # noqa: PT007 - ], -) -def test_create_openedx_user( # noqa: PLR0913 - mocker, user, raises_error, is_active, is_new, creates_records -): # pylint: disable=too-many-arguments - """Test that activate_user takes the correct action""" - user.is_active = is_active - - mock_create_user_api = mocker.patch( - "authentication.pipeline.user.openedx_api.create_user" - ) - if raises_error: - mock_create_user_api.side_effect = Exception("error") - mock_create_user_task = mocker.patch( - "authentication.pipeline.user.openedx_tasks.create_user_from_id" - ) - - assert user_actions.create_openedx_user(None, None, user=user, is_new=is_new) == {} - - if creates_records: - mock_create_user_api.assert_called_once_with(user) - - if raises_error: - mock_create_user_task.apply_async.assert_called_once_with( - (user.id,), countdown=60 - ) - else: - mock_create_user_task.apply_async.assert_not_called() - else: - mock_create_user_api.assert_not_called() - mock_create_user_task.apply_async.assert_not_called() - - -@pytest.mark.parametrize( - "backend_name,flow,data", # noqa: PT006 - [ - ("notemail", SocialAuthState.FLOW_REGISTER, {}), - ("notemail", SocialAuthState.FLOW_LOGIN, dict(email="test@example.com")), # noqa: C408 - ], -) -def test_validate_email_backend(mocker, backend_name, flow, data): - """Tests validate_email with data and flows""" - mock_strategy = mocker.Mock() - mock_backend = mocker.Mock() - mock_backend.name = backend_name - mock_strategy.request_data.return_value = data - assert ( - user_actions.validate_email( - mock_strategy, mock_backend, pipeline_index=0, flow=flow - ) - == {} - ) - - mock_strategy.request_data.assert_called_once() - - -@pytest.mark.django_db -def test_create_user_when_email_blocked(mocker): - """Tests that validate_email raises an error if user email is blocked""" - mock_strategy = mocker.Mock() - mock_email_backend = mocker.Mock() - mock_strategy.request_data.return_value = { - "email": "test@example.com", - "flow": "register", - } - mocker.patch( - "authentication.pipeline.user.is_user_email_blocked", return_value=True - ) - with pytest.raises(EmailBlockedException): - user_actions.validate_email( - mock_strategy, - mock_email_backend, - pipeline_index=0, - flow=SocialAuthState.FLOW_REGISTER, - ) - - -@pytest.mark.django_db -@pytest.mark.parametrize( - ( - "new_user_login", - "use_backend", - ), - [ - (False, "email"), - (False, "oidc"), - (True, "oidc"), - ], -) -def test_create_ol_oidc_user( # noqa: PLR0913 - mocker, - new_user_login, - use_backend, - mock_email_backend, - mock_ol_oidc_backend, - mock_create_user_strategy, -): - """Tests that create_ol_oidc_user creates a new user for an OIDC login""" - - backend = mock_email_backend if use_backend == "email" else mock_ol_oidc_backend - user_global_id = FAKE.uuid4() - base_details = { - "email": "admin@odl.local", - "global_id": user_global_id, - "is_active": True, - "name": "Test Admin", - "username": "admin@odl.local", - } - details = { - **base_details, - "fullname": "Test Admin", - "last_name": "Admin", - "profile": { - "email_optin": None, - "name": "Test Admin", - }, - } - - user = None if new_user_login else UserFactory.create(**base_details) - - strategy = mock_create_user_strategy - strategy.request_data.return_value = { - **strategy.request_data.return_value, - "global_id": user_global_id, - "is_active": True, - } - - if new_user_login: - strategy.create_user = mocker.Mock( - side_effect=lambda *args, **kwargs: UserFactory.create( # noqa: ARG005 - password="fake password", # noqa: S106 - **base_details, - ) - ) - - with pytest.raises(ObjectDoesNotExist): - User.objects.get(global_id=user_global_id) - - response = user_actions.create_ol_oidc_user( - strategy, details, backend, user, pipeline_index=0 - ) - - if use_backend == "oidc": - if new_user_login: - assert response["is_new"] - assert response["user"].global_id == user_global_id - assert strategy.create_user.called - assert User.objects.get(global_id=user_global_id) - else: - assert not response["is_new"] - assert not strategy.create_user.called - else: - assert response == {} diff --git a/authentication/strategy.py b/authentication/strategy.py deleted file mode 100644 index 636d05a7c4..0000000000 --- a/authentication/strategy.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Custom strategy""" - -from social_django.strategy import DjangoStrategy - - -class DjangoRestFrameworkStrategy(DjangoStrategy): - """Strategy specific to handling DRF requests""" - - def __init__(self, storage, drf_request=None, tpl=None): - self.drf_request = drf_request - # pass the original django request to DjangoStrategy - request = drf_request._request # pylint: disable=protected-access # noqa: SLF001 - super().__init__(storage, request=request, tpl=tpl) - - def request_data(self, merge=True): # noqa: FBT002, ARG002 - """Returns the request data""" - if not self.drf_request: - return {} - - # DRF stores json payload data here, not in request.POST or request.GET like PSA expects - return self.drf_request.data diff --git a/authentication/strategy_test.py b/authentication/strategy_test.py deleted file mode 100644 index eaf6a32836..0000000000 --- a/authentication/strategy_test.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Tests for the strategy""" - -from django.http import HttpRequest -from rest_framework.request import Request - -from authentication.utils import load_drf_strategy - - -def test_strategy_init(mocker): - """Test that the constructor works as expected""" - drf_request = mocker.Mock() - strategy = load_drf_strategy(request=drf_request) - assert strategy.drf_request == drf_request - assert strategy.request == drf_request._request # pylint: disable=protected-access # noqa: SLF001 - - -def test_strategy_request_data(mocker): - """Tests that the strategy request_data correctly returns the DRF request data""" - drf_request = mocker.Mock() - strategy = load_drf_strategy(request=drf_request) - assert strategy.request_data() == drf_request.data - - -def test_strategy_clean_authenticate_args(mocker): - """Tests that the strategy clean_authenticate_args moves the request to kwargs""" - # NOTE: don't pass this to load_drf_Strategy, it will error - drf_request = Request(mocker.Mock(spec=HttpRequest)) - strategy = load_drf_strategy(mocker.Mock()) - assert strategy.clean_authenticate_args(drf_request, 2, 3, kwarg1=1, kwarg2=2) == ( - (2, 3), - {"request": drf_request, "kwarg1": 1, "kwarg2": 2}, - ) diff --git a/authentication/utils.py b/authentication/utils.py index e54ef4896f..7a0b17566c 100644 --- a/authentication/utils.py +++ b/authentication/utils.py @@ -2,73 +2,9 @@ import hashlib -from social_core.utils import get_strategy -from social_django.utils import STORAGE - from users.models import BlockList -class SocialAuthState: # pylint: disable=too-many-instance-attributes - """Social auth state""" - - FLOW_REGISTER = "register" - FLOW_LOGIN = "login" - - # login states - STATE_LOGIN_EMAIL = "login/email" - STATE_LOGIN_PASSWORD = "login/password" # noqa: S105 - STATE_LOGIN_PROVIDER = "login/provider" - - # registration states - STATE_REGISTER_EMAIL = "register/email" - STATE_REGISTER_CONFIRM_SENT = "register/confirm-sent" - STATE_REGISTER_CONFIRM = "register/confirm" - STATE_REGISTER_DETAILS = "register/details" - STATE_REGISTER_REQUIRED = "register/required" - - # end states - STATE_SUCCESS = "success" - STATE_ERROR = "error" - STATE_ERROR_TEMPORARY = "error-temporary" - STATE_INACTIVE = "inactive" - STATE_INVALID_EMAIL = "invalid-email" - STATE_USER_BLOCKED = "user-blocked" - STATE_INVALID_LINK = "invalid-link" - STATE_EXISTING_ACCOUNT = "existing-account" - - def __init__( # noqa: PLR0913 - self, - state, - *, - provider=None, - partial=None, - flow=None, - errors=None, - field_errors=None, - redirect_url=None, - user=None, - ): # pylint: disable=too-many-arguments - self.state = state - self.partial = partial - self.flow = flow - self.provider = provider - self.errors = errors or [] - self.field_errors = field_errors or {} - self.redirect_url = redirect_url - self.user = user - - def get_partial_token(self): - """Return the partial token or None""" - return self.partial.token if self.partial else None - - -def load_drf_strategy(request=None): - """Returns the DRF strategy""" - return get_strategy( - "authentication.strategy.DjangoRestFrameworkStrategy", STORAGE, request - ) - - def get_md5_hash(value): """Returns the md5 hash object for the given value""" return hashlib.md5(value.lower().encode("utf-8")) # noqa: S324 diff --git a/authentication/utils_test.py b/authentication/utils_test.py deleted file mode 100644 index 7fc4056eb2..0000000000 --- a/authentication/utils_test.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Utils tests""" - -from authentication.strategy import DjangoRestFrameworkStrategy -from authentication.utils import load_drf_strategy - - -def test_load_drf_strategy(mocker): - """Test that load_drf_strategy returns a DjangoRestFrameworkStrategy instance""" - assert isinstance(load_drf_strategy(mocker.Mock()), DjangoRestFrameworkStrategy) diff --git a/b2b/management/commands/b2b_courseware.py b/b2b/management/commands/b2b_courseware.py index 2ee6d6b992..33881f2e1e 100644 --- a/b2b/management/commands/b2b_courseware.py +++ b/b2b/management/commands/b2b_courseware.py @@ -143,7 +143,7 @@ def handle_add(self, contract, coursewares, **kwargs): # noqa: PLR0915, C901 create_runs = kwargs.pop("create_runs") force_associate = kwargs.pop("force") - can_import = kwargs.pop("import") + can_import = kwargs.pop("can_import") managed = skipped = 0 @@ -171,6 +171,7 @@ def handle_add(self, contract, coursewares, **kwargs): # noqa: PLR0915, C901 course_run_id=importable_id, departments=can_import.split(sep=","), create_cms_page=True, + create_depts=True, ) if not imported_run: diff --git a/b2b/management/commands/b2b_list.py b/b2b/management/commands/b2b_list.py index 9e5b1feaea..0e4db704a3 100644 --- a/b2b/management/commands/b2b_list.py +++ b/b2b/management/commands/b2b_list.py @@ -293,7 +293,9 @@ def handle_list_courseware(self, *args, **kwargs): # noqa: ARG002 if contract_id: contract_page_qs = contract_page_qs.filter(id=contract_id) - contracts = contract_page_qs.prefetch_related("programs", "course_runs").all() + contracts = contract_page_qs.prefetch_related( + "contract_programs__program", "course_runs" + ).all() courseware_table = Table(title="Courseware") courseware_table.add_column("ID", justify="right") diff --git a/cms/management/commands/backfill_certificate_page_revision.py b/cms/management/commands/backfill_certificate_page_revision.py new file mode 100644 index 0000000000..25fa457b2b --- /dev/null +++ b/cms/management/commands/backfill_certificate_page_revision.py @@ -0,0 +1,153 @@ +import csv +import json +from pathlib import Path + +from django.core.management.base import BaseCommand +from wagtail.blocks import StreamValue + +from cms.models import CertificatePage, SignatoryPage + + +class Command(BaseCommand): + """Django management command to Wagtail CertificatePage revisions using a CSV file""" + + help = "Create backfilled revisions for CertificatePage objects using CSV input" + + def add_arguments(self, parser): + """Parses command line arguments.""" + parser.add_argument( + "--csv-file", + type=str, + help="Path to the CSV file containing signatory and certificate page information", + required=True, + ) + + def handle(self, *args, **options): # noqa: ARG002 + """Handles the command execution.""" + csv_file_path = options["csv_file"] + + try: + with Path(csv_file_path).open("r", newline="", encoding="utf-8") as csvfile: + reader = csv.DictReader(csvfile) + for row_num, row in enumerate(reader, start=2): # Start at 2 for header + self.process_certificate_page_revision(row, row_num) + + except FileNotFoundError: + self.stdout.write(self.style.ERROR(f"CSV file not found: {csv_file_path}")) + + def revision_has_same_signatories(self, certificate_revision, signatory_blocks): + page_obj = certificate_revision.as_object() + rev_signatory_ids = [child.value.id for child in page_obj.signatories] + new_signatory_ids = [sp.id for _, sp in signatory_blocks] + return rev_signatory_ids == new_signatory_ids + + def process_certificate_page_revision(self, row, row_num): + certificate_page_id = row.get("certificate_page_id", "").strip() + signatories_json = row.get("signatory_names", "").strip() + + if not certificate_page_id: + self.stderr.write( + self.style.ERROR(f"Row {row_num}: Missing certificate_page_id") + ) + return + + if not signatories_json: + self.stderr.write( + self.style.ERROR(f"Row {row_num}: Missing signatory_names JSON array") + ) + return + + try: + # Example: a list of pairs, e.g. [["Name1", "Name2"], ["Name3", "Name4"]] + signatory_pairs = json.loads(signatories_json) + except Exception: # noqa: BLE001 + self.stderr.write( + self.style.ERROR( + f"Row {row_num}: signatory_names is not valid JSON: {signatories_json}" + ) + ) + return + + # Load the CertificatePage + certificate_page = CertificatePage.objects.get(id=certificate_page_id).specific + + # Copy the original live revision to restore later + original_live_revision = certificate_page.get_latest_revision() + + backfill_created = False + for index, signatory_names in enumerate(signatory_pairs, start=1): + if not isinstance(signatory_names, list) or not signatory_names: + self.stderr.write( + self.style.ERROR( + f"Row {row_num}: Pair #{index} must be a non-empty list of names: {signatory_names}" + ) + ) + continue + + signatory_pages = list( + SignatoryPage.objects.filter(name__in=signatory_names) + ) + found_names = {s.name for s in signatory_pages} + missing = [n for n in signatory_names if n not in found_names] + + if missing: + self.stderr.write( + self.style.WARNING( + f"Row {row_num} Pair {index}: Missing SignatoryPage(s): {missing}" + ) + ) + continue + + signatory_blocks = [ + ("signatory", signatory_page) for signatory_page in signatory_pages + ] + + backfill_signatories = StreamValue( + certificate_page.signatories.stream_block, + signatory_blocks, + is_lazy=False, + ) + + if any( + self.revision_has_same_signatories(revision, signatory_blocks) + for revision in certificate_page.revisions.all() + ): + self.stdout.write( + self.style.WARNING( + f"Row {row_num} Pair {index}: Identical revision already exists" + ) + ) + continue + + # Apply new signatories + certificate_page.signatories = backfill_signatories + + # create the backfilled historical revision + backfill_revision = certificate_page.save_revision( + changed=True, + log_action="wagtail.edit", + ) + backfill_revision.publish() + + backfill_created = True + self.stdout.write( + self.style.SUCCESS( + f"Row {row_num} Pair {index}: Created revision {backfill_revision.id} for CertificatePage {certificate_page_id}" + ) + ) + + if backfill_created: + # Restore original live revision + restored_page = original_live_revision.as_object() + restored_page.pk = certificate_page.pk + restored_revision = restored_page.save_revision( + changed=False, + log_action="wagtail.revert", + ) + restored_revision.publish() + + self.stdout.write( + self.style.SUCCESS( + f"Row {row_num}: Restored original live revision {restored_revision.id} for CertificatePage {certificate_page_id}" + ) + ) diff --git a/cms/models.py b/cms/models.py index 69abb46553..fff0b4671f 100644 --- a/cms/models.py +++ b/cms/models.py @@ -459,7 +459,7 @@ def get_context(self, request, *args, **kwargs): "uuid": "fake-uuid", "learner_name": "Anthony M. Stark", "product_name": product_name, - "issue_date": self.issue_date, + "end_date": datetime.now(), # noqa: DTZ005 "CEUs": self.CEUs, "is_program_certificate": is_program_certificate, } diff --git a/courses/api_test.py b/courses/api_test.py index 02276c8955..f250255bbe 100644 --- a/courses/api_test.py +++ b/courses/api_test.py @@ -925,7 +925,9 @@ def test_sync_course_mode(settings, mocker, mocked_api_response, expect_success) [1.0, False, True, False, False, False], # noqa: PT007 ], ) +@patch("courses.signals.upsert_custom_properties") def test_course_run_certificate( # noqa: PLR0913 + mock_upsert_custom_properties, user, passed_grade_with_enrollment, grade, @@ -943,7 +945,7 @@ def test_course_run_certificate( # noqa: PLR0913 "hubspot_sync.task_helpers.sync_hubspot_user", ) mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) passed_grade_with_enrollment.grade = grade passed_grade_with_enrollment.passed = passed @@ -963,7 +965,10 @@ def test_course_run_certificate( # noqa: PLR0913 assert deleted is exp_deleted -def test_course_run_certificate_idempotent(passed_grade_with_enrollment, mocker, user): +@patch("courses.signals.upsert_custom_properties") +def test_course_run_certificate_idempotent( + mock_upsert_custom_properties, passed_grade_with_enrollment, mocker, user +): """ Test that the certificate generation is idempotent """ @@ -971,7 +976,7 @@ def test_course_run_certificate_idempotent(passed_grade_with_enrollment, mocker, "hubspot_sync.task_helpers.sync_hubspot_user", ) mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) # Certificate is created the first time certificate, created, deleted = process_course_run_grade_certificate( @@ -992,12 +997,15 @@ def test_course_run_certificate_idempotent(passed_grade_with_enrollment, mocker, assert not deleted -def test_course_run_certificate_not_passing(passed_grade_with_enrollment, mocker): +@patch("courses.signals.upsert_custom_properties") +def test_course_run_certificate_not_passing( + mock_upsert_custom_properties, passed_grade_with_enrollment, mocker +): """ Test that the certificate is not generated if the grade is set to not passed """ mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) # Initially the certificate is created certificate, created, deleted = process_course_run_grade_certificate( @@ -1039,8 +1047,12 @@ def test_generate_course_certificates_no_valid_course_run(settings, courses_api_ ) +@patch("courses.signals.upsert_custom_properties") def test_generate_course_certificates_self_paced_course( - mocker, courses_api_logs, passed_grade_with_enrollment + mock_upsert_custom_properties, + mocker, + courses_api_logs, + passed_grade_with_enrollment, ): """Test that certificates are generated for self paced course runs independent of course run end date""" course_run = passed_grade_with_enrollment.course_run @@ -1048,7 +1060,7 @@ def test_generate_course_certificates_self_paced_course( course_run.is_self_paced = True course_run.save() mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) mocker.patch( "courses.api.ensure_course_run_grade", @@ -1073,7 +1085,9 @@ def test_generate_course_certificates_self_paced_course( (False, None), ], ) +@patch("courses.signals.upsert_custom_properties") def test_course_certificates_with_course_end_date_self_paced_combination( # noqa: PLR0913 + mock_upsert_custom_properties, mocker, settings, courses_api_logs, @@ -1093,7 +1107,7 @@ def test_course_certificates_with_course_end_date_self_paced_combination( # noq "hubspot_sync.task_helpers.sync_hubspot_user", ) mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) mocker.patch( "courses.api.exception_logging_generator", @@ -1112,8 +1126,13 @@ def test_course_certificates_with_course_end_date_self_paced_combination( # noq ) +@patch("courses.signals.upsert_custom_properties") def test_generate_course_certificates_with_course_end_date( - mocker, courses_api_logs, passed_grade_with_enrollment, settings + mock_upsert_custom_properties, + mocker, + courses_api_logs, + passed_grade_with_enrollment, + settings, ): """Test that certificates are generated for passed grades when there are valid course runs for certificates""" course_run = passed_grade_with_enrollment.course_run @@ -1122,7 +1141,7 @@ def test_generate_course_certificates_with_course_end_date( user = passed_grade_with_enrollment.user mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) mocker.patch( "courses.api.ensure_course_run_grade", @@ -1139,10 +1158,11 @@ def test_generate_course_certificates_with_course_end_date( ) -def test_course_run_certificates_access(mocker): +@patch("courses.signals.upsert_custom_properties") +def test_course_run_certificates_access(mock_upsert_custom_properties, mocker): """Tests that the revoke and unrevoke for a course run certificates sets the states properly""" mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) test_certificate = CourseRunCertificateFactory.create(is_revoked=False) @@ -1264,7 +1284,9 @@ def test_generate_program_certificate_failure_missing_certificates( assert len(ProgramCertificate.objects.all()) == 0 +@patch("courses.signals.upsert_custom_properties") def test_generate_program_certificate_failure_not_all_passed( + mock_upsert_custom_properties, user, program_with_requirements, # noqa: F811 mocker, @@ -1274,7 +1296,7 @@ def test_generate_program_certificate_failure_not_all_passed( if there is not any course_run certificate for the given course. """ mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) courses = CourseFactory.create_batch(3) course_runs = CourseRunFactory.create_batch(3, course=factory.Iterator(courses)) @@ -1291,7 +1313,10 @@ def test_generate_program_certificate_failure_not_all_passed( assert len(ProgramCertificate.objects.all()) == 0 -def test_generate_program_certificate_success_single_requirement_course(user, mocker): +@patch("courses.signals.upsert_custom_properties") +def test_generate_program_certificate_success_single_requirement_course( + mock_upsert_custom_properties, user, mocker +): """ Test that generate_program_certificate generates a program certificate for a Program with a single required Course. """ @@ -1299,7 +1324,7 @@ def test_generate_program_certificate_success_single_requirement_course(user, mo "hubspot_sync.task_helpers.sync_hubspot_user", ) mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) course = CourseFactory.create() program = ProgramFactory.create() @@ -1323,7 +1348,10 @@ def test_generate_program_certificate_success_single_requirement_course(user, mo patched_sync_hubspot_user.assert_called_once_with(user) -def test_generate_program_certificate_success_multiple_required_courses(user, mocker): +@patch("courses.signals.upsert_custom_properties") +def test_generate_program_certificate_success_multiple_required_courses( + mock_upsert_custom_properties, user, mocker +): """ Test that generate_program_certificate generate a program certificate """ @@ -1331,7 +1359,7 @@ def test_generate_program_certificate_success_multiple_required_courses(user, mo "hubspot_sync.task_helpers.sync_hubspot_user", ) mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) courses = CourseFactory.create_batch(3) program = ProgramFactory.create() @@ -1356,13 +1384,16 @@ def test_generate_program_certificate_success_multiple_required_courses(user, mo patched_sync_hubspot_user.assert_called_once_with(user) -def test_generate_program_certificate_success_minimum_electives_not_met(user, mocker): +@patch("courses.signals.upsert_custom_properties") +def test_generate_program_certificate_success_minimum_electives_not_met( + mock_upsert_custom_properties, user, mocker +): """ Test that generate_program_certificate does not generate a program certificate if minimum electives have not been met. """ courses = CourseFactory.create_batch(3) mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) # Create Program with 2 minimum elective courses. @@ -1404,7 +1435,9 @@ def test_generate_program_certificate_success_minimum_electives_not_met(user, mo assert len(ProgramCertificate.objects.all()) == 0 +@patch("courses.signals.upsert_custom_properties") def test_force_generate_program_certificate_success( + mock_upsert_custom_properties, user, program_with_requirements, # noqa: F811 mocker, @@ -1417,7 +1450,7 @@ def test_force_generate_program_certificate_success( "hubspot_sync.task_helpers.sync_hubspot_user", ) mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) courses = CourseFactory.create_batch(3) course_runs = CourseRunFactory.create_batch(3, course=factory.Iterator(courses)) @@ -1480,7 +1513,9 @@ def test_program_certificates_access(): assert test_certificate.is_revoked is False +@patch("courses.signals.upsert_custom_properties") def test_generate_program_certificate_failure_not_all_passed_nested_elective_stipulation( + mock_upsert_custom_properties, user, mocker, ): @@ -1491,7 +1526,7 @@ def test_generate_program_certificate_failure_not_all_passed_nested_elective_sti courses = CourseFactory.create_batch(3) course_runs = CourseRunFactory.create_batch(3, course=factory.Iterator(courses)) mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) # Create Program program = ProgramFactory.create() @@ -1566,7 +1601,10 @@ def test_program_enrollment_unenrollment_re_enrollment( ).exists() -def test_generate_program_certificate_with_subprogram_requirement(user, mocker): +@patch("courses.signals.upsert_custom_properties") +def test_generate_program_certificate_with_subprogram_requirement( + mock_upsert_custom_properties, user, mocker +): """ Test that generate_program_certificate considers sub-program (nested program) requirements when determining if a user has earned a program certificate. @@ -1575,7 +1613,7 @@ def test_generate_program_certificate_with_subprogram_requirement(user, mocker): "hubspot_sync.task_helpers.sync_hubspot_user", ) mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) # Create a sub-program that the user will complete @@ -1620,7 +1658,7 @@ def test_generate_program_certificate_with_subprogram_requirement_missing_certif sub-program certificate is missing. """ mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) # Create a sub-program @@ -1643,13 +1681,16 @@ def test_generate_program_certificate_with_subprogram_requirement_missing_certif assert len(ProgramCertificate.objects.all()) == 0 -def test_generate_program_certificate_with_revoked_subprogram_certificate(user, mocker): +@patch("courses.signals.upsert_custom_properties") +def test_generate_program_certificate_with_revoked_subprogram_certificate( + mock_upsert_custom_properties, user, mocker +): """ Test that generate_program_certificate does NOT consider revoked sub-program certificates when determining if a user has earned a program certificate. """ mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) # Create a sub-program diff --git a/courses/management/commands/migrate_edx_data.py b/courses/management/commands/migrate_edx_data.py index b20cf7ad6c..7153a0cf1a 100644 --- a/courses/management/commands/migrate_edx_data.py +++ b/courses/management/commands/migrate_edx_data.py @@ -193,9 +193,18 @@ def _migrate_course_runs(self, conn, options): continue signatories = self._get_signatories( - row.get("signatory_names", []), use_default_signatory + row.get("signatory_names") or [], use_default_signatory ) + if not signatories and row.get("mitxonline_course_id") is None: + self.stdout.write( + self.style.ERROR( + f"No valid signatories found with names {row.get('signatory_names')} for course " + f"{row.get('course_readable_id')}, skipping it." + ) + ) + continue + (course, course_created) = self._create_course(row) if course_created: diff --git a/courses/management/commands/test_manage_certificate.py b/courses/management/commands/test_manage_certificate.py index be57fa1d27..8e7b8e4911 100644 --- a/courses/management/commands/test_manage_certificate.py +++ b/courses/management/commands/test_manage_certificate.py @@ -125,7 +125,7 @@ def test_certificate_management_revoke_unrevoke_invalid_args( def test_certificate_management_revoke_unrevoke_success(user, revoke, unrevoke, mocker): """Test that certificate revoke, un-revoke work as expected and manage the certificate access properly""" mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) course_run = CourseRunFactory.create() certificate = CourseRunCertificateFactory( @@ -150,7 +150,7 @@ def test_certificate_management_create(mocker, user, edx_grade_json, revoked): when a user is provided """ mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) edx_grade = CurrentGrade(edx_grade_json) course_run = CourseRunFactory.create() @@ -192,7 +192,7 @@ def test_certificate_management_create_no_user(mocker, edx_grade_json, user): enrolled users in a run when no user is provided """ mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) passed_edx_grade = CurrentGrade(edx_grade_json) course_run = CourseRunFactory.create() diff --git a/courses/management/commands/test_manage_program_certificate.py b/courses/management/commands/test_manage_program_certificate.py index 1d47420252..ae02cabd10 100644 --- a/courses/management/commands/test_manage_program_certificate.py +++ b/courses/management/commands/test_manage_program_certificate.py @@ -130,7 +130,7 @@ def test_program_certificate_management_create( creates the program certificate for a user """ mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) courses = CourseFactory.create_batch(2) program_with_empty_requirements.add_requirement(courses[0]) @@ -162,7 +162,7 @@ def test_program_certificate_management_force_create( forcefully creates the certificate for a user """ mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) courses = CourseFactory.create_batch(3) course_runs = CourseRunFactory.create_batch(3, course=factory.Iterator(courses)) diff --git a/courses/models_test.py b/courses/models_test.py index bfce94a114..89c0321383 100644 --- a/courses/models_test.py +++ b/courses/models_test.py @@ -420,7 +420,7 @@ def test_course_run_certificate_start_end_dates_and_page_revision(mocker): Test that the CourseRunCertificate start_end_dates property works properly """ mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) certificate = CourseRunCertificateFactory.create( course_run__course__page__certificate_page__product_name="product_name" @@ -441,7 +441,7 @@ def test_program_certificate_start_end_dates_and_page_revision(user, mocker): The end date is the date the user received the program certificate. """ mocker.patch( - "hubspot_sync.management.commands.configure_hubspot_properties._upsert_custom_properties", + "hubspot_sync.api.upsert_custom_properties", ) now = now_in_utc() start_date = now + timedelta(days=1) diff --git a/courses/signals.py b/courses/signals.py index 0ed43b0913..6e9e07581f 100644 --- a/courses/signals.py +++ b/courses/signals.py @@ -2,6 +2,8 @@ Signals for mitxonline course certificates """ +import logging + from django.db import transaction from django.db.models.signals import post_save from django.dispatch import receiver @@ -11,6 +13,7 @@ CourseRunCertificate, Program, ) +from hubspot_sync.api import upsert_custom_properties from hubspot_sync.task_helpers import sync_hubspot_user @@ -40,4 +43,11 @@ def handle_create_course_run_certificate( transaction.on_commit( lambda: generate_multiple_programs_certificate(user, programs) ) - sync_hubspot_user(instance.user) + + try: + upsert_custom_properties() + sync_hubspot_user(instance.user) + except Exception: # pylint: disable=broad-except + logger = logging.getLogger(__name__) + logger.exception("Error syncing Hubspot user") + # avoid blocking certificate creation diff --git a/ecommerce/api.py b/ecommerce/api.py index 177e9189f6..97a4735dbb 100644 --- a/ecommerce/api.py +++ b/ecommerce/api.py @@ -253,7 +253,14 @@ def apply_user_discounts(request): if BasketDiscount.objects.filter(redeemed_basket=basket).count() > 0: return - product = BasketItem.objects.get(basket=basket).product + # For multiple items, check each item for flexible pricing discounts + basket_items = BasketItem.objects.filter(basket=basket) + if basket_items.count() == 0: + return + + # Use the first item's product for flexible pricing determination + # This maintains backward compatibility while supporting multiple items + product = basket_items.first().product flexible_price_discount = determine_courseware_flexible_price_discount( product, user ) diff --git a/ecommerce/migrations/0022_backfill_reference_number.py b/ecommerce/migrations/0022_backfill_reference_number.py index 6a301e2dcc..aa1dad8c2e 100644 --- a/ecommerce/migrations/0022_backfill_reference_number.py +++ b/ecommerce/migrations/0022_backfill_reference_number.py @@ -2,7 +2,7 @@ from django.conf import settings from django.db import migrations -from django.db.models import F, Value +from django.db.models import CharField, F, Value from django.db.models.functions import Concat from ecommerce.constants import REFERENCE_NUMBER_PREFIX @@ -16,6 +16,7 @@ def backfill_order_reference_number(apps, schema_editor): Value(settings.ENVIRONMENT), Value("-"), F("id"), + output_field=CharField(), ) ) diff --git a/ecommerce/models.py b/ecommerce/models.py index c51fe4b419..762e998729 100644 --- a/ecommerce/models.py +++ b/ecommerce/models.py @@ -198,7 +198,7 @@ def get_products(self): Returns the products that have been added to the basket so far. """ - return [item.product for item in self.basket_items.all()] + return [item.product for item in self.basket_items.select_related("product")] class BasketItem(TimestampedModel): diff --git a/ecommerce/serializers.py b/ecommerce/serializers.py index bf0b49f873..4216f6bcee 100644 --- a/ecommerce/serializers.py +++ b/ecommerce/serializers.py @@ -178,7 +178,7 @@ def get_basket_items(self, instance): """Get items in the basket""" return [ BasketItemSerializer(instance=basket, context=self.context).data - for basket in instance.basket_items.all() + for basket in instance.basket_items.select_related("product") ] class Meta: @@ -263,14 +263,15 @@ def get_basket_items(self, instance) -> list[dict[str, any]]: """ return [ BasketItemWithProductSerializer(instance=basket, context=self.context).data - for basket in instance.basket_items.all() + for basket in instance.basket_items.select_related("product") ] @extend_schema_field(Decimal) def get_total_price(self, instance) -> Decimal: """Get total price of all items in basket before discounts""" return sum( - basket_item.base_price for basket_item in instance.basket_items.all() + basket_item.base_price + for basket_item in instance.basket_items.select_related("product") ) @extend_schema_field(Decimal) @@ -280,7 +281,8 @@ def get_discounted_price(self, instance) -> Decimal: if discounts.count() == 0: return self.get_total_price(instance) return sum( - basket_item.discounted_price for basket_item in instance.basket_items.all() + basket_item.discounted_price + for basket_item in instance.basket_items.select_related("product") ) @extend_schema_field(list[BasketDiscountSerializer]) diff --git a/ecommerce/tests/test_multiple_cart_items.py b/ecommerce/tests/test_multiple_cart_items.py new file mode 100644 index 0000000000..aec737d917 --- /dev/null +++ b/ecommerce/tests/test_multiple_cart_items.py @@ -0,0 +1,136 @@ +"""Tests for multiple cart items functionality""" + +import pytest +from django.test import override_settings +from django.urls import reverse +from rest_framework import status + +from ecommerce.factories import BasketFactory, BasketItemFactory, ProductFactory + + +@pytest.mark.django_db +class TestMultipleCartItems: + """Test class for multiple cart items functionality""" + + @override_settings(ENABLE_MULTIPLE_CART_ITEMS=False) + def test_add_to_cart_single_item_mode_default(self, user_drf_client, user): + """Test that with feature flag disabled, adding items replaces existing items""" + # Create a product and basket with existing item + existing_product = ProductFactory.create() + new_product = ProductFactory.create() + + basket = BasketFactory.create(user=user) + BasketItemFactory.create(basket=basket, product=existing_product) + + assert basket.basket_items.count() == 1 + + # Add new product to cart + response = user_drf_client.post( + reverse("checkout_api-add_to_cart"), + data={"product_id": new_product.id}, + ) + + assert response.status_code == status.HTTP_200_OK + assert response.data["message"] == "Product added to cart" + + # Verify that only the new product is in the basket + basket.refresh_from_db() + assert basket.basket_items.count() == 1 + assert basket.basket_items.first().product == new_product + + @override_settings(ENABLE_MULTIPLE_CART_ITEMS=True) + def test_add_to_cart_multiple_items_mode_new_product(self, user_drf_client, user): + """Test that with feature flag enabled, adding new items keeps existing items""" + # Create products and basket with existing item + existing_product = ProductFactory.create() + new_product = ProductFactory.create() + + basket = BasketFactory.create(user=user) + BasketItemFactory.create(basket=basket, product=existing_product) + + assert basket.basket_items.count() == 1 + + # Add new product to cart + response = user_drf_client.post( + reverse("checkout_api-add_to_cart"), + data={"product_id": new_product.id}, + ) + + assert response.status_code == status.HTTP_200_OK + assert response.data["message"] == "Product added to cart" + + # Verify that both products are in the basket + basket.refresh_from_db() + assert basket.basket_items.count() == 2 + products_in_basket = {item.product for item in basket.basket_items.all()} + assert products_in_basket == {existing_product, new_product} + + @override_settings(ENABLE_MULTIPLE_CART_ITEMS=True) + def test_add_to_cart_nonexistent_product(self, user_drf_client): + """Test that adding a non-existent product returns 404""" + # Add non-existent product to cart + response = user_drf_client.post( + reverse("checkout_api-add_to_cart"), + data={"product_id": 99999}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.data["message"] == "Product not found" + + @override_settings(ENABLE_MULTIPLE_CART_ITEMS=True) + def test_cart_with_multiple_items_pricing(self, user_drf_client, user): + """Test that cart pricing works correctly with multiple items""" + # Create products with different prices + product1 = ProductFactory.create(price=50.00) + product2 = ProductFactory.create(price=30.00) + + basket = BasketFactory.create(user=user) + BasketItemFactory.create(basket=basket, product=product1, quantity=2) + BasketItemFactory.create(basket=basket, product=product2, quantity=1) + + # Get cart info + response = user_drf_client.get(reverse("checkout_api-cart")) + + assert response.status_code == status.HTTP_200_OK + cart_data = response.data + + # Verify total price calculation (2 * $50 + 1 * $30 = $130) + assert float(cart_data["total_price"]) == 130.00 + assert len(cart_data["basket_items"]) == 2 + + def test_basket_items_count_endpoint(self, user_drf_client, user): + """Test that basket items count endpoint works with multiple items""" + # Create products and add to basket + product1 = ProductFactory.create() + product2 = ProductFactory.create() + + basket = BasketFactory.create(user=user) + BasketItemFactory.create(basket=basket, product=product1, quantity=2) + BasketItemFactory.create(basket=basket, product=product2, quantity=1) + + # Get basket items count + response = user_drf_client.get(reverse("checkout_api-basket_items_count")) + + assert response.status_code == status.HTTP_200_OK + # Should return count of distinct items, not total quantity + assert response.data == 2 + + @override_settings(ENABLE_MULTIPLE_CART_ITEMS=False) + def test_existing_basket_item_viewset_still_works(self, user_drf_client, user): + """Test that the existing BasketItemViewSet still works regardless of feature flag""" + product1 = ProductFactory.create() + product2 = ProductFactory.create() + + basket = BasketFactory.create(user=user) + BasketItemFactory.create(basket=basket, product=product1) + + # Add item using the existing ViewSet endpoint + response = user_drf_client.post( + f"/api/baskets/{basket.id}/items/", + data={"product": product2.id}, + ) + + assert response.status_code == status.HTTP_201_CREATED + + # Verify both items are in basket + assert basket.basket_items.count() == 2 diff --git a/ecommerce/views/v0/__init__.py b/ecommerce/views/v0/__init__.py index 29df6bbe1a..7768527ace 100644 --- a/ecommerce/views/v0/__init__.py +++ b/ecommerce/views/v0/__init__.py @@ -647,17 +647,44 @@ def add_to_cart(self, request): basket, _ = Basket.objects.select_for_update().get_or_create( user=self.request.user ) - basket.basket_items.all().delete() - BasketDiscount.objects.filter(redeemed_basket=basket).delete() - all_product_ids = [request.data["product_id"]] + # Check if multiple cart items feature is enabled + allow_multiple_items = getattr( + settings, "ENABLE_MULTIPLE_CART_ITEMS", False + ) - for product in Product.objects.filter(id__in=all_product_ids): + if not allow_multiple_items: + # Legacy behavior: clear existing items and discounts + basket.basket_items.all().delete() + BasketDiscount.objects.filter(redeemed_basket=basket).delete() + else: + # New behavior: only clear discounts, keep existing items + BasketDiscount.objects.filter(redeemed_basket=basket).delete() + + product_id = request.data["product_id"] + + try: + product = Product.objects.get(id=product_id) + except Product.DoesNotExist: + return Response( + {"message": "Product not found"}, + status=status.HTTP_404_NOT_FOUND, + ) + message = "Product already in cart" + if allow_multiple_items: + # Check if product already exists in basket + if not basket.basket_items.filter(product=product).exists(): + # Add new item to basket + BasketItem.objects.create(basket=basket, product=product) + message = "Product added to cart" + else: + # Legacy behavior: add single item BasketItem.objects.create(basket=basket, product=product) + message = "Product added to cart" return Response( { - "message": "Product added to cart", + "message": message, } ) diff --git a/ecommerce/views_test.py b/ecommerce/views_test.py index b96765cc45..cd1fcf77aa 100644 --- a/ecommerce/views_test.py +++ b/ecommerce/views_test.py @@ -739,6 +739,46 @@ def test_checkout_product_with_program_id(user, user_client): assert [item.product for item in basket.basket_items.all()] == [product] +@pytest.mark.parametrize("multiple_cart_enabled", [True, False]) +def test_add_to_cart_api_with_feature_flag( + user_drf_client, user, multiple_cart_enabled, settings +): + """Test add_to_cart API behavior with and without multiple cart items feature""" + settings.ENABLE_MULTIPLE_CART_ITEMS = multiple_cart_enabled + + # Create products + product1 = ProductFactory.create() + product2 = ProductFactory.create() + + # Add first product + resp = user_drf_client.post( + reverse("checkout_api-add_to_cart"), + data={"product_id": product1.id}, + ) + assert resp.status_code == 200 + + basket = Basket.objects.get(user=user) + assert basket.basket_items.count() == 1 + + # Add second product + resp = user_drf_client.post( + reverse("checkout_api-add_to_cart"), + data={"product_id": product2.id}, + ) + assert resp.status_code == 200 + + basket.refresh_from_db() + if multiple_cart_enabled: + # Should have both products + assert basket.basket_items.count() == 2 + products_in_basket = {item.product for item in basket.basket_items.all()} + assert products_in_basket == {product1, product2} + else: + # Should only have the second product (legacy behavior) + assert basket.basket_items.count() == 1 + assert basket.basket_items.first().product == product2 + + def test_discount_rest_api(admin_drf_client, user_drf_client): """ Checks that the admin REST API is only accessible by an admin diff --git a/frontend/public/src/components/CourseProductDetailEnroll.js b/frontend/public/src/components/CourseProductDetailEnroll.js index fb3d015d19..6abc41e9c4 100644 --- a/frontend/public/src/components/CourseProductDetailEnroll.js +++ b/frontend/public/src/components/CourseProductDetailEnroll.js @@ -292,7 +292,10 @@ export class CourseProductDetailEnroll extends React.Component< {course && course.title} added to your cart.
-