From 6ae217b69c7129d760c2278034f734fba45f313d Mon Sep 17 00:00:00 2001 From: James Kachel Date: Thu, 16 Oct 2025 12:25:43 -0500 Subject: [PATCH 1/4] Add ability to manage the link between programs and contracts (#3006) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- b2b/admin.py | 55 ++++ b2b/api.py | 69 +++- b2b/management/commands/b2b_contract.py | 152 --------- b2b/management/commands/b2b_courseware.py | 307 ++++++++++++++++++ b2b/management/commands/b2b_list.py | 47 +-- b2b/models.py | 5 +- b2b/models_test.py | 4 +- b2b/views/v0/views_test.py | 2 +- courses/admin.py | 8 + courses/api.py | 173 +++++++++- courses/factories.py | 1 + .../migrations/0069_add_source_run_flag.py | 20 ++ courses/models.py | 4 + courses/views/v2/views_test.py | 2 +- 14 files changed, 653 insertions(+), 196 deletions(-) create mode 100644 b2b/management/commands/b2b_courseware.py create mode 100644 courses/migrations/0069_add_source_run_flag.py diff --git a/b2b/admin.py b/b2b/admin.py index 54de66fd56..3858c3e198 100644 --- a/b2b/admin.py +++ b/b2b/admin.py @@ -7,6 +7,7 @@ DiscountContractAttachmentRedemption, OrganizationPage, ) +from courses.models import CourseRun class ReadOnlyModelAdmin(admin.ModelAdmin): @@ -42,6 +43,59 @@ class DiscountContractAttachmentRedemptionAdmin(ReadOnlyModelAdmin): readonly_fields = ["user", "contract", "discount", "created_on"] +class ContractPageProgramInline(admin.TabularInline): + """Inline to display programs for contract pages.""" + + model = ContractPage.programs.through + extra = 0 + verbose_name = "Contract Program" + verbose_name_plural = "Contract Programs" + + def has_add_permission(self, request, obj): # noqa: ARG002 + """Turn off add permission. These admins are supposed to be read-only.""" + + return False + + def has_delete_permission(self, request, obj): # noqa: ARG002 + """Turn off delete permission. These admins are supposed to be read-only.""" + + return False + + def has_change_permission(self, request, obj): # noqa: ARG002 + """Turn off change permission. These admins are supposed to be read-only.""" + + return False + + +class ContractPageCourseRunInline(admin.TabularInline): + """Inline to display course runs for contract pages.""" + + model = CourseRun + fk_name = "b2b_contract" + extra = 0 + fields = [ + "courseware_id", + "title", + ] + verbose_name = "Contract Course Run" + verbose_name_plural = "Contract Course Runs" + + def has_add_permission(self, request, obj): # noqa: ARG002 + """Turn off add permission. These admins are supposed to be read-only.""" + + return False + + def has_delete_permission(self, request, obj): # noqa: ARG002 + """Turn off delete permission. These admins are supposed to be read-only.""" + + return False + + def has_change_permission(self, request, obj): # noqa: ARG002 + """Turn off change permission. These admins are supposed to be read-only.""" + + return False + + @admin.register(ContractPage) class ContractPageAdmin(ReadOnlyModelAdmin): """Admin for contract pages.""" @@ -70,6 +124,7 @@ class ContractPageAdmin(ReadOnlyModelAdmin): "max_learners", "enrollment_fixed_price", ] + inlines = [ContractPageCourseRunInline, ContractPageProgramInline] @admin.register(OrganizationPage) diff --git a/b2b/api.py b/b2b/api.py index 54b487d872..a95f27d03f 100644 --- a/b2b/api.py +++ b/b2b/api.py @@ -71,8 +71,34 @@ def ensure_b2b_organization_index() -> OrganizationIndexPage: return org_index_page +def import_and_create_contract_run(contract: ContractPage, course_run_id: str): + """ + Create a contract run for the given course, importing it from edX if necessary. + + Check for the specified course run. If it exists, create the contract run in + the usual fashion. If it doesn't, check for it in edX and import it into + MITx Online first, then create the contract run. + + If the specified run is imported, it will have the "is_source_run" flag set. + + Args: + contract (ContractPage): The contract to create the run for. + course_run_id (str): The readable ID for the source course run. + Keyword Args: + skip_edx (bool): Don't try to create a course run in edX. + require_designated_source_run (bool): Require a flagged source run. + Returns: + CourseRun: The created CourseRun object. + Product: The created Product object. + """ + + def create_contract_run( - contract: ContractPage, course: Course + contract: ContractPage, + course: Course, + *, + skip_edx=False, + require_designated_source_run=True, ) -> tuple[CourseRun, Product]: """ Create a run for the specified contract. @@ -84,24 +110,24 @@ def create_contract_run( source course run in edX. - Create a product for the run. - Source course runs are identified by looking for the most recent course run - for the given course. This code expects you to pass in an MITx Online course - that has a readable ID like 'course-v1:UAI_SOURCE+number` and that it has a - run of some sort. This should just be a single run. If there's multiple runs, - this will look for a run tag of "SOURCE". Failing that, it will try to use - the _first_ run in the list. + Source course runs are runs that have the "is_source_run" flag set. If one + cannot be found, this will also check for a run with the run tag "SOURCE". + If neither of those are found, it will throw an error. However, setting + "require_designated_source_run" to False will add a third attempt, which will + try to use whatever the first course run is for the specified course. This + may not be what you want, so this functionality is disabled by default. The MITx Online course run will belong to the source course. They will get a course key that is modified to represent the organization they belong to, the current year, and the contract ID. This means that the source course will have runs that have readable IDs that do not match the course ID. - The course key is changed according to a set algorithm. For more information, - see this discussion post: https://github.com/mitodl/hq/discussions/7525 - In general, this expects a course run that is in org `UAI_SOURCE` and then - will create a new run that is `UAI_orgkey`, with a run tag that reflects - the year we're creating the run in and the contract ID (`2025_C19` for - instance). + The course key is generated according to a set algorithm. The new course key + will have the organization part set to "UAI_orgkey" and the run tag set to + "year_Cid" where orgkey is the organization key (set in the organization + record), year is the current year, and id is the ID of the contract. For more + information on the key format, see this discussion post: + https://github.com/mitodl/hq/discussions/7525 A product will be created for the new contract course run, and its price will either be zero or the amount specified by the contract. (Free courses still @@ -111,18 +137,23 @@ def create_contract_run( Args: contract (ContractPage): The contract to create the run for. course (Course): The course for which we should create a run. + Keyword Args: + skip_edx (bool): Don't try to create a course run in edX. + require_designated_source_run (bool): Require a flagged source run. Returns: CourseRun: The created CourseRun object. Product: The created Product object. """ - clone_course_run = course.courseruns.filter(run_tag="SOURCE").first() + clone_course_run = course.courseruns.filter( + Q(is_source_run=True) | Q(run_tag="SOURCE") + ).first() - if not clone_course_run: + if not clone_course_run and not require_designated_source_run: try: clone_course_run = course.courseruns.order_by("-id").first() log.warning( - "create_contract_run: No SOURCE run for %s, using %s", + "create_contract_run: Couldn't find an appropriate source run for %s, using %s", course, clone_course_run, ) @@ -175,7 +206,9 @@ def create_contract_run( ), ) course_run.save() - clone_courserun.delay(course_run.id, base_id=clone_course_run.courseware_id) + + if not skip_edx: + clone_courserun.delay(course_run.id, base_id=clone_course_run.courseware_id) log.debug( "Created run %s for course %s in contract %s from course run %s", @@ -424,7 +457,7 @@ def _handle_unlimited_seats( ) -> tuple[int, int, int]: """Handle unlimited seat contracts by creating/updating one discount per product.""" created = updated = errors = 0 - discount_amount = contract.enrollment_fixed_price or 0 + discount_amount = contract.enrollment_fixed_price or Decimal(0) if len(product_discounts) == 0: discount = _create_discount_with_product( diff --git a/b2b/management/commands/b2b_contract.py b/b2b/management/commands/b2b_contract.py index 9ddf53dc12..31db33f073 100644 --- a/b2b/management/commands/b2b_contract.py +++ b/b2b/management/commands/b2b_contract.py @@ -6,10 +6,8 @@ from django.core.management import BaseCommand, CommandError from django.db.models import Q -from b2b.api import create_contract_run from b2b.constants import CONTRACT_INTEGRATION_NONSSO, CONTRACT_INTEGRATION_SSO from b2b.models import ContractPage, OrganizationIndexPage, OrganizationPage -from courses.api import resolve_courseware_object_from_id log = logging.getLogger(__name__) @@ -19,26 +17,6 @@ class Command(BaseCommand): help = "Manage B2B contracts." - def create_run(self, contract, courseware): - """Create a run for the specified contract.""" - run_tuple = create_contract_run(contract=contract, course=courseware) - - if not run_tuple: - self.stdout.write( - self.style.ERROR( - f"Failed to create run for course {courseware} for contract {contract}." - ) - ) - return False - - self.stdout.write( - self.style.SUCCESS( - f"Created run {run_tuple[0]} and product {run_tuple[1]} for course {courseware} for contract {contract}." - ) - ) - - return True - def add_arguments(self, parser): """Add command line arguments.""" @@ -174,34 +152,6 @@ def add_arguments(self, parser): help="Clear the end date.", ) - courseware_parser = subparsers.add_parser( - "courseware", - help="Manage courseware assigned to a contract.", - ) - courseware_parser.add_argument( - "contract_id", - type=int, - help="The ID of the contract to courseware courseware to.", - ) - courseware_parser.add_argument( - "--remove", - action="store_true", - help="Remove courseware from the contract. (Default is to add.)", - dest="remove", - ) - courseware_parser.add_argument( - "--no-create-runs", - action="store_false", - help="Don't create new runs for this contract.", - dest="create_runs", - ) - courseware_parser.add_argument( - "courseware_id", - type=str, - help="The ID of the courseware to courseware.", - action="append", - ) - return super().add_arguments(parser) def handle_create(self, *args, **kwargs): # noqa: ARG002 @@ -303,106 +253,6 @@ def handle_modify(self, *args, **kwargs): # noqa: ARG002, C901 contract.save() self.stdout.write(f"Modified contract with ID '{contract_id}'") - def handle_courseware(self, *args, **kwargs): # noqa: ARG002, C901 - """Add/remove courseware in a contract.""" - contract_id = kwargs.pop("contract_id") - remove = kwargs.pop("remove") - create_runs = kwargs.pop("create_runs") - courseware_ids = kwargs.pop("courseware_id") - - contract = ContractPage.objects.filter(id=contract_id).first() - if not contract: - msg = f"Contract with ID '{contract_id}' does not exist." - raise CommandError(msg) - - managed = skipped = 0 - - for courseware_id in courseware_ids: - courseware = resolve_courseware_object_from_id(courseware_id) - if not courseware: - self.stdout.write( - self.style.ERROR( - f"Courseware with ID '{courseware_id}' does not exist, skipping." - ) - ) - skipped += 1 - elif courseware.is_program: - # If you're specifying a program, we will always make new runs - # since we won't be able to tell which existing ones to use. - - self.stdout.write( - self.style.WARNING( - f"'{courseware_id}' is a program, so creating runs for all of its courses." - ) - ) - - prog_add, prog_skip = contract.add_program_courses(courseware) - contract.save() - managed += prog_add - skipped += prog_skip - elif courseware.is_run: - # This run already exists, so just add/remove it. - # - If the run is owned by a different contract, skip it. - # - If remove is True, remove the run from the contract. - # - If remove is False, add the run to the contract. - - if courseware.b2b_contract and courseware.b2b_contract != contract: - # Already owned by another contract, so skip - self.stdout.write( - self.style.WARNING( - f"Run '{courseware_id}' is already owned by {courseware.b2b_contract}." - ) - ) - skipped += 1 - continue - - if remove: - # Remove the run from the contract - courseware.b2b_contract = None - courseware.save() - managed += 1 - elif courseware.b2b_contract == contract: - # Already owned by this contract, so skip - self.stdout.write( - self.style.WARNING( - f"Run '{courseware_id}' is already owned by this contract." - ) - ) - skipped += 1 - else: - # Add the run to the contract - courseware.b2b_contract = contract - courseware.save() - managed += 1 - elif remove: - # If we're removing courseware, skip it if it's not a run. - self.stdout.write( - self.style.WARNING( - f"Skipping removal of courseware '{courseware_id}' for contract {contract} because it is not a run. Removals must specify runs." - ) - ) - skipped += 1 - elif create_runs: - # This is a course, so create a run (unless we've been told not to). - - if self.create_run(contract, courseware): - managed += 1 - else: - skipped += 1 - else: - self.stdout.write( - self.style.WARNING( - f"Skipped run creation for for course {courseware} for contract {contract}." - ) - ) - skipped += 1 - - self.stdout.write( - self.style.SUCCESS( - f"Managed {managed} courseware items and skipped {skipped} courseware items for {len(courseware_ids)} specified courseware IDs." - ) - ) - def handle(self, *args, **kwargs): # noqa: ARG002 """Handle the command.""" subcommand = kwargs.pop("subcommand") @@ -410,8 +260,6 @@ def handle(self, *args, **kwargs): # noqa: ARG002 self.handle_create(**kwargs) elif subcommand == "modify": self.handle_modify(**kwargs) - elif subcommand == "courseware": - self.handle_courseware(**kwargs) else: log.error("Unknown subcommand: %s", subcommand) return 1 diff --git a/b2b/management/commands/b2b_courseware.py b/b2b/management/commands/b2b_courseware.py new file mode 100644 index 0000000000..8133219248 --- /dev/null +++ b/b2b/management/commands/b2b_courseware.py @@ -0,0 +1,307 @@ +""" +Manage courseware objects for B2B contracts. + +Allows you to +""" + +import logging +from argparse import RawTextHelpFormatter + +from django.core.management import BaseCommand, CommandError + +from b2b.api import create_contract_run +from b2b.models import ContractPage +from courses.api import resolve_courseware_object_from_id +from courses.models import CourseRun + +log = logging.getLogger(__name__) + + +class Command(BaseCommand): + """Manage B2B contract courseware objects.""" + + help = """Add or remove a B2B contract's courseware objects. + +Courseware objects can be course runs, courses, or programs. specified by their readable ID (i.e. course-v1:MITxT+12.345s+3T2022). Contract should be specified by either their numeric ID or their slug. + +Specifying courseware: You must specify one courseware item (of any type). You can specify more than one by adding "--also " to the end of the command. You can repeat this as many times as necessary. + +To add courseware: + b2b_courseware add [--no-create-runs] [--force] contract courseware [--also courseware] [--also courseware...] + +Example: b2b_courseware add contract-100-101 program-v1:UAI+Fundamentals --also course-v1:UAI_C100+14.314x+2025_C101 + +Specifying a course run will attach it to the contract unless the contract is already attached to a contract. Specify "--force" to override any existing contract attachment. + +Specifying a course will attempt to create a course run for the contract for the specified course. It will try to create a course run in edX as well unless "--no-create-runs" is specified. + +Specifying a program will iterate through the program's courses and create runs for each. It will also link the program to the contract. + +To remove: + b2b_courseware remove [--remove-program-runs] contract courseware [--also courseware] [--also courseware...] + +Example: b2b_courseware remove --remove-program-runs contract-100-101 program-v1:UAI+Fundamentals --also course-v1:UAI_C100+14.314x+2025_C101 + +Specifying a course run will unlink the run from the contract. + +Specifying a course will unlink any of the course's runs that are attached to the contract from the contract. + +Specifying a program will only unlink the program from the contract, unless "--remove-program-runs" is set. If it is, then all the runs that belong to both the contract and the program's courses will be removed from the contract. Note that doing this and then re-adding the program will *not* re-attach the existing runs to the contract - you will need to do that manually. + """ + + def create_run(self, contract, courseware, *, skip_edx=False): + """Create a run for the specified contract.""" + run_tuple = create_contract_run( + contract=contract, course=courseware, skip_edx=skip_edx + ) + + if not run_tuple: + self.stdout.write( + self.style.ERROR( + f"Failed to create run for course {courseware} for contract {contract}." + ) + ) + return False + + self.stdout.write( + self.style.SUCCESS( + f"Created run {run_tuple[0]} and product {run_tuple[1]} for course {courseware} for contract {contract}." + ) + ) + + return True + + def add_arguments(self, parser): + """Add command line arguments.""" + + parser.formatter_class = RawTextHelpFormatter + + subparsers = parser.add_subparsers( + title="Task", + dest="subcommand", + required=True, + help="The task to perform - add or remove.", + ) + parser.add_argument( + "contract", type=str, help="The contract to work on (slug or ID)." + ) + parser.add_argument( + "courseware", + type=str, + help="The courseware object (readable ID) to work with. Can be a program, course, or course run.", + ) + parser.add_argument( + "--also", + type=str, + action="append", + dest="additional_courseware", + help="Additional courseware objects (readable IDs) to work with.", + ) + + add_subparser = subparsers.add_parser( + "add", + help="Add courseware to a contract.", + ) + add_subparser.add_argument( + "--no-create-runs", + help="Don't create contract runs in edX for the specified course, just add it to the contract.", + dest="create_runs", + action="store_false", + ) + add_subparser.add_argument( + "--force", + help="Force adding any specified runs to the contract (overwrite existing contract associations).", + dest="force", + action="store_true", + ) + + remove_subparser = subparsers.add_parser( + "remove", + help="Remove courseware from a contract.", + ) + + remove_subparser.add_argument( + "--remove-program-runs", + help="For programs, unlink the program's contract runs as well as the program.", + action="store_true", + ) + + return super().add_arguments(parser) + + def handle_add(self, contract, coursewares, **kwargs): + """Handle the add subcommand.""" + + create_runs = kwargs.pop("create_runs") + force_associate = kwargs.pop("force") + + managed = skipped = 0 + + for courseware in coursewares: + if courseware.is_program: + # If you're specifying a program, we will always make new runs + # since we won't be able to tell which existing ones to use. + + self.stdout.write( + self.style.WARNING( + f"'{courseware.readable_id}' is a program, so creating runs for all of its courses." + ) + ) + + prog_add, prog_skip = contract.add_program_courses(courseware) + contract.save() + managed += prog_add + skipped += prog_skip + contract.programs.add(courseware) + self.stdout.write( + self.style.SUCCESS(f"Added {courseware.readable_id} to {contract}.") + ) + elif courseware.is_run: + # This run already exists, so: + # - If it's in a contract already and we're not forcing it, skip it. + # - If it's in a contract already and we *are* forcing it, set it to be in this contract. + # - If it's not in a contract, add it to this contract. + + if ( + not force_associate + and courseware.b2b_contract + and courseware.b2b_contract != contract + ): + # Already owned by another contract, so skip + self.stdout.write( + self.style.WARNING( + f"Run '{courseware.courseware_id}' is already owned by {courseware.b2b_contract}." + ) + ) + skipped += 1 + continue + elif courseware.b2b_contract == contract: + # Already owned by this contract, so skip + self.stdout.write( + self.style.WARNING( + f"Run '{courseware.courseware_id}' is already owned by this contract." + ) + ) + skipped += 1 + continue + + # Add the run to the contract + courseware.b2b_contract = contract + courseware.save() + managed += 1 + elif create_runs: + # This is a course, so create a run (unless we've been told not to). + + if self.create_run(contract, courseware, skip_edx=create_runs): + managed += 1 + else: + skipped += 1 + else: + self.stdout.write( + self.style.WARNING( + f"Skipped run creation for for course {courseware.readable_id} for contract {contract}." + ) + ) + skipped += 1 + + self.stdout.write( + self.style.SUCCESS( + f"Managed {managed} courseware items and skipped {skipped} courseware items for {len(coursewares)} specified courseware IDs." + ) + ) + + return True + + def handle_remove(self, contract, coursewares, **kwargs): + """Handle removing courseware from a contract.""" + + remove_runs = kwargs.pop("remove_program_runs") + + for courseware in coursewares: + if courseware.is_program: + # If we have a program, unlink the program from the contract. + # Then, if we're told to, unlink any contract runs that are + # part of the program too. + + if remove_runs: + program_courses = courseware.courses + program_runs = CourseRun.objects.filter( + b2b_contract=contract, + course__in=[course for (course, _) in program_courses], + ).all() + + coursewares.extend(program_runs) + + self.stdout.write( + self.style.NOTICE( + f"{courseware.readable_id} is a program and --remove-program-runs set, so adding {len(program_runs)} course runs" + ) + ) + + contract.programs.remove(courseware) + self.stdout.write( + self.style.SUCCESS( + f"Removed program {courseware.readable_id} from contract {contract}." + ) + ) + elif not courseware.is_run: + # If we have a course, find and add the contract runs for the + # course to the list. We don't link courses to contracts, so + # there's nothing else to do here. + + course_contract_runs = courseware.courseruns.filter( + b2b_contract=contract + ).all() + + coursewares.extend(course_contract_runs) + + self.stdout.write( + self.style.SUCCESS( + f"Added {len(course_contract_runs)} course runs from course {courseware.readable_id} to remove from contract {contract}." + ) + ) + else: + # We're actually at a course run now. + + courseware.b2b_contract = None + courseware.save() + + self.stdout.write( + self.style.SUCCESS( + f"Unlinked {courseware.courseware_id} from {contract}." + ) + ) + + return True + + def handle(self, *args, **kwargs): # noqa: ARG002 + """Dispatch the requested task.""" + + contract_id = kwargs.pop("contract") + courseware_id = kwargs.pop("courseware") + additional_courseware_ids = kwargs.pop("additional_courseware") + subcommand = kwargs.pop("subcommand") + + if contract_id.isdecimal(): + contract = ContractPage.objects.filter(id=contract_id).first() + else: + contract = ContractPage.objects.filter(slug=contract_id).first() + + if not contract: + msg = f"Contract with ID/slug '{contract_id}' does not exist." + raise CommandError(msg) + + courseware_ids = [courseware_id] + if additional_courseware_ids: + courseware_ids.extend(additional_courseware_ids) + + coursewares = [ + resolve_courseware_object_from_id(courseware_id) + for courseware_id in courseware_ids + ] + + if subcommand == "add": + self.handle_add(contract, coursewares, **kwargs) + elif subcommand == "remove": + self.handle_remove(contract, coursewares, **kwargs) + else: + self.stderr.write(self.style.ERROR(f"Unknown command {subcommand}")) diff --git a/b2b/management/commands/b2b_list.py b/b2b/management/commands/b2b_list.py index 0d2bb4265f..4cfc01c4d5 100644 --- a/b2b/management/commands/b2b_list.py +++ b/b2b/management/commands/b2b_list.py @@ -13,7 +13,6 @@ from rich.table import Table from b2b.models import ContractPage, OrganizationPage -from courses.models import CourseRun log = logging.getLogger(__name__) @@ -275,8 +274,7 @@ def handle_list_courseware(self, *args, **kwargs): # noqa: ARG002 org_text = kwargs.pop("org_text") contract_id = kwargs.pop("contract_id") - # We only link course runs to contracts. This will need to be updated - # if we ever have other types (like program runs or something). + # We now attach course runs and programs to contracts. contract_page_qs = ContractPage.objects @@ -295,13 +293,7 @@ def handle_list_courseware(self, *args, **kwargs): # noqa: ARG002 if contract_id: contract_page_qs = contract_page_qs.filter(id=contract_id) - contracts = contract_page_qs.all() - - courseware = ( - CourseRun.objects.prefetch_related("b2b_contract") - .filter(b2b_contract__in=contracts) - .all() - ) + contracts = contract_page_qs.prefetch_related("programs", "course_runs").all() courseware_table = Table(title="Courseware") courseware_table.add_column("ID", justify="right") @@ -312,16 +304,31 @@ def handle_list_courseware(self, *args, **kwargs): # noqa: ARG002 courseware_table.add_column("Start", justify="left") courseware_table.add_column("End", justify="left") - for cw in courseware: - courseware_table.add_row( - str(cw.id), - f"{cw.b2b_contract.organization.name}\n{cw.b2b_contract.name}", - "CR", - cw.readable_id, - cw.title, - cw.start_date.strftime("%Y-%m-%d\n%H:%M") if cw.start_date else "", - cw.end_date.strftime("%Y-%m-%d\n%H:%M") if cw.end_date else "", - ) + for contract in contracts: + # We'll start by listing out the associated programs, and then listing + # out the course runs. + + for prog in contract.programs.all(): + courseware_table.add_row( + str(prog.id), + f"{contract.organization.name}\n{contract.name}", + "PROG", + prog.readable_id, + prog.title, + "---", + "---", + ) + + for cw in contract.course_runs.all(): + courseware_table.add_row( + str(cw.id), + f"{contract.organization.name}\n{contract.name}", + "CR", + cw.readable_id, + cw.title, + cw.start_date.strftime("%Y-%m-%d\n%H:%M") if cw.start_date else "", + cw.end_date.strftime("%Y-%m-%d\n%H:%M") if cw.end_date else "", + ) self.console.print(courseware_table) diff --git a/b2b/models.py b/b2b/models.py index a5cd75de5d..01449ef64b 100644 --- a/b2b/models.py +++ b/b2b/models.py @@ -314,7 +314,10 @@ def add_program_courses(self, program): if hasattr(program, "_courses_with_requirements_data"): delattr(program, "_courses_with_requirements_data") - for course, _ in program.courses: + for course in program.courses_qset.filter( + models.Q(courseruns__is_source_run=True) + | models.Q(courseruns__run_tag="SOURCE") + ).all(): try: create_contract_run(self, course) managed += 1 diff --git a/b2b/models_test.py b/b2b/models_test.py index 7535a7c28c..8501c10734 100644 --- a/b2b/models_test.py +++ b/b2b/models_test.py @@ -16,7 +16,7 @@ def test_add_program_courses_to_contract(mocker): mocker.patch("openedx.tasks.clone_courserun.delay") program = ProgramFactory.create() - courseruns = CourseRunFactory.create_batch(3) + courseruns = CourseRunFactory.create_batch(3, is_source_run=True) contract = ContractPageFactory.create() for courserun in courseruns: @@ -35,7 +35,7 @@ def test_add_program_courses_to_contract(mocker): assert contract.programs.count() == 1 assert contract.get_course_runs().count() == 3 - new_courserun = CourseRunFactory.create() + new_courserun = CourseRunFactory.create(is_source_run=True) program.add_requirement(new_courserun.course) program.save() program.refresh_from_db() diff --git a/b2b/views/v0/views_test.py b/b2b/views/v0/views_test.py index 2db61bdf54..3d11079408 100644 --- a/b2b/views/v0/views_test.py +++ b/b2b/views/v0/views_test.py @@ -237,7 +237,7 @@ def test_b2b_enroll(mocker, settings, user_has_edx_user, has_price): integration_type=CONTRACT_INTEGRATION_SSO, enrollment_fixed_price=100 if has_price else 0, ) - source_courserun = CourseRunFactory.create() + source_courserun = CourseRunFactory.create(is_source_run=True) courserun, _ = create_contract_run(contract, source_courserun.course) diff --git a/courses/admin.py b/courses/admin.py index 9d7ea958c2..39678b2f93 100644 --- a/courses/admin.py +++ b/courses/admin.py @@ -35,6 +35,13 @@ from openedx.tasks import retry_failed_edx_enrollments +class ProgramContractPageInline(admin.TabularInline): + """Inline for contract pages""" + + model = Program.contracts.through + extra = 0 + + @admin.register(Program) class ProgramAdmin(admin.ModelAdmin): """Admin for Program""" @@ -44,6 +51,7 @@ class ProgramAdmin(admin.ModelAdmin): search_fields = ["title", "readable_id", "program_type"] list_display = ("id", "title", "live", "readable_id", "program_type") list_filter = ["live", "program_type", "departments"] + inlines = [ProgramContractPageInline] @admin.register(ProgramRun) diff --git a/courses/api.py b/courses/api.py index 6510022b66..56d7f664ce 100644 --- a/courses/api.py +++ b/courses/api.py @@ -6,23 +6,29 @@ import re from collections import namedtuple from datetime import timedelta +from decimal import Decimal from traceback import format_exc from typing import TYPE_CHECKING +from urllib.parse import urljoin +import reversion from django.conf import settings from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ValidationError from django.db import IntegrityError, transaction from django.db.models import Q +from django_countries import countries from mitol.common.utils import now_in_utc from mitol.common.utils.collections import ( first_or_none, has_equal_properties, ) +from opaque_keys.edx.keys import CourseKey from requests.exceptions import ConnectionError as RequestsConnectionError from requests.exceptions import HTTPError from rest_framework.status import HTTP_404_NOT_FOUND +from cms.api import create_default_courseware_page from courses import mail_api from courses.constants import ( COURSE_KEY_PATTERN, @@ -31,11 +37,13 @@ PROGRAM_TEXT_ID_PREFIX, ) from courses.models import ( + BlockedCountry, Course, CourseRun, CourseRunCertificate, CourseRunEnrollment, CourseRunGrade, + Department, PaidCourseRun, Program, ProgramCertificate, @@ -49,10 +57,11 @@ is_grade_valid, is_letter_grade_valid, ) -from ecommerce.models import OrderStatus +from ecommerce.models import OrderStatus, Product from openedx.api import ( create_edx_course_mode, enroll_in_edx_course_runs, + get_edx_api_course_detail_client, get_edx_api_course_list_client, get_edx_api_course_mode_client, get_edx_course_modes, @@ -1096,3 +1105,165 @@ def resolve_courseware_object_from_id( CourseRun.objects.filter(courseware_id=courseware_id).first() or Course.objects.filter(readable_id=courseware_id).first() ) + + +def import_courserun_from_edx( # noqa: C901, PLR0913 + course_key: str, + *, + live: bool = False, + use_specific_course: str | None = None, + departments: list[Department | str] | None = None, + create_depts: bool = False, + block_countries: list[str] | None = None, + price: Decimal | None = None, + create_cms_page: bool = False, + publish_cms_page: bool = False, + include_in_learn_catalog: bool = False, + ingest_content_files_for_ai: bool = False, +): + """ + Import a course run from edX. + + This checks for the course run in edX, and imports it if it exists. If + necessary, it creates: + - The underlying Course object + - Any necessary Departments (if the flag for this is set) + - A Product (if a price is set) + - A CMS page (if the flag is set; will publish if the flag is set) + + It will also add the blocked countries for the course if those are set. + + If the course does need to be created, departments must be supplied. The + function will throw an AttributeError if there aren't any. An empty list can + be supplied if the course exists. + + A specific course can be specified. This is to cover cases where the run you + may want to import doesn't technically "live" under the course that is + specified in its key. (This happens with B2B/UAI courses - the course will + be in the UAI_SOURCE org, but the runs all use an org that matches up with + the contract on the MITx Online side. E.g. course-v1:UAI_SOURCE+UAI.0 is the + root course for course-v1:UAI_MIT+UAI.0+2025_C999.) + + If the specified course run exists, then this won't do anything. There are + separate processes to update an existing run from edX data. + + This will not add the course to any programs - you can do that later. + + Args: + - course_key (str): The readable ID of the course run to import. + - live (bool): Make the new course run live, and the course if one is created. + - use_specific_course (str|None): Readable ID of a specific course to use as the base course. + - departments (list[Department | str] | None): Departments to add to the new course. Only required if creating a new course. + - create_depts (bool): Create departments. + - block_countries (list[str] | None): Country codes to add to the block list for the course. + - price (Decimal | None): Price for the course product, if any. If no price is set, a product won't be created. + - create_cms_page (bool): Create a CMS page for the course. Only applies if a course is being created. + - publish_cms_page (bool): Publish the new CMS page. Only takes effect if creating a CMS page. + - include_in_learn_catalog (bool): Set the "include_in_learn_catalog" flag on the new page. + - ingest_content_files_for_ai (bool): Set the "ingest_content_files_for_ai" flag on the new page. + Returns: + tuple of (CourseRun, CoursePage|None, Product|None) - relevant objects for the imported run + """ + + if CourseRun.objects.filter(courseware_id=course_key).exists(): + return False + + processed_course_key = CourseKey(course_key) + + edx_course_detail = get_edx_api_course_detail_client() + + edx_course_run = edx_course_detail.get_detail( + course_id=course_key, + username=settings.OPENEDX_SERVICE_WORKER_USERNAME, + ) + + processed_run_key = CourseKey(edx_course_run.course_id) + + if use_specific_course: + root_course = Course.objects.get(readable_id=use_specific_course) + else: + # edX doesn't have the concept of a "Course", so there's not an opaque + # key type for it. + root_course_id = ( + f"course-v1:{processed_course_key.org}+{processed_course_key.course}" + ) + root_course = Course.objects.filter(readable_id=root_course_id).first() + + if not root_course: + if not departments or len(departments) == 0: + msg = f"Course {root_course_id} would be created, so departments are required." + raise AttributeError(msg) + + root_course = Course.objects.create( + readable_id=root_course_id, + title=edx_course_run.name, + live=live, + ) + + for department in departments: + if isinstance(department, str) and create_depts: + dept = Department.objects.get_or_create(name=department) + else: + dept = department + + root_course.departments.add(dept) + + new_run = CourseRun.objects.create( + course=root_course, + run_tag=processed_run_key.run, + courseware_id=edx_course_run.course_id, + start_date=edx_course_run.start, + end_date=edx_course_run.end, + enrollment_start=edx_course_run.enrollment_start, + enrollment_end=edx_course_run.enrollment_end, + title=edx_course_run.name, + live=live, + is_self_paced=edx_course_run.is_self_paced(), + courseware_url_path=urljoin( + settings.OPENEDX_COURSE_BASE_URL, + f"/{edx_course_run.course_id}/course", + ), + ) + + course_page = None + if create_cms_page: + course_page = create_default_courseware_page( + courseware=new_run.course, + live=publish_cms_page, + ) + + course_page.ingest_content_files_for_ai = ingest_content_files_for_ai + course_page.include_in_learn_catalog = include_in_learn_catalog + course_page.save() + + course_product = None + if price: + content_type = ContentType.objects.get_for_model(CourseRun) + with reversion.create_revision(): + course_product, _ = Product.objects.update_or_create( + content_type=content_type, + object_id=new_run.id, + defaults={ + "price": Decimal(price), + "description": new_run.courseware_id, + "is_active": True, + }, + ) + + course_product.save() + + if block_countries: + for block_country in block_countries: + country_code = countries.by_name(block_country) + if not country_code: + country_name = countries.countries.get(block_country, None) + country_code = block_country if country_name else None + else: + country_name = block_country + + if country_code: + BlockedCountry.objects.get_or_create( + course=new_run.course, country=country_code + ) + + return (new_run, course_page, course_product) diff --git a/courses/factories.py b/courses/factories.py index 70e5144774..ece9db50b7 100644 --- a/courses/factories.py +++ b/courses/factories.py @@ -137,6 +137,7 @@ class CourseRunFactory(DjangoModelFactory): live = True b2b_contract = None + is_source_run = False class Meta: model = CourseRun diff --git a/courses/migrations/0069_add_source_run_flag.py b/courses/migrations/0069_add_source_run_flag.py new file mode 100644 index 0000000000..a56555447a --- /dev/null +++ b/courses/migrations/0069_add_source_run_flag.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.25 on 2025-10-14 15:40 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("courses", "0068_department_related_names"), + ] + + operations = [ + migrations.AddField( + model_name="courserun", + name="is_source_run", + field=models.BooleanField( + default=False, + help_text='Designate this run as a "source" run for contract re-runs of the course.', + ), + ), + ] diff --git a/courses/models.py b/courses/models.py index 9790cf4dbd..1f2e5925a6 100644 --- a/courses/models.py +++ b/courses/models.py @@ -1104,6 +1104,10 @@ class CourseRun(TimestampedModel): on_delete=models.DO_NOTHING, related_name="course_runs", ) + is_source_run = models.BooleanField( + default=False, + help_text='Designate this run as a "source" run for contract re-runs of the course.', + ) class Meta: unique_together = ("course", "courseware_id", "run_tag") diff --git a/courses/views/v2/views_test.py b/courses/views/v2/views_test.py index 96a862855f..62f922492b 100644 --- a/courses/views/v2/views_test.py +++ b/courses/views/v2/views_test.py @@ -628,7 +628,7 @@ def test_program_filter_for_b2b_org(user, mock_course_run_clone): regular_program.save() b2b_course = CourseFactory.create() - CourseRunFactory.create(course=b2b_course) + CourseRunFactory.create(course=b2b_course, is_source_run=True) b2b_program.add_requirement(b2b_course) b2b_program.add_requirement(regular_course) b2b_program.b2b_only = True From 549cab38f51a20a3b37151198328922f49dd072f Mon Sep 17 00:00:00 2001 From: Chris Chudzicki Date: Thu, 16 Oct 2025 15:06:35 -0400 Subject: [PATCH 2/4] Change v2 req_tree to not return root node, improve OpenAPI spec (#3010) --- courses/serializers/v1/programs_test.py | 63 ++++++++++++- courses/serializers/v2/programs.py | 20 ++++- courses/serializers/v2/programs_test.py | 113 ++++++++++++++++++++++++ openapi/specs/v0.yaml | 17 ++-- openapi/specs/v1.yaml | 17 ++-- openapi/specs/v2.yaml | 17 ++-- 6 files changed, 228 insertions(+), 19 deletions(-) diff --git a/courses/serializers/v1/programs_test.py b/courses/serializers/v1/programs_test.py index f516ca12ba..a6d116f81f 100644 --- a/courses/serializers/v1/programs_test.py +++ b/courses/serializers/v1/programs_test.py @@ -1,5 +1,6 @@ from datetime import timedelta from decimal import Decimal +from unittest.mock import ANY import pytest from django.utils.timezone import now @@ -108,7 +109,7 @@ def test_serialize_program(mock_context, remove_tree, program_with_empty_require ) -def test_program_requirement_tree_serializer_valid(): +def test_program_requirement_tree_serializer_save(): """Verify that the ProgramRequirementTreeSerializer validates data""" program = ProgramFactory.create() course1, course2, course3 = CourseFactory.create_batch(3) @@ -142,6 +143,66 @@ def test_program_requirement_tree_serializer_valid(): serializer.is_valid(raise_exception=True) serializer.save() + root.refresh_from_db() + assert ProgramRequirementTreeSerializer(instance=root).data == [ + { + "data": { + "node_type": "program_root", + "operator": None, + "operator_value": None, + "program": program.id, + "course": None, + "required_program": None, + "title": "", + "elective_flag": False, + }, + "id": ANY, + "children": [ + { + "data": { + "node_type": "operator", + "operator": "all_of", + "operator_value": None, + "program": program.id, + "course": None, + "required_program": None, + "title": "Required Courses", + "elective_flag": False, + }, + "id": ANY, + "children": [ + { + "data": { + "node_type": "course", + "operator": None, + "operator_value": None, + "program": program.id, + "course": course1.id, + "required_program": None, + "title": None, + "elective_flag": False, + }, + "id": ANY, + } + ], + }, + { + "data": { + "node_type": "operator", + "operator": "min_number_of", + "operator_value": "1", + "program": program.id, + "course": None, + "required_program": None, + "title": "Elective Courses", + "elective_flag": False, + }, + "id": ANY, + }, + ], + } + ] + def test_program_requirement_deletion(): """Verify that saving the requirements for one program doesn't affect other programs""" diff --git a/courses/serializers/v2/programs.py b/courses/serializers/v2/programs.py index b1d376b94b..fba49ab996 100644 --- a/courses/serializers/v2/programs.py +++ b/courses/serializers/v2/programs.py @@ -24,12 +24,18 @@ class ProgramRequirementDataSerializer(StrictFieldsSerializer): node_type = serializers.ChoiceField( choices=( - ProgramRequirementNodeType.OPERATOR, ProgramRequirementNodeType.COURSE, + ProgramRequirementNodeType.PROGRAM, + ProgramRequirementNodeType.OPERATOR, ) ) - course = serializers.CharField(source="course_id", allow_null=True, default=None) - program = serializers.CharField(source="program_id", required=False) + course = serializers.IntegerField(source="course_id", allow_null=True, default=None) + program = serializers.IntegerField( + source="program_id", allow_null=True, default=None + ) + required_program = serializers.IntegerField( + source="required_program_id", allow_null=True, default=None + ) title = serializers.CharField(allow_null=True, default=None) operator = serializers.CharField(allow_null=True, default=None) operator_value = serializers.CharField(allow_null=True, default=None) @@ -79,6 +85,14 @@ class Meta: class ProgramRequirementTreeSerializer(BaseProgramRequirementTreeSerializer): child = ProgramRequirementSerializer() + @property + def data(self): + """Return children of root node directly, or empty array if no children""" + # BaseProgramRequirementTreeSerializer overrides the data property + # to bypass to_implementation, so we do also. + full_data = super().data + return full_data[0].get("children", []) if full_data else [] + @extend_schema_serializer( component_name="V2ProgramSerializer", diff --git a/courses/serializers/v2/programs_test.py b/courses/serializers/v2/programs_test.py index 66036ef0c3..2324fae1b3 100644 --- a/courses/serializers/v2/programs_test.py +++ b/courses/serializers/v2/programs_test.py @@ -1,4 +1,5 @@ from datetime import timedelta +from unittest.mock import ANY import pytest from django.utils.timezone import now @@ -6,6 +7,7 @@ from cms.factories import CoursePageFactory from cms.serializers import ProgramPageSerializer from courses.factories import ( # noqa: F401 + CourseFactory, CourseRunFactory, ProgramCollectionFactory, ProgramFactory, @@ -144,3 +146,114 @@ def test_serialize_program( "max_price": program_with_empty_requirements.page.max_price, }, ) + + +def test_program_requirement_tree_serializer_save(): + """Verify that the ProgramRequirementTreeSerializer validates data""" + program = ProgramFactory.create() + course1, course2, course3 = CourseFactory.create_batch(3) + root = program.requirements_root + + serializer = ProgramRequirementTreeSerializer( + instance=root, + data=[ + { + "data": { + "node_type": "operator", + "title": "Required Courses", + "operator": "all_of", + }, + "children": [ + {"id": None, "data": {"node_type": "course", "course": course1.id}} + ], + }, + { + "data": { + "node_type": "operator", + "title": "Elective Courses", + "operator": "min_number_of", + "operator_value": "1", + }, + "children": [ + {"id": None, "data": {"node_type": "course", "course": course2.id}}, + {"id": None, "data": {"node_type": "course", "course": course3.id}}, + ], + }, + ], + context={"program": program}, + ) + serializer.is_valid(raise_exception=True) + serializer.save() + + root.refresh_from_db() + assert ProgramRequirementTreeSerializer(instance=root).data == [ + { + "data": { + "node_type": "operator", + "operator": "all_of", + "operator_value": None, + "program": program.id, + "course": None, + "required_program": None, + "title": "Required Courses", + "elective_flag": False, + }, + "id": ANY, + "children": [ + { + "data": { + "node_type": "course", + "operator": None, + "operator_value": None, + "program": program.id, + "course": course1.id, + "required_program": None, + "title": None, + "elective_flag": False, + }, + "id": ANY, + } + ], + }, + { + "data": { + "node_type": "operator", + "operator": "min_number_of", + "operator_value": "1", + "program": program.id, + "course": None, + "required_program": None, + "title": "Elective Courses", + "elective_flag": False, + }, + "id": ANY, + "children": [ + { + "data": { + "node_type": "course", + "operator": None, + "operator_value": None, + "program": program.id, + "course": course2.id, + "required_program": None, + "title": None, + "elective_flag": False, + }, + "id": ANY, + }, + { + "data": { + "node_type": "course", + "operator": None, + "operator_value": None, + "program": program.id, + "course": course3.id, + "required_program": None, + "title": None, + "elective_flag": False, + }, + "id": ANY, + }, + ], + }, + ] diff --git a/openapi/specs/v0.yaml b/openapi/specs/v0.yaml index 756ee79630..79bba4c18c 100644 --- a/openapi/specs/v0.yaml +++ b/openapi/specs/v0.yaml @@ -4884,10 +4884,14 @@ components: node_type: $ref: '#/components/schemas/V2ProgramRequirementDataNodeTypeEnum' course: - type: string + type: integer nullable: true program: - type: string + type: integer + nullable: true + required_program: + type: integer + nullable: true title: type: string nullable: true @@ -4905,15 +4909,18 @@ components: - node_type V2ProgramRequirementDataNodeTypeEnum: enum: - - operator - course + - program + - operator type: string description: |- - * `operator` - operator * `course` - course + * `program` - program + * `operator` - operator x-enum-descriptions: - - operator - course + - program + - operator YearsExperienceEnum: enum: - 2 diff --git a/openapi/specs/v1.yaml b/openapi/specs/v1.yaml index 8812be584c..74da736ec6 100644 --- a/openapi/specs/v1.yaml +++ b/openapi/specs/v1.yaml @@ -4884,10 +4884,14 @@ components: node_type: $ref: '#/components/schemas/V2ProgramRequirementDataNodeTypeEnum' course: - type: string + type: integer nullable: true program: - type: string + type: integer + nullable: true + required_program: + type: integer + nullable: true title: type: string nullable: true @@ -4905,15 +4909,18 @@ components: - node_type V2ProgramRequirementDataNodeTypeEnum: enum: - - operator - course + - program + - operator type: string description: |- - * `operator` - operator * `course` - course + * `program` - program + * `operator` - operator x-enum-descriptions: - - operator - course + - program + - operator YearsExperienceEnum: enum: - 2 diff --git a/openapi/specs/v2.yaml b/openapi/specs/v2.yaml index 037505f2d1..181b69c01d 100644 --- a/openapi/specs/v2.yaml +++ b/openapi/specs/v2.yaml @@ -4884,10 +4884,14 @@ components: node_type: $ref: '#/components/schemas/V2ProgramRequirementDataNodeTypeEnum' course: - type: string + type: integer nullable: true program: - type: string + type: integer + nullable: true + required_program: + type: integer + nullable: true title: type: string nullable: true @@ -4905,15 +4909,18 @@ components: - node_type V2ProgramRequirementDataNodeTypeEnum: enum: - - operator - course + - program + - operator type: string description: |- - * `operator` - operator * `course` - course + * `program` - program + * `operator` - operator x-enum-descriptions: - - operator - course + - program + - operator YearsExperienceEnum: enum: - 2 From 786e840a47a93d87f4eb64bc026d2b691e3113d1 Mon Sep 17 00:00:00 2001 From: James Kachel Date: Thu, 16 Oct 2025 15:00:53 -0500 Subject: [PATCH 3/4] B2B: Update Keycloak org membership when a user redeems an enrollment code (#2989) --- b2b/admin.py | 1 + b2b/api.py | 220 ++++++++--- b2b/api_test.py | 350 ++++++++++++++++-- b2b/constants.py | 59 ++- b2b/factories.py | 8 +- b2b/management/commands/b2b_contract.py | 13 +- ...9_contractpage_membership_type_and_more.py | 96 +++++ b2b/models.py | 101 ++++- b2b/models_test.py | 5 +- b2b/serializers/v0/__init__.py | 71 +++- b2b/views/v0/__init__.py | 5 +- b2b/views/v0/views_test.py | 36 +- courses/api.py | 6 +- courses/api_test.py | 5 +- courses/views/v2/__init__.py | 2 +- courses/views/v2/views_test.py | 7 + ecommerce/discounts_test.py | 8 +- openapi/specs/v0.yaml | 61 +-- openapi/specs/v1.yaml | 61 +-- openapi/specs/v2.yaml | 61 +-- pytest.ini | 5 + users/admin.py | 9 + .../migrations/0038_user_b2b_organizations.py | 23 ++ users/models.py | 25 +- users/serializers.py | 44 +-- users/views_test.py | 2 + 26 files changed, 981 insertions(+), 303 deletions(-) create mode 100644 b2b/migrations/0009_contractpage_membership_type_and_more.py create mode 100644 users/migrations/0038_user_b2b_organizations.py diff --git a/b2b/admin.py b/b2b/admin.py index 3858c3e198..8a6f4ec58e 100644 --- a/b2b/admin.py +++ b/b2b/admin.py @@ -119,6 +119,7 @@ class ContractPageAdmin(ReadOnlyModelAdmin): "title", "description", "integration_type", + "membership_type", "contract_start", "contract_end", "max_learners", diff --git a/b2b/api.py b/b2b/api.py index a95f27d03f..03aa715f10 100644 --- a/b2b/api.py +++ b/b2b/api.py @@ -19,13 +19,18 @@ from b2b.constants import ( B2B_RUN_TAG_FORMAT, - CONTRACT_INTEGRATION_SSO, + CONTRACT_MEMBERSHIP_AUTOS, ORG_KEY_MAX_LENGTH, ) from b2b.exceptions import SourceCourseIncompleteError, TargetCourseRunExistsError from b2b.keycloak_admin_api import KCAM_ORGANIZATIONS, get_keycloak_model from b2b.keycloak_admin_dataclasses import OrganizationRepresentation -from b2b.models import ContractPage, OrganizationIndexPage, OrganizationPage +from b2b.models import ( + ContractPage, + OrganizationIndexPage, + OrganizationPage, + UserOrganization, +) from cms.api import get_home_page from courses.constants import UAI_COURSEWARE_ID_PREFIX from courses.models import Course, CourseRun @@ -596,7 +601,10 @@ def ensure_enrollment_codes_exist(contract: ContractPage): """ log.info("Checking enrollment codes for contract %s", contract) - if contract.integration_type == "sso" and not contract.enrollment_fixed_price: + if ( + contract.integration_type in CONTRACT_MEMBERSHIP_AUTOS + or contract.membership_type in CONTRACT_MEMBERSHIP_AUTOS + ) and not contract.enrollment_fixed_price: # SSO contracts w/out price don't need discounts. return _handle_sso_free_contract(contract) @@ -762,19 +770,30 @@ def create_b2b_enrollment(request, product: Product): def reconcile_user_orgs(user, organizations): """ - Reconcile the specified users with the provided organization list. - - When we get a list of organizations from an authoritative source, we need to - be able to parse that list and make sure the user's org attachments match. - This will pull the contracts that the user belongs to that are also - SSO-enabled, and will remove the user from the contract if they're not - supposed to be in them. It will also add the user to any SSO-enabled contract - that the org has. - - This only considers contracts that are SSO-enabled and zero-cost. If the - contract is seat limited, we will only add the user if there's room. - (If there isn't, we will log an error.) Only SSO-enabled contracts are - considered; any that the user is in that aren't SSO-enabled will be left alone. + Reconcile the specified user with the provided organization list. + + When we get a list of organizations from a source (so, in the user payload + from APISIX) for a particular user, we need to ensure that the user's + organization membership in MITx Online matches up with what we're given. In + addition, once we've done that, we need to ensure they're also in the + contracts that are marked as "managed". If the user is in an organization + that isn't in the list we've received, we need to remove them; in addition, + they should be removed from any "managed" contracts for the org they're in + as well. + + There is a special case where the user may be in an organization that isn't + represented in the payload we're given. This happens when the user uses an + enrollment code. We update the org membership here and in Keycloak, but the + payload from APISIX won't include their updated org membership until their + APISIX session expires. We have a flag on the many-to-many table that + indicates that we should leave those memberships alone - otherwise, we'll + inadvertently add them to the org and then immediately remove them. (Once + the org _does_ show up in the list, we should clear the flag.) + + We cache the user's org membership in redis to save some hits to the + database. This gets hit on every authenticated request, so probably good to + try to keep the query count low. The cache is a list of tuples of (org_uuid, + not_expected_in_payload). If the user is enrolled in any courses that are in a contract they'll be removed from, they will be left there. Not real sure what we should do in @@ -791,61 +810,70 @@ def reconcile_user_orgs(user, organizations): user_org_cache_key = f"org-membership-cache-{user.id}" cached_org_membership = caches["redis"].get(user_org_cache_key, False) - if cached_org_membership and sorted(cached_org_membership) == sorted(organizations): - log.info("reconcile_user_orgs: skipping reconcilation for %s", user.id) - return ( - 0, - 0, - ) + if cached_org_membership: + cached_expected_org_membership = [ + str(org_id) + for org_id, not_expected_in_payload in cached_org_membership + if not_expected_in_payload + ] + + if sorted(cached_expected_org_membership) == sorted(organizations): + log.info( + "reconcile_user_orgs: everything OK, skipping reconcilation for %s", + user.id, + ) + return ( + 0, + 0, + ) log.info("reconcile_user_orgs: running reconcilation for %s", user.id) - user_contracts_qs = user.b2b_contracts.filter( - integration_type=CONTRACT_INTEGRATION_SSO - ) + # we've checked the cached org membership, so now figure out what orgs + # we're in but aren't in the list, and vice versa - if len(organizations) == 0: - # User has no orgs, so we should clear them from all SSO contracts. - contracts_to_remove = user_contracts_qs.all() - [user.b2b_contracts.remove(contract) for contract in contracts_to_remove] - user.save() - return (0, len(contracts_to_remove)) + orgs_to_add = OrganizationPage.objects.filter( + Q(sso_organization_id__in=organizations) & ~Q(organization_users__user=user) + ).filter(sso_organization_id__isnull=False) - orgs = OrganizationPage.objects.filter(sso_organization_id__in=organizations).all() - no_orgs = OrganizationPage.objects.exclude( - sso_organization_id__in=organizations - ).all() + orgs_to_remove = UserOrganization.objects.filter( + ~Q(organization__sso_organization_id__in=organizations) + & Q(user=user, keep_until_seen=False) + ).filter(organization__sso_organization_id__isnull=False) - contracts_to_remove = user_contracts_qs.filter(organization__in=no_orgs).all() + for add_org in orgs_to_add: + # add org, add contracts, clear flag if we need to + UserOrganization.objects.update_or_create( + user=user, + organization=add_org, + defaults={"keep_until_seen": False}, + ) - if contracts_to_remove.count() > 0: - [ - user.b2b_contracts.remove(contract_to_remove) - for contract_to_remove in contracts_to_remove - ] + add_org.add_user_contracts(user) + log.info("reconcile_user_orgs: added user %s to org %s", user.id, add_org) - contracts_to_add = ( - ContractPage.objects.filter( - integration_type=CONTRACT_INTEGRATION_SSO, organization__in=orgs + for remove_org in orgs_to_remove: + # remove org, remove contracts + remove_org.organization.remove_user_contracts(user) + log.info( + "reconcile_user_orgs: removed user %s from org %s", user.id, remove_org ) - .exclude(pk__in=user_contracts_qs.all().values_list("id", flat=True)) - .all() - ) - - if contracts_to_add.count() > 0: - [ - user.b2b_contracts.add(contract_to_add) - for contract_to_add in contracts_to_add - ] + remove_org.delete() - user.save() user.refresh_from_db() - orgs = [str(org_id) for org_id in user.b2b_organization_sso_ids] + orgs = [ + (str(org.organization.sso_organization_id), not org.keep_until_seen) + for org in user.user_organizations.all() + ] + + user.user_organizations.filter( + organization__sso_organization_id__in=organizations, keep_until_seen=True + ).update(keep_until_seen=False) user_org_cache_key = f"org-membership-cache-{user.id}" caches["redis"].set(user_org_cache_key, sorted(orgs)) - return (len(contracts_to_add), len(contracts_to_remove)) + return (len(orgs_to_add), len(orgs_to_remove)) def reconcile_single_keycloak_org(keycloak_org: OrganizationRepresentation): @@ -925,3 +953,83 @@ def reconcile_keycloak_orgs(): ) return (created_count, updated_count) + + +def add_user_org_membership(org, user): + """ + Add a given user to a Keycloak organization. + + If we're adding a user to a contract, and they're not in that contract's + organization, we need to do that and update Keycloak as well. Since the user + won't have the org in their user data list initially, we'll also need to + flag the membership so we don't remove it immediately later in the + middleware. + + Args: + - org (OrganizationPage): The organization to add the user to. + - user (User): The user to add to the organization. + Returns: + - bool: True if the user was added, False otherwise. + """ + + org_model = get_keycloak_model(OrganizationRepresentation, "organizations") + + kc_org = org_model.get(org.sso_organization_id) + + if not kc_org: + log.warning("No Keycloak organization found for %s", org.sso_organization_id) + return False + + return org_model.associate("members", org.sso_organization_id, user.global_id) + + +def process_add_org_membership(user, organization, *, keep_until_seen=False): + """ + Add a user to an org, and kick off contract processing. + + This allows us to manage UserOrganization records without necessarily + being forced to process contract memberships at the same time. + + Args: + - user (users.models.User): the user to add + - organization (b2b.models.OrganizationPage): the organization to add the user to + - keep_until_seen (bool): if True, the user will be kept in the org until the + organization is seen in their SSO data. + """ + + obj, created = UserOrganization.objects.get_or_create( + user=user, + organization=organization, + ) + if created: + obj.keep_until_seen = keep_until_seen + obj.save() + try: + organization.attach_user(user) + except ConnectionError: + log.exception( + "Could not attach %s to Keycloak org for %s", user, organization + ) + organization.add_user_contracts(user) + + return obj + + +def process_remove_org_membership(user, organization): + """ + Remove a user from an org, and kick off contract processing. + + Other side of the process_add_org_membership function - removes the membership + and associated managed contracts. + + Args: + - user (users.models.User): the user to remove + - organization (b2b.models.OrganizationPage): the organization to remove the user from + """ + + organization.remove_user_contracts(user) + + UserOrganization.objects.filter( + user=user, + organization=organization, + ).get().delete() diff --git a/b2b/api_test.py b/b2b/api_test.py index 55c87f31f7..b3112e98ab 100644 --- a/b2b/api_test.py +++ b/b2b/api_test.py @@ -18,6 +18,8 @@ create_contract_run, ensure_enrollment_codes_exist, get_active_contracts_from_basket_items, + process_add_org_membership, + process_remove_org_membership, reconcile_keycloak_orgs, reconcile_single_keycloak_org, reconcile_user_orgs, @@ -25,13 +27,18 @@ ) from b2b.constants import ( B2B_RUN_TAG_FORMAT, - CONTRACT_INTEGRATION_NONSSO, - CONTRACT_INTEGRATION_SSO, + CONTRACT_MEMBERSHIP_NONSSO, + CONTRACT_MEMBERSHIP_SSO, ) from b2b.exceptions import SourceCourseIncompleteError, TargetCourseRunExistsError -from b2b.models import OrganizationIndexPage, OrganizationPage +from b2b.factories import ContractPageFactory, OrganizationPageFactory +from b2b.models import OrganizationIndexPage, OrganizationPage, UserOrganization from courses.constants import UAI_COURSEWARE_ID_PREFIX -from courses.factories import CourseFactory, CourseRunFactory +from courses.factories import ( + CourseFactory, + CourseRunEnrollmentFactory, + CourseRunFactory, +) from courses.models import CourseRunEnrollment from ecommerce.api_test import create_basket from ecommerce.constants import REDEMPTION_TYPE_ONE_TIME, REDEMPTION_TYPE_UNLIMITED @@ -59,6 +66,13 @@ ] +@pytest.fixture +def mocked_b2b_org_attach(mocker): + """Mock the org attachment call.""" + + return mocker.patch("b2b.api.add_user_org_membership", return_value=True) + + @pytest.mark.parametrize( ( "has_start", @@ -260,9 +274,12 @@ def test_ensure_enrollment_codes( # noqa: PLR0913 assert_price = price if price else Decimal(0) contract = factories.ContractPageFactory( - integration_type=CONTRACT_INTEGRATION_SSO + integration_type=CONTRACT_MEMBERSHIP_SSO if is_sso - else CONTRACT_INTEGRATION_NONSSO, + else CONTRACT_MEMBERSHIP_NONSSO, + membership_type=CONTRACT_MEMBERSHIP_SSO + if is_sso + else CONTRACT_MEMBERSHIP_NONSSO, enrollment_fixed_price=price, max_learners=max_learners, ) @@ -305,7 +322,10 @@ def test_ensure_enrollment_codes( # noqa: PLR0913 contract.enrolment_fixed_price = None if update_sso: contract.integration_type = ( - CONTRACT_INTEGRATION_NONSSO if is_sso else CONTRACT_INTEGRATION_SSO + CONTRACT_MEMBERSHIP_NONSSO if is_sso else CONTRACT_MEMBERSHIP_SSO + ) + contract.membership_type = ( + CONTRACT_MEMBERSHIP_NONSSO if is_sso else CONTRACT_MEMBERSHIP_SSO ) contract.save() @@ -361,7 +381,8 @@ def test_create_b2b_enrollment( # noqa: PLR0913, C901, PLR0915 settings.OPENEDX_SERVICE_WORKER_USERNAME = "a username" contract = factories.ContractPageFactory.create( - integration_type=CONTRACT_INTEGRATION_SSO, + integration_type=CONTRACT_MEMBERSHIP_SSO, + membership_type=CONTRACT_MEMBERSHIP_SSO, enrollment_fixed_price=Decimal(0) if price_is_zero else FAKE.pydecimal(left_digits=2, right_digits=2, positive=True), @@ -520,44 +541,118 @@ def test_create_contract_run(mocker, source_run_exists, run_exists): mocked_clone_run.assert_called() -def test_b2b_reconcile_user_orgs(): +def test_b2b_reconcile_user_orgs(): # noqa: PLR0915 """Test that we can get a list of B2B orgs from somewhere and fix a user's associations.""" - contracts = factories.ContractPageFactory.create_batch( - 2, integration_type=CONTRACT_INTEGRATION_NONSSO - ) - sso_contracts = factories.ContractPageFactory.create_batch( - 2, integration_type=CONTRACT_INTEGRATION_SSO - ) user = UserFactory.create() + organization_to_add = OrganizationPageFactory.create() + organization_to_ignore = OrganizationPageFactory.create() + organization_to_remove = OrganizationPageFactory.create() + weird_organization = OrganizationPageFactory.create(sso_organization_id=None) assert user.b2b_contracts.count() == 0 + assert user.b2b_organizations.count() == 0 - user.b2b_contracts.add(contracts[0]) - user.b2b_contracts.add(contracts[1]) - user.b2b_contracts.add(sso_contracts[0]) - user.save() + # Step 1: pass in an org to a user that's not in anything + # We should get back one addition, which is the org we're adding - assert user.b2b_contracts.count() == 3 + added, removed = reconcile_user_orgs( + user, [organization_to_add.sso_organization_id] + ) - sso_required_org = sso_contracts[1].organization.sso_organization_id + assert added == 1 + assert removed == 0 - added, removed = reconcile_user_orgs(user, [sso_required_org]) + user.refresh_from_db() + assert user.b2b_organizations.count() == 1 + assert user.b2b_organizations.filter(pk=organization_to_add.id).exists() + assert not user.b2b_organizations.filter(pk=organization_to_ignore.id).exists() + assert not user.b2b_organizations.filter(pk=organization_to_remove.id).exists() - assert added == 1 + # Step 2: Add an org through a back channel, and then reconcile + # The org should be removed + + UserOrganization.objects.create( + user=user, organization=organization_to_remove, keep_until_seen=False + ) + + assert user.b2b_organizations.count() == 2 + + added, removed = reconcile_user_orgs( + user, [organization_to_add.sso_organization_id] + ) + + assert added == 0 assert removed == 1 user.refresh_from_db() - assert user.b2b_contracts.count() == 3 - assert ( - user.b2b_contracts.filter( - organization__sso_organization_id=sso_contracts[ - 0 - ].organization.sso_organization_id - ).count() - == 0 + assert user.b2b_organizations.count() == 1 + assert user.b2b_organizations.filter(pk=organization_to_add.id).exists() + assert not user.b2b_organizations.filter(pk=organization_to_ignore.id).exists() + assert not user.b2b_organizations.filter(pk=organization_to_remove.id).exists() + + # Step 3: Add the remove org, but set the flag so it should be kept now. + + UserOrganization.objects.create( + user=user, organization=organization_to_remove, keep_until_seen=True + ) + + assert user.b2b_organizations.count() == 2 + + added, removed = reconcile_user_orgs( + user, [organization_to_add.sso_organization_id] + ) + + assert added == 0 + assert removed == 0 + + user.refresh_from_db() + assert user.b2b_organizations.count() == 2 + assert user.b2b_organizations.filter(pk=organization_to_add.id).exists() + assert not user.b2b_organizations.filter(pk=organization_to_ignore.id).exists() + assert user.b2b_organizations.filter(pk=organization_to_remove.id).exists() + + # Step 3.5: now reconcile with the remove org, we should clear the flag + + added, removed = reconcile_user_orgs( + user, + [ + organization_to_add.sso_organization_id, + organization_to_remove.sso_organization_id, + ], + ) + + assert added == 0 + assert removed == 0 + + user.refresh_from_db() + assert user.b2b_organizations.count() == 2 + assert user.b2b_organizations.filter(pk=organization_to_add.id).exists() + assert not user.b2b_organizations.filter(pk=organization_to_ignore.id).exists() + assert user.b2b_organizations.filter(pk=organization_to_remove.id).exists() + + # Step 4: add the weird org that doesn't have a UUID + # Legacy non-manged orgs won't have a UUID, so we should leave them alone + + UserOrganization.objects.create( + user=user, organization=weird_organization, keep_until_seen=False + ) + + added, removed = reconcile_user_orgs( + user, + [ + organization_to_add.sso_organization_id, + organization_to_remove.sso_organization_id, + ], ) + assert added == 0 + assert removed == 0 + + user.refresh_from_db() + assert user.b2b_organizations.count() == 3 + assert user.b2b_organizations.filter(pk=weird_organization.id).exists() + @pytest.mark.parametrize( "update_an_org", @@ -650,3 +745,196 @@ def test_reconcile_bad_keycloak_org(mocker): page.save() assert "Organization with this Org key already exists." in str(exc) + + +def test_user_add_b2b_org(mocked_b2b_org_attach): + """Ensure adding a user to an organization works as expected.""" + + orgs = OrganizationPageFactory.create_batch(2) + user = UserFactory.create() + + # New-style ones + contract_auto = ContractPageFactory.create( + organization=orgs[0], + membership_type="auto", + integration_type="auto", + title="Contract Auto", + name="Contract Auto", + ) + contract_managed = ContractPageFactory.create( + organization=orgs[0], + membership_type="managed", + integration_type="managed", + title="Contract Managed", + name="Contract Managed", + ) + contract_code = ContractPageFactory.create( + organization=orgs[0], + membership_type="code", + integration_type="code", + title="Contract Enrollment Code", + name="Contract Enrollment Code", + ) + # Legacy ones - these will migrate to "managed" and "code" + contract_sso = ContractPageFactory.create( + organization=orgs[0], + membership_type="sso", + integration_type="sso", + title="Contract SSO", + name="Contract SSO", + ) + contract_non_sso = ContractPageFactory.create( + organization=orgs[0], + membership_type="non-sso", + integration_type="non-sso", + title="Contract NonSSO", + name="Contract NonSSO", + ) + + process_add_org_membership(user, orgs[0]) + + # We should now be in the SSO, auto, and managed contracts + # but not the other two. + + user.refresh_from_db() + assert user.b2b_contracts.count() == 3 + assert user.b2b_organizations.filter(pk=orgs[0].id).exists() + assert ( + user.b2b_contracts.filter( + pk__in=[ + contract_auto.id, + contract_sso.id, + contract_managed.id, + ] + ).count() + == 3 + ) + assert ( + user.b2b_contracts.filter( + pk__in=[ + contract_code.id, + contract_non_sso.id, + ] + ).count() + == 0 + ) + + +def test_user_remove_b2b_org(mocked_b2b_org_attach): + """Ensure removing a user from an org also clears the appropriate contracts.""" + + orgs = OrganizationPageFactory.create_batch(2) + user = UserFactory.create() + + # New-style ones + contract_auto = ContractPageFactory.create( + organization=orgs[0], + membership_type="auto", + integration_type="auto", + title="Contract Auto", + name="Contract Auto", + ) + contract_managed = ContractPageFactory.create( + organization=orgs[0], + membership_type="managed", + integration_type="managed", + title="Contract Managed", + name="Contract Managed", + ) + contract_code = ContractPageFactory.create( + organization=orgs[1], + membership_type="code", + integration_type="code", + title="Contract Enrollment Code", + name="Contract Enrollment Code", + ) + # Legacy ones - these will migrate to "managed" and "code" + contract_sso = ContractPageFactory.create( + organization=orgs[0], + membership_type="sso", + integration_type="sso", + title="Contract SSO", + name="Contract SSO", + ) + contract_non_sso = ContractPageFactory.create( + organization=orgs[1], + membership_type="non-sso", + integration_type="non-sso", + title="Contract NonSSO", + name="Contract NonSSO", + ) + + managed_ids = [ + contract_auto.id, + contract_managed.id, + contract_sso.id, + ] + unmanaged_ids = [ + contract_code.id, + contract_non_sso.id, + ] + + process_add_org_membership(user, orgs[0]) + process_add_org_membership(user, orgs[1]) + + user.b2b_contracts.add(contract_code) + user.b2b_contracts.add(contract_non_sso) + user.save() + + user.refresh_from_db() + + assert user.b2b_contracts.count() == 5 + + process_remove_org_membership(user, orgs[1]) + + assert user.b2b_contracts.filter(id__in=managed_ids).count() == 3 + assert user.b2b_contracts.filter(id__in=unmanaged_ids).count() == 0 + + process_remove_org_membership(user, orgs[0]) + + # we should have no contracts now since we're no longer in any orgs + + assert user.b2b_contracts.count() == 0 + + +def test_b2b_contract_removal_keeps_enrollments(mocked_b2b_org_attach): + """Ensure that removing a user from a B2B contract leaves their enrollments alone.""" + + org = OrganizationPageFactory.create() + user = UserFactory.create() + + contract_auto = ContractPageFactory.create( + organization=org, + membership_type="auto", + integration_type="auto", + title="Contract Auto", + name="Contract Auto", + ) + + courserun = CourseRunFactory.create(b2b_contract=contract_auto) + + process_add_org_membership(user, org) + + CourseRunEnrollmentFactory( + user=user, + run=courserun, + ) + + user.refresh_from_db() + + assert courserun.enrollments.filter(user=user).count() == 1 + + process_remove_org_membership(user, org) + + assert courserun.enrollments.filter(user=user).count() == 1 + + +def test_b2b_org_attach_calls_keycloak(mocked_b2b_org_attach): + """Test that attaching a user to an org calls Keycloak successfully.""" + + org = OrganizationPageFactory.create() + user = UserFactory.create() + + process_add_org_membership(user, org) + + mocked_b2b_org_attach.assert_called() diff --git a/b2b/constants.py b/b2b/constants.py index 92871c7993..02c2878d65 100644 --- a/b2b/constants.py +++ b/b2b/constants.py @@ -2,14 +2,59 @@ ORG_INDEX_SLUG = "organizations" -CONTRACT_INTEGRATION_SSO = "sso" -CONTRACT_INTEGRATION_SSO_NAME = "SSO" -CONTRACT_INTEGRATION_NONSSO = "non-sso" -CONTRACT_INTEGRATION_NONSSO_NAME = "Non-SSO" +# Old values which will be removed in a future PR +CONTRACT_MEMBERSHIP_SSO = "sso" +CONTRACT_MEMBERSHIP_SSO_NAME = "SSO" +CONTRACT_MEMBERSHIP_NONSSO = "non-sso" +CONTRACT_MEMBERSHIP_NONSSO_NAME = "Non-SSO" +# New values +CONTRACT_MEMBERSHIP_MANAGED = "managed" +CONTRACT_MEMBERSHIP_MANAGED_NAME = "Managed" +CONTRACT_MEMBERSHIP_CODE = "code" +CONTRACT_MEMBERSHIP_CODE_NAME = "Enrollment Code" +CONTRACT_MEMBERSHIP_AUTO = "auto" +CONTRACT_MEMBERSHIP_AUTO_NAME = "Auto Enrollment" -CONTRACT_INTEGRATION_CHOICES = zip( - [CONTRACT_INTEGRATION_SSO, CONTRACT_INTEGRATION_NONSSO], - [CONTRACT_INTEGRATION_SSO_NAME, CONTRACT_INTEGRATION_NONSSO_NAME], +CONTRACT_MEMBERSHIP_AUTOS = [ + CONTRACT_MEMBERSHIP_AUTO, + CONTRACT_MEMBERSHIP_MANAGED, + CONTRACT_MEMBERSHIP_SSO, +] + +CONTRACT_MEMBERSHIP_CHOICES = zip( + [ + CONTRACT_MEMBERSHIP_SSO, + CONTRACT_MEMBERSHIP_NONSSO, + CONTRACT_MEMBERSHIP_MANAGED, + CONTRACT_MEMBERSHIP_CODE, + CONTRACT_MEMBERSHIP_AUTO, + ], + [ + CONTRACT_MEMBERSHIP_SSO_NAME, + CONTRACT_MEMBERSHIP_NONSSO_NAME, + CONTRACT_MEMBERSHIP_MANAGED_NAME, + CONTRACT_MEMBERSHIP_CODE_NAME, + CONTRACT_MEMBERSHIP_AUTO_NAME, + ], +) +# When the integration_type field is removed, this should be removed as well. +# It is just here so we can have the same choices in both integration_type and +# membership_type fields. (Using the constant for `choices` consumes it.) +CONTRACT_MEMBERSHIP_TYPE_CHOICES = zip( + [ + CONTRACT_MEMBERSHIP_SSO, + CONTRACT_MEMBERSHIP_NONSSO, + CONTRACT_MEMBERSHIP_MANAGED, + CONTRACT_MEMBERSHIP_CODE, + CONTRACT_MEMBERSHIP_AUTO, + ], + [ + CONTRACT_MEMBERSHIP_SSO_NAME, + CONTRACT_MEMBERSHIP_NONSSO_NAME, + CONTRACT_MEMBERSHIP_MANAGED_NAME, + CONTRACT_MEMBERSHIP_CODE_NAME, + CONTRACT_MEMBERSHIP_AUTO_NAME, + ], ) B2B_RUN_TAG_FORMAT = "{year}_C{contract_id}" diff --git a/b2b/factories.py b/b2b/factories.py index 9541639127..811d42e264 100644 --- a/b2b/factories.py +++ b/b2b/factories.py @@ -6,7 +6,7 @@ import wagtail_factories from factory import Factory, LazyAttribute, LazyFunction, SubFactory -from b2b.constants import CONTRACT_INTEGRATION_NONSSO, CONTRACT_INTEGRATION_SSO +from b2b.constants import CONTRACT_MEMBERSHIP_NONSSO, CONTRACT_MEMBERSHIP_SSO from b2b.keycloak_admin_dataclasses import ( OrganizationRepresentation, RealmRepresentation, @@ -56,11 +56,11 @@ class ContractPageFactory(wagtail_factories.PageFactory): organization = SubFactory(OrganizationPageFactory) parent = LazyAttribute(lambda o: o.organization) integration_type = LazyFunction( - lambda: CONTRACT_INTEGRATION_NONSSO + lambda: CONTRACT_MEMBERSHIP_NONSSO if FAKE.boolean() - else CONTRACT_INTEGRATION_SSO + else CONTRACT_MEMBERSHIP_SSO ) - slug = LazyAttribute(lambda _: FAKE.unique.slug()) + membership_type = LazyAttribute(lambda o: o.integration_type) class Meta: model = ContractPage diff --git a/b2b/management/commands/b2b_contract.py b/b2b/management/commands/b2b_contract.py index 31db33f073..96c8991b59 100644 --- a/b2b/management/commands/b2b_contract.py +++ b/b2b/management/commands/b2b_contract.py @@ -6,7 +6,9 @@ from django.core.management import BaseCommand, CommandError from django.db.models import Q -from b2b.constants import CONTRACT_INTEGRATION_NONSSO, CONTRACT_INTEGRATION_SSO +from b2b.constants import ( + CONTRACT_MEMBERSHIP_CHOICES, +) from b2b.models import ContractPage, OrganizationIndexPage, OrganizationPage log = logging.getLogger(__name__) @@ -43,12 +45,8 @@ def add_arguments(self, parser): create_parser.add_argument( "integration_type", type=str, - help="The type of integration for this contract.", - choices=[ - CONTRACT_INTEGRATION_SSO, - CONTRACT_INTEGRATION_NONSSO, - ], - default=CONTRACT_INTEGRATION_NONSSO, + help="The membership type for this contract.", + choices=[value[0] for value in CONTRACT_MEMBERSHIP_CHOICES], ) create_parser.add_argument( "--description", @@ -197,6 +195,7 @@ def handle_create(self, *args, **kwargs): # noqa: ARG002 name=contract_name, description=description or "", integration_type=integration_type, + membership_type=integration_type, organization=org, contract_start=start_date, contract_end=end_date, diff --git a/b2b/migrations/0009_contractpage_membership_type_and_more.py b/b2b/migrations/0009_contractpage_membership_type_and_more.py new file mode 100644 index 0000000000..a0ea3a074b --- /dev/null +++ b/b2b/migrations/0009_contractpage_membership_type_and_more.py @@ -0,0 +1,96 @@ +# Generated by Django 4.2.25 on 2025-10-08 19:32 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("b2b", "0008_increase_length_on_org_key_field"), + ] + + operations = [ + migrations.AddField( + model_name="contractpage", + name="membership_type", + field=models.CharField( + choices=[ + ("sso", "SSO"), + ("non-sso", "Non-SSO"), + ("managed", "Managed"), + ("code", "Enrollment Code"), + ("auto", "Auto Enrollment"), + ], + default="managed", + help_text="The method to use to manage membership in the contract.", + max_length=255, + ), + ), + migrations.AlterField( + model_name="contractpage", + name="integration_type", + field=models.CharField( + choices=[ + ("sso", "SSO"), + ("non-sso", "Non-SSO"), + ("managed", "Managed"), + ("code", "Enrollment Code"), + ("auto", "Auto Enrollment"), + ], + help_text="The type of integration for this contract.", + max_length=255, + ), + ), + migrations.AlterField( + model_name="contractpage", + name="organization", + field=models.ForeignKey( + help_text="The organization that owns this contract.", + on_delete=django.db.models.deletion.PROTECT, + related_name="contracts", + to="b2b.organizationpage", + ), + ), + migrations.CreateModel( + name="UserOrganization", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "keep_until_seen", + models.BooleanField( + default=False, + help_text="If True, the user will be kept in the organization until the organization is seen in their SSO data.", + ), + ), + ( + "organization", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="organization_users", + to="b2b.organizationpage", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="user_organizations", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "unique_together": {("user", "organization")}, + }, + ), + ] diff --git a/b2b/models.py b/b2b/models.py index 01449ef64b..e1846f7206 100644 --- a/b2b/models.py +++ b/b2b/models.py @@ -1,5 +1,7 @@ """Models for B2B data.""" +import logging + from django.conf import settings from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType @@ -12,10 +14,18 @@ from wagtail.fields import RichTextField from wagtail.models import Page -from b2b.constants import CONTRACT_INTEGRATION_CHOICES, ORG_INDEX_SLUG +from b2b.constants import ( + CONTRACT_MEMBERSHIP_AUTOS, + CONTRACT_MEMBERSHIP_CHOICES, + CONTRACT_MEMBERSHIP_MANAGED, + CONTRACT_MEMBERSHIP_TYPE_CHOICES, + ORG_INDEX_SLUG, +) from b2b.exceptions import TargetCourseRunExistsError from b2b.tasks import queue_enrollment_code_check +log = logging.getLogger(__name__) + class OrganizationObjectIndexPage(Page): """The index page for organizations - provides the root level folder.""" @@ -109,6 +119,53 @@ def get_learners(self): .distinct() ) + def attach_user(self, user): + """ + Attach the given user to the org in Keycloak. + + Args: + - user (User): the user to add to the org + Returns: + - bool: success flag + """ + + from b2b.api import add_user_org_membership + + return add_user_org_membership(self, user) + + def add_user_contracts(self, user): + """ + Add contracts that the user should get automatically to the user. + + Args: + - user (User): the user to add contracts to + Returns: + - int: number of contracts added + """ + + contracts_qs = self.contracts.filter( + integration_type__in=CONTRACT_MEMBERSHIP_AUTOS, active=True + ) + + for contract in contracts_qs.all(): + user.b2b_contracts.add(contract) + + return contracts_qs.count() + + def remove_user_contracts(self, user): + """ + Remove managed contracts from the given user. + + Args: + - user (User): the user to remove contracts from + Returns: + - int: number of contracts removed + """ + + return user.b2b_contracts.through.objects.filter( + contractpage_id__in=self.contracts.all().values_list("id", flat=True) + ).delete() + def __str__(self): """Return a reasonable representation of the org as a string.""" @@ -137,14 +194,22 @@ class ContractPage(Page): ) integration_type = models.CharField( max_length=255, - choices=CONTRACT_INTEGRATION_CHOICES, + choices=CONTRACT_MEMBERSHIP_CHOICES, help_text="The type of integration for this contract.", ) + # This doesn't have a choices setting because you can't re-use a constant. + # + membership_type = models.CharField( + max_length=255, + choices=CONTRACT_MEMBERSHIP_TYPE_CHOICES, + help_text="The method to use to manage membership in the contract.", + default=CONTRACT_MEMBERSHIP_MANAGED, + ) organization = models.ForeignKey( OrganizationPage, on_delete=models.PROTECT, related_name="contracts", - help_text="The organization this contract is with.", + help_text="The organization that owns this contract.", ) contract_start = models.DateField( blank=True, @@ -191,6 +256,7 @@ class ContractPage(Page): MultiFieldPanel( [ FieldPanel("integration_type"), + FieldPanel("membership_type"), FieldPanel("max_learners"), FieldPanel("enrollment_fixed_price"), ], @@ -231,7 +297,7 @@ def save(self, clean=True, user=None, log_action=False, **kwargs): # noqa: FBT0 self.title = str(self.name) - self.slug = slugify(f"contract-{self.organization.id}-{self.id}") + self.slug = slugify(f"contract-{self.organization.id}-{self.title}") Page.save(self, clean=clean, user=user, log_action=log_action, **kwargs) queue_enrollment_code_check.delay(self.id) @@ -353,3 +419,30 @@ class DiscountContractAttachmentRedemption(TimestampedModel): on_delete=models.DO_NOTHING, help_text="The contract that the user was attached to.", ) + + +class UserOrganization(models.Model): + """The user's organizations memberships.""" + + user = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + related_name="user_organizations", + ) + organization = models.ForeignKey( + "b2b.OrganizationPage", + on_delete=models.CASCADE, + related_name="organization_users", + ) + keep_until_seen = models.BooleanField( + default=False, + help_text="If True, the user will be kept in the organization until the organization is seen in their SSO data.", + ) + + class Meta: + unique_together = ("user", "organization") + + def __str__(self): + """Return a reasonable representation of the object as a string.""" + + return f"UserOrganization: {self.user} in {self.organization}" diff --git a/b2b/models_test.py b/b2b/models_test.py index 8501c10734..265962ba03 100644 --- a/b2b/models_test.py +++ b/b2b/models_test.py @@ -4,7 +4,10 @@ import pytest from b2b.factories import ContractPageFactory -from courses.factories import CourseRunFactory, ProgramFactory +from courses.factories import ( + CourseRunFactory, + ProgramFactory, +) pytestmark = [pytest.mark.django_db] FAKE = faker.Faker() diff --git a/b2b/serializers/v0/__init__.py b/b2b/serializers/v0/__init__.py index e1b68c11d7..232f788533 100644 --- a/b2b/serializers/v0/__init__.py +++ b/b2b/serializers/v0/__init__.py @@ -1,8 +1,10 @@ """Serializers for the B2B API (v0).""" +from drf_spectacular.utils import extend_schema_field from rest_framework import serializers -from b2b.models import ContractPage, OrganizationPage +from b2b.models import ContractPage, OrganizationPage, UserOrganization +from cms.api import get_wagtail_img_src from main.constants import USER_MSG_TYPE_B2B_CHOICES @@ -11,6 +13,8 @@ class ContractPageSerializer(serializers.ModelSerializer): Serializer for the ContractPage model. """ + membership_type = serializers.CharField() + class Meta: model = ContractPage fields = [ @@ -18,6 +22,7 @@ class Meta: "name", "description", "integration_type", + "membership_type", "organization", "contract_start", "contract_end", @@ -30,6 +35,7 @@ class Meta: "name", "description", "integration_type", + "membership_type", "organization", "contract_start", "contract_end", @@ -109,3 +115,66 @@ class CreateB2BEnrollmentSerializer(serializers.Serializer): max_digits=None, decimal_places=2, read_only=True, required=False ) checkout_result = GenerateCheckoutPayloadSerializer(required=False) + + +class UserOrganizationSerializer(serializers.ModelSerializer): + """ + Serializer for user organization data. + + Return the user's organizations in a manner that makes them look like + OrganizationPage objects. (Previously, the user organizations were a queryset + of OrganizationPages that related to the user, but now we have a through + table.) + """ + + contracts = serializers.SerializerMethodField() + id = serializers.IntegerField(source="organization.id") + name = serializers.CharField(source="organization.name") + description = serializers.CharField(source="organization.description") + logo = serializers.SerializerMethodField() + slug = serializers.CharField(source="organization.slug") + + @extend_schema_field(ContractPageSerializer(many=True)) + def get_contracts(self, instance): + """Get the contracts for the organization for the user""" + contracts = ( + self.context["user"] + .b2b_contracts.filter( + organization=instance.organization, + ) + .all() + ) + return ContractPageSerializer(contracts, many=True).data + + @extend_schema_field(str) + def get_logo(self, instance): + """Get logo""" + + if hasattr(instance.organization, "logo"): + try: + return get_wagtail_img_src(instance.organization.logo) + except AttributeError: + pass + + return None + + class Meta: + """Meta opts for the serializer.""" + + model = UserOrganization + fields = [ + "id", + "name", + "description", + "logo", + "slug", + "contracts", + ] + read_only_fields = [ + "id", + "name", + "description", + "logo", + "slug", + "contracts", + ] diff --git a/b2b/views/v0/__init__.py b/b2b/views/v0/__init__.py index 62bd0d1b36..7289660b4b 100644 --- a/b2b/views/v0/__init__.py +++ b/b2b/views/v0/__init__.py @@ -11,7 +11,7 @@ from rest_framework.views import APIView from rest_framework_api_key.permissions import HasAPIKey -from b2b.api import create_b2b_enrollment +from b2b.api import create_b2b_enrollment, process_add_org_membership from b2b.models import ( ContractPage, DiscountContractAttachmentRedemption, @@ -141,6 +141,9 @@ def post(self, request, enrollment_code: str, format=None): # noqa: A002, ARG00 if contract.is_full(): continue + process_add_org_membership( + request.user, contract.organization, keep_until_seen=True + ) request.user.b2b_contracts.add(contract) DiscountContractAttachmentRedemption.objects.create( user=request.user, discount=code, contract=contract diff --git a/b2b/views/v0/views_test.py b/b2b/views/v0/views_test.py index 3d11079408..5c15d49cf0 100644 --- a/b2b/views/v0/views_test.py +++ b/b2b/views/v0/views_test.py @@ -11,7 +11,7 @@ from rest_framework.test import APIClient from b2b.api import create_contract_run, ensure_enrollment_codes_exist -from b2b.constants import CONTRACT_INTEGRATION_NONSSO, CONTRACT_INTEGRATION_SSO +from b2b.constants import CONTRACT_MEMBERSHIP_NONSSO, CONTRACT_MEMBERSHIP_SSO from b2b.factories import ContractPageFactory from b2b.models import DiscountContractAttachmentRedemption from courses.factories import CourseRunFactory @@ -39,11 +39,16 @@ def test_b2b_contract_attachment_bad_code(user): assert user.b2b_contracts.count() == 0 -def test_b2b_contract_attachment(user): +def test_b2b_contract_attachment(mocker, user): """Ensure a supplied code results in attachment for the user.""" + mocked_attach_user = mocker.patch( + "b2b.models.OrganizationPage.attach_user", return_value=True + ) + contract = ContractPageFactory.create( - integration_type=CONTRACT_INTEGRATION_NONSSO, + membership_type=CONTRACT_MEMBERSHIP_NONSSO, + integration_type=CONTRACT_MEMBERSHIP_NONSSO, max_learners=10, ) @@ -62,8 +67,10 @@ def test_b2b_contract_attachment(user): resp = client.post(url) assert resp.status_code == 200 + mocked_attach_user.assert_called() user.refresh_from_db() + assert user.b2b_organizations.filter(pk=contract.organization.id).exists() assert user.b2b_contracts.filter(pk=contract.id).exists() assert DiscountContractAttachmentRedemption.objects.filter( @@ -82,7 +89,8 @@ def test_b2b_contract_attachment_invalid_code_dates(user, bad_start_or_end): """Test that the attachment fails properly if the code has invalid dates.""" contract = ContractPageFactory.create( - integration_type=CONTRACT_INTEGRATION_NONSSO, + membership_type=CONTRACT_MEMBERSHIP_NONSSO, + integration_type=CONTRACT_MEMBERSHIP_NONSSO, max_learners=1, ) @@ -135,7 +143,8 @@ def test_b2b_contract_attachment_invalid_contract_dates(user, bad_start_or_end): """Test that the attachment fails properly if the contract has invalid dates.""" contract = ContractPageFactory.create( - integration_type=CONTRACT_INTEGRATION_NONSSO, + membership_type=CONTRACT_MEMBERSHIP_NONSSO, + integration_type=CONTRACT_MEMBERSHIP_NONSSO, max_learners=1, ) @@ -177,11 +186,16 @@ def test_b2b_contract_attachment_invalid_contract_dates(user, bad_start_or_end): ).exists() -def test_b2b_contract_attachment_full_contract(): +def test_b2b_contract_attachment_full_contract(mocker): """Test that the attachment fails properly if the contract is full.""" + mocked_attach_user = mocker.patch( + "b2b.models.OrganizationPage.attach_user", return_value=True + ) + contract = ContractPageFactory.create( - integration_type=CONTRACT_INTEGRATION_NONSSO, + membership_type=CONTRACT_MEMBERSHIP_NONSSO, + integration_type=CONTRACT_MEMBERSHIP_NONSSO, max_learners=1, ) @@ -201,6 +215,7 @@ def test_b2b_contract_attachment_full_contract(): resp = client.post(url) assert resp.status_code == 200 + mocked_attach_user.assert_called() user.refresh_from_db() assert user.b2b_contracts.filter(pk=contract.id).exists() @@ -209,14 +224,18 @@ def test_b2b_contract_attachment_full_contract(): client = APIClient() client.force_login(user) + mocked_attach_user.reset_mock() + url = reverse( "b2b:attach-user", kwargs={"enrollment_code": contract_code.discount_code} ) resp = client.post(url) assert resp.status_code == 200 + mocked_attach_user.assert_not_called() user.refresh_from_db() + assert not user.b2b_organizations.filter(pk=contract.organization.id).exists() assert not user.b2b_contracts.filter(pk=contract.id).exists() assert not DiscountContractAttachmentRedemption.objects.filter( contract=contract, user=user, discount=contract_code @@ -234,7 +253,8 @@ def test_b2b_enroll(mocker, settings, user_has_edx_user, has_price): settings.OPENEDX_SERVICE_WORKER_API_TOKEN = "a token" # noqa: S105 contract = ContractPageFactory.create( - integration_type=CONTRACT_INTEGRATION_SSO, + membership_type=CONTRACT_MEMBERSHIP_SSO, + integration_type=CONTRACT_MEMBERSHIP_SSO, enrollment_fixed_price=100 if has_price else 0, ) source_courserun = CourseRunFactory.create(is_source_run=True) diff --git a/courses/api.py b/courses/api.py index 56d7f664ce..b6c87780a9 100644 --- a/courses/api.py +++ b/courses/api.py @@ -28,6 +28,7 @@ from requests.exceptions import HTTPError from rest_framework.status import HTTP_404_NOT_FOUND +from b2b.api import process_add_org_membership from cms.api import create_default_courseware_page from courses import mail_api from courses.constants import ( @@ -214,8 +215,11 @@ def _enroll_learner_into_associated_programs(): _enroll_learner_into_associated_programs() # If the run is associated with a B2B contract, add the contract - # to the user's contract list + # to the user's contract list and update their org memberships if run.b2b_contract: + process_add_org_membership( + user, run.b2b_contract.organization, keep_until_seen=True + ) user.b2b_contracts.add(run.b2b_contract) user.save() diff --git a/courses/api_test.py b/courses/api_test.py index 7f4b3d9b1c..c4e5630edc 100644 --- a/courses/api_test.py +++ b/courses/api_test.py @@ -17,7 +17,7 @@ from reversion.models import Version from b2b.api import create_b2b_enrollment -from b2b.constants import CONTRACT_INTEGRATION_NONSSO +from b2b.constants import CONTRACT_MEMBERSHIP_NONSSO from b2b.factories import ( ContractPageFactory, OrganizationIndexPageFactory, @@ -1807,7 +1807,8 @@ def test_b2b_re_enrollment_after_multiple_unenrollments(mocker, user): contract = ContractPageFactory.create( organization=org, enrollment_fixed_price=Decimal("0.00"), - integration_type=CONTRACT_INTEGRATION_NONSSO, + membership_type=CONTRACT_MEMBERSHIP_NONSSO, + integration_type=CONTRACT_MEMBERSHIP_NONSSO, ) course_run = CourseRunFactory.create(b2b_contract=contract) with reversion.create_revision(): diff --git a/courses/views/v2/__init__.py b/courses/views/v2/__init__.py index d5d5ce90bd..4d5ce1c035 100644 --- a/courses/views/v2/__init__.py +++ b/courses/views/v2/__init__.py @@ -114,7 +114,7 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet): def get_queryset(self): return ( Program.objects.filter() - .select_related("page") + .select_related("page", "page__feature_image") .prefetch_related( Prefetch("departments", queryset=Department.objects.only("id", "name")), Prefetch( diff --git a/courses/views/v2/views_test.py b/courses/views/v2/views_test.py index 62f922492b..8debd61c31 100644 --- a/courses/views/v2/views_test.py +++ b/courses/views/v2/views_test.py @@ -310,7 +310,10 @@ def test_filter_with_org_id_returns_contracted_course( org = OrganizationPageFactory(name="Test Org") contract = ContractPageFactory(organization=org, active=True) user = UserFactory() + user.b2b_organizations.add(org) user.b2b_contracts.add(contract) + user.refresh_from_db() + (course, _) = contract_ready_course create_contract_run(contract, course) @@ -498,7 +501,10 @@ def test_next_run_id_with_org_filter( # noqa: PLR0915 contract = ContractPageFactory.create(organization=orgs[0]) second_contract = ContractPageFactory.create(organization=orgs[1]) test_user = UserFactory() + test_user.b2b_organizations.add(contract.organization) test_user.b2b_contracts.add(contract) + test_user.save() + test_user.refresh_from_db() auth_api_client = APIClient() auth_api_client.force_authenticate(user=test_user) @@ -637,6 +643,7 @@ def test_program_filter_for_b2b_org(user, mock_course_run_clone): contract.add_program_courses(b2b_program) contract.save() + user.b2b_organizations.add(org) user.b2b_contracts.add(contract) user.save() diff --git a/ecommerce/discounts_test.py b/ecommerce/discounts_test.py index dea01d1d20..04634fde62 100644 --- a/ecommerce/discounts_test.py +++ b/ecommerce/discounts_test.py @@ -3,6 +3,7 @@ import pytest +from ecommerce.constants import ALL_DISCOUNT_TYPES from ecommerce.discounts import ( DiscountType, DollarsOffDiscount, @@ -85,14 +86,17 @@ def test_discount_factory_adjustment(discounts, products): ) -def test_discounted_price(products): +@pytest.mark.parametrize("discount_type", ALL_DISCOUNT_TYPES) +def test_discounted_price(products, discount_type): """ Tests the get_discounted_price call with some products to make sure the discount is applied successfully. """ product = products[random.randrange(0, len(products), 1)] # noqa: S311 - applied_discounts = [UnlimitedUseDiscountFactory.create()] + applied_discounts = [ + UnlimitedUseDiscountFactory.create(discount_type=discount_type) + ] manually_discounted_prices = DiscountType.for_discount( applied_discounts[0] diff --git a/openapi/specs/v0.yaml b/openapi/specs/v0.yaml index 79bba4c18c..b9ffe8c8b4 100644 --- a/openapi/specs/v0.yaml +++ b/openapi/specs/v0.yaml @@ -1519,10 +1519,15 @@ components: * `sso` - SSO * `non-sso` - Non-SSO + * `managed` - Managed + * `code` - Enrollment Code + * `auto` - Auto Enrollment + membership_type: + type: string organization: type: integer readOnly: true - description: The organization this contract is with. + description: The organization that owns this contract. contract_start: type: string format: date @@ -1551,6 +1556,7 @@ components: - description - id - integration_type + - membership_type - name - organization - slug @@ -2538,13 +2544,22 @@ components: enum: - sso - non-sso + - managed + - code + - auto type: string description: |- * `sso` - SSO * `non-sso` - Non-SSO + * `managed` - Managed + * `code` - Enrollment Code + * `auto` - Auto Enrollment x-enum-descriptions: - SSO - Non-SSO + - Managed + - Enrollment Code + - Auto Enrollment LearnerProgramRecordShare: type: object properties: @@ -3583,7 +3598,7 @@ components: b2b_organizations: type: array items: - $ref: '#/components/schemas/UserOrganization' + $ref: '#/components/schemas/OrganizationPage' readOnly: true required: - b2b_organizations @@ -3597,48 +3612,6 @@ components: - is_superuser - legal_address - updated_on - UserOrganization: - type: object - description: |- - Serializer for user organization data. - - Slightly different from the OrganizationPageSerializer; we only need - the user's orgs and contracts. - properties: - id: - type: integer - readOnly: true - name: - type: string - readOnly: true - description: The name of the organization - description: - type: string - readOnly: true - description: Any useful extra information about the organization - logo: - type: string - format: uri - readOnly: true - description: The organization's logo. Will be displayed in the app in various - places. - slug: - type: string - readOnly: true - description: The name of the page as it will appear in URLs e.g http://domain.com/blog/[my-slug]/ - pattern: ^[-\w]+$ - contracts: - type: array - items: - $ref: '#/components/schemas/ContractPage' - readOnly: true - required: - - contracts - - description - - id - - logo - - name - - slug UserProfile: type: object description: Serializer for profile diff --git a/openapi/specs/v1.yaml b/openapi/specs/v1.yaml index 74da736ec6..e20a068305 100644 --- a/openapi/specs/v1.yaml +++ b/openapi/specs/v1.yaml @@ -1519,10 +1519,15 @@ components: * `sso` - SSO * `non-sso` - Non-SSO + * `managed` - Managed + * `code` - Enrollment Code + * `auto` - Auto Enrollment + membership_type: + type: string organization: type: integer readOnly: true - description: The organization this contract is with. + description: The organization that owns this contract. contract_start: type: string format: date @@ -1551,6 +1556,7 @@ components: - description - id - integration_type + - membership_type - name - organization - slug @@ -2538,13 +2544,22 @@ components: enum: - sso - non-sso + - managed + - code + - auto type: string description: |- * `sso` - SSO * `non-sso` - Non-SSO + * `managed` - Managed + * `code` - Enrollment Code + * `auto` - Auto Enrollment x-enum-descriptions: - SSO - Non-SSO + - Managed + - Enrollment Code + - Auto Enrollment LearnerProgramRecordShare: type: object properties: @@ -3583,7 +3598,7 @@ components: b2b_organizations: type: array items: - $ref: '#/components/schemas/UserOrganization' + $ref: '#/components/schemas/OrganizationPage' readOnly: true required: - b2b_organizations @@ -3597,48 +3612,6 @@ components: - is_superuser - legal_address - updated_on - UserOrganization: - type: object - description: |- - Serializer for user organization data. - - Slightly different from the OrganizationPageSerializer; we only need - the user's orgs and contracts. - properties: - id: - type: integer - readOnly: true - name: - type: string - readOnly: true - description: The name of the organization - description: - type: string - readOnly: true - description: Any useful extra information about the organization - logo: - type: string - format: uri - readOnly: true - description: The organization's logo. Will be displayed in the app in various - places. - slug: - type: string - readOnly: true - description: The name of the page as it will appear in URLs e.g http://domain.com/blog/[my-slug]/ - pattern: ^[-\w]+$ - contracts: - type: array - items: - $ref: '#/components/schemas/ContractPage' - readOnly: true - required: - - contracts - - description - - id - - logo - - name - - slug UserProfile: type: object description: Serializer for profile diff --git a/openapi/specs/v2.yaml b/openapi/specs/v2.yaml index 181b69c01d..eca38f4cf1 100644 --- a/openapi/specs/v2.yaml +++ b/openapi/specs/v2.yaml @@ -1519,10 +1519,15 @@ components: * `sso` - SSO * `non-sso` - Non-SSO + * `managed` - Managed + * `code` - Enrollment Code + * `auto` - Auto Enrollment + membership_type: + type: string organization: type: integer readOnly: true - description: The organization this contract is with. + description: The organization that owns this contract. contract_start: type: string format: date @@ -1551,6 +1556,7 @@ components: - description - id - integration_type + - membership_type - name - organization - slug @@ -2538,13 +2544,22 @@ components: enum: - sso - non-sso + - managed + - code + - auto type: string description: |- * `sso` - SSO * `non-sso` - Non-SSO + * `managed` - Managed + * `code` - Enrollment Code + * `auto` - Auto Enrollment x-enum-descriptions: - SSO - Non-SSO + - Managed + - Enrollment Code + - Auto Enrollment LearnerProgramRecordShare: type: object properties: @@ -3583,7 +3598,7 @@ components: b2b_organizations: type: array items: - $ref: '#/components/schemas/UserOrganization' + $ref: '#/components/schemas/OrganizationPage' readOnly: true required: - b2b_organizations @@ -3597,48 +3612,6 @@ components: - is_superuser - legal_address - updated_on - UserOrganization: - type: object - description: |- - Serializer for user organization data. - - Slightly different from the OrganizationPageSerializer; we only need - the user's orgs and contracts. - properties: - id: - type: integer - readOnly: true - name: - type: string - readOnly: true - description: The name of the organization - description: - type: string - readOnly: true - description: Any useful extra information about the organization - logo: - type: string - format: uri - readOnly: true - description: The organization's logo. Will be displayed in the app in various - places. - slug: - type: string - readOnly: true - description: The name of the page as it will appear in URLs e.g http://domain.com/blog/[my-slug]/ - pattern: ^[-\w]+$ - contracts: - type: array - items: - $ref: '#/components/schemas/ContractPage' - readOnly: true - required: - - contracts - - description - - id - - logo - - name - - slug UserProfile: type: object description: Serializer for profile diff --git a/pytest.ini b/pytest.ini index 599f8a5198..97cbb9b657 100644 --- a/pytest.ini +++ b/pytest.ini @@ -15,6 +15,11 @@ env = KEYCLOAK_BASE_URL=http://keycloak/ KEYCLOAK_REALM_NAME=ol-test KEYCLOAK_CLIENT_ID=mitxonline + KEYCLOAK_DISCOVERY_URL=https://keycloak/ + KEYCLOAK_ADMIN_CLIENT_ID=apisix + KEYCLOAK_ADMIN_CLIENT_SECRET=fake_admin_secret + KEYCLOAK_ADMIN_CLIENT_SCOPES="openid profile email" + KEYCLOAK_ADMIN_CLIENT_NO_VERIFY_SSL=True LOGOUT_REDIRECT_URL=https://openedx.odl.local/logout MAILGUN_KEY=invalid-key MAILGUN_SENDER_DOMAIN=localhost diff --git a/users/admin.py b/users/admin.py index 352f0d0a38..c4841ef2a5 100644 --- a/users/admin.py +++ b/users/admin.py @@ -11,6 +11,7 @@ from hijack.contrib.admin import HijackUserAdminMixin from mitol.common.admin import TimestampedModelAdmin +from b2b.models import UserOrganization from openedx.models import OpenEdxUser from users.models import BlockList, LegalAddress, User, UserProfile @@ -109,6 +110,13 @@ class UserContractPageInline(admin.TabularInline): extra = 0 +class UserOrganizationInline(admin.TabularInline): + """Inline to allow modifying the contracts associated with a user""" + + model = UserOrganization + extra = 0 + + @admin.register(User) class UserAdmin( DjangoObjectActions, ContribUserAdmin, HijackUserAdminMixin, TimestampedModelAdmin @@ -165,6 +173,7 @@ class UserAdmin( OpenEdxUserInline, UserLegalAddressInline, UserProfileInline, + UserOrganizationInline, UserContractPageInline, ] diff --git a/users/migrations/0038_user_b2b_organizations.py b/users/migrations/0038_user_b2b_organizations.py new file mode 100644 index 0000000000..a500f45cd7 --- /dev/null +++ b/users/migrations/0038_user_b2b_organizations.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.25 on 2025-10-08 19:34 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("b2b", "0009_contractpage_membership_type_and_more"), + ("users", "0037_add_global_id_default"), + ] + + operations = [ + migrations.AddField( + model_name="user", + name="b2b_organizations", + field=models.ManyToManyField( + help_text="The organizations the user is associated with.", + related_name="+", + through="b2b.UserOrganization", + to="b2b.organizationpage", + ), + ), + ] diff --git a/users/models.py b/users/models.py index 8cb8a712d6..9bcaf96943 100644 --- a/users/models.py +++ b/users/models.py @@ -1,5 +1,6 @@ """User models""" +import logging import uuid from datetime import timedelta from functools import cached_property @@ -18,11 +19,12 @@ from mitol.common.utils import now_in_utc from mitol.common.utils.collections import chunks -from b2b.models import OrganizationPage from cms.constants import CMS_EDITORS_GROUP_NAME from openedx.constants import OPENEDX_REPAIR_GRACE_PERIOD_MINS, OPENEDX_USERNAME_MAX_LEN from openedx.models import OpenEdxUser +log = logging.getLogger(__name__) + MALE = "m" FEMALE = "f" OTHER = "o" @@ -306,6 +308,12 @@ class User( related_name="users", help_text="The contracts the user is associated with.", ) + b2b_organizations = models.ManyToManyField( + "b2b.OrganizationPage", + through="b2b.UserOrganization", + related_name="+", + help_text="The organizations the user is associated with.", + ) objects = UserManager() faulty_openedx_users = FaultyOpenEdxUserManager() @@ -362,20 +370,13 @@ def is_editor(self) -> bool: or self.groups.filter(name=CMS_EDITORS_GROUP_NAME).exists() ) - @cached_property - def b2b_organizations(self): - """Return the organizations the user is associated with.""" - return OrganizationPage.objects.filter( - pk__in=self.b2b_contracts.values_list("organization", flat=True).distinct() - ).all() - @cached_property def b2b_organization_sso_ids(self): - """Similar to b2b_organizations, but returns just the UUIDs.""" + """Just the UUIDs for the organizations the user is in.""" return list( - self.b2b_organizations.filter( - sso_organization_id__isnull=False - ).values_list("sso_organization_id", flat=True) + self.organizations.filter(sso_organization_id__isnull=False).values_list( + "sso_organization_id", flat=True + ) ) @property diff --git a/users/serializers.py b/users/serializers.py index ace4f925b0..2bae801841 100644 --- a/users/serializers.py +++ b/users/serializers.py @@ -11,7 +11,7 @@ from rest_framework import serializers from social_django.models import UserSocialAuth -from b2b.serializers.v0 import ContractPageSerializer, OrganizationPageSerializer +from b2b.serializers.v0 import OrganizationPageSerializer from hubspot_sync.task_helpers import sync_hubspot_user # from ecommerce.api import fetch_and_serialize_unused_coupons # noqa: ERA001 @@ -22,7 +22,12 @@ from openedx.exceptions import EdxApiRegistrationValidationException from openedx.models import OpenEdxUser from openedx.tasks import change_edx_user_email_async -from users.models import ChangeEmailRequest, LegalAddress, User, UserProfile +from users.models import ( + ChangeEmailRequest, + LegalAddress, + User, + UserProfile, +) log = logging.getLogger() @@ -202,33 +207,6 @@ class Meta: ) -class UserOrganizationSerializer(OrganizationPageSerializer): - """ - Serializer for user organization data. - - Slightly different from the OrganizationPageSerializer; we only need - the user's orgs and contracts. - """ - - contracts = serializers.SerializerMethodField() - - @extend_schema_field(ContractPageSerializer(many=True)) - def get_contracts(self, instance): - """Get the contracts for the organization for the user""" - contracts = ( - self.context["user"] - .b2b_contracts.filter( - organization=instance, - ) - .all() - ) - return ContractPageSerializer(contracts, many=True).data - - class Meta(OrganizationPageSerializer.Meta): - fields = (*OrganizationPageSerializer.Meta.fields, "contracts") - read_only_fields = (*OrganizationPageSerializer.Meta.fields, "contracts") - - class UserSerializer(serializers.ModelSerializer): """Serializer for users""" @@ -268,15 +246,15 @@ def validate_username(self, value): def get_grants(self, instance): return instance.get_all_permissions() - @extend_schema_field(UserOrganizationSerializer(many=True)) + @extend_schema_field(OrganizationPageSerializer(many=True)) def get_b2b_organizations(self, instance): """Get the organizations for the user""" if instance.is_anonymous: return [] - organizations = instance.b2b_organizations - return UserOrganizationSerializer( - organizations, many=True, context={"user": instance} + return OrganizationPageSerializer( + instance.b2b_organizations, + many=True, ).data def validate(self, data): diff --git a/users/views_test.py b/users/views_test.py index 6565eab85d..1d88bf15cc 100644 --- a/users/views_test.py +++ b/users/views_test.py @@ -58,6 +58,7 @@ def test_get_user_by_me(mocker, client, user, is_anonymous, has_orgs): if has_orgs: contract = ContractPageFactory.create() + user.b2b_organizations.add(contract.organization) user.b2b_contracts.add(contract) user.save() b2b_orgs = [ @@ -72,6 +73,7 @@ def test_get_user_by_me(mocker, client, user, is_anonymous, has_orgs): "id": contract.id, "name": contract.name, "description": contract.description, + "membership_type": contract.membership_type, "integration_type": contract.integration_type, "contract_start": None, "contract_end": None, From 1dbaf1d92219132e7e5d7c82b4899fabfbb80921 Mon Sep 17 00:00:00 2001 From: Doof Date: Thu, 16 Oct 2025 20:01:38 +0000 Subject: [PATCH 4/4] Release 0.131.6 --- RELEASE.rst | 7 +++++++ main/settings.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/RELEASE.rst b/RELEASE.rst index c0c7e56005..4f7df15d3c 100644 --- a/RELEASE.rst +++ b/RELEASE.rst @@ -1,6 +1,13 @@ Release Notes ============= +Version 0.131.6 +--------------- + +- B2B: Update Keycloak org membership when a user redeems an enrollment code (#2989) +- Change v2 req_tree to not return root node, improve OpenAPI spec (#3010) +- Add ability to manage the link between programs and contracts (#3006) + Version 0.131.4 (Released October 16, 2025) --------------- diff --git a/main/settings.py b/main/settings.py index 56f962d386..68c0527f85 100644 --- a/main/settings.py +++ b/main/settings.py @@ -36,7 +36,7 @@ from main.sentry import init_sentry from openapi.settings_spectacular import open_spectacular_settings -VERSION = "0.131.4" +VERSION = "0.131.6" log = logging.getLogger()