diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000000..1cff6e92e9 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: Lint +on: [push, pull_request] + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install Python Packages + run: | + pip install --upgrade pip + pip install tox + + - name: Run isort + run: tox -e isort + + - name: Run Black + run: tox -e black diff --git a/api/admin/admin_authentication_provider.py b/api/admin/admin_authentication_provider.py index 57f79e5824..1ac953f377 100644 --- a/api/admin/admin_authentication_provider.py +++ b/api/admin/admin_authentication_provider.py @@ -1,4 +1,3 @@ - class AdminAuthenticationProvider(object): def __init__(self, integration): self.integration = integration diff --git a/api/admin/announcement_list_validator.py b/api/admin/announcement_list_validator.py index 1752086258..006eaba368 100644 --- a/api/admin/announcement_list_validator.py +++ b/api/admin/announcement_list_validator.py @@ -6,17 +6,21 @@ from flask_babel import lazy_gettext as _ from api.admin.validator import Validator - -from core.util.problem_detail import ProblemDetail from core.problem_details import * +from core.util.problem_detail import ProblemDetail class AnnouncementListValidator(Validator): - DATE_FORMAT = '%Y-%m-%d' + DATE_FORMAT = "%Y-%m-%d" - def __init__(self, maximum_announcements=3, minimum_announcement_length=15, - maximum_announcement_length=350, default_duration_days=60): + def __init__( + self, + maximum_announcements=3, + minimum_announcement_length=15, + maximum_announcement_length=350, + default_duration_days=60, + ): super(AnnouncementListValidator, self).__init__() self.maximum_announcements = maximum_announcements self.minimum_announcement_length = minimum_announcement_length @@ -26,8 +30,10 @@ def __init__(self, maximum_announcements=3, minimum_announcement_length=15, def validate_announcements(self, announcements): validated_announcements = [] bad_format = INVALID_INPUT.detailed( - _("Invalid announcement list format: %(announcements)r", - announcements=announcements) + _( + "Invalid announcement list format: %(announcements)r", + announcements=announcements, + ) ) if isinstance(announcements, (bytes, str)): try: @@ -38,8 +44,10 @@ def validate_announcements(self, announcements): return bad_format if len(announcements) > self.maximum_announcements: return INVALID_INPUT.detailed( - _("Too many announcements: maximum is %(maximum)d", - maximum=self.maximum_announcements) + _( + "Too many announcements: maximum is %(maximum)d", + maximum=self.maximum_announcements, + ) ) seen_ids = set() @@ -47,7 +55,7 @@ def validate_announcements(self, announcements): validated = self.validate_announcement(announcement) if isinstance(validated, ProblemDetail): return validated - id = validated['id'] + id = validated["id"] if id in seen_ids: return INVALID_INPUT.detailed(_("Duplicate announcement ID: %s" % id)) seen_ids.add(id) @@ -58,46 +66,47 @@ def validate_announcement(self, announcement): validated = dict() if not isinstance(announcement, dict): return INVALID_INPUT.detailed( - _("Invalid announcement format: %(announcement)r", announcement=announcement) + _( + "Invalid announcement format: %(announcement)r", + announcement=announcement, + ) ) - validated['id'] = announcement.get('id', str(uuid.uuid4())) + validated["id"] = announcement.get("id", str(uuid.uuid4())) - for required_field in ('content',): + for required_field in ("content",): if not required_field in announcement: return INVALID_INPUT.detailed( _("Missing required field: %(field)s", field=required_field) ) # Validate the content of the announcement. - content = announcement['content'] + content = announcement["content"] content = self.validate_length( content, self.minimum_announcement_length, self.maximum_announcement_length ) if isinstance(content, ProblemDetail): return content - validated['content'] = content + validated["content"] = content # Validate the dates associated with the announcement today_local = datetime.date.today() - start = self.validate_date( - 'start', announcement.get('start', today_local) - ) + start = self.validate_date("start", announcement.get("start", today_local)) if isinstance(start, ProblemDetail): return start - validated['start'] = start + validated["start"] = start default_finish = start + datetime.timedelta(days=self.default_duration_days) day_after_start = start + datetime.timedelta(days=1) finish = self.validate_date( - 'finish', - announcement.get('finish', default_finish), + "finish", + announcement.get("finish", default_finish), minimum=day_after_start, ) if isinstance(finish, ProblemDetail): return finish - validated['finish'] = finish + validated["finish"] = finish # That's it! return validated @@ -114,14 +123,22 @@ def validate_length(self, value, minimum, maximum): """ if len(value) < minimum: return INVALID_INPUT.detailed( - _('Value too short (%(length)d versus %(limit)d characters): %(value)s', - length=len(value), limit=minimum, value=value) + _( + "Value too short (%(length)d versus %(limit)d characters): %(value)s", + length=len(value), + limit=minimum, + value=value, + ) ) if len(value) > maximum: return INVALID_INPUT.detailed( - _('Value too long (%(length)d versus %(limit)d characters): %(value)s', - length=len(value), limit=maximum, value=value) + _( + "Value too long (%(length)d versus %(limit)d characters): %(value)s", + length=len(value), + limit=maximum, + value=value, + ) ) return value @@ -145,7 +162,11 @@ def validate_date(cls, field, value, minimum=None): value = value.replace(tzinfo=dateutil.tz.tzlocal()) except ValueError as e: return INVALID_INPUT.detailed( - _("Value for %(field)s is not a date: %(date)s", field=field, date=value) + _( + "Value for %(field)s is not a date: %(date)s", + field=field, + date=value, + ) ) if isinstance(value, datetime.datetime): value = value.date() @@ -153,8 +174,10 @@ def validate_date(cls, field, value, minimum=None): minimum = minimum.date() if minimum and value < minimum: return INVALID_INPUT.detailed( - _("Value for %(field)s must be no earlier than %(minimum)s", - field=field, minimum=minimum.strftime(cls.DATE_FORMAT) + _( + "Value for %(field)s must be no earlier than %(minimum)s", + field=field, + minimum=minimum.strftime(cls.DATE_FORMAT), ) ) return value @@ -162,4 +185,5 @@ def validate_date(cls, field, value, minimum=None): def format_as_string(self, value): """Format the output of validate_announcements for storage in ConfigurationSetting.value""" from ..announcements import Announcements + return json.dumps([x.json_ready for x in Announcements(value).announcements]) diff --git a/api/admin/config.py b/api/admin/config.py index fa4a5360c4..08969fd9f9 100644 --- a/api/admin/config.py +++ b/api/admin/config.py @@ -1,52 +1,54 @@ -from enum import Enum import os +from enum import Enum from urllib.parse import urljoin class OperationalMode(str, Enum): - production = 'production' - development = 'development' + production = "production" + development = "development" class Configuration: - APP_NAME = 'Palace Collection Manager' - PACKAGE_NAME = '@thepalaceproject/circulation-admin' - PACKAGE_VERSION = '0.0.6' + APP_NAME = "Palace Collection Manager" + PACKAGE_NAME = "@thepalaceproject/circulation-admin" + PACKAGE_VERSION = "0.0.6" STATIC_ASSETS = { - 'admin_js': 'circulation-admin.js', - 'admin_css': 'circulation-admin.css', - 'admin_logo': 'PalaceCollectionManagerLogo.svg', + "admin_js": "circulation-admin.js", + "admin_css": "circulation-admin.css", + "admin_logo": "PalaceCollectionManagerLogo.svg", } # For proper operation, `package_url` MUST end with a slash ('/') and # `asset_rel_url` MUST NOT begin with one. PACKAGE_TEMPLATES = { OperationalMode.production: { - 'package_url': 'https://cdn.jsdelivr.net/npm/{name}@{version}/', - 'asset_rel_url': 'dist/{filename}' + "package_url": "https://cdn.jsdelivr.net/npm/{name}@{version}/", + "asset_rel_url": "dist/{filename}", }, OperationalMode.development: { - 'package_url': '/admin/', - 'asset_rel_url': 'static/{filename}', + "package_url": "/admin/", + "asset_rel_url": "static/{filename}", }, } - DEVELOPMENT_MODE_PACKAGE_TEMPLATE = 'node_modules/{name}' - STATIC_ASSETS_REL_PATH = 'dist' + DEVELOPMENT_MODE_PACKAGE_TEMPLATE = "node_modules/{name}" + STATIC_ASSETS_REL_PATH = "dist" ADMIN_DIRECTORY = os.path.abspath(os.path.dirname(__file__)) # Environment variables that contain admin client package information. - ENV_ADMIN_UI_PACKAGE_NAME = 'TPP_CIRCULATION_ADMIN_PACKAGE_NAME' - ENV_ADMIN_UI_PACKAGE_VERSION = 'TPP_CIRCULATION_ADMIN_PACKAGE_VERSION' + ENV_ADMIN_UI_PACKAGE_NAME = "TPP_CIRCULATION_ADMIN_PACKAGE_NAME" + ENV_ADMIN_UI_PACKAGE_VERSION = "TPP_CIRCULATION_ADMIN_PACKAGE_VERSION" @classmethod def operational_mode(cls) -> OperationalMode: - return (OperationalMode.development - if os.path.isdir(cls.package_development_directory()) - else OperationalMode.production) + return ( + OperationalMode.development + if os.path.isdir(cls.package_development_directory()) + else OperationalMode.production + ) @classmethod def _package_name(cls) -> str: @@ -58,7 +60,9 @@ def _package_name(cls) -> str: return os.environ.get(cls.ENV_ADMIN_UI_PACKAGE_NAME) or cls.PACKAGE_NAME @classmethod - def lookup_asset_url(cls, key: str, *, _operational_mode: OperationalMode = None) -> str: + def lookup_asset_url( + cls, key: str, *, _operational_mode: OperationalMode = None + ) -> str: """Get the URL for the asset_type. :param key: The key used to lookup an asset's filename. If the key is @@ -72,8 +76,12 @@ def lookup_asset_url(cls, key: str, *, _operational_mode: OperationalMode = None """ operational_mode = _operational_mode or cls.operational_mode() filename = cls.STATIC_ASSETS.get(key, key) - return urljoin(cls.package_url(_operational_mode=operational_mode), - cls.PACKAGE_TEMPLATES[operational_mode]['asset_rel_url'].format(filename=filename)) + return urljoin( + cls.package_url(_operational_mode=operational_mode), + cls.PACKAGE_TEMPLATES[operational_mode]["asset_rel_url"].format( + filename=filename + ), + ) @classmethod def package_url(cls, *, _operational_mode: OperationalMode = None) -> str: @@ -88,11 +96,13 @@ def package_url(cls, *, _operational_mode: OperationalMode = None) -> str: :rtype: str """ operational_mode = _operational_mode or cls.operational_mode() - version = (os.environ.get(cls.ENV_ADMIN_UI_PACKAGE_VERSION) or cls.PACKAGE_VERSION) - template = cls.PACKAGE_TEMPLATES[operational_mode]['package_url'] + version = ( + os.environ.get(cls.ENV_ADMIN_UI_PACKAGE_VERSION) or cls.PACKAGE_VERSION + ) + template = cls.PACKAGE_TEMPLATES[operational_mode]["package_url"] url = template.format(name=cls._package_name(), version=version) - if not url.endswith('/'): - url += '/' + if not url.endswith("/"): + url += "/" return url @classmethod @@ -105,8 +115,10 @@ def package_development_directory(cls, *, _base_dir: str = None) -> str: :rtype: str """ base_dir = _base_dir or cls.ADMIN_DIRECTORY - return os.path.join(base_dir, - cls.DEVELOPMENT_MODE_PACKAGE_TEMPLATE.format(name=cls._package_name())) + return os.path.join( + base_dir, + cls.DEVELOPMENT_MODE_PACKAGE_TEMPLATE.format(name=cls._package_name()), + ) @classmethod def static_files_directory(cls, *, _base_dir: str = None) -> str: diff --git a/api/admin/controller/__init__.py b/api/admin/controller/__init__.py index 61482e558b..5a864b4d76 100644 --- a/api/admin/controller/__init__.py +++ b/api/admin/controller/__init__.py @@ -5,35 +5,32 @@ import os import sys import urllib.parse -from datetime import ( - date, - datetime, - timedelta, -) +from datetime import date, datetime, timedelta import flask import jwt -from flask import (redirect, Response) +from flask import Response, redirect from flask_babel import lazy_gettext as _ from sqlalchemy.sql import func -from sqlalchemy.sql.expression import (and_, desc, distinct, join, nullslast, select) +from sqlalchemy.sql.expression import and_, desc, distinct, join, nullslast, select from api.admin.config import Configuration as AdminClientConfig from api.admin.exceptions import * -from api.admin.google_oauth_admin_authentication_provider import GoogleOAuthAdminAuthenticationProvider -from api.admin.opds import ( - AdminAnnotator, - AdminFeed, +from api.admin.google_oauth_admin_authentication_provider import ( + GoogleOAuthAdminAuthenticationProvider, +) +from api.admin.opds import AdminAnnotator, AdminFeed +from api.admin.password_admin_authentication_provider import ( + PasswordAdminAuthenticationProvider, ) -from api.admin.password_admin_authentication_provider import PasswordAdminAuthenticationProvider from api.admin.template_styles import * from api.admin.templates import admin as admin_template from api.admin.validator import Validator from api.adobe_vendor_id import AuthdataUtility -from api.authenticator import (CannotCreateLocalPatron, LibraryAuthenticator, PatronData) +from api.authenticator import CannotCreateLocalPatron, LibraryAuthenticator, PatronData from api.axis import Axis360API from api.bibliotheca import BibliothecaAPI -from api.config import (CannotLoadConfiguration, Configuration) +from api.config import CannotLoadConfiguration, Configuration from api.controller import CirculationManagerController from api.enki import EnkiAPI from api.feedbooks import FeedbooksOPDSImporter @@ -41,25 +38,15 @@ from api.lcp.collection import LCPAPI from api.local_analytics_exporter import LocalAnalyticsExporter from api.odilo import OdiloAPI -from api.odl import ( - ODLAPI, - SharedODLAPI, -) +from api.odl import ODLAPI, SharedODLAPI from api.odl2 import ODL2API from api.opds_for_distributors import OPDSForDistributorsAPI from api.overdrive import OverdriveAPI from api.proquest.importer import ProQuestOPDS2Importer -from core.app_server import ( - load_pagination_from_request, -) -from core.classifier import ( - genres -) +from core.app_server import load_pagination_from_request +from core.classifier import genres from core.external_search import ExternalSearchIndex -from core.lane import ( - Lane, - WorkList, -) +from core.lane import Lane, WorkList from core.local_analytics_provider import LocalAnalyticsProvider from core.model import ( Admin, @@ -67,13 +54,10 @@ CirculationEvent, Collection, ConfigurationSetting, - create, CustomList, CustomListEntry, DataSource, ExternalIntegration, - get_one, - get_one_or_create, Hold, Identifier, Library, @@ -82,11 +66,14 @@ Patron, Timestamp, Work, + create, + get_one, + get_one_or_create, ) from core.model.configuration import ExternalIntegrationLink from core.opds import AcquisitionFeed from core.opds2_import import OPDS2Importer -from core.opds_import import (OPDSImporter, OPDSImportMonitor) +from core.opds_import import OPDSImporter, OPDSImportMonitor from core.s3 import S3UploaderConfiguration from core.selftest import HasSelfTests from core.util.datetime_helpers import utc_now @@ -108,6 +95,7 @@ def setup_admin_controllers(manager): manager.admin_sign_in_controller = SignInController(manager) manager.timestamps_controller = TimestampsController(manager) from api.admin.controller.work_editor import WorkController + manager.admin_work_controller = WorkController(manager) manager.admin_feed_controller = FeedController(manager) manager.admin_custom_lists_controller = CustomListsController(manager) @@ -116,50 +104,105 @@ def setup_admin_controllers(manager): manager.admin_settings_controller = SettingsController(manager) manager.admin_patron_controller = PatronController(manager) from api.admin.controller.self_tests import SelfTestsController + manager.admin_self_tests_controller = SelfTestsController(manager) from api.admin.controller.discovery_services import DiscoveryServicesController + manager.admin_discovery_services_controller = DiscoveryServicesController(manager) - from api.admin.controller.discovery_service_library_registrations import DiscoveryServiceLibraryRegistrationsController - manager.admin_discovery_service_library_registrations_controller = DiscoveryServiceLibraryRegistrationsController(manager) + from api.admin.controller.discovery_service_library_registrations import ( + DiscoveryServiceLibraryRegistrationsController, + ) + + manager.admin_discovery_service_library_registrations_controller = ( + DiscoveryServiceLibraryRegistrationsController(manager) + ) from api.admin.controller.cdn_services import CDNServicesController + manager.admin_cdn_services_controller = CDNServicesController(manager) from api.admin.controller.analytics_services import AnalyticsServicesController + manager.admin_analytics_services_controller = AnalyticsServicesController(manager) from api.admin.controller.metadata_services import MetadataServicesController + manager.admin_metadata_services_controller = MetadataServicesController(manager) + from api.admin.controller.metadata_service_self_tests import ( + MetadataServiceSelfTestsController, + ) from api.admin.controller.patron_auth_services import PatronAuthServicesController - from api.admin.controller.metadata_service_self_tests import MetadataServiceSelfTestsController - manager.admin_metadata_service_self_tests_controller = MetadataServiceSelfTestsController(manager) - manager.admin_patron_auth_services_controller = PatronAuthServicesController(manager) - from api.admin.controller.patron_auth_service_self_tests import PatronAuthServiceSelfTestsController - manager.admin_patron_auth_service_self_tests_controller = PatronAuthServiceSelfTestsController(manager) + + manager.admin_metadata_service_self_tests_controller = ( + MetadataServiceSelfTestsController(manager) + ) + manager.admin_patron_auth_services_controller = PatronAuthServicesController( + manager + ) + from api.admin.controller.patron_auth_service_self_tests import ( + PatronAuthServiceSelfTestsController, + ) + + manager.admin_patron_auth_service_self_tests_controller = ( + PatronAuthServiceSelfTestsController(manager) + ) from api.admin.controller.admin_auth_services import AdminAuthServicesController + manager.admin_auth_services_controller = AdminAuthServicesController(manager) from api.admin.controller.collection_settings import CollectionSettingsController + manager.admin_collection_settings_controller = CollectionSettingsController(manager) from api.admin.controller.collection_self_tests import CollectionSelfTestsController - manager.admin_collection_self_tests_controller = CollectionSelfTestsController(manager) - from api.admin.controller.collection_library_registrations import CollectionLibraryRegistrationsController - manager.admin_collection_library_registrations_controller = CollectionLibraryRegistrationsController(manager) - from api.admin.controller.sitewide_settings import SitewideConfigurationSettingsController - manager.admin_sitewide_configuration_settings_controller = SitewideConfigurationSettingsController(manager) + + manager.admin_collection_self_tests_controller = CollectionSelfTestsController( + manager + ) + from api.admin.controller.collection_library_registrations import ( + CollectionLibraryRegistrationsController, + ) + + manager.admin_collection_library_registrations_controller = ( + CollectionLibraryRegistrationsController(manager) + ) + from api.admin.controller.sitewide_settings import ( + SitewideConfigurationSettingsController, + ) + + manager.admin_sitewide_configuration_settings_controller = ( + SitewideConfigurationSettingsController(manager) + ) from api.admin.controller.library_settings import LibrarySettingsController + manager.admin_library_settings_controller = LibrarySettingsController(manager) - from api.admin.controller.individual_admin_settings import IndividualAdminSettingsController - manager.admin_individual_admin_settings_controller = IndividualAdminSettingsController(manager) - from api.admin.controller.sitewide_services import SitewideServicesController, LoggingServicesController, SearchServicesController + from api.admin.controller.individual_admin_settings import ( + IndividualAdminSettingsController, + ) + + manager.admin_individual_admin_settings_controller = ( + IndividualAdminSettingsController(manager) + ) + from api.admin.controller.sitewide_services import ( + LoggingServicesController, + SearchServicesController, + SitewideServicesController, + ) + manager.admin_sitewide_services_controller = SitewideServicesController(manager) manager.admin_logging_services_controller = LoggingServicesController(manager) - from api.admin.controller.search_service_self_tests import SearchServiceSelfTestsController - manager.admin_search_service_self_tests_controller = SearchServiceSelfTestsController(manager) + from api.admin.controller.search_service_self_tests import ( + SearchServiceSelfTestsController, + ) + + manager.admin_search_service_self_tests_controller = ( + SearchServiceSelfTestsController(manager) + ) manager.admin_search_services_controller = SearchServicesController(manager) from api.admin.controller.storage_services import StorageServicesController + manager.admin_storage_services_controller = StorageServicesController(manager) from api.admin.controller.catalog_services import CatalogServicesController + manager.admin_catalog_services_controller = CatalogServicesController(manager) -class AdminController(object): +class AdminController(object): def __init__(self, manager): self.manager = manager self._db = self.manager._db @@ -171,15 +214,19 @@ def admin_auth_providers(self): auth_providers = [] auth_service = ExternalIntegration.admin_authentication(self._db) if auth_service and auth_service.protocol == ExternalIntegration.GOOGLE_OAUTH: - auth_providers.append(GoogleOAuthAdminAuthenticationProvider( - auth_service, - self.url_for('google_auth_callback'), - test_mode=self.manager.testing, - )) + auth_providers.append( + GoogleOAuthAdminAuthenticationProvider( + auth_service, + self.url_for("google_auth_callback"), + test_mode=self.manager.testing, + ) + ) if Admin.with_password(self._db).count() != 0: - auth_providers.append(PasswordAdminAuthenticationProvider( - auth_service, - )) + auth_providers.append( + PasswordAdminAuthenticationProvider( + auth_service, + ) + ) return auth_providers def admin_auth_provider(self, type): @@ -212,23 +259,27 @@ def authenticated_admin_from_request(self): def authenticated_admin(self, admin_details): """Creates or updates an admin with the given details""" - admin, is_new = get_one_or_create( - self._db, Admin, email=admin_details['email'] - ) + admin, is_new = get_one_or_create(self._db, Admin, email=admin_details["email"]) admin.update_credentials( self._db, - credential=admin_details.get('credentials'), + credential=admin_details.get("credentials"), ) if is_new and admin_details.get("roles"): for role in admin_details.get("roles"): if role.get("role") in AdminRole.ROLES: library = Library.lookup(self._db, role.get("library")) if role.get("library") and not library: - self.log.warn("%s authentication provider specifiec an unknown library for a new admin: %s" % (admin_details.get("type"), role.get("library"))) + self.log.warn( + "%s authentication provider specifiec an unknown library for a new admin: %s" + % (admin_details.get("type"), role.get("library")) + ) else: admin.add_role(role.get("role"), library) else: - self.log.warn("%s authentication provider specified an unknown role for a new admin: %s" % (admin_details.get("type"), role.get("role"))) + self.log.warn( + "%s authentication provider specified an unknown role for a new admin: %s" + % (admin_details.get("type"), role.get("role")) + ) # Set up the admin's flask session. flask.session["admin_email"] = admin_details.get("email") @@ -244,11 +295,9 @@ def authenticated_admin(self, admin_details): # current request. This assumes the first authenticated admin # is accessing the admin interface through the hostname they # want to be used for the site itself. - base_url = ConfigurationSetting.sitewide( - self._db, Configuration.BASE_URL_KEY - ) + base_url = ConfigurationSetting.sitewide(self._db, Configuration.BASE_URL_KEY) if not base_url.value: - base_url.value = urllib.parse.urljoin(flask.request.url, '/') + base_url.value = urllib.parse.urljoin(flask.request.url, "/") return admin @@ -270,6 +319,7 @@ def generate_csrf_token(self): """Generate a random CSRF token.""" return base64.b64encode(os.urandom(24)).decode("utf-8") + class AdminCirculationManagerController(CirculationManagerController): """Parent class that provides methods for verifying an admin's roles.""" @@ -304,24 +354,24 @@ def require_higher_than_librarian(self): class ViewController(AdminController): def __call__(self, collection, book, path=None): - setting_up = (self.admin_auth_providers == []) + setting_up = self.admin_auth_providers == [] email = None roles = [] if not setting_up: admin = self.authenticated_admin_from_request() if isinstance(admin, ProblemDetail): redirect_url = flask.request.url - if (collection): + if collection: quoted_collection = urllib.parse.quote(collection) redirect_url = redirect_url.replace( - quoted_collection, - quoted_collection.replace("/", "%2F")) - if (book): + quoted_collection, quoted_collection.replace("/", "%2F") + ) + if book: quoted_book = urllib.parse.quote(book) redirect_url = redirect_url.replace( - quoted_book, - quoted_book.replace("/", "%2F")) - return redirect(self.url_for('admin_sign_in', redirect=redirect_url)) + quoted_book, quoted_book.replace("/", "%2F") + ) + return redirect(self.url_for("admin_sign_in", redirect=redirect_url)) if not collection and not book and not path: if self._db.query(Library).count() > 0: @@ -332,49 +382,62 @@ def __call__(self, collection, book, path=None): library_name = library.short_name break if not library_name: - return Response(_("Your admin account doesn't have access to any libraries. Contact your library manager for assistance."), 200) - return redirect(self.url_for('admin_view', collection=library_name)) + return Response( + _( + "Your admin account doesn't have access to any libraries. Contact your library manager for assistance." + ), + 200, + ) + return redirect(self.url_for("admin_view", collection=library_name)) email = admin.email for role in admin.roles: if role.library: - roles.append({ "role": role.role, "library": role.library }) + roles.append({"role": role.role, "library": role.library}) else: - roles.append({ "role": role.role }) + roles.append({"role": role.role}) - csrf_token = flask.request.cookies.get("csrf_token") or self.generate_csrf_token() - admin_js = AdminClientConfig.lookup_asset_url(key='admin_js') - admin_css = AdminClientConfig.lookup_asset_url(key='admin_css') + csrf_token = ( + flask.request.cookies.get("csrf_token") or self.generate_csrf_token() + ) + admin_js = AdminClientConfig.lookup_asset_url(key="admin_js") + admin_css = AdminClientConfig.lookup_asset_url(key="admin_css") # Find the URL and text to use when rendering the Terms of # Service link in the footer. - sitewide_tos_href = ConfigurationSetting.sitewide( - self._db, Configuration.CUSTOM_TOS_HREF - ).value or Configuration.DEFAULT_TOS_HREF + sitewide_tos_href = ( + ConfigurationSetting.sitewide(self._db, Configuration.CUSTOM_TOS_HREF).value + or Configuration.DEFAULT_TOS_HREF + ) - sitewide_tos_text = ConfigurationSetting.sitewide( - self._db, Configuration.CUSTOM_TOS_TEXT - ).value or Configuration.DEFAULT_TOS_TEXT + sitewide_tos_text = ( + ConfigurationSetting.sitewide(self._db, Configuration.CUSTOM_TOS_TEXT).value + or Configuration.DEFAULT_TOS_TEXT + ) local_analytics = get_one( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=LocalAnalyticsProvider.__module__, - goal=ExternalIntegration.ANALYTICS_GOAL) - show_circ_events_download = (local_analytics != None) - - response = Response(flask.render_template_string( - admin_template, - app_name=AdminClientConfig.APP_NAME, - csrf_token=csrf_token, - sitewide_tos_href=sitewide_tos_href, - sitewide_tos_text=sitewide_tos_text, - show_circ_events_download=show_circ_events_download, - setting_up=setting_up, - email=email, - roles=roles, - admin_js=admin_js, - admin_css=admin_css - )) + goal=ExternalIntegration.ANALYTICS_GOAL, + ) + show_circ_events_download = local_analytics != None + + response = Response( + flask.render_template_string( + admin_template, + app_name=AdminClientConfig.APP_NAME, + csrf_token=csrf_token, + sitewide_tos_href=sitewide_tos_href, + sitewide_tos_text=sitewide_tos_text, + show_circ_events_download=show_circ_events_download, + setting_up=setting_up, + email=email, + roles=roles, + admin_js=admin_js, + admin_css=admin_css, + ) + ) # The CSRF token is in its own cookie instead of the session cookie, # because if your session expires and you log in again, you should @@ -383,6 +446,7 @@ def __call__(self, collection, book, path=None): response.set_cookie("csrf_token", csrf_token, httponly=True) return response + class TimestampsController(AdminCirculationManagerController): """Returns a dict: each key is a type of service (script, monitor, or coverage provider); each value is a nested dict in which timestamps are organized by service name and then by collection ID.""" @@ -447,9 +511,10 @@ def _extract_info(self, timestamp): exception=timestamp.exception, service=timestamp.service, collection_name=collection_name, - achievements=timestamp.achievements + achievements=timestamp.achievements, ) + class SignInController(AdminController): HEAD_TEMPLATE = """ @@ -458,7 +523,9 @@ class SignInController(AdminController): @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;700&display=swap'); -""".format(app_name=AdminClientConfig.APP_NAME) +""".format( + app_name=AdminClientConfig.APP_NAME + ) ERROR_RESPONSE_TEMPLATE = """ @@ -468,7 +535,9 @@ class SignInController(AdminController):
Try again -""".format(head_html=HEAD_TEMPLATE, error=error_style, hr=hr_style, link=small_link_style) +""".format( + head_html=HEAD_TEMPLATE, error=error_style, hr=hr_style, link=small_link_style + ) SIGN_IN_TEMPLATE = """ @@ -477,7 +546,12 @@ class SignInController(AdminController): {app_name} %(auth_provider_html)s -""".format(head_html=HEAD_TEMPLATE, body=body_style, app_name=AdminClientConfig.APP_NAME, logo=logo_style) +""".format( + head_html=HEAD_TEMPLATE, + body=body_style, + app_name=AdminClientConfig.APP_NAME, + logo=logo_style, + ) def sign_in(self): """Redirects admin if they're signed in, or shows the sign in page.""" @@ -488,19 +562,26 @@ def sign_in(self): if isinstance(admin, ProblemDetail): redirect_url = flask.request.args.get("redirect") - auth_provider_html = [auth.sign_in_template(redirect_url) for auth in self.admin_auth_providers] + auth_provider_html = [ + auth.sign_in_template(redirect_url) + for auth in self.admin_auth_providers + ] auth_provider_html = """

or
- """.format(section=section_style, hr=hr_style).join(auth_provider_html) + """.format( + section=section_style, hr=hr_style + ).join( + auth_provider_html + ) html = self.SIGN_IN_TEMPLATE % dict( auth_provider_html=auth_provider_html, - logo_url=AdminClientConfig.lookup_asset_url(key='admin_logo') + logo_url=AdminClientConfig.lookup_asset_url(key="admin_logo"), ) headers = dict() - headers['Content-Type'] = "text/html" + headers["Content-Type"] = "text/html" return Response(html, 200, headers) elif admin: return redirect(flask.request.args.get("redirect"), Response=Response) @@ -549,19 +630,20 @@ def sign_out(self): flask.session.pop("admin_email", None) flask.session.pop("auth_type", None) - redirect_url = self.url_for("admin_sign_in", redirect=self.url_for("admin_view"), _external=True) + redirect_url = self.url_for( + "admin_sign_in", redirect=self.url_for("admin_view"), _external=True + ) return redirect(redirect_url) def error_response(self, problem_detail): """Returns a problem detail as an HTML response""" html = self.ERROR_RESPONSE_TEMPLATE % dict( - status_code=problem_detail.status_code, - message=problem_detail.detail + status_code=problem_detail.status_code, message=problem_detail.detail ) return Response(html, problem_detail.status_code) -class PatronController(AdminCirculationManagerController): +class PatronController(AdminCirculationManagerController): def _load_patrondata(self, authenticator=None): """Extract a patron identifier from an incoming form submission, and ask the library's LibraryAuthenticator to turn it into a @@ -597,8 +679,10 @@ def _load_patrondata(self, authenticator=None): # If we get here, none of the providers succeeded. if not complete_patron_data: return NO_SUCH_PATRON.detailed( - _("No patron with identifier %(patron_identifier)s was found at your library", - patron_identifier=identifier), + _( + "No patron with identifier %(patron_identifier)s was found at your library", + patron_identifier=identifier, + ), ) def lookup_patron(self, authenticator=None): @@ -629,8 +713,9 @@ def reset_adobe_id(self, authenticator=None): ) except CannotCreateLocalPatron as e: return NO_SUCH_PATRON.detailed( - _("Could not create local patron object for %(patron_identifier)s", - patron_identifier=patrondata.authorization_identifier + _( + "Could not create local patron object for %(patron_identifier)s", + patron_identifier=patrondata.authorization_identifier, ) ) @@ -642,56 +727,65 @@ def reset_adobe_id(self, authenticator=None): else: identifier = "with identifier " + patron.authorization_identifier return Response( - str(_("Adobe ID for patron %(name_or_auth_id)s has been reset.", name_or_auth_id=identifier)), - 200 + str( + _( + "Adobe ID for patron %(name_or_auth_id)s has been reset.", + name_or_auth_id=identifier, + ) + ), + 200, ) -class FeedController(AdminCirculationManagerController): +class FeedController(AdminCirculationManagerController): def complaints(self): self.require_librarian(flask.request.library) - this_url = self.url_for('complaints') + this_url = self.url_for("complaints") annotator = AdminAnnotator(self.circulation, flask.request.library) pagination = load_pagination_from_request() if isinstance(pagination, ProblemDetail): return pagination opds_feed = AdminFeed.complaints( - library=flask.request.library, title="Complaints", - url=this_url, annotator=annotator, - pagination=pagination + library=flask.request.library, + title="Complaints", + url=this_url, + annotator=annotator, + pagination=pagination, ) return OPDSFeedResponse(opds_feed, max_age=0) def suppressed(self): self.require_librarian(flask.request.library) - this_url = self.url_for('suppressed') + this_url = self.url_for("suppressed") annotator = AdminAnnotator(self.circulation, flask.request.library) pagination = load_pagination_from_request() if isinstance(pagination, ProblemDetail): return pagination opds_feed = AdminFeed.suppressed( - _db=self._db, title="Hidden Books", - url=this_url, annotator=annotator, - pagination=pagination + _db=self._db, + title="Hidden Books", + url=this_url, + annotator=annotator, + pagination=pagination, ) return OPDSFeedResponse(opds_feed, max_age=0) def genres(self): - data = dict({ - "Fiction": dict({}), - "Nonfiction": dict({}) - }) + data = dict({"Fiction": dict({}), "Nonfiction": dict({})}) for name in genres: top = "Fiction" if genres[name].is_fiction else "Nonfiction" - data[top][name] = dict({ - "name": name, - "parents": [parent.name for parent in genres[name].parents], - "subgenres": [subgenre.name for subgenre in genres[name].subgenres] - }) + data[top][name] = dict( + { + "name": name, + "parents": [parent.name for parent in genres[name].parents], + "subgenres": [subgenre.name for subgenre in genres[name].subgenres], + } + ) return data + class CustomListsController(AdminCirculationManagerController): def custom_lists(self): library = flask.request.library @@ -702,16 +796,33 @@ def custom_lists(self): for list in library.custom_lists: collections = [] for collection in list.collections: - collections.append(dict(id=collection.id, name=collection.name, protocol=collection.protocol)) - custom_lists.append(dict(id=list.id, name=list.name, collections=collections, entry_count=list.size)) + collections.append( + dict( + id=collection.id, + name=collection.name, + protocol=collection.protocol, + ) + ) + custom_lists.append( + dict( + id=list.id, + name=list.name, + collections=collections, + entry_count=list.size, + ) + ) return dict(custom_lists=custom_lists) if flask.request.method == "POST": id = flask.request.form.get("id") name = flask.request.form.get("name") entries = self._getJSONFromRequest(flask.request.form.get("entries")) - collections = self._getJSONFromRequest(flask.request.form.get("collections")) - return self._create_or_update_list(library, name, entries, collections, id=id) + collections = self._getJSONFromRequest( + flask.request.form.get("collections") + ) + return self._create_or_update_list( + library, name, entries, collections, id=id + ) def _getJSONFromRequest(self, values): if values: @@ -723,21 +834,19 @@ def _getJSONFromRequest(self, values): def _get_work_from_urn(self, library, urn): identifier, ignore = Identifier.parse_urn(self._db, urn) - query = self._db.query( - Work - ).join( - LicensePool, LicensePool.work_id==Work.id - ).join( - Collection, LicensePool.collection_id==Collection.id - ).filter( - LicensePool.identifier_id==identifier.id - ).filter( - Collection.id.in_([c.id for c in library.all_collections]) + query = ( + self._db.query(Work) + .join(LicensePool, LicensePool.work_id == Work.id) + .join(Collection, LicensePool.collection_id == Collection.id) + .filter(LicensePool.identifier_id == identifier.id) + .filter(Collection.id.in_([c.id for c in library.all_collections])) ) work = query.one() return work - def _create_or_update_list(self, library, name, entries, collections, deletedEntries=None, id=None): + def _create_or_update_list( + self, library, name, entries, collections, deletedEntries=None, id=None + ): data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) old_list_with_name = CustomList.find(self._db, name, library=library) @@ -754,7 +863,9 @@ def _create_or_update_list(self, library, name, entries, collections, deletedEnt elif old_list_with_name: return CUSTOM_LIST_NAME_ALREADY_IN_USE else: - list, is_new = create(self._db, CustomList, name=name, data_source=data_source) + list, is_new = create( + self._db, CustomList, name=name, data_source=data_source + ) list.created = datetime.now() list.library = library @@ -812,7 +923,13 @@ def _create_or_update_list(self, library, name, entries, collections, deletedEnt def url_for_custom_list(self, library, list): def url_fn(after): - return self.url_for("custom_list", after=after, library_short_name=library.short_name, list_id=list.id) + return self.url_for( + "custom_list", + after=after, + library_short_name=library.short_name, + list_id=list.id, + ) + return url_fn def custom_list(self, list_id): @@ -829,17 +946,15 @@ def custom_list(self, list_id): if isinstance(pagination, ProblemDetail): return pagination - query = self._db.query( - Work - ).join( - Work.custom_list_entries - ).filter( - CustomListEntry.list_id==list_id - ).order_by( - Work.id + query = ( + self._db.query(Work) + .join(Work.custom_list_entries) + .filter(CustomListEntry.list_id == list_id) + .order_by(Work.id) ) url = self.url_for( - "custom_list", list_name=list.name, + "custom_list", + list_name=list.name, library_short_name=library.short_name, list_id=list_id, ) @@ -850,8 +965,7 @@ def custom_list(self, list_id): annotator = self.manager.annotator(worklist) url_fn = self.url_for_custom_list(library, list) feed = AcquisitionFeed.from_query( - query, self._db, list.name, - url, pagination, url_fn, annotator + query, self._db, list.name, url, pagination, url_fn, annotator ) annotator.annotate_feed(feed, worklist) @@ -860,9 +974,20 @@ def custom_list(self, list_id): elif flask.request.method == "POST": name = flask.request.form.get("name") entries = self._getJSONFromRequest(flask.request.form.get("entries")) - collections = self._getJSONFromRequest(flask.request.form.get("collections")) - deletedEntries = self._getJSONFromRequest(flask.request.form.get("deletedEntries")) - return self._create_or_update_list(library, name, entries, collections, deletedEntries=deletedEntries, id=list_id) + collections = self._getJSONFromRequest( + flask.request.form.get("collections") + ) + deletedEntries = self._getJSONFromRequest( + flask.request.form.get("deletedEntries") + ) + return self._create_or_update_list( + library, + name, + entries, + collections, + deletedEntries=deletedEntries, + id=list_id, + ) elif flask.request.method == "DELETE": # Deleting requires a library manager. @@ -873,8 +998,7 @@ def custom_list(self, list_id): affected_lanes = Lane.affected_by_customlist(list) surviving_lanes = [] for lane in affected_lanes: - if (lane.list_datasource == None - and len(lane.customlist_ids) == 1): + if lane.list_datasource == None and len(lane.customlist_ids) == 1: # This Lane is based solely upon this custom list, # which is about to be deleted. Delete the Lane # itself. @@ -893,22 +1017,32 @@ def custom_list(self, list_id): class LanesController(AdminCirculationManagerController): - def lanes(self): library = flask.request.library self.require_librarian(library) if flask.request.method == "GET": + def lanes_for_parent(parent): - lanes = self._db.query(Lane).filter(Lane.library==library).filter(Lane.parent==parent).order_by(Lane.priority) - return [{ "id": lane.id, - "display_name": lane.display_name, - "visible": lane.visible, - "count": lane.size, - "sublanes": lanes_for_parent(lane), - "custom_list_ids": [list.id for list in lane.customlists], - "inherit_parent_restrictions": lane.inherit_parent_restrictions, - } for lane in lanes] + lanes = ( + self._db.query(Lane) + .filter(Lane.library == library) + .filter(Lane.parent == parent) + .order_by(Lane.priority) + ) + return [ + { + "id": lane.id, + "display_name": lane.display_name, + "visible": lane.visible, + "count": lane.size, + "sublanes": lanes_for_parent(lane), + "custom_list_ids": [list.id for list in lane.customlists], + "inherit_parent_restrictions": lane.inherit_parent_restrictions, + } + for lane in lanes + ] + return dict(lanes=lanes_for_parent(None)) if flask.request.method == "POST": @@ -917,8 +1051,12 @@ def lanes_for_parent(parent): id = flask.request.form.get("id") parent_id = flask.request.form.get("parent_id") display_name = flask.request.form.get("display_name") - custom_list_ids = json.loads(flask.request.form.get("custom_list_ids", "[]")) - inherit_parent_restrictions = flask.request.form.get("inherit_parent_restrictions") + custom_list_ids = json.loads( + flask.request.form.get("custom_list_ids", "[]") + ) + inherit_parent_restrictions = flask.request.form.get( + "inherit_parent_restrictions" + ) if inherit_parent_restrictions == "true": inherit_parent_restrictions = True else: @@ -938,7 +1076,9 @@ def lanes_for_parent(parent): if not lane.customlists: return CANNOT_EDIT_DEFAULT_LANE if display_name != lane.display_name: - old_lane = get_one(self._db, Lane, display_name=display_name, parent=lane.parent) + old_lane = get_one( + self._db, Lane, display_name=display_name, parent=lane.parent + ) if old_lane: return LANE_WITH_PARENT_AND_DISPLAY_NAME_ALREADY_EXISTS lane.display_name = display_name @@ -947,17 +1087,32 @@ def lanes_for_parent(parent): if parent_id: parent = get_one(self._db, Lane, id=parent_id, library=library) if not parent: - return MISSING_LANE.detailed(_("The specified parent lane does not exist, or is associated with a different library.")) - old_lane = get_one(self._db, Lane, display_name=display_name, parent=parent) + return MISSING_LANE.detailed( + _( + "The specified parent lane does not exist, or is associated with a different library." + ) + ) + old_lane = get_one( + self._db, Lane, display_name=display_name, parent=parent + ) if old_lane: return LANE_WITH_PARENT_AND_DISPLAY_NAME_ALREADY_EXISTS lane, is_new = create( - self._db, Lane, display_name=display_name, - parent=parent, library=library) + self._db, + Lane, + display_name=display_name, + parent=parent, + library=library, + ) # Make a new lane the first child of its parent and bump all the siblings down in priority. - siblings = self._db.query(Lane).filter(Lane.library==library).filter(Lane.parent==lane.parent).filter(Lane.id!=lane.id) + siblings = ( + self._db.query(Lane) + .filter(Lane.library == library) + .filter(Lane.parent == lane.parent) + .filter(Lane.id != lane.id) + ) for sibling in siblings: sibling.priority += 1 lane.priority = 0 @@ -969,7 +1124,11 @@ def lanes_for_parent(parent): if not list: self._db.rollback() return MISSING_CUSTOM_LIST.detailed( - _("The list with id %(list_id)s does not exist or is associated with a different library.", list_id=list_id)) + _( + "The list with id %(list_id)s does not exist or is associated with a different library.", + list_id=list_id, + ) + ) lane.customlists.append(list) for list in lane.customlists: @@ -1038,7 +1197,7 @@ def change_order(self): def update_lane_order(lanes): for index, lane_data in enumerate(lanes): lane_id = lane_data.get("id") - lane = self._db.query(Lane).filter(Lane.id==lane_id).one() + lane = self._db.query(Lane).filter(Lane.id == lane_id).one() lane.priority = index update_lane_order(lane_data.get("sublanes", [])) @@ -1048,7 +1207,6 @@ def update_lane_order(lanes): class DashboardController(AdminCirculationManagerController): - def stats(self): library_stats = {} @@ -1058,46 +1216,52 @@ def stats(self): collection_counts = dict() for collection in self._db.query(Collection): - if not flask.request.admin or not flask.request.admin.can_see_collection(collection): + if not flask.request.admin or not flask.request.admin.can_see_collection( + collection + ): continue - licensed_title_count = self._db.query( - LicensePool - ).filter( - LicensePool.collection_id == collection.id - ).filter( - and_( - LicensePool.licenses_owned > 0, - LicensePool.open_access == False, + licensed_title_count = ( + self._db.query(LicensePool) + .filter(LicensePool.collection_id == collection.id) + .filter( + and_( + LicensePool.licenses_owned > 0, + LicensePool.open_access == False, + ) ) - ).count() + .count() + ) - open_title_count = self._db.query( - LicensePool - ).filter( - LicensePool.collection_id == collection.id - ).filter( - LicensePool.open_access == True - ).count() + open_title_count = ( + self._db.query(LicensePool) + .filter(LicensePool.collection_id == collection.id) + .filter(LicensePool.open_access == True) + .count() + ) # The sum queries return None instead of 0 if there are # no license pools in the db. - license_count = self._db.query( - func.sum(LicensePool.licenses_owned) - ).filter( - LicensePool.collection_id == collection.id - ).filter( - LicensePool.open_access == False, - ).all()[0][0] or 0 - - available_license_count = self._db.query( - func.sum(LicensePool.licenses_available) - ).filter( - LicensePool.collection_id == collection.id - ).filter( - LicensePool.open_access == False, - ).all()[0][0] or 0 + license_count = ( + self._db.query(func.sum(LicensePool.licenses_owned)) + .filter(LicensePool.collection_id == collection.id) + .filter( + LicensePool.open_access == False, + ) + .all()[0][0] + or 0 + ) + + available_license_count = ( + self._db.query(func.sum(LicensePool.licenses_available)) + .filter(LicensePool.collection_id == collection.id) + .filter( + LicensePool.open_access == False, + ) + .all()[0][0] + or 0 + ) total_title_count += licensed_title_count + open_title_count total_license_count += license_count @@ -1110,80 +1274,76 @@ def stats(self): available_licenses=available_license_count, ) - for library in self._db.query(Library): # Only include libraries this admin has librarian access to. if not flask.request.admin or not flask.request.admin.is_librarian(library): continue - patron_count = self._db.query(Patron).filter(Patron.library_id==library.id).count() + patron_count = ( + self._db.query(Patron).filter(Patron.library_id == library.id).count() + ) - active_loans_patron_count = self._db.query( - distinct(Patron.id) - ).join( - Patron.loans - ).filter( - Loan.end >= datetime.now(), - ).filter( - Patron.library_id == library.id - ).count() - - active_patrons = select( - [Patron.id] - ).select_from( - join( - Loan, - Patron, - and_( - Patron.id == Loan.patron_id, - Patron.library_id == library.id, - Loan.id != None, - Loan.end >= datetime.now() - ) + active_loans_patron_count = ( + self._db.query(distinct(Patron.id)) + .join(Patron.loans) + .filter( + Loan.end >= datetime.now(), ) - ).union( - select( - [Patron.id] - ).select_from( + .filter(Patron.library_id == library.id) + .count() + ) + + active_patrons = ( + select([Patron.id]) + .select_from( join( - Hold, + Loan, Patron, and_( - Patron.id == Hold.patron_id, + Patron.id == Loan.patron_id, Patron.library_id == library.id, - Hold.id != None, + Loan.id != None, + Loan.end >= datetime.now(), + ), + ) + ) + .union( + select([Patron.id]).select_from( + join( + Hold, + Patron, + and_( + Patron.id == Hold.patron_id, + Patron.library_id == library.id, + Hold.id != None, + ), ) ) ) - ).alias() - + .alias() + ) active_loans_or_holds_patron_count_query = select( [func.count(distinct(active_patrons.c.id))] - ).select_from( - active_patrons - ) + ).select_from(active_patrons) result = self._db.execute(active_loans_or_holds_patron_count_query) active_loans_or_holds_patron_count = [r[0] for r in result][0] - loan_count = self._db.query( - Loan - ).join( - Loan.patron - ).filter( - Patron.library_id == library.id - ).filter( - Loan.end >= datetime.now() - ).count() - - hold_count = self._db.query( - Hold - ).join( - Hold.patron - ).filter( - Patron.library_id == library.id - ).count() + loan_count = ( + self._db.query(Loan) + .join(Loan.patron) + .filter(Patron.library_id == library.id) + .filter(Loan.end >= datetime.now()) + .count() + ) + + hold_count = ( + self._db.query(Hold) + .join(Hold.patron) + .filter(Patron.library_id == library.id) + .count() + ) title_count = 0 license_count = 0 @@ -1193,7 +1353,9 @@ def stats(self): for collection in library.all_collections: counts = collection_counts[collection.name] library_collection_counts[collection.name] = counts - title_count += counts.get("licensed_titles", 0) + counts.get("open_access_titles", 0) + title_count += counts.get("licensed_titles", 0) + counts.get( + "open_access_titles", 0 + ) license_count += counts.get("licenses", 0) available_license_count += counts.get("available_licenses", 0) @@ -1213,24 +1375,39 @@ def stats(self): collections=library_collection_counts, ) - total_patrons = sum([ - stats.get("patrons", {}).get("total", 0) - for stats in list(library_stats.values())]) - total_with_active_loans = sum([ - stats.get("patrons", {}).get("with_active_loans", 0) - for stats in list(library_stats.values())]) - total_with_active_loans_or_holds = sum([ - stats.get("patrons", {}).get("with_active_loans_or_holds", 0) - for stats in list(library_stats.values())]) + total_patrons = sum( + [ + stats.get("patrons", {}).get("total", 0) + for stats in list(library_stats.values()) + ] + ) + total_with_active_loans = sum( + [ + stats.get("patrons", {}).get("with_active_loans", 0) + for stats in list(library_stats.values()) + ] + ) + total_with_active_loans_or_holds = sum( + [ + stats.get("patrons", {}).get("with_active_loans_or_holds", 0) + for stats in list(library_stats.values()) + ] + ) # TODO: show shared collection loans and holds for libraries outside this # circ manager? - total_loans = sum([ - stats.get("patrons", {}).get("loans", 0) - for stats in list(library_stats.values())]) - total_holds = sum([ - stats.get("patrons", {}).get("holds", 0) - for stats in list(library_stats.values())]) + total_loans = sum( + [ + stats.get("patrons", {}).get("loans", 0) + for stats in list(library_stats.values()) + ] + ) + total_holds = sum( + [ + stats.get("patrons", {}).get("holds", 0) + for stats in list(library_stats.values()) + ] + ) library_stats["total"] = dict( patrons=dict( @@ -1254,29 +1431,39 @@ def circulation_events(self): annotator = AdminAnnotator(self.circulation, flask.request.library) num = min(int(flask.request.args.get("num", "100")), 500) - results = self._db.query(CirculationEvent) \ - .join(LicensePool) \ - .join(Work) \ - .join(DataSource) \ - .join(Identifier) \ - .order_by(nullslast(desc(CirculationEvent.start))) \ - .limit(num) \ + results = ( + self._db.query(CirculationEvent) + .join(LicensePool) + .join(Work) + .join(DataSource) + .join(Identifier) + .order_by(nullslast(desc(CirculationEvent.start))) + .limit(num) .all() + ) - events = [{ - "id": result.id, - "type": result.type, - "time": result.start, - "book": { - "title": result.license_pool.work.title, - "url": annotator.permalink_for(result.license_pool.work, result.license_pool, result.license_pool.identifier) + events = [ + { + "id": result.id, + "type": result.type, + "time": result.start, + "book": { + "title": result.license_pool.work.title, + "url": annotator.permalink_for( + result.license_pool.work, + result.license_pool, + result.license_pool.identifier, + ), + }, } - } for result in results] + for result in results + ] - return dict({ "circulation_events": events }) + return dict({"circulation_events": events}) def bulk_circulation_events(self, analytics_exporter=None): date_format = "%Y-%m-%d" + def get_date(field): # Return a date or datetime object representing the # _beginning_ of the asked-for day, local time. @@ -1306,19 +1493,24 @@ def get_date(field): date_end_label = get_date("dateEnd") date_end = date_end_label + timedelta(days=1) locations = flask.request.args.get("locations", None) - library = getattr(flask.request, 'library', None) + library = getattr(flask.request, "library", None) library_short_name = library.short_name if library else None analytics_exporter = analytics_exporter or LocalAnalyticsExporter() data = analytics_exporter.export( self._db, date_start, date_end, locations, library ) - return (data, date_start.strftime(date_format), - date_end_label.strftime(date_format), library_short_name) + return ( + data, + date_start.strftime(date_format), + date_end_label.strftime(date_format), + library_short_name, + ) + class SettingsController(AdminCirculationManagerController): - METADATA_SERVICE_URI_TYPE = 'application/opds+json;profile=https://librarysimplified.org/rel/profile/metadata-service' + METADATA_SERVICE_URI_TYPE = "application/opds+json;profile=https://librarysimplified.org/rel/profile/metadata-service" NO_MIRROR_INTEGRATION = "NO_MIRROR" @@ -1336,7 +1528,7 @@ class SettingsController(AdminCirculationManagerController): ODL2API, SharedODLAPI, FeedbooksOPDSImporter, - LCPAPI + LCPAPI, ] @classmethod @@ -1373,16 +1565,16 @@ def _get_integration_protocols(cls, provider_apis, protocol_name_attr="__module_ if library_settings != None: protocol["library_settings"] = list(library_settings) - cardinality = getattr(api, 'CARDINALITY', None) + cardinality = getattr(api, "CARDINALITY", None) if cardinality != None: - protocol['cardinality'] = cardinality + protocol["cardinality"] = cardinality supports_registration = getattr(api, "SUPPORTS_REGISTRATION", None) if supports_registration != None: - protocol['supports_registration'] = supports_registration + protocol["supports_registration"] = supports_registration supports_staging = getattr(api, "SUPPORTS_STAGING", None) if supports_staging != None: - protocol['supports_staging'] = supports_staging + protocol["supports_staging"] = supports_staging protocols.append(protocol) return protocols @@ -1406,7 +1598,8 @@ def _get_integration_library_info(self, integration, library, protocol): def _get_integration_info(self, goal, protocols): services = [] for service in self._db.query(ExternalIntegration).filter( - ExternalIntegration.goal==goal): + ExternalIntegration.goal == goal + ): candidates = [p for p in protocols if p.get("name") == service.protocol] if not candidates: continue @@ -1414,8 +1607,9 @@ def _get_integration_info(self, goal, protocols): libraries = [] if not protocol.get("sitewide") or protocol.get("library_settings"): for library in service.libraries: - libraries.append(self._get_integration_library_info( - service, library, protocol)) + libraries.append( + self._get_integration_library_info(service, library, protocol) + ) settings = dict() for setting in protocol.get("settings", []): @@ -1424,9 +1618,11 @@ def _get_integration_info(self, goal, protocols): # If the setting is a covers or books mirror, we need to get # the value from ExternalIntegrationLink and # not from a ConfigurationSetting. - if key.endswith('mirror_integration_id'): + if key.endswith("mirror_integration_id"): storage_integration = get_one( - self._db, ExternalIntegrationLink, external_integration_id=service.id + self._db, + ExternalIntegrationLink, + external_integration_id=service.id, ) if storage_integration: value = str(storage_integration.other_integration_id) @@ -1435,10 +1631,12 @@ def _get_integration_info(self, goal, protocols): else: if setting.get("type") in ("list", "menu"): value = ConfigurationSetting.for_externalintegration( - key, service).json_value + key, service + ).json_value else: value = ConfigurationSetting.for_externalintegration( - key, service).value + key, service + ).value settings[key] = value service_info = dict( @@ -1450,7 +1648,9 @@ def _get_integration_info(self, goal, protocols): ) if "test_search_term" in [x.get("key") for x in protocol.get("settings")]: - service_info["self_test_results"] = self._get_prior_test_results(service) + service_info["self_test_results"] = self._get_prior_test_results( + service + ) services.append(service_info) return services @@ -1480,9 +1680,9 @@ def _get_menu_values(setting_key, form): for form_item_key in list(form.keys()): if setting_key in form_item_key: - value = form_item_key.replace(setting_key, '').lstrip('_') + value = form_item_key.replace(setting_key, "").lstrip("_") - if value != 'menu': + if value != "menu": values.append(value) return values @@ -1495,7 +1695,7 @@ def _set_integration_setting(self, integration, setting): value = [item for item in flask.request.form.getlist(setting_key) if item] if value: value = json.dumps(value) - elif setting_type == 'menu': + elif setting_type == "menu": value = self._get_menu_values(setting_key, flask.request.form) else: value = flask.request.form.get(setting_key) @@ -1511,14 +1711,23 @@ def _set_integration_setting(self, integration, setting): for submitted_value in submitted_values: if submitted_value not in allowed_values: - return INVALID_CONFIGURATION_OPTION.detailed(_( - "The configuration value for %(setting)s is invalid.", - setting=setting.get("label"), - )) - if not value and setting.get("required") and not "default" in list(setting.keys()): + return INVALID_CONFIGURATION_OPTION.detailed( + _( + "The configuration value for %(setting)s is invalid.", + setting=setting.get("label"), + ) + ) + if ( + not value + and setting.get("required") + and not "default" in list(setting.keys()) + ): return INCOMPLETE_CONFIGURATION.detailed( - _("The configuration is missing a required setting: %(setting)s", - setting=setting.get("label"))) + _( + "The configuration is missing a required setting: %(setting)s", + setting=setting.get("label"), + ) + ) if isinstance(value, list): value = json.dumps(value) @@ -1528,7 +1737,12 @@ def _set_integration_setting(self, integration, setting): def _set_integration_library(self, integration, library_info, protocol): library = get_one(self._db, Library, short_name=library_info.get("short_name")) if not library: - return NO_SUCH_LIBRARY.detailed(_("You attempted to add the integration to %(library_short_name)s, but it does not exist.", library_short_name=library_info.get("short_name"))) + return NO_SUCH_LIBRARY.detailed( + _( + "You attempted to add the integration to %(library_short_name)s, but it does not exist.", + library_short_name=library_info.get("short_name"), + ) + ) integration.libraries += [library] for setting in protocol.get("library_settings", []): @@ -1536,23 +1750,31 @@ def _set_integration_library(self, integration, library_info, protocol): value = library_info.get(key) if value and setting.get("type") == "list" and not setting.get("options"): value = json.dumps(value) - if setting.get("options") and value not in [option.get("key") for option in setting.get("options")]: - return INVALID_CONFIGURATION_OPTION.detailed(_( - "The configuration value for %(setting)s is invalid.", - setting=setting.get("label"), - )) + if setting.get("options") and value not in [ + option.get("key") for option in setting.get("options") + ]: + return INVALID_CONFIGURATION_OPTION.detailed( + _( + "The configuration value for %(setting)s is invalid.", + setting=setting.get("label"), + ) + ) if not value and setting.get("required"): return INCOMPLETE_CONFIGURATION.detailed( - _("The configuration is missing a required setting: %(setting)s for library %(library)s", - setting=setting.get("label"), - library=library.short_name, - )) - ConfigurationSetting.for_library_and_externalintegration(self._db, key, library, integration).value = value + _( + "The configuration is missing a required setting: %(setting)s for library %(library)s", + setting=setting.get("label"), + library=library.short_name, + ) + ) + ConfigurationSetting.for_library_and_externalintegration( + self._db, key, library, integration + ).value = value def _set_integration_settings_and_libraries(self, integration, protocol): settings = protocol.get("settings") for setting in settings: - if not setting.get('key').endswith('mirror_integration_id'): + if not setting.get("key").endswith("mirror_integration_id"): result = self._set_integration_setting(integration, setting) if isinstance(result, ProblemDetail): return result @@ -1565,7 +1787,9 @@ def _set_integration_settings_and_libraries(self, integration, protocol): libraries = json.loads(flask.request.form.get("libraries")) for library_info in libraries: - result = self._set_integration_library(integration, library_info, protocol) + result = self._set_integration_library( + integration, library_info, protocol + ) if isinstance(result, ProblemDetail): return result return True @@ -1575,22 +1799,27 @@ def _delete_integration(self, integration_id, goal): return self.require_system_admin() - integration = get_one(self._db, ExternalIntegration, - id=integration_id, goal=goal) + integration = get_one( + self._db, ExternalIntegration, id=integration_id, goal=goal + ) if not integration: return MISSING_SERVICE self._db.delete(integration) return Response(str(_("Deleted")), 200) def _get_collection_protocols(self, provider_apis): - protocols = self._get_integration_protocols(provider_apis, protocol_name_attr="NAME") + protocols = self._get_integration_protocols( + provider_apis, protocol_name_attr="NAME" + ) protocols.append( { - 'name': ExternalIntegration.MANUAL, - 'label': _('Manual import'), - 'description': _('Books will be manually added to the circulation manager, ' - 'not imported automatically through a protocol.'), - 'settings': [] + "name": ExternalIntegration.MANUAL, + "label": _("Manual import"), + "description": _( + "Books will be manually added to the circulation manager, " + "not imported automatically through a protocol." + ), + "settings": [], } ) @@ -1616,8 +1845,10 @@ def _get_prior_test_results(self, item, protocol_class=None, *extra_args): if item.protocol == OPDSImportMonitor.PROTOCOL: protocol_class = OPDSImportMonitor - if protocol_class in provider_apis and issubclass(protocol_class, HasSelfTests): - if (item.protocol == OPDSImportMonitor.PROTOCOL): + if protocol_class in provider_apis and issubclass( + protocol_class, HasSelfTests + ): + if item.protocol == OPDSImportMonitor.PROTOCOL: extra_args = (OPDSImporter,) else: extra_args = () @@ -1643,8 +1874,10 @@ def _get_prior_test_results(self, item, protocol_class=None, *extra_args): ) else: self_test_results = dict( - exception=_("You must associate this service with at least one library before you can run self tests for it."), - disabled=True + exception=_( + "You must associate this service with at least one library before you can run self tests for it." + ), + disabled=True, ) except Exception as e: @@ -1664,44 +1897,48 @@ def _mirror_integration_settings(self): """Create a setting interface for selecting a storage integration to be used when mirroring items from a collection. """ - integrations = self._db.query(ExternalIntegration).filter( - ExternalIntegration.goal==ExternalIntegration.STORAGE_GOAL - ).order_by( - ExternalIntegration.name + integrations = ( + self._db.query(ExternalIntegration) + .filter(ExternalIntegration.goal == ExternalIntegration.STORAGE_GOAL) + .order_by(ExternalIntegration.name) ) if not integrations.all(): return - mirror_integration_settings = copy.deepcopy(ExternalIntegrationLink.COLLECTION_MIRROR_SETTINGS) + mirror_integration_settings = copy.deepcopy( + ExternalIntegrationLink.COLLECTION_MIRROR_SETTINGS + ) for integration in integrations: - book_covers_bucket = integration.setting(S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY).value - open_access_bucket = integration.setting(S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY).value - protected_access_bucket = integration.setting(S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY).value + book_covers_bucket = integration.setting( + S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY + ).value + open_access_bucket = integration.setting( + S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY + ).value + protected_access_bucket = integration.setting( + S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY + ).value for setting in mirror_integration_settings: - if setting['key'] == ExternalIntegrationLink.COVERS_KEY and book_covers_bucket: - setting['options'].append( - { - 'key': str(integration.id), - 'label': integration.name - } + if ( + setting["key"] == ExternalIntegrationLink.COVERS_KEY + and book_covers_bucket + ): + setting["options"].append( + {"key": str(integration.id), "label": integration.name} ) - elif setting['key'] == ExternalIntegrationLink.OPEN_ACCESS_BOOKS_KEY: + elif setting["key"] == ExternalIntegrationLink.OPEN_ACCESS_BOOKS_KEY: if open_access_bucket: - setting['options'].append( - { - 'key': str(integration.id), - 'label': integration.name - } + setting["options"].append( + {"key": str(integration.id), "label": integration.name} ) - elif setting['key'] == ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS_KEY: + elif ( + setting["key"] == ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS_KEY + ): if protected_access_bucket: - setting['options'].append( - { - 'key': str(integration.id), - 'label': integration.name - } + setting["options"].append( + {"key": str(integration.id), "label": integration.name} ) return mirror_integration_settings @@ -1717,7 +1954,7 @@ def _create_integration(self, protocol_definitions, protocol, goal): """ if not protocol: return NO_PROTOCOL_FOR_NEW_SERVICE, False - matches = [x for x in protocol_definitions if x.get('name') == protocol] + matches = [x for x in protocol_definitions if x.get("name") == protocol] if not matches: return UNKNOWN_PROTOCOL, False definition = matches[0] @@ -1728,7 +1965,7 @@ def _create_integration(self, protocol_definitions, protocol, goal): m = create args = (self._db, ExternalIntegration) kwargs = dict(protocol=protocol, goal=goal) - if definition.get('cardinality') == 1: + if definition.get("cardinality") == 1: # ...but not all the time. allow_multiple = False existing = get_one(*args, **kwargs) @@ -1786,7 +2023,7 @@ def url_variants(cls, url, check_protocol_variant=True): if url.endswith("/"): yield url[:-1] else: - yield url + '/' + yield url + "/" # Changing protocols may create one or more variants. https = "https://" @@ -1820,22 +2057,29 @@ def check_url_unique(self, new_service, url, protocol, goal): # because we're doing the comparison in the database. urls = list(self.url_variants(url)) - qu = self._db.query(ExternalIntegration).join( - ExternalIntegration.settings - ).filter( - # Protocol must match. - ExternalIntegration.protocol==protocol - ).filter( - # Goal must match. - ExternalIntegration.goal==goal - ).filter( - ConfigurationSetting.key==ExternalIntegration.URL - ).filter( - # URL must be one of the URLs we're concerned about. - ConfigurationSetting.value.in_(urls) - ).filter( - # But don't count the service we're trying to edit. - ExternalIntegration.id != new_service.id + qu = ( + self._db.query(ExternalIntegration) + .join(ExternalIntegration.settings) + .filter( + # Protocol must match. + ExternalIntegration.protocol + == protocol + ) + .filter( + # Goal must match. + ExternalIntegration.goal + == goal + ) + .filter(ConfigurationSetting.key == ExternalIntegration.URL) + .filter( + # URL must be one of the URLs we're concerned about. + ConfigurationSetting.value.in_(urls) + ) + .filter( + # But don't count the service we're trying to edit. + ExternalIntegration.id + != new_service.id + ) ) if qu.count() > 0: return INTEGRATION_URL_ALREADY_IN_USE @@ -1879,8 +2123,12 @@ def validate_protocol(self, protocols=None): return UNKNOWN_PROTOCOL def _get_settings(self): - if hasattr(self, 'protocols'): - [protocol] = [p for p in self.protocols if p.get("name") == flask.request.form.get("protocol")] + if hasattr(self, "protocols"): + [protocol] = [ + p + for p in self.protocols + if p.get("name") == flask.request.form.get("protocol") + ] return protocol.get("settings") return [] @@ -1911,8 +2159,8 @@ class SitewideRegistrationController(SettingsController): the circulation manager. """ - def process_sitewide_registration(self, integration, do_get=HTTP.debuggable_get, - do_post=HTTP.debuggable_post + def process_sitewide_registration( + self, integration, do_get=HTTP.debuggable_get, do_post=HTTP.debuggable_post ): """Performs a sitewide registration for a particular service. @@ -1932,7 +2180,7 @@ def process_sitewide_registration(self, integration, do_get=HTTP.debuggable_get, return self.check_content_type(catalog_response) catalog = catalog_response.json() - links = catalog.get('links', []) + links = catalog.get("links", []) register_url = self.get_registration_link(catalog, links) if isinstance(register_url, ProblemDetail): @@ -1970,31 +2218,31 @@ def get_catalog(self, do_get, url): def check_content_type(self, catalog_response): """Make sure the catalog for the service is in a valid format.""" - content_type = catalog_response.headers.get('Content-Type') - if content_type != 'application/opds+json': + content_type = catalog_response.headers.get("Content-Type") + if content_type != "application/opds+json": return REMOTE_INTEGRATION_FAILED.detailed( - _('The service did not provide a valid catalog.') + _("The service did not provide a valid catalog.") ) def get_registration_link(self, catalog, links): """Get the link for registration from the catalog.""" register_link_filter = lambda l: ( - l.get('rel')=='register' and - l.get('type')==self.METADATA_SERVICE_URI_TYPE + l.get("rel") == "register" + and l.get("type") == self.METADATA_SERVICE_URI_TYPE ) register_urls = list(filter(register_link_filter, links)) if not register_urls: return REMOTE_INTEGRATION_FAILED.detailed( - _('The service did not provide a register link.') + _("The service did not provide a register link.") ) # Get the full registration url. - register_url = register_urls[0].get('href') - if not register_url.startswith('http'): + register_url = register_urls[0].get("href") + if not register_url.startswith("http"): # We have a relative path. Create a full registration url. - base_url = catalog.get('id') + base_url = catalog.get("id") register_url = urllib.parse.urljoin(base_url, register_url) return register_url @@ -2005,10 +2253,10 @@ def update_headers(self, integration): # NOTE: This is no longer technically necessary since we prove # ownership with a signed JWT. - headers = { 'Content-Type' : 'application/x-www-form-urlencoded' } + headers = {"Content-Type": "application/x-www-form-urlencoded"} if integration.password: - token = base64.b64encode(integration.password.encode('utf-8')) - headers['Authorization'] = 'Bearer ' + token + token = base64.b64encode(integration.password.encode("utf-8")) + headers["Authorization"] = "Bearer " + token return headers def register(self, register_url, headers, do_post): @@ -2017,8 +2265,7 @@ def register(self, register_url, headers, do_post): try: body = self.sitewide_registration_document() response = do_post( - register_url, body, allowed_response_codes=['2xx'], - headers=headers + register_url, body, allowed_response_codes=["2xx"], headers=headers ) except Exception as e: return REMOTE_INTEGRATION_FAILED.detailed(str(e)) @@ -2029,10 +2276,10 @@ def get_shared_secret(self, response): service, or return an error message if there is no shared secret.""" registration_info = response.json() - shared_secret = registration_info.get('metadata', {}).get('shared_secret') + shared_secret = registration_info.get("metadata", {}).get("shared_secret") if not shared_secret: return REMOTE_INTEGRATION_FAILED.detailed( - _('The service did not provide registration information.') + _("The service did not provide registration information.") ) return shared_secret @@ -2048,10 +2295,10 @@ def sitewide_registration_document(self): public_key, private_key = self.manager.sitewide_key_pair # Advertise the public key so that the foreign site can encrypt # things for us. - public_key_dict = dict(type='RSA', value=public_key) - public_key_url = self.url_for('public_key_document') + public_key_dict = dict(type="RSA", value=public_key) + public_key_url = self.url_for("public_key_document") in_one_minute = utc_now() + timedelta(seconds=60) - payload = {'exp': in_one_minute} + payload = {"exp": in_one_minute} # Sign a JWT with the private key to prove ownership of the site. - token = jwt.encode(payload, private_key, algorithm='RS256') + token = jwt.encode(payload, private_key, algorithm="RS256") return dict(url=public_key_url, jwt=token) diff --git a/api/admin/controller/admin_auth_services.py b/api/admin/controller/admin_auth_services.py index d4bc5d8b10..8084abfc1c 100644 --- a/api/admin/controller/admin_auth_services.py +++ b/api/admin/controller/admin_auth_services.py @@ -1,32 +1,36 @@ import flask from flask import Response from flask_babel import lazy_gettext as _ -from core.model import ( - ExternalIntegration, - get_one, - get_one_or_create, + +from api.admin.google_oauth_admin_authentication_provider import ( + GoogleOAuthAdminAuthenticationProvider, ) +from api.admin.problem_details import * +from core.model import ExternalIntegration, get_one, get_one_or_create from core.util.problem_detail import ProblemDetail + from . import SettingsController -from api.admin.google_oauth_admin_authentication_provider import GoogleOAuthAdminAuthenticationProvider -from api.admin.problem_details import * -class AdminAuthServicesController(SettingsController): +class AdminAuthServicesController(SettingsController): def __init__(self, manager): super(AdminAuthServicesController, self).__init__(manager) provider_apis = [GoogleOAuthAdminAuthenticationProvider] - self.protocols = self._get_integration_protocols(provider_apis, protocol_name_attr="NAME") + self.protocols = self._get_integration_protocols( + provider_apis, protocol_name_attr="NAME" + ) def process_admin_auth_services(self): self.require_system_admin() - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post() def process_get(self): - auth_services = self._get_integration_info(ExternalIntegration.ADMIN_AUTH_GOAL, self.protocols) + auth_services = self._get_integration_info( + ExternalIntegration.ADMIN_AUTH_GOAL, self.protocols + ) return dict( admin_auth_services=auth_services, protocols=self.protocols, @@ -46,8 +50,10 @@ def process_post(self): if not auth_service: if protocol: auth_service, is_new = get_one_or_create( - self._db, ExternalIntegration, protocol=protocol, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + self._db, + ExternalIntegration, + protocol=protocol, + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) else: return NO_PROTOCOL_FOR_NEW_SERVICE @@ -92,7 +98,12 @@ def validate_form_fields(self, **fields): def process_delete(self, protocol): self.require_system_admin() - service = get_one(self._db, ExternalIntegration, protocol=protocol, goal=ExternalIntegration.ADMIN_AUTH_GOAL) + service = get_one( + self._db, + ExternalIntegration, + protocol=protocol, + goal=ExternalIntegration.ADMIN_AUTH_GOAL, + ) if not service: return MISSING_SERVICE self._db.delete(service) diff --git a/api/admin/controller/analytics_services.py b/api/admin/controller/analytics_services.py index 38a9d81b6a..a56b3b1b24 100644 --- a/api/admin/controller/analytics_services.py +++ b/api/admin/controller/analytics_services.py @@ -1,38 +1,38 @@ import flask from flask import Response + from api.admin.problem_details import * from api.google_analytics_provider import GoogleAnalyticsProvider from core.local_analytics_provider import LocalAnalyticsProvider -from core.model import ( - ExternalIntegration, - get_one, -) +from core.model import ExternalIntegration, get_one from core.util.problem_detail import ProblemDetail + from . import SettingsController -class AnalyticsServicesController(SettingsController): +class AnalyticsServicesController(SettingsController): def __init__(self, manager): super(AnalyticsServicesController, self).__init__(manager) - provider_apis = [GoogleAnalyticsProvider, - LocalAnalyticsProvider, - ] + provider_apis = [ + GoogleAnalyticsProvider, + LocalAnalyticsProvider, + ] self.protocols = self._get_integration_protocols(provider_apis) self.goal = ExternalIntegration.ANALYTICS_GOAL def process_analytics_services(self): - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post() def process_get(self): - if flask.request.method == 'GET': + if flask.request.method == "GET": services = self._get_integration_info(self.goal, self.protocols) # Librarians should be able to see, but not modify local analytics services. # Setting the level to 2 will communicate that to the front end. for x in services: - if (x["protocol"] == 'core.local_analytics_provider'): + if x["protocol"] == "core.local_analytics_provider": x["level"] = 2 return dict( analytics_services=services, @@ -46,7 +46,7 @@ def process_post(self): fields = {"name": name, "protocol": protocol, "url": url} # Don't let librarians create local analytics services. - if protocol == 'core.local_analytics_provider': + if protocol == "core.local_analytics_provider": self.require_higher_than_librarian() form_field_error = self.validate_form_fields(**fields) @@ -109,6 +109,4 @@ def validate_form_fields(self, **fields): return INCOMPLETE_CONFIGURATION def process_delete(self, service_id): - return self._delete_integration( - service_id, self.goal - ) + return self._delete_integration(service_id, self.goal) diff --git a/api/admin/controller/catalog_services.py b/api/admin/controller/catalog_services.py index 7dbf9646f0..c7845ea75f 100644 --- a/api/admin/controller/catalog_services.py +++ b/api/admin/controller/catalog_services.py @@ -4,27 +4,25 @@ from api.admin.problem_details import * from core.marc import MARCExporter -from core.model import ( - ExternalIntegration, - get_one, - get_one_or_create -) +from core.model import ExternalIntegration, get_one, get_one_or_create from core.model.configuration import ExternalIntegrationLink from core.s3 import S3UploaderConfiguration from core.util.problem_detail import ProblemDetail + from . import SettingsController class CatalogServicesController(SettingsController): - def __init__(self, manager): super(CatalogServicesController, self).__init__(manager) service_apis = [MARCExporter] - self.protocols = self._get_integration_protocols(service_apis, protocol_name_attr="NAME") + self.protocols = self._get_integration_protocols( + service_apis, protocol_name_attr="NAME" + ) self.update_protocol_settings() - + def update_protocol_settings(self): - self.protocols[0]['settings'] = [MARCExporter.get_storage_settings(self._db)] + self.protocols[0]["settings"] = [MARCExporter.get_storage_settings(self._db)] def process_catalog_services(self): self.require_system_admin() @@ -35,7 +33,9 @@ def process_catalog_services(self): return self.process_post() def process_get(self): - services = self._get_integration_info(ExternalIntegration.CATALOG_GOAL, self.protocols) + services = self._get_integration_info( + ExternalIntegration.CATALOG_GOAL, self.protocols + ) self.update_protocol_settings() return dict( catalog_services=services, @@ -52,7 +52,12 @@ def process_post(self): id = flask.request.form.get("id") if id: # Find an existing service to edit - service = get_one(self._db, ExternalIntegration, id=id, goal=ExternalIntegration.CATALOG_GOAL) + service = get_one( + self._db, + ExternalIntegration, + id=id, + goal=ExternalIntegration.CATALOG_GOAL, + ) if not service: return MISSING_SERVICE if protocol != service.protocol: @@ -60,7 +65,9 @@ def process_post(self): else: # Create a new service service, is_new = self._create_integration( - self.protocols, protocol, ExternalIntegration.CATALOG_GOAL, + self.protocols, + protocol, + ExternalIntegration.CATALOG_GOAL, ) if isinstance(service, ProblemDetail): return service @@ -91,20 +98,21 @@ def process_post(self): return Response(str(service.id), 201) else: return Response(str(service.id), 200) - + def _set_external_integration_link(self, service): """Either set or delete the external integration link between the service and the storage integration. """ - mirror_integration_id = flask.request.form.get('mirror_integration_id') - + mirror_integration_id = flask.request.form.get("mirror_integration_id") + # If no storage integration was selected, then delete the existing # external integration link. current_integration_link, ignore = get_one_or_create( - self._db, ExternalIntegrationLink, + self._db, + ExternalIntegrationLink, library_id=None, external_integration_id=service.id, - purpose=ExternalIntegrationLink.MARC + purpose=ExternalIntegrationLink.MARC, ) if mirror_integration_id == self.NO_MIRROR_INTEGRATION: @@ -115,10 +123,14 @@ def _set_external_integration_link(self, service): self._db, ExternalIntegration, id=mirror_integration_id ) # Only get storage integrations that have a MARC file option set - if not storage_integration or \ - not storage_integration.setting(S3UploaderConfiguration.MARC_BUCKET_KEY).value: + if ( + not storage_integration + or not storage_integration.setting( + S3UploaderConfiguration.MARC_BUCKET_KEY + ).value + ): return MISSING_INTEGRATION - current_integration_link.other_integration_id=storage_integration.id + current_integration_link.other_integration_id = storage_integration.id def validate_form_fields(self, protocol): """Verify that the protocol which the user has selected is in the list @@ -144,15 +156,18 @@ def check_libraries(self, service): for library in service.libraries: marc_export_count = 0 for integration in library.integrations: - if integration.goal == ExternalIntegration.CATALOG_GOAL and integration.protocol == ExternalIntegration.MARC_EXPORT: + if ( + integration.goal == ExternalIntegration.CATALOG_GOAL + and integration.protocol == ExternalIntegration.MARC_EXPORT + ): marc_export_count += 1 if marc_export_count > 1: - return MULTIPLE_SERVICES_FOR_LIBRARY.detailed(_( - "You tried to add a MARC export service to %(library)s, but it already has one.", - library=library.short_name, - )) + return MULTIPLE_SERVICES_FOR_LIBRARY.detailed( + _( + "You tried to add a MARC export service to %(library)s, but it already has one.", + library=library.short_name, + ) + ) def process_delete(self, service_id): - return self._delete_integration( - service_id, ExternalIntegration.CATALOG_GOAL - ) + return self._delete_integration(service_id, ExternalIntegration.CATALOG_GOAL) diff --git a/api/admin/controller/cdn_services.py b/api/admin/controller/cdn_services.py index 3704570959..eef53af1a2 100644 --- a/api/admin/controller/cdn_services.py +++ b/api/admin/controller/cdn_services.py @@ -1,18 +1,15 @@ import flask from flask import Response from flask_babel import lazy_gettext as _ + from api.admin.problem_details import * -from core.model import ( - Configuration, - ExternalIntegration, - get_one, -) +from core.model import Configuration, ExternalIntegration, get_one from core.util.problem_detail import ProblemDetail from . import SettingsController -class CDNServicesController(SettingsController): +class CDNServicesController(SettingsController): def __init__(self, manager): super(CDNServicesController, self).__init__(manager) self.protocols = [ @@ -20,8 +17,17 @@ def __init__(self, manager): "name": ExternalIntegration.CDN, "sitewide": True, "settings": [ - { "key": ExternalIntegration.URL, "label": _("CDN URL"), "required": True, "format": "url" }, - { "key": Configuration.CDN_MIRRORED_DOMAIN_KEY, "label": _("Mirrored domain"), "required": True }, + { + "key": ExternalIntegration.URL, + "label": _("CDN URL"), + "required": True, + "format": "url", + }, + { + "key": Configuration.CDN_MIRRORED_DOMAIN_KEY, + "label": _("Mirrored domain"), + "required": True, + }, ], } ] @@ -29,12 +35,11 @@ def __init__(self, manager): def process_cdn_services(self): self.require_system_admin() - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post() - def process_get(self): services = self._get_integration_info(self.goal, self.protocols) return dict( @@ -101,6 +106,4 @@ def validate_form_fields(self, **fields): return wrong_format def process_delete(self, service_id): - return self._delete_integration( - service_id, self.goal - ) + return self._delete_integration(service_id, self.goal) diff --git a/api/admin/controller/collection_library_registrations.py b/api/admin/controller/collection_library_registrations.py index 0570d883c0..0475822e34 100644 --- a/api/admin/controller/collection_library_registrations.py +++ b/api/admin/controller/collection_library_registrations.py @@ -1,22 +1,17 @@ import flask from flask import Response from flask_babel import lazy_gettext as _ + from api.admin.problem_details import * from api.odl import SharedODLAPI -from api.registry import ( - Registration, - RemoteRegistry, -) -from core.model import ( - Collection, - ConfigurationSetting, - get_one, - Library, -) +from api.registry import Registration, RemoteRegistry +from core.model import Collection, ConfigurationSetting, Library, get_one from core.util.http import HTTP from core.util.problem_detail import ProblemDetail + from . import SettingsController + class CollectionLibraryRegistrationsController(SettingsController): """Use the OPDS Directory Registration Protocol to register a Collection with its remote source of truth. @@ -29,15 +24,17 @@ def __init__(self, manager): super(CollectionLibraryRegistrationsController, self).__init__(manager) self.shared_collection_provider_apis = [SharedODLAPI] - def process_collection_library_registrations(self, - do_get=HTTP.debuggable_get, - do_post=HTTP.debuggable_post, - key=None, - registration_class=Registration): + def process_collection_library_registrations( + self, + do_get=HTTP.debuggable_get, + do_post=HTTP.debuggable_post, + key=None, + registration_class=Registration, + ): registration_class = registration_class or Registration self.require_system_admin() - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post(registration_class, do_get, do_post) @@ -48,7 +45,10 @@ def get_library_info(self, library, collection): library_info = dict(short_name=library.short_name) status = ConfigurationSetting.for_library_and_externalintegration( - self._db, Registration.LIBRARY_REGISTRATION_STATUS, library, collection.external_integration, + self._db, + Registration.LIBRARY_REGISTRATION_STATUS, + library, + collection.external_integration, ).value if status: library_info["status"] = status @@ -89,9 +89,11 @@ def process_post(self, registration_class, do_get, do_post): registration = registration_class(registry, library) registered = registration.push( - Registration.PRODUCTION_STAGE, self.url_for, + Registration.PRODUCTION_STAGE, + self.url_for, catalog_url=collection.external_account_id, - do_get=do_get, do_post=do_post + do_get=do_get, + do_post=do_post, ) if isinstance(registered, ProblemDetail): @@ -106,7 +108,9 @@ def look_up_collection(self, collection_id): collection = get_one(self._db, Collection, id=collection_id) if not collection: return MISSING_COLLECTION - if collection.protocol not in [api.NAME for api in self.shared_collection_provider_apis]: + if collection.protocol not in [ + api.NAME for api in self.shared_collection_provider_apis + ]: return COLLECTION_DOES_NOT_SUPPORT_REGISTRATION return collection diff --git a/api/admin/controller/collection_self_tests.py b/api/admin/controller/collection_self_tests.py index fe0b00575e..316430144f 100644 --- a/api/admin/controller/collection_self_tests.py +++ b/api/admin/controller/collection_self_tests.py @@ -1,17 +1,16 @@ import flask from flask import Response from flask_babel import lazy_gettext as _ + +from api.admin.controller.self_tests import SelfTestsController from api.admin.problem_details import * -from core.opds_import import (OPDSImporter, OPDSImportMonitor) -from core.model import ( - Collection -) +from core.model import Collection +from core.opds_import import OPDSImporter, OPDSImportMonitor from core.selftest import HasSelfTests from core.util.problem_detail import ProblemDetail -from api.admin.controller.self_tests import SelfTestsController -class CollectionSelfTestsController(SelfTestsController): +class CollectionSelfTestsController(SelfTestsController): def __init__(self, manager): super(CollectionSelfTestsController, self).__init__(manager) self.type = _("collection") @@ -46,7 +45,9 @@ def get_info(self, collection): def _find_protocol_class(self, collection): """Figure out which protocol is providing books to this collection""" if collection.protocol in [p.get("name") for p in self.protocols]: - protocol_class_found = [p for p in self.PROVIDER_APIS if p.NAME == collection.protocol] + protocol_class_found = [ + p for p in self.PROVIDER_APIS if p.NAME == collection.protocol + ] if len(protocol_class_found) == 1: return protocol_class_found[0] @@ -55,10 +56,14 @@ def run_tests(self, collection): if self.protocol_class: value = None - if (collection_protocol == OPDSImportMonitor.PROTOCOL): + if collection_protocol == OPDSImportMonitor.PROTOCOL: self.protocol_class = OPDSImportMonitor - value, results = self.protocol_class.run_self_tests(self._db, self.protocol_class, self._db, collection, OPDSImporter) + value, results = self.protocol_class.run_self_tests( + self._db, self.protocol_class, self._db, collection, OPDSImporter + ) elif issubclass(self.protocol_class, HasSelfTests): - value, results = self.protocol_class.run_self_tests(self._db, self.protocol_class, self._db, collection) + value, results = self.protocol_class.run_self_tests( + self._db, self.protocol_class, self._db, collection + ) return value diff --git a/api/admin/controller/collection_settings.py b/api/admin/controller/collection_settings.py index 5b8a63b647..5ddd5f9385 100644 --- a/api/admin/controller/collection_settings.py +++ b/api/admin/controller/collection_settings.py @@ -1,19 +1,23 @@ -from . import SettingsController +import json + import flask from flask import Response from flask_babel import lazy_gettext as _ -import json + from api.admin.problem_details import * from core.model import ( Collection, ConfigurationSetting, ExternalIntegration, + Library, get_one, get_one_or_create, - Library, ) -from core.util.problem_detail import ProblemDetail from core.model.configuration import ExternalIntegrationLink +from core.util.problem_detail import ProblemDetail + +from . import SettingsController + class CollectionSettingsController(SettingsController): def __init__(self, manager): @@ -21,17 +25,19 @@ def __init__(self, manager): self.type = _("collection") def _get_collection_protocols(self): - protocols = super(CollectionSettingsController, self)._get_collection_protocols(self.PROVIDER_APIS) + protocols = super(CollectionSettingsController, self)._get_collection_protocols( + self.PROVIDER_APIS + ) # If there are storage integrations, add a mirror integration # setting to every protocol's 'settings' block. mirror_integration_settings = self._mirror_integration_settings() if mirror_integration_settings: for protocol in protocols: - protocol['settings'] += mirror_integration_settings + protocol["settings"] += mirror_integration_settings return protocols def process_collections(self): - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post() @@ -42,22 +48,34 @@ def process_get(self): user = flask.request.admin collections = [] protocolClass = None - for collection_object in self._db.query(Collection).order_by(Collection.name).all(): + for collection_object in ( + self._db.query(Collection).order_by(Collection.name).all() + ): if not user or not user.can_see_collection(collection_object): continue collection_dict = self.collection_to_dict(collection_object) if collection_object.protocol in [p.get("name") for p in protocols]: - [protocol] = [p for p in protocols if p.get("name") == collection_object.protocol] + [protocol] = [ + p for p in protocols if p.get("name") == collection_object.protocol + ] libraries = self.load_libraries(collection_object, user, protocol) - collection_dict['libraries'] = libraries - settings = self.load_settings(protocol.get("settings"), collection_object, collection_dict.get("settings")) - collection_dict['settings'] = settings + collection_dict["libraries"] = libraries + settings = self.load_settings( + protocol.get("settings"), + collection_object, + collection_dict.get("settings"), + ) + collection_dict["settings"] = settings protocolClass = self.find_protocol_class(collection_object) - collection_dict["self_test_results"] = self._get_prior_test_results(collection_object, protocolClass) - collection_dict["marked_for_deletion"] = collection_object.marked_for_deletion + collection_dict["self_test_results"] = self._get_prior_test_results( + collection_object, protocolClass + ) + collection_dict[ + "marked_for_deletion" + ] = collection_object.marked_for_deletion collections.append(collection_dict) @@ -82,8 +100,11 @@ def load_libraries(self, collection_object, user, protocol): for library in collection_object.libraries: if not user or not user.is_librarian(library): continue - libraries.append(self._get_integration_library_info( - collection_object.external_integration, library, protocol)) + libraries.append( + self._get_integration_library_info( + collection_object.external_integration, library, protocol + ) + ) return libraries @@ -97,19 +118,22 @@ def load_settings(self, protocol_settings, collection_object, collection_setting continue key = protocol_setting.get("key") if not collection_settings or key not in collection_settings: - if key.endswith('mirror_integration_id'): + if key.endswith("mirror_integration_id"): storage_integration = get_one( - self._db, ExternalIntegrationLink, + self._db, + ExternalIntegrationLink, external_integration_id=collection_object.external_integration_id, # either 'books_mirror' or 'covers_mirror' - purpose=key.rsplit('_', 2)[0] + purpose=key.rsplit("_", 2)[0], ) if storage_integration: value = str(storage_integration.other_integration_id) else: value = self.NO_MIRROR_INTEGRATION elif protocol_setting.get("type") in ("list", "menu"): - value = collection_object.external_integration.setting(key).json_value + value = collection_object.external_integration.setting( + key + ).json_value else: value = collection_object.external_integration.setting(key).value settings[key] = value @@ -120,7 +144,9 @@ def find_protocol_class(self, collection_object): """Figure out which class this collection's protocol belongs to, from the list of possible protocols defined in PROVIDER_APIS (in SettingsController)""" - protocolClassFound = [p for p in self.PROVIDER_APIS if p.NAME == collection_object.protocol] + protocolClassFound = [ + p for p in self.PROVIDER_APIS if p.NAME == collection_object.protocol + ] if len(protocolClassFound) == 1: return protocolClassFound[0] @@ -189,7 +215,9 @@ def validate_form_fields(self, is_new, protocols, **fields): if fields.get("protocol") not in [p.get("name") for p in protocols]: return UNKNOWN_PROTOCOL else: - [protocol] = [p for p in protocols if p.get("name") == fields.get("protocol")] + [protocol] = [ + p for p in protocols if p.get("name") == fields.get("protocol") + ] wrong_format = self.validate_formats(protocol.get("settings")) if wrong_format: return wrong_format @@ -202,7 +230,9 @@ def validate_collection(self, **fields): if fields.get("protocol") != fields.get("collection").protocol: return CANNOT_CHANGE_PROTOCOL if fields.get("name") != fields.get("collection").name: - collection_with_name = get_one(self._db, Collection, name=fields.get("name")) + collection_with_name = get_one( + self._db, Collection, name=fields.get("name") + ) if collection_with_name: return COLLECTION_NAME_ALREADY_IN_USE @@ -233,8 +263,11 @@ def validate_external_account_id_setting(self, value, setting): if not value and not setting.get("optional"): # Roll back any changes to the collection that have already been made. return INCOMPLETE_CONFIGURATION.detailed( - _("The collection configuration is missing a required setting: %(setting)s", - setting=setting.get("label"))) + _( + "The collection configuration is missing a required setting: %(setting)s", + setting=setting.get("label"), + ) + ) def process_settings(self, settings, collection): """Go through the settings that the user has just submitted for this collection, @@ -249,39 +282,48 @@ def process_settings(self, settings, collection): if error: return error collection.external_account_id = value - elif key.endswith('mirror_integration_id') and value: + elif key.endswith("mirror_integration_id") and value: external_integration_link = self._set_external_integration_link( - self._db, key, value, collection, + self._db, + key, + value, + collection, ) if isinstance(external_integration_link, ProblemDetail): return external_integration_link else: - result = self._set_integration_setting(collection.external_integration, setting) + result = self._set_integration_setting( + collection.external_integration, setting + ) if isinstance(result, ProblemDetail): return result def _set_external_integration_link( - self, _db, key, value, collection, + self, + _db, + key, + value, + collection, ): """Find or create a ExternalIntegrationLink and either delete it or update the other external integration it links to. """ collection_service = get_one( - _db, ExternalIntegration, - id=collection.external_integration_id + _db, ExternalIntegration, id=collection.external_integration_id ) storage_service = None other_integration_id = None - purpose = key.rsplit('_', 2)[0] + purpose = key.rsplit("_", 2)[0] external_integration_link, ignore = get_one_or_create( - _db, ExternalIntegrationLink, + _db, + ExternalIntegrationLink, library_id=None, external_integration_id=collection_service.id, - purpose=purpose + purpose=purpose, ) if not external_integration_link: return MISSING_INTEGRATION @@ -289,10 +331,7 @@ def _set_external_integration_link( if value == self.NO_MIRROR_INTEGRATION: _db.delete(external_integration_link) else: - storage_service = get_one( - _db, ExternalIntegration, - id=value - ) + storage_service = get_one(_db, ExternalIntegration, id=value) if storage_service: if storage_service.goal != ExternalIntegration.STORAGE_GOAL: return INTEGRATION_GOAL_CONFLICT @@ -315,12 +354,21 @@ def process_libraries(self, protocol, collection): libraries = json.loads(flask.request.form.get("libraries")) for library_info in libraries: - library = get_one(self._db, Library, short_name=library_info.get("short_name")) + library = get_one( + self._db, Library, short_name=library_info.get("short_name") + ) if not library: - return NO_SUCH_LIBRARY.detailed(_("You attempted to add the collection to %(library_short_name)s, but the library does not exist.", library_short_name=library_info.get("short_name"))) + return NO_SUCH_LIBRARY.detailed( + _( + "You attempted to add the collection to %(library_short_name)s, but the library does not exist.", + library_short_name=library_info.get("short_name"), + ) + ) if collection not in library.collections: library.collections.append(collection) - result = self._set_integration_library(collection.external_integration, library_info, protocol) + result = self._set_integration_library( + collection.external_integration, library_info, protocol + ) if isinstance(result, ProblemDetail): return result for library in collection.libraries: diff --git a/api/admin/controller/discovery_service_library_registrations.py b/api/admin/controller/discovery_service_library_registrations.py index 5dbc6fbc88..8c3afacc9a 100644 --- a/api/admin/controller/discovery_service_library_registrations.py +++ b/api/admin/controller/discovery_service_library_registrations.py @@ -1,21 +1,18 @@ +import json + import flask from flask import Response from flask_babel import lazy_gettext as _ -import json + from api.admin.problem_details import * -from api.registry import ( - RemoteRegistry, - Registration, -) -from core.model import ( - ExternalIntegration, - get_one, - Library, -) +from api.registry import Registration, RemoteRegistry +from core.model import ExternalIntegration, Library, get_one from core.util.http import HTTP from core.util.problem_detail import ProblemDetail + from . import SettingsController + class DiscoveryServiceLibraryRegistrationsController(SettingsController): """List the libraries that have been registered with a specific @@ -29,14 +26,15 @@ def __init__(self, manager): super(DiscoveryServiceLibraryRegistrationsController, self).__init__(manager) self.goal = ExternalIntegration.DISCOVERY_GOAL - def process_discovery_service_library_registrations(self, - registration_class=None, - do_get=HTTP.debuggable_get, - do_post=HTTP.debuggable_post + def process_discovery_service_library_registrations( + self, + registration_class=None, + do_get=HTTP.debuggable_get, + do_post=HTTP.debuggable_post, ): registration_class = registration_class or Registration self.require_system_admin() - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get(do_get) else: return self.process_post(registration_class, do_get, do_post) @@ -48,11 +46,9 @@ def process_get(self, do_get=HTTP.debuggable_get): services = [] for registry in RemoteRegistry.for_protocol_and_goal( - self._db, ExternalIntegration.OPDS_REGISTRATION, self.goal + self._db, ExternalIntegration.OPDS_REGISTRATION, self.goal ): - result = ( - registry.fetch_registration_document(do_get=do_get) - ) + result = registry.fetch_registration_document(do_get=do_get) if isinstance(result, ProblemDetail): # Unlike most cases like this, a ProblemDetail doesn't # mean the whole request is ruined -- just that one of @@ -98,7 +94,7 @@ def get_library_info(self, registration): def look_up_registry(self, integration_id): """Find the RemoteRegistry that the user is trying to register the library with, - and check that it actually exists.""" + and check that it actually exists.""" registry = RemoteRegistry.for_integration_id( self._db, integration_id, self.goal @@ -120,7 +116,9 @@ def process_post(self, registration_class, do_get, do_post): integration_id = flask.request.form.get("integration_id") library_short_name = flask.request.form.get("library_short_name") - stage = flask.request.form.get("registration_stage") or Registration.TESTING_STAGE + stage = ( + flask.request.form.get("registration_stage") or Registration.TESTING_STAGE + ) registry = self.look_up_registry(integration_id) if isinstance(registry, ProblemDetail): diff --git a/api/admin/controller/discovery_services.py b/api/admin/controller/discovery_services.py index 2a5f9a5a6e..5afe67e35d 100644 --- a/api/admin/controller/discovery_services.py +++ b/api/admin/controller/discovery_services.py @@ -1,18 +1,16 @@ import flask from flask import Response from flask_babel import lazy_gettext as _ + from api.admin.problem_details import * from api.registry import RemoteRegistry -from core.model import ( - ExternalIntegration, - get_one, - get_one_or_create, -) +from core.model import ExternalIntegration, get_one, get_one_or_create from core.util.problem_detail import ProblemDetail + from . import SettingsController -class DiscoveryServicesController(SettingsController): +class DiscoveryServicesController(SettingsController): def __init__(self, manager): super(DiscoveryServicesController, self).__init__(manager) self.opds_registration = ExternalIntegration.OPDS_REGISTRATION @@ -21,7 +19,12 @@ def __init__(self, manager): "name": self.opds_registration, "sitewide": True, "settings": [ - { "key": ExternalIntegration.URL, "label": _("URL"), "required": True, "format": "url" }, + { + "key": ExternalIntegration.URL, + "label": _("URL"), + "required": True, + "format": "url", + }, ], "supports_registration": True, "supports_staging": True, @@ -31,7 +34,7 @@ def __init__(self, manager): def process_discovery_services(self): self.require_system_admin() - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post() @@ -57,8 +60,10 @@ def set_up_default_registry(self): """Set up the default library registry; no other registries exist yet.""" service, is_new = get_one_or_create( - self._db, ExternalIntegration, protocol=self.opds_registration, - goal=self.goal + self._db, + ExternalIntegration, + protocol=self.opds_registration, + goal=self.goal, ) if is_new: service.url = RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_URL @@ -92,9 +97,7 @@ def process_post(self): return name_error url = flask.request.form.get("url") - url_not_unique = self.check_url_unique( - service, url, protocol, self.goal - ) + url_not_unique = self.check_url_unique(service, url, protocol, self.goal) if url_not_unique: self._db.rollback() return url_not_unique @@ -143,6 +146,4 @@ def look_up_service_from_registry(self, protocol, id): return service def process_delete(self, service_id): - return self._delete_integration( - service_id, ExternalIntegration.DISCOVERY_GOAL - ) + return self._delete_integration(service_id, ExternalIntegration.DISCOVERY_GOAL) diff --git a/api/admin/controller/individual_admin_settings.py b/api/admin/controller/individual_admin_settings.py index c0e762c5ba..5caa175189 100644 --- a/api/admin/controller/individual_admin_settings.py +++ b/api/admin/controller/individual_admin_settings.py @@ -1,24 +1,21 @@ -from . import SettingsController +import json + import flask from flask import Response from flask_babel import lazy_gettext as _ -import json -from core.model import ( - Admin, - AdminRole, - Library, - get_one, - get_one_or_create, -) -from core.util.problem_detail import ProblemDetail + from api.admin.exceptions import * from api.admin.problem_details import * from api.admin.validator import Validator +from core.model import Admin, AdminRole, Library, get_one, get_one_or_create +from core.util.problem_detail import ProblemDetail -class IndividualAdminSettingsController(SettingsController): +from . import SettingsController + +class IndividualAdminSettingsController(SettingsController): def process_individual_admins(self): - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post() @@ -29,7 +26,9 @@ def process_get(self): roles = [] for role in admin.roles: if role.library: - if not flask.request.admin or not flask.request.admin.is_librarian(role.library): + if not flask.request.admin or not flask.request.admin.is_librarian( + role.library + ): continue roles.append(dict(role=role.role, library=role.library.short_name)) else: @@ -48,9 +47,11 @@ def process_post(self): return error # If there are no admins yet, anyone can create the first system admin. - settingUp = (self._db.query(Admin).count() == 0) + settingUp = self._db.query(Admin).count() == 0 if settingUp and not flask.request.form.get("password"): - return INCOMPLETE_CONFIGURATION.detailed(_("The password field cannot be blank.")) + return INCOMPLETE_CONFIGURATION.detailed( + _("The password field cannot be blank.") + ) admin, is_new = get_one_or_create(self._db, Admin, email=email) @@ -73,11 +74,11 @@ def process_post(self): def check_permissions(self, admin, settingUp): """Before going any further, check that the user actually has permission - to create/edit this type of admin""" + to create/edit this type of admin""" - # For readability: the person who is submitting the form is referred to as "user" - # rather than as something that could be confused with "admin" (the admin - # which the user is submitting the form in order to create/edit.) + # For readability: the person who is submitting the form is referred to as "user" + # rather than as something that could be confused with "admin" (the admin + # which the user is submitting the form in order to create/edit.) if not settingUp: user = flask.request.admin @@ -120,7 +121,12 @@ def look_up_library_for_role(self, role): if library_short_name: library = Library.lookup(self._db, library_short_name) if not library: - return LIBRARY_NOT_FOUND.detailed(_("Library \"%(short_name)s\" does not exist.", short_name=library_short_name)) + return LIBRARY_NOT_FOUND.detailed( + _( + 'Library "%(short_name)s" does not exist.', + short_name=library_short_name, + ) + ) return library def handle_roles(self, admin, roles, settingUp): @@ -148,7 +154,7 @@ def handle_roles(self, admin, roles, settingUp): return library if (role.get("role"), library) in old_roles_set: - # The admin already has this role. + # The admin already has this role. continue if library: diff --git a/api/admin/controller/library_settings.py b/api/admin/controller/library_settings.py index 3f5649f314..0561a78f62 100644 --- a/api/admin/controller/library_settings.py +++ b/api/admin/controller/library_settings.py @@ -1,35 +1,30 @@ import base64 -from io import BytesIO import json -from typing import Any, Dict, Optional import uuid +from io import BytesIO +from typing import Any, Dict, Optional import flask +import wcag_contrast_ratio from flask import Response from flask_babel import lazy_gettext as _ from PIL import Image -import wcag_contrast_ratio -from . import SettingsController from api.admin.announcement_list_validator import AnnouncementListValidator -from api.config import Configuration -from api.lanes import create_default_lanes from api.admin.geographic_validator import GeographicValidator from api.admin.problem_details import * -from core.model import ( - ConfigurationSetting, - create, - get_one, - Library, -) -from core.util.problem_detail import ProblemDetail +from api.config import Configuration +from api.lanes import create_default_lanes +from core.model import ConfigurationSetting, Library, create, get_one from core.util import LanguageCodes +from core.util.problem_detail import ProblemDetail +from . import SettingsController -class LibrarySettingsController(SettingsController): +class LibrarySettingsController(SettingsController): def process_libraries(self): - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post() @@ -44,11 +39,17 @@ def process_get(self): settings = dict() for setting in Configuration.LIBRARY_SETTINGS: if setting.get("type") == "announcements": - value = ConfigurationSetting.for_library(setting.get("key"), library).json_value + value = ConfigurationSetting.for_library( + setting.get("key"), library + ).json_value if value: - value = AnnouncementListValidator().validate_announcements(value) + value = AnnouncementListValidator().validate_announcements( + value + ) if setting.get("type") == "list": - value = ConfigurationSetting.for_library(setting.get("key"), library).json_value + value = ConfigurationSetting.for_library( + setting.get("key"), library + ).json_value if value and setting.get("format") == "geographic": value = self.get_extra_geographic_information(value) else: @@ -57,12 +58,14 @@ def process_get(self): if value: settings[setting.get("key")] = value - libraries += [dict( - uuid=library.uuid, - name=library.name, - short_name=library.short_name, - settings=settings, - )] + libraries += [ + dict( + uuid=library.uuid, + name=library.name, + short_name=library.short_name, + settings=settings, + ) + ] return dict(libraries=libraries, settings=Configuration.LIBRARY_SETTINGS) def process_post(self, validators_by_type=None): @@ -70,8 +73,8 @@ def process_post(self, validators_by_type=None): is_new = False if validators_by_type is None: validators_by_type = dict() - validators_by_type['geographic'] = GeographicValidator() - validators_by_type['announcements'] = AnnouncementListValidator() + validators_by_type["geographic"] = GeographicValidator() + validators_by_type["announcements"] = AnnouncementListValidator() library_uuid = flask.request.form.get("uuid") library = self.get_library_from_uuid(library_uuid) @@ -98,7 +101,9 @@ def process_post(self, validators_by_type=None): if short_name: library.short_name = short_name - configuration_settings = self.library_configuration_settings(library, validators_by_type) + configuration_settings = self.library_configuration_settings( + library, validators_by_type + ) if isinstance(configuration_settings, ProblemDetail): return configuration_settings @@ -113,8 +118,8 @@ def process_post(self, validators_by_type=None): def create_library(self, short_name, library_uuid): self.require_system_admin() library, is_new = create( - self._db, Library, short_name=short_name, - uuid=str(uuid.uuid4())) + self._db, Library, short_name=short_name, uuid=str(uuid.uuid4()) + ) return library, is_new def process_delete(self, library_uuid): @@ -123,7 +128,7 @@ def process_delete(self, library_uuid): self._db.delete(library) return Response(str(_("Deleted")), 200) -# Validation methods: + # Validation methods: def validate_form_fields(self): settings = Configuration.LIBRARY_SETTINGS @@ -131,7 +136,7 @@ def validate_form_fields(self): self.check_for_missing_fields, self.check_web_color_contrast, self.check_header_links, - self.validate_formats + self.validate_formats, ] for validation in validations: result = validation(settings) @@ -147,12 +152,18 @@ def check_for_missing_fields(self, settings): return error def check_for_missing_settings(self, settings): - required = [s for s in Configuration.LIBRARY_SETTINGS if s.get('required') and not s.get('default')] + required = [ + s + for s in Configuration.LIBRARY_SETTINGS + if s.get("required") and not s.get("default") + ] missing = [s for s in required if not flask.request.form.get(s.get("key"))] if missing: return INCOMPLETE_CONFIGURATION.detailed( - _("The configuration is missing a required setting: %(setting)s", - setting=missing[0].get("label")) + _( + "The configuration is missing a required setting: %(setting)s", + setting=missing[0].get("label"), + ) ) def check_web_color_contrast(self, settings): @@ -161,19 +172,43 @@ def check_web_color_contrast(self, settings): well on white, as these colors will serve as button backgrounds with white test, as well as text color on white backgrounds. """ - primary = flask.request.form.get(Configuration.WEB_PRIMARY_COLOR, Configuration.DEFAULT_WEB_PRIMARY_COLOR) - secondary = flask.request.form.get(Configuration.WEB_SECONDARY_COLOR, Configuration.DEFAULT_WEB_SECONDARY_COLOR) + primary = flask.request.form.get( + Configuration.WEB_PRIMARY_COLOR, Configuration.DEFAULT_WEB_PRIMARY_COLOR + ) + secondary = flask.request.form.get( + Configuration.WEB_SECONDARY_COLOR, Configuration.DEFAULT_WEB_SECONDARY_COLOR + ) + def hex_to_rgb(hex): hex = hex.lstrip("#") - return tuple(int(hex[i:i+2], 16)/255.0 for i in (0, 2 ,4)) - primary_passes = wcag_contrast_ratio.passes_AA(wcag_contrast_ratio.rgb(hex_to_rgb(primary), hex_to_rgb("#ffffff"))) - secondary_passes = wcag_contrast_ratio.passes_AA(wcag_contrast_ratio.rgb(hex_to_rgb(secondary), hex_to_rgb("#ffffff"))) + return tuple(int(hex[i : i + 2], 16) / 255.0 for i in (0, 2, 4)) + + primary_passes = wcag_contrast_ratio.passes_AA( + wcag_contrast_ratio.rgb(hex_to_rgb(primary), hex_to_rgb("#ffffff")) + ) + secondary_passes = wcag_contrast_ratio.passes_AA( + wcag_contrast_ratio.rgb(hex_to_rgb(secondary), hex_to_rgb("#ffffff")) + ) if not (primary_passes and secondary_passes): - primary_check_url = "https://contrast-ratio.com/#%23" + secondary[1:] + "-on-%23" + "#ffffff"[1:] - secondary_check_url = "https://contrast-ratio.com/#%23" + secondary[1:] + "-on-%23" + "#ffffff"[1:] + primary_check_url = ( + "https://contrast-ratio.com/#%23" + + secondary[1:] + + "-on-%23" + + "#ffffff"[1:] + ) + secondary_check_url = ( + "https://contrast-ratio.com/#%23" + + secondary[1:] + + "-on-%23" + + "#ffffff"[1:] + ) return INVALID_CONFIGURATION_OPTION.detailed( - _("The web primary and secondary colors don't have enough contrast to pass the WCAG 2.0 AA guidelines and will be difficult for some patrons to read. Check contrast for primary here and secondary here.", - primary_check_url=primary_check_url, secondary_check_url=secondary_check_url)) + _( + "The web primary and secondary colors don't have enough contrast to pass the WCAG 2.0 AA guidelines and will be difficult for some patrons to read. Check contrast for primary here and secondary here.", + primary_check_url=primary_check_url, + secondary_check_url=secondary_check_url, + ) + ) def check_header_links(self, settings): """Verify that header links and labels are the same length.""" @@ -181,19 +216,26 @@ def check_header_links(self, settings): header_labels = flask.request.form.getlist(Configuration.WEB_HEADER_LABELS) if len(header_links) != len(header_labels): return INVALID_CONFIGURATION_OPTION.detailed( - _("There must be the same number of web header links and web header labels.")) + _( + "There must be the same number of web header links and web header labels." + ) + ) def get_library_from_uuid(self, library_uuid): if library_uuid: # Library UUID is required when editing an existing library # from the admin interface, and isn't present for new libraries. library = get_one( - self._db, Library, uuid=library_uuid, + self._db, + Library, + uuid=library_uuid, ) if library: return library else: - return LIBRARY_NOT_FOUND.detailed(_("The specified library uuid does not exist.")) + return LIBRARY_NOT_FOUND.detailed( + _("The specified library uuid does not exist.") + ) def check_short_name_unique(self, library, short_name): if not library or short_name != library.short_name: @@ -203,7 +245,7 @@ def check_short_name_unique(self, library, short_name): if library_with_short_name: return LIBRARY_SHORT_NAME_ALREADY_IN_USE -# Configuration settings: + # Configuration settings: def get_extra_geographic_information(self, value): validator = GeographicValidator() @@ -220,7 +262,7 @@ def get_extra_geographic_information(self, value): return value def library_configuration_settings( - self, library, validators_by_format, settings=None + self, library, validators_by_format, settings=None ): """Validate and update a library's configuration settings based on incoming new values. @@ -235,9 +277,9 @@ def library_configuration_settings( for setting in settings: # Validate the incoming value. validator = None - if 'format' in setting: + if "format" in setting: validator = validators_by_format.get(setting["format"]) - elif 'type' in setting: + elif "type" in setting: validator = validators_by_format.get(setting["type"]) validated_value = self._validate_setting(library, setting, validator) @@ -246,9 +288,9 @@ def library_configuration_settings( return validated_value # Validation succeeded -- set the new value. - ConfigurationSetting.for_library(setting['key'], library).value = self._format_validated_value( - validated_value, validator - ) + ConfigurationSetting.for_library( + setting["key"], library + ).value = self._format_validated_value(validated_value, validator) def _validate_setting(self, library, setting, validator=None): """Validate the incoming value for a single library setting. @@ -266,8 +308,8 @@ def _validate_setting(self, library, setting, validator=None): # * A list value is returned as a JSON-encoded string. It # would be better to keep that as a list for longer in case # controller code needs to look at it. - format = setting.get('format') - type = setting.get('type') + format = setting.get("format") + type = setting.get("type") # In some cases, if there is no incoming value we can use a # default value or the current value. @@ -275,7 +317,7 @@ def _validate_setting(self, library, setting, validator=None): # When the configuration item is a list, we can't do this # because an empty list may be a valid value. current_value = self.current_value(setting, library) - default_value = setting.get('default') or current_value + default_value = setting.get("default") or current_value if format == "geographic": value = self.list_setting(setting) @@ -286,7 +328,12 @@ def _validate_setting(self, library, setting, validator=None): elif type == "list": value = self.list_setting(setting) if format == "language-code": - value = json.dumps([LanguageCodes.string_to_alpha_3(language) for language in json.loads(value)]) + value = json.dumps( + [ + LanguageCodes.string_to_alpha_3(language) + for language in json.loads(value) + ] + ) else: if type == "image": value = self.image_setting(setting) or default_value @@ -296,7 +343,7 @@ def _validate_setting(self, library, setting, validator=None): def scalar_setting(self, setting): """Retrieve the single value of the given setting from the current HTTP request.""" - return flask.request.form.get(setting['key']) + return flask.request.form.get(setting["key"]) def list_setting(self, setting, json_objects=False): """Retrieve the list of values for the given setting from the current HTTP request. @@ -307,7 +354,7 @@ def list_setting(self, setting, json_objects=False): :return: A JSON-encoded string encoding the list of values set for the given setting in the current request. """ - if setting.get('options'): + if setting.get("options"): # Restrict to the values in 'options'. value = [] for option in setting.get("options"): @@ -342,9 +389,11 @@ def _data_url_for_image(image: Image, _format="PNG") -> str: buffer = BytesIO() image.save(buffer, format=_format) b64 = base64.b64encode(buffer.getvalue()) - return "data:image/png;base64,%s" % b64.decode('utf-8') + return "data:image/png;base64,%s" % b64.decode("utf-8") - def image_setting(self, setting: Dict[str, Any], max_dimension=Configuration.LOGO_MAX_DIMENSION) -> Optional[str]: + def image_setting( + self, setting: Dict[str, Any], max_dimension=Configuration.LOGO_MAX_DIMENSION + ) -> Optional[str]: """Retrieve an uploaded image file for the setting and return its data URL. If the image is too large, scale it down to the `max_dimension` @@ -365,12 +414,11 @@ def image_setting(self, setting: Dict[str, Any], max_dimension=Configuration.LOG def current_value(self, setting, library): """Retrieve the current value of the given setting from the database.""" - return ConfigurationSetting.for_library(setting['key'], library).value + return ConfigurationSetting.for_library(setting["key"], library).value @classmethod def _format_validated_value(cls, value, validator=None): - """Convert a validated value to a string that can be stored in ConfigurationSetting.value - """ + """Convert a validated value to a string that can be stored in ConfigurationSetting.value""" if not validator: # Assume the value is already a string. return value diff --git a/api/admin/controller/metadata_service_self_tests.py b/api/admin/controller/metadata_service_self_tests.py index 8fd1729983..e04561a3f1 100644 --- a/api/admin/controller/metadata_service_self_tests.py +++ b/api/admin/controller/metadata_service_self_tests.py @@ -1,15 +1,16 @@ """Self-tests for metadata integrations.""" from flask_babel import lazy_gettext as _ -from core.opds_import import MetadataWranglerOPDSLookup -from core.model import ( - ExternalIntegration -) -from api.nyt import NYTBestSellerAPI -from api.admin.controller.self_tests import SelfTestsController + from api.admin.controller.metadata_services import MetadataServicesController +from api.admin.controller.self_tests import SelfTestsController +from api.nyt import NYTBestSellerAPI +from core.model import ExternalIntegration +from core.opds_import import MetadataWranglerOPDSLookup -class MetadataServiceSelfTestsController(MetadataServicesController, SelfTestsController): +class MetadataServiceSelfTestsController( + MetadataServicesController, SelfTestsController +): def __init__(self, manager): super(MetadataServiceSelfTestsController, self).__init__(manager) self.type = _("metadata service") diff --git a/api/admin/controller/metadata_services.py b/api/admin/controller/metadata_services.py index 062376b521..044c173772 100644 --- a/api/admin/controller/metadata_services.py +++ b/api/admin/controller/metadata_services.py @@ -1,36 +1,36 @@ import flask from flask import Response from flask_babel import lazy_gettext as _ + from api.admin.problem_details import * -from api.nyt import NYTBestSellerAPI from api.novelist import NoveListAPI +from api.nyt import NYTBestSellerAPI +from core.model import ExternalIntegration, get_one from core.opds_import import MetadataWranglerOPDSLookup -from core.model import ( - ExternalIntegration, - get_one, -) from core.util.http import HTTP from core.util.problem_detail import ProblemDetail from . import SitewideRegistrationController -class MetadataServicesController(SitewideRegistrationController): +class MetadataServicesController(SitewideRegistrationController): def __init__(self, manager): super(MetadataServicesController, self).__init__(manager) self.provider_apis = [ - NYTBestSellerAPI, - NoveListAPI, - MetadataWranglerOPDSLookup, - ] + NYTBestSellerAPI, + NoveListAPI, + MetadataWranglerOPDSLookup, + ] - self.protocols = self._get_integration_protocols(self.provider_apis, protocol_name_attr="PROTOCOL") + self.protocols = self._get_integration_protocols( + self.provider_apis, protocol_name_attr="PROTOCOL" + ) self.goal = ExternalIntegration.METADATA_GOAL self.type = _("metadata service") def process_metadata_services(self): self.require_system_admin() - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post() @@ -38,9 +38,16 @@ def process_metadata_services(self): def process_get(self): metadata_services = self._get_integration_info(self.goal, self.protocols) for service in metadata_services: - service_object = get_one(self._db, ExternalIntegration, id=service.get("id"), goal=ExternalIntegration.METADATA_GOAL) + service_object = get_one( + self._db, + ExternalIntegration, + id=service.get("id"), + goal=ExternalIntegration.METADATA_GOAL, + ) protocol_class, tuple = self.find_protocol_class(service_object) - service["self_test_results"] = self._get_prior_test_results(service_object, protocol_class, *tuple) + service["self_test_results"] = self._get_prior_test_results( + service_object, protocol_class, *tuple + ) return dict( metadata_services=metadata_services, @@ -51,18 +58,12 @@ def find_protocol_class(self, integration): if integration.protocol == ExternalIntegration.METADATA_WRANGLER: return ( MetadataWranglerOPDSLookup, - (MetadataWranglerOPDSLookup.from_config, self._db) + (MetadataWranglerOPDSLookup.from_config, self._db), ) elif integration.protocol == ExternalIntegration.NYT: - return ( - NYTBestSellerAPI, - (NYTBestSellerAPI.from_config, self._db) - ) + return (NYTBestSellerAPI, (NYTBestSellerAPI.from_config, self._db)) elif integration.protocol == ExternalIntegration.NOVELIST: - return ( - NoveListAPI, - (NoveListAPI.from_config, self._db) - ) + return (NoveListAPI, (NoveListAPI.from_config, self._db)) raise NotImplementedError( "No metadata self-test class for protocol %s" % integration.protocol ) @@ -137,14 +138,13 @@ def validate_form_fields(self, **fields): def register_with_metadata_wrangler(self, do_get, do_post, is_new, service): """Register this site with the Metadata Wrangler.""" - if ((is_new or not service.password) and - service.protocol == ExternalIntegration.METADATA_WRANGLER): + if ( + is_new or not service.password + ) and service.protocol == ExternalIntegration.METADATA_WRANGLER: return self.process_sitewide_registration( integration=service, do_get=do_get, do_post=do_post ) def process_delete(self, service_id): - return self._delete_integration( - service_id, self.goal - ) + return self._delete_integration(service_id, self.goal) diff --git a/api/admin/controller/patron_auth_service_self_tests.py b/api/admin/controller/patron_auth_service_self_tests.py index ac351726b9..e74abad8b4 100644 --- a/api/admin/controller/patron_auth_service_self_tests.py +++ b/api/admin/controller/patron_auth_service_self_tests.py @@ -1,24 +1,24 @@ import flask from flask import Response from flask_babel import lazy_gettext as _ + +from api.admin.controller.patron_auth_services import PatronAuthServicesController +from api.admin.controller.self_tests import SelfTestsController from api.admin.problem_details import * -from api.simple_authentication import SimpleAuthenticationProvider -from api.millenium_patron import MilleniumPatronAPI -from api.sip import SIP2AuthenticationProvider +from api.clever import CleverAuthenticationAPI from api.firstbook import FirstBookAuthenticationAPI as OldFirstBookAuthenticationAPI from api.firstbook2 import FirstBookAuthenticationAPI -from api.clever import CleverAuthenticationAPI -from core.model import ( - get_one, - ExternalIntegration, -) +from api.millenium_patron import MilleniumPatronAPI +from api.simple_authentication import SimpleAuthenticationProvider +from api.sip import SIP2AuthenticationProvider +from core.model import ExternalIntegration, get_one from core.selftest import HasSelfTests from core.util.problem_detail import ProblemDetail -from api.admin.controller.self_tests import SelfTestsController -from api.admin.controller.patron_auth_services import PatronAuthServicesController -class PatronAuthServiceSelfTestsController(SelfTestsController, PatronAuthServicesController): +class PatronAuthServiceSelfTestsController( + SelfTestsController, PatronAuthServicesController +): def process_patron_auth_service_self_tests(self, identifier): return self._manage_self_tests(identifier) @@ -27,20 +27,24 @@ def look_up_by_id(self, identifier): self._db, ExternalIntegration, id=identifier, - goal=ExternalIntegration.PATRON_AUTH_GOAL + goal=ExternalIntegration.PATRON_AUTH_GOAL, ) if not service: return MISSING_SERVICE return service def get_info(self, patron_auth_service): - [protocol] = [p for p in self._get_integration_protocols(self.provider_apis) if p.get("name") == patron_auth_service.protocol] + [protocol] = [ + p + for p in self._get_integration_protocols(self.provider_apis) + if p.get("name") == patron_auth_service.protocol + ] info = dict( id=patron_auth_service.id, name=patron_auth_service.name, protocol=patron_auth_service.protocol, goal=patron_auth_service.goal, - settings=protocol.get("settings") + settings=protocol.get("settings"), ) return info diff --git a/api/admin/controller/patron_auth_services.py b/api/admin/controller/patron_auth_services.py index 9d37eb6c57..936d4d4985 100644 --- a/api/admin/controller/patron_auth_services.py +++ b/api/admin/controller/patron_auth_services.py @@ -16,56 +16,60 @@ from api.saml.provider import SAMLWebSSOAuthenticationProvider from api.simple_authentication import SimpleAuthenticationProvider from api.sip import SIP2AuthenticationProvider -from core.model import ( - ConfigurationSetting, - ExternalIntegration, - get_one, -) +from core.model import ConfigurationSetting, ExternalIntegration, get_one from core.util.problem_detail import ProblemDetail class PatronAuthServicesController(SettingsController): def __init__(self, manager): super(PatronAuthServicesController, self).__init__(manager) - self.provider_apis = [SimpleAuthenticationProvider, - MilleniumPatronAPI, - SIP2AuthenticationProvider, - FirstBookAuthenticationAPI, - OldFirstBookAuthenticationAPI, - CleverAuthenticationAPI, - KansasAuthenticationAPI, - SAMLWebSSOAuthenticationProvider - ] + self.provider_apis = [ + SimpleAuthenticationProvider, + MilleniumPatronAPI, + SIP2AuthenticationProvider, + FirstBookAuthenticationAPI, + OldFirstBookAuthenticationAPI, + CleverAuthenticationAPI, + KansasAuthenticationAPI, + SAMLWebSSOAuthenticationProvider, + ] self.protocols = self._get_integration_protocols(self.provider_apis) - self.basic_auth_protocols = [SimpleAuthenticationProvider.__module__, - MilleniumPatronAPI.__module__, - SIP2AuthenticationProvider.__module__, - FirstBookAuthenticationAPI.__module__, - OldFirstBookAuthenticationAPI.__module__, - KansasAuthenticationAPI.__module__, - ] + self.basic_auth_protocols = [ + SimpleAuthenticationProvider.__module__, + MilleniumPatronAPI.__module__, + SIP2AuthenticationProvider.__module__, + FirstBookAuthenticationAPI.__module__, + OldFirstBookAuthenticationAPI.__module__, + KansasAuthenticationAPI.__module__, + ] self.type = _("patron authentication service") self._validator_factory = PatronAuthenticationValidatorFactory() def process_patron_auth_services(self): self.require_system_admin() - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post() def process_get(self): - services = self._get_integration_info(ExternalIntegration.PATRON_AUTH_GOAL, self.protocols) + services = self._get_integration_info( + ExternalIntegration.PATRON_AUTH_GOAL, self.protocols + ) for service in services: - service_object = get_one(self._db, ExternalIntegration, id=service.get("id"), goal=ExternalIntegration.PATRON_AUTH_GOAL) - service["self_test_results"] = self._get_prior_test_results(service_object, self._find_protocol_class(service_object)) - return dict( - patron_auth_services=services, - protocols=self.protocols - ) + service_object = get_one( + self._db, + ExternalIntegration, + id=service.get("id"), + goal=ExternalIntegration.PATRON_AUTH_GOAL, + ) + service["self_test_results"] = self._get_prior_test_results( + service_object, self._find_protocol_class(service_object) + ) + return dict(patron_auth_services=services, protocols=self.protocols) def process_post(self): protocol = flask.request.form.get("protocol") @@ -77,7 +81,12 @@ def process_post(self): id = flask.request.form.get("id") if id: # Find an existing service to edit - auth_service = get_one(self._db, ExternalIntegration, id=id, goal=ExternalIntegration.PATRON_AUTH_GOAL) + auth_service = get_one( + self._db, + ExternalIntegration, + id=id, + goal=ExternalIntegration.PATRON_AUTH_GOAL, + ) if not auth_service: return MISSING_SERVICE if protocol != auth_service.protocol: @@ -119,7 +128,9 @@ def process_post(self): return Response(str(auth_service.id), 200) def _find_protocol_class(self, service_object): - [protocol_class] = [p for p in self.provider_apis if p.__module__ == service_object.protocol] + [protocol_class] = [ + p for p in self.provider_apis if p.__module__ == service_object.protocol + ] return protocol_class def validate_form_fields(self, protocol): @@ -145,20 +156,28 @@ def check_library_integrations(self, library): basic_auth_count = 0 for integration in library.integrations: - if integration.goal == ExternalIntegration.PATRON_AUTH_GOAL and integration.protocol in self.basic_auth_protocols: + if ( + integration.goal == ExternalIntegration.PATRON_AUTH_GOAL + and integration.protocol in self.basic_auth_protocols + ): basic_auth_count += 1 if basic_auth_count > 1: - return MULTIPLE_BASIC_AUTH_SERVICES.detailed(_( - "You tried to add a patron authentication service that uses basic auth to %(library)s, but it already has one.", - library=library.short_name, - )) + return MULTIPLE_BASIC_AUTH_SERVICES.detailed( + _( + "You tried to add a patron authentication service that uses basic auth to %(library)s, but it already has one.", + library=library.short_name, + ) + ) def check_external_type(self, library, auth_service): """Check that the library's external type regular expression is valid, if it was set.""" value = ConfigurationSetting.for_library_and_externalintegration( - self._db, AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, - library, auth_service).value + self._db, + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, + library, + auth_service, + ).value if value: try: re.compile(value) @@ -169,13 +188,27 @@ def check_identifier_restriction(self, library, auth_service): """Check whether the library's identifier restriction regular expression is set and is supposed to be a regular expression; if so, check that it's valid.""" - identifier_restriction_type = ConfigurationSetting.for_library_and_externalintegration( - self._db, AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE, - library, auth_service).value - identifier_restriction = ConfigurationSetting.for_library_and_externalintegration( - self._db, AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION, - library, auth_service).value - if identifier_restriction and identifier_restriction_type == AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX: + identifier_restriction_type = ( + ConfigurationSetting.for_library_and_externalintegration( + self._db, + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE, + library, + auth_service, + ).value + ) + identifier_restriction = ( + ConfigurationSetting.for_library_and_externalintegration( + self._db, + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION, + library, + auth_service, + ).value + ) + if ( + identifier_restriction + and identifier_restriction_type + == AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX + ): try: re.compile(identifier_restriction) except Exception as e: @@ -186,7 +219,11 @@ def check_libraries(self, auth_service): to configure this patron auth service.""" for library in auth_service.libraries: - error = self.check_library_integrations(library) or self.check_external_type(library, auth_service) or self.check_identifier_restriction(library, auth_service) + error = ( + self.check_library_integrations(library) + or self.check_external_type(library, auth_service) + or self.check_identifier_restriction(library, auth_service) + ) if error: return error diff --git a/api/admin/controller/search_service_self_tests.py b/api/admin/controller/search_service_self_tests.py index 8dc751e130..0ee2fad992 100644 --- a/api/admin/controller/search_service_self_tests.py +++ b/api/admin/controller/search_service_self_tests.py @@ -1,19 +1,17 @@ import flask from flask import Response from flask_babel import lazy_gettext as _ + +from api.admin.controller.self_tests import SelfTestsController from api.admin.problem_details import * -from core.model import ( - ExternalIntegration -) from core.external_search import ExternalSearchIndex -from core.testing import ExternalSearchTest - +from core.model import ExternalIntegration from core.selftest import HasSelfTests +from core.testing import ExternalSearchTest from core.util.problem_detail import ProblemDetail -from api.admin.controller.self_tests import SelfTestsController -class SearchServiceSelfTestsController(SelfTestsController, ExternalSearchTest): +class SearchServiceSelfTestsController(SelfTestsController, ExternalSearchTest): def __init__(self, manager): super(SearchServiceSelfTestsController, self).__init__(manager) self.type = _("search service") @@ -23,11 +21,14 @@ def process_search_service_self_tests(self, identifier): def _find_protocol_class(self, integration): # There's only one possibility for search integrations. - return ExternalSearchIndex, (None, self._db,) + return ExternalSearchIndex, ( + None, + self._db, + ) def look_up_by_id(self, identifier): return self.look_up_service_by_id( identifier, ExternalIntegration.ELASTICSEARCH, - ExternalIntegration.SEARCH_GOAL + ExternalIntegration.SEARCH_GOAL, ) diff --git a/api/admin/controller/self_tests.py b/api/admin/controller/self_tests.py index b1e5e1ff49..24c55f9e8d 100644 --- a/api/admin/controller/self_tests.py +++ b/api/admin/controller/self_tests.py @@ -1,10 +1,13 @@ import flask from flask import Response from flask_babel import lazy_gettext as _ + from api.admin.problem_details import * from core.util.problem_detail import ProblemDetail + from . import SettingsController + class SelfTestsController(SettingsController): def _manage_self_tests(self, identifier): """Generic request-processing method.""" @@ -37,14 +40,12 @@ def get_info(self, integration): name=integration.name, protocol=protocol, settings=protocol.get("settings"), - goal=integration.goal + goal=integration.goal, ) def run_tests(self, integration): protocol_class, extra_arguments = self.find_protocol_class(integration) - value, results = protocol_class.run_self_tests( - self._db, *extra_arguments - ) + value, results = protocol_class.run_self_tests(self._db, *extra_arguments) return value def self_tests_process_get(self, identifier): @@ -60,7 +61,7 @@ def self_tests_process_get(self, identifier): def self_tests_process_post(self, identifier): integration = self.look_up_by_id(identifier) - if isinstance (integration, ProblemDetail): + if isinstance(integration, ProblemDetail): return integration value = self.run_tests(integration) if value and isinstance(value, ProblemDetail): diff --git a/api/admin/controller/sitewide_services.py b/api/admin/controller/sitewide_services.py index 2f92b667c4..c4ca17f497 100644 --- a/api/admin/controller/sitewide_services.py +++ b/api/admin/controller/sitewide_services.py @@ -1,31 +1,31 @@ import flask from flask import Response from flask_babel import lazy_gettext as _ + from api.admin.problem_details import * from core.external_search import ExternalSearchIndex -from core.model import ( - ExternalIntegration, - get_one, - get_one_or_create, -) -from core.log import ( - Loggly, - SysLogger, - CloudwatchLogs, -) +from core.log import CloudwatchLogs, Loggly, SysLogger +from core.model import ExternalIntegration, get_one, get_one_or_create from core.util.problem_detail import ProblemDetail + from . import SettingsController -class SitewideServicesController(SettingsController): +class SitewideServicesController(SettingsController): def _manage_sitewide_service( - self, goal, provider_apis, service_key_name, - multiple_sitewide_services_detail, protocol_name_attr='NAME' + self, + goal, + provider_apis, + service_key_name, + multiple_sitewide_services_detail, + protocol_name_attr="NAME", ): - protocols = self._get_integration_protocols(provider_apis, protocol_name_attr=protocol_name_attr) + protocols = self._get_integration_protocols( + provider_apis, protocol_name_attr=protocol_name_attr + ) self.require_system_admin() - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get(protocols, goal, service_key_name) else: return self.process_post(protocols, goal, multiple_sitewide_services_detail) @@ -33,8 +33,8 @@ def _manage_sitewide_service( def process_get(self, protocols, goal, service_key_name): services = self._get_integration_info(goal, protocols) return { - service_key_name : services, - 'protocols' : protocols, + service_key_name: services, + "protocols": protocols, } def process_post(self, protocols, goal, multiple_sitewide_services_detail): @@ -59,8 +59,7 @@ def process_post(self, protocols, goal, multiple_sitewide_services_detail): else: if protocol: service, is_new = get_one_or_create( - self._db, ExternalIntegration, protocol=protocol, - goal=goal + self._db, ExternalIntegration, protocol=protocol, goal=goal ) # There can only be one of each sitewide service. if not is_new: @@ -104,19 +103,22 @@ def validate_form_fields(self, protocols, **fields): if protocol and protocol not in [p.get("name") for p in protocols]: return UNKNOWN_PROTOCOL + class LoggingServicesController(SitewideServicesController): def process_services(self): - detail = _("You tried to create a new logging service, but a logging service is already configured.") + detail = _( + "You tried to create a new logging service, but a logging service is already configured." + ) return self._manage_sitewide_service( ExternalIntegration.LOGGING_GOAL, [Loggly, SysLogger, CloudwatchLogs], - 'logging_services', detail + "logging_services", + detail, ) def process_delete(self, service_id): - return self._delete_integration( - service_id, ExternalIntegration.LOGGING_GOAL - ) + return self._delete_integration(service_id, ExternalIntegration.LOGGING_GOAL) + class SearchServicesController(SitewideServicesController): def __init__(self, manager): @@ -124,13 +126,15 @@ def __init__(self, manager): self.type = _("search service") def process_services(self): - detail = _("You tried to create a new search service, but a search service is already configured.") + detail = _( + "You tried to create a new search service, but a search service is already configured." + ) return self._manage_sitewide_service( - ExternalIntegration.SEARCH_GOAL, [ExternalSearchIndex], - 'search_services', detail + ExternalIntegration.SEARCH_GOAL, + [ExternalSearchIndex], + "search_services", + detail, ) def process_delete(self, service_id): - return self._delete_integration( - service_id, ExternalIntegration.SEARCH_GOAL - ) + return self._delete_integration(service_id, ExternalIntegration.SEARCH_GOAL) diff --git a/api/admin/controller/sitewide_settings.py b/api/admin/controller/sitewide_settings.py index a4013d42c0..5296fee347 100644 --- a/api/admin/controller/sitewide_settings.py +++ b/api/admin/controller/sitewide_settings.py @@ -1,15 +1,17 @@ -from core.model import ConfigurationSetting -from . import SettingsController -from api.config import Configuration -from flask import Response -from api.admin.problem_details import * import flask +from flask import Response from flask_babel import lazy_gettext as _ -class SitewideConfigurationSettingsController(SettingsController): +from api.admin.problem_details import * +from api.config import Configuration +from core.model import ConfigurationSetting +from . import SettingsController + + +class SitewideConfigurationSettingsController(SettingsController): def process_services(self): - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post() @@ -22,7 +24,7 @@ def process_get(self): for s in sitewide_settings: setting = ConfigurationSetting.sitewide(self._db, s.get("key")) if setting.value: - settings += [{ "key": setting.key, "value": setting.value }] + settings += [{"key": setting.key, "value": setting.value}] return dict( settings=settings, @@ -50,15 +52,16 @@ def process_delete(self, key): def validate_form_fields(self, setting, fields): MISSING_FIELD_MESSAGES = dict( - key = MISSING_SITEWIDE_SETTING_KEY, - value = MISSING_SITEWIDE_SETTING_VALUE + key=MISSING_SITEWIDE_SETTING_KEY, value=MISSING_SITEWIDE_SETTING_VALUE ) for field in fields: if not flask.request.form.get(field): return MISSING_FIELD_MESSAGES.get(field) - [setting] = [s for s in Configuration.SITEWIDE_SETTINGS if s.get("key") == setting.key] + [setting] = [ + s for s in Configuration.SITEWIDE_SETTINGS if s.get("key") == setting.key + ] error = self.validate_formats([setting]) if error: return error diff --git a/api/admin/controller/storage_services.py b/api/admin/controller/storage_services.py index 1bfa3eecc2..34d8891d55 100644 --- a/api/admin/controller/storage_services.py +++ b/api/admin/controller/storage_services.py @@ -2,41 +2,35 @@ from flask import Response from api.admin.problem_details import * -from core.mirror import MirrorUploader -from core.model import ( - ExternalIntegration, - get_one -) -from core.util.problem_detail import ProblemDetail -from . import SettingsController # NOTE: We need to import it explicitly to initialize MirrorUploader.IMPLEMENTATION_REGISTRY from api.lcp import mirror +from core.mirror import MirrorUploader +from core.model import ExternalIntegration, get_one +from core.util.problem_detail import ProblemDetail + +from . import SettingsController class StorageServicesController(SettingsController): - def __init__(self, manager): super(StorageServicesController, self).__init__(manager) self.goal = ExternalIntegration.STORAGE_GOAL self.protocols = self._get_integration_protocols( list(MirrorUploader.IMPLEMENTATION_REGISTRY.values()), - protocol_name_attr="NAME" + protocol_name_attr="NAME", ) def process_services(self): - if flask.request.method == 'GET': + if flask.request.method == "GET": return self.process_get() else: return self.process_post() def process_get(self): services = self._get_integration_info(self.goal, self.protocols) - return dict( - storage_services=services, - protocols=self.protocols - ) - + return dict(storage_services=services, protocols=self.protocols) + def process_post(self): protocol = flask.request.form.get("protocol") name = flask.request.form.get("name") @@ -48,7 +42,9 @@ def process_post(self): id = flask.request.form.get("id") if id: # Find an existing service to edit - storage_service = get_one(self._db, ExternalIntegration, id=id, goal=self.goal) + storage_service = get_one( + self._db, ExternalIntegration, id=id, goal=self.goal + ) if not storage_service: return MISSING_SERVICE if protocol != storage_service.protocol: @@ -75,6 +71,4 @@ def process_post(self): return Response(str(storage_service.id), 200) def process_delete(self, service_id): - return self._delete_integration( - service_id, ExternalIntegration.STORAGE_GOAL - ) + return self._delete_integration(service_id, ExternalIntegration.STORAGE_GOAL) diff --git a/api/admin/controller/work_editor.py b/api/admin/controller/work_editor.py index 6439694acb..1ce242ff74 100644 --- a/api/admin/controller/work_editor.py +++ b/api/admin/controller/work_editor.py @@ -1,39 +1,29 @@ +import base64 +import json +import os +import textwrap +import urllib.error +import urllib.parse +import urllib.request +from collections import Counter +from io import BytesIO + import flask from flask import Response from flask_babel import lazy_gettext as _ -from . import AdminCirculationManagerController -from collections import Counter -from core.opds import AcquisitionFeed +from PIL import Image, ImageDraw, ImageFont + from api.admin.opds import AdminAnnotator, AdminFeed from api.admin.problem_details import * -from api.config import ( - Configuration, - CannotLoadConfiguration -) -from api.metadata_wrangler import MetadataWranglerCollectionRegistrar from api.admin.validator import Validator -from core.app_server import ( - load_pagination_from_request, -) -from core.classifier import ( - genres, - SimplifiedGenreClassifier, - NO_NUMBER, - NO_VALUE -) +from api.config import CannotLoadConfiguration, Configuration +from api.metadata_wrangler import MetadataWranglerCollectionRegistrar +from core.app_server import load_pagination_from_request +from core.classifier import NO_NUMBER, NO_VALUE, SimplifiedGenreClassifier, genres +from core.lane import Lane, WorkList +from core.metadata_layer import LinkData, Metadata, ReplacementPolicy from core.mirror import MirrorUploader -from core.util.problem_detail import ProblemDetail -from core.util import LanguageCodes -from core.metadata_layer import ( - Metadata, - LinkData, - ReplacementPolicy, -) -from core.lane import (Lane, WorkList) from core.model import ( - create, - get_one, - get_one_or_create, Classification, Collection, Complaint, @@ -48,20 +38,19 @@ Representation, RightsStatus, Subject, - Work + Work, + create, + get_one, + get_one_or_create, ) from core.model.configuration import ExternalIntegrationLink -from core.util.datetime_helpers import ( - strptime_utc, - utc_now, -) -import base64 -import json -import os -from PIL import Image, ImageDraw, ImageFont -from io import BytesIO -import textwrap -import urllib.request, urllib.parse, urllib.error +from core.opds import AcquisitionFeed +from core.util import LanguageCodes +from core.util.datetime_helpers import strptime_utc, utc_now +from core.util.problem_detail import ProblemDetail + +from . import AdminCirculationManagerController + class WorkController(AdminCirculationManagerController): @@ -96,13 +85,12 @@ def complaints(self, identifier_type, identifier): return work counter = self._count_complaints_for_work(work) - response = dict({ - "book": { - "identifier_type": identifier_type, - "identifier": identifier - }, - "complaints": counter - }) + response = dict( + { + "book": {"identifier_type": identifier_type, "identifier": identifier}, + "complaints": counter, + } + ) return response @@ -139,7 +127,7 @@ def roles(self): Contributor.PRODUCER_ROLE, Contributor.TRANSCRIBER_ROLE, Contributor.TRANSLATOR_ROLE, - ]: + ]: marc_to_role[CODES[role]] = role return marc_to_role @@ -154,10 +142,14 @@ def media(self): def rights_status(self): """Return the supported rights status values with their names and whether they are open access.""" - return {uri: dict(name=name, - open_access=(uri in RightsStatus.OPEN_ACCESS), - allows_derivatives=(uri in RightsStatus.ALLOWS_DERIVATIVES)) - for uri, name in list(RightsStatus.NAMES.items())} + return { + uri: dict( + name=name, + open_access=(uri in RightsStatus.OPEN_ACCESS), + allows_derivatives=(uri in RightsStatus.ALLOWS_DERIVATIVES), + ) + for uri, name in list(RightsStatus.NAMES.items()) + } def edit(self, identifier_type, identifier): """Edit a work's metadata.""" @@ -179,9 +171,10 @@ def edit(self, identifier_type, identifier): staff_data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) primary_identifier = work.presentation_edition.primary_identifier staff_edition, is_new = get_one_or_create( - self._db, Edition, + self._db, + Edition, primary_identifier_id=primary_identifier.id, - data_source_id=staff_data_source.id + data_source_id=staff_data_source.id, ) self._db.expire(primary_identifier) @@ -199,7 +192,9 @@ def edit(self, identifier_type, identifier): # The form data includes roles and names for contributors in the same order. new_contributor_roles = flask.request.form.getlist("contributor-role") - new_contributor_names = [str(n) for n in flask.request.form.getlist("contributor-name")] + new_contributor_names = [ + str(n) for n in flask.request.form.getlist("contributor-name") + ] # The first author in the form is considered the primary author, even # though there's no separate MARC code for that. for i, role in enumerate(new_contributor_roles): @@ -212,12 +207,17 @@ def edit(self, identifier_type, identifier): # that already exist from the list so they won't be added again. deleted_contributions = False for contribution in staff_edition.contributions: - if (contribution.role, contribution.contributor.display_name) not in roles_and_names: + if ( + contribution.role, + contribution.contributor.display_name, + ) not in roles_and_names: self._db.delete(contribution) deleted_contributions = True changed = True else: - roles_and_names.remove((contribution.role, contribution.contributor.display_name)) + roles_and_names.remove( + (contribution.role, contribution.contributor.display_name) + ) if deleted_contributions: # Ensure the staff edition's contributions are up-to-date when # calculating the presentation edition later. @@ -232,8 +232,11 @@ def edit(self, identifier_type, identifier): if role not in list(Contributor.MARC_ROLE_CODES.keys()): self._db.rollback() return UNKNOWN_ROLE.detailed( - _("Role %(role)s is not one of the known contributor roles.", - role=role)) + _( + "Role %(role)s is not one of the known contributor roles.", + role=role, + ) + ) contributor = staff_edition.add_contributor(name=name, roles=[role]) contributor.display_name = name changed = True @@ -246,7 +249,7 @@ def edit(self, identifier_type, identifier): changed = True new_series_position = flask.request.form.get("series_position") - if new_series_position != None and new_series_position != '': + if new_series_position != None and new_series_position != "": try: new_series_position = int(new_series_position) except ValueError: @@ -265,13 +268,16 @@ def edit(self, identifier_type, identifier): if new_medium not in list(Edition.medium_to_additional_type.keys()): self._db.rollback() return UNKNOWN_MEDIUM.detailed( - _("Medium %(medium)s is not one of the known media.", - medium=new_medium)) + _( + "Medium %(medium)s is not one of the known media.", + medium=new_medium, + ) + ) staff_edition.medium = new_medium changed = True new_language = flask.request.form.get("language") - if new_language != None and new_language != '': + if new_language != None and new_language != "": new_language = LanguageCodes.string_to_alpha_3(new_language) if not new_language: self._db.rollback() @@ -297,9 +303,9 @@ def edit(self, identifier_type, identifier): changed = True new_issued = flask.request.form.get("issued") - if new_issued != None and new_issued != '': + if new_issued != None and new_issued != "": try: - new_issued = strptime_utc(new_issued, '%Y-%m-%d') + new_issued = strptime_utc(new_issued, "%Y-%m-%d") except ValueError: self._db.rollback() return INVALID_DATE_FORMAT @@ -317,7 +323,7 @@ def edit(self, identifier_type, identifier): # relates to the quality threshold in the library settings. changed_rating = False new_rating = flask.request.form.get("rating") - if new_rating != None and new_rating != '': + if new_rating != None and new_rating != "": try: new_rating = float(new_rating) except ValueError: @@ -327,10 +333,19 @@ def edit(self, identifier_type, identifier): if new_rating < scale[0] or new_rating > scale[1]: self._db.rollback() return INVALID_RATING.detailed( - _("The rating must be a number between %(low)s and %(high)s.", - low=scale[0], high=scale[1])) + _( + "The rating must be a number between %(low)s and %(high)s.", + low=scale[0], + high=scale[1], + ) + ) if (new_rating - scale[0]) / (scale[1] - scale[0]) != work.quality: - primary_identifier.add_measurement(staff_data_source, Measurement.RATING, new_rating, weight=WorkController.STAFF_WEIGHT) + primary_identifier.add_measurement( + staff_data_source, + Measurement.RATING, + new_rating, + weight=WorkController.STAFF_WEIGHT, + ) changed = True changed_rating = True @@ -342,8 +357,8 @@ def edit(self, identifier_type, identifier): old_summary = work.summary work.presentation_edition.primary_identifier.add_link( - Hyperlink.DESCRIPTION, None, - staff_data_source, content=new_summary) + Hyperlink.DESCRIPTION, None, staff_data_source, content=new_summary + ) # Delete previous staff summary if old_summary: @@ -376,7 +391,9 @@ def suppress(self, identifier_type, identifier): self.require_librarian(flask.request.library) # Turn source + identifier into a LicensePool - pools = self.load_licensepools(flask.request.library, identifier_type, identifier) + pools = self.load_licensepools( + flask.request.library, identifier_type, identifier + ) if isinstance(pools, ProblemDetail): # Something went wrong. return pools @@ -398,7 +415,9 @@ def unsuppress(self, identifier_type, identifier): self.require_librarian(flask.request.library) # Turn source + identifier into a group of LicensePools - pools = self.load_licensepools(flask.request.library, identifier_type, identifier) + pools = self.load_licensepools( + flask.request.library, identifier_type, identifier + ) if isinstance(pools, ProblemDetail): # Something went wrong. return pools @@ -418,7 +437,9 @@ def refresh_metadata(self, identifier_type, identifier, provider=None): if not provider and work.license_pools: try: - provider = MetadataWranglerCollectionRegistrar(work.license_pools[0].collection) + provider = MetadataWranglerCollectionRegistrar( + work.license_pools[0].collection + ) except CannotLoadConfiguration: return METADATA_REFRESH_FAILURE @@ -431,8 +452,9 @@ def refresh_metadata(self, identifier_type, identifier, provider=None): if record.exception: # There was a coverage failure. - if (str(record.exception).startswith("201") or - str(record.exception).startswith("202")): + if str(record.exception).startswith("201") or str( + record.exception + ).startswith("202"): # A 201/202 error means it's never looked up this work before # so it's started the resolution process or looking for sources. return METADATA_REFRESH_PENDING @@ -476,30 +498,34 @@ def classifications(self, identifier_type, identifier): return work identifier_id = work.presentation_edition.primary_identifier.id - results = self._db \ - .query(Classification) \ - .join(Subject) \ - .join(DataSource) \ - .filter(Classification.identifier_id == identifier_id) \ - .order_by(Classification.weight.desc()) \ + results = ( + self._db.query(Classification) + .join(Subject) + .join(DataSource) + .filter(Classification.identifier_id == identifier_id) + .order_by(Classification.weight.desc()) .all() + ) data = [] for result in results: - data.append(dict({ - "type": result.subject.type, - "name": result.subject.identifier, - "source": result.data_source.name, - "weight": result.weight - })) - - return dict({ - "book": { - "identifier_type": identifier_type, - "identifier": identifier - }, - "classifications": data - }) + data.append( + dict( + { + "type": result.subject.type, + "name": result.subject.identifier, + "source": result.data_source.name, + "weight": result.weight, + } + ) + ) + + return dict( + { + "book": {"identifier_type": identifier_type, "identifier": identifier}, + "classifications": data, + } + ) def edit_classifications(self, identifier_type, identifier): """Edit a work's audience, target age, fiction status, and genres.""" @@ -513,24 +539,19 @@ def edit_classifications(self, identifier_type, identifier): # Previous staff classifications primary_identifier = work.presentation_edition.primary_identifier - old_classifications = self._db \ - .query(Classification) \ - .join(Subject) \ + old_classifications = ( + self._db.query(Classification) + .join(Subject) .filter( Classification.identifier == primary_identifier, - Classification.data_source == staff_data_source + Classification.data_source == staff_data_source, ) - old_genre_classifications = old_classifications \ - .filter(Subject.genre_id != None) + ) + old_genre_classifications = old_classifications.filter(Subject.genre_id != None) old_staff_genres = [ - c.subject.genre.name - for c in old_genre_classifications - if c.subject.genre - ] - old_computed_genres = [ - work_genre.genre.name - for work_genre in work.work_genres + c.subject.genre.name for c in old_genre_classifications if c.subject.genre ] + old_computed_genres = [work_genre.genre.name for work_genre in work.work_genres] # New genres should be compared to previously computed genres new_genres = flask.request.form.getlist("genres") @@ -557,9 +578,14 @@ def edit_classifications(self, identifier_type, identifier): new_target_age_min = int(new_target_age_min) if new_target_age_min else None new_target_age_max = flask.request.form.get("target_age_max") new_target_age_max = int(new_target_age_max) if new_target_age_max else None - if new_target_age_max is not None and new_target_age_min is not None and \ - new_target_age_max < new_target_age_min: - return INVALID_EDIT.detailed(_("Minimum target age must be less than maximum target age.")) + if ( + new_target_age_max is not None + and new_target_age_min is not None + and new_target_age_max < new_target_age_min + ): + return INVALID_EDIT.detailed( + _("Minimum target age must be less than maximum target age.") + ) if work.target_age: old_target_age_min = work.target_age.lower @@ -567,7 +593,10 @@ def edit_classifications(self, identifier_type, identifier): else: old_target_age_min = None old_target_age_max = None - if new_target_age_min != old_target_age_min or new_target_age_max != old_target_age_max: + if ( + new_target_age_min != old_target_age_min + or new_target_age_max != old_target_age_max + ): # Delete all previous staff target age classifications for c in old_classifications: if c.subject.type == Subject.AGE_RANGE: @@ -575,7 +604,10 @@ def edit_classifications(self, identifier_type, identifier): # Create a new classification with a high weight - higher than audience if new_target_age_min and new_target_age_max: - age_range_identifier = "%s-%s" % (new_target_age_min, new_target_age_max) + age_range_identifier = "%s-%s" % ( + new_target_age_min, + new_target_age_max, + ) primary_identifier.classify( data_source=staff_data_source, subject_type=Subject.AGE_RANGE, @@ -609,7 +641,10 @@ def edit_classifications(self, identifier_type, identifier): genre, is_new = Genre.lookup(self._db, name) if not isinstance(genre, Genre): return GENRE_NOT_FOUND - if genres[name].is_fiction is not None and genres[name].is_fiction != new_fiction: + if ( + genres[name].is_fiction is not None + and genres[name].is_fiction != new_fiction + ): return INCOMPATIBLE_GENRE if name == "Erotica" and new_audience != "Adults Only": return EROTICA_FOR_ADULTS_ONLY @@ -627,7 +662,7 @@ def edit_classifications(self, identifier_type, identifier): data_source=staff_data_source, subject_type=Subject.SIMPLIFIED_GENRE, subject_identifier=genre, - weight=WorkController.STAFF_WEIGHT + weight=WorkController.STAFF_WEIGHT, ) # add NONE genre classification if we aren't keeping any genres @@ -636,18 +671,19 @@ def edit_classifications(self, identifier_type, identifier): data_source=staff_data_source, subject_type=Subject.SIMPLIFIED_GENRE, subject_identifier=SimplifiedGenreClassifier.NONE, - weight=WorkController.STAFF_WEIGHT + weight=WorkController.STAFF_WEIGHT, ) else: # otherwise delete existing NONE genre classification - none_classifications = self._db \ - .query(Classification) \ - .join(Subject) \ + none_classifications = ( + self._db.query(Classification) + .join(Subject) .filter( Classification.identifier == primary_identifier, - Subject.identifier == SimplifiedGenreClassifier.NONE - ) \ + Subject.identifier == SimplifiedGenreClassifier.NONE, + ) .all() + ) for c in none_classifications: self._db.delete(c) @@ -656,7 +692,7 @@ def edit_classifications(self, identifier_type, identifier): classify=True, regenerate_opds_entries=True, regenerate_marc_record=True, - update_search_index=True + update_search_index=True, ) work.calculate_presentation(policy=policy) @@ -664,16 +700,24 @@ def edit_classifications(self, identifier_type, identifier): MINIMUM_COVER_WIDTH = 600 MINIMUM_COVER_HEIGHT = 900 - TOP = 'top' - CENTER = 'center' - BOTTOM = 'bottom' + TOP = "top" + CENTER = "center" + BOTTOM = "bottom" TITLE_POSITIONS = [TOP, CENTER, BOTTOM] def _validate_cover_image(self, image): image_width, image_height = image.size - if image_width < self.MINIMUM_COVER_WIDTH or image_height < self.MINIMUM_COVER_HEIGHT: - return INVALID_IMAGE.detailed(_("Cover image must be at least %(width)spx in width and %(height)spx in height.", - width=self.MINIMUM_COVER_WIDTH, height=self.MINIMUM_COVER_HEIGHT)) + if ( + image_width < self.MINIMUM_COVER_WIDTH + or image_height < self.MINIMUM_COVER_HEIGHT + ): + return INVALID_IMAGE.detailed( + _( + "Cover image must be at least %(width)spx in width and %(height)spx in height.", + width=self.MINIMUM_COVER_WIDTH, + height=self.MINIMUM_COVER_HEIGHT, + ) + ) return True def _process_cover_image(self, work, image, title_position): @@ -684,7 +728,7 @@ def _process_cover_image(self, work, image, title_position): if title_position in self.TITLE_POSITIONS: # Convert image to 'RGB' mode if it's not already, so drawing on it works. - if image.mode != 'RGB': + if image.mode != "RGB": image = image.convert("RGB") draw = ImageDraw.Draw(image) @@ -693,7 +737,9 @@ def _process_cover_image(self, work, image, title_position): admin_dir = os.path.dirname(os.path.split(__file__)[0]) package_dir = os.path.join(admin_dir, "../..") bold_font_path = os.path.join(package_dir, "resources/OpenSans-Bold.ttf") - regular_font_path = os.path.join(package_dir, "resources/OpenSans-Regular.ttf") + regular_font_path = os.path.join( + package_dir, "resources/OpenSans-Regular.ttf" + ) font_size = image_width // 20 bold_font = ImageFont.truetype(bold_font_path, font_size) regular_font = ImageFont.truetype(regular_font_path, font_size) @@ -729,16 +775,24 @@ def _process_cover_image(self, work, image, title_position): else: start_y = image_height / 14 - draw.rectangle([(start_x, start_y), - (start_x + rectangle_width, start_y + rectangle_height)], - fill=(255,255,255,255)) + draw.rectangle( + [ + (start_x, start_y), + (start_x + rectangle_width, start_y + rectangle_height), + ], + fill=(255, 255, 255, 255), + ) current_y = start_y + line_height / 2 for lines, font in [(title_lines, bold_font), (author_lines, regular_font)]: for line in lines: line_width, ignore = font.getsize(line) - draw.text((start_x + (rectangle_width - line_width) / 2, current_y), - line, font=font, fill=(0,0,0,255)) + draw.text( + (start_x + (rectangle_width - line_width) / 2, current_y), + line, + font=font, + fill=(0, 0, 0, 255), + ) current_y += line_height del draw @@ -769,7 +823,9 @@ def generate_cover_image(self, work, identifier_type, identifier, preview=False) if not image_file and not image_url: return INVALID_IMAGE.detailed(_("Image file or image URL is required.")) elif image_url and not Validator()._is_url(image_url, []): - return INVALID_URL.detailed(_('"%(url)s" is not a valid URL.', url=image_url)) + return INVALID_URL.detailed( + _('"%(url)s" is not a valid URL.', url=image_url) + ) title_position = flask.request.form.get("title_position") if image_url and not image_file: @@ -791,7 +847,9 @@ def _title_position(self, work, image): return self._process_cover_image(work, image, title_position) return image - def _original_cover_info(self, image, work, data_source, rights_uri, rights_explanation): + def _original_cover_info( + self, image, work, data_source, rights_uri, rights_explanation + ): original, derivation_settings, cover_href = None, None, None cover_rights_explanation = rights_explanation title_position = flask.request.form.get("title_position") @@ -802,7 +860,12 @@ def _original_cover_info(self, image, work, data_source, rights_uri, rights_expl image.save(original_buffer, format="PNG") original_content = original_buffer.getvalue() if not original_href: - original_href = Hyperlink.generic_uri(data_source, work.presentation_edition.primary_identifier, Hyperlink.IMAGE, content=original_content) + original_href = Hyperlink.generic_uri( + data_source, + work.presentation_edition.primary_identifier, + Hyperlink.IMAGE, + content=original_content, + ) image = self._process_cover_image(work, image, title_position) @@ -810,19 +873,26 @@ def _original_cover_info(self, image, work, data_source, rights_uri, rights_expl if rights_uri != RightsStatus.IN_COPYRIGHT: original_rights_explanation = rights_explanation original = LinkData( - Hyperlink.IMAGE, original_href, rights_uri=rights_uri, - rights_explanation=original_rights_explanation, content=original_content, + Hyperlink.IMAGE, + original_href, + rights_uri=rights_uri, + rights_explanation=original_rights_explanation, + content=original_content, ) derivation_settings = dict(title_position=title_position) if rights_uri in RightsStatus.ALLOWS_DERIVATIVES: - cover_rights_explanation = "The original image license allows derivatives." + cover_rights_explanation = ( + "The original image license allows derivatives." + ) else: cover_href = cover_url return original, derivation_settings, cover_href, cover_rights_explanation def _get_collection_from_pools(self, identifier_type, identifier): - pools = self.load_licensepools(flask.request.library, identifier_type, identifier) + pools = self.load_licensepools( + flask.request.library, identifier_type, identifier + ) if isinstance(pools, ProblemDetail): return pools if not pools: @@ -853,31 +923,50 @@ def change_book_cover(self, identifier_type, identifier, mirrors=None): # Look for an appropriate mirror to store this cover image. Since the # mirror should be used for covers, we don't need a mirror for books. mirrors = mirrors or dict( - covers_mirror=MirrorUploader.for_collection(collection, ExternalIntegrationLink.COVERS), - books_mirror=None + covers_mirror=MirrorUploader.for_collection( + collection, ExternalIntegrationLink.COVERS + ), + books_mirror=None, ) if not mirrors.get(ExternalIntegrationLink.COVERS): - return INVALID_CONFIGURATION_OPTION.detailed(_("Could not find a storage integration for uploading the cover.")) + return INVALID_CONFIGURATION_OPTION.detailed( + _("Could not find a storage integration for uploading the cover.") + ) image = self.generate_cover_image(work, identifier_type, identifier) if isinstance(image, ProblemDetail): return image - original, derivation_settings, cover_href, cover_rights_explanation = self._original_cover_info(image, work, data_source, rights_uri, rights_explanation) + ( + original, + derivation_settings, + cover_href, + cover_rights_explanation, + ) = self._original_cover_info( + image, work, data_source, rights_uri, rights_explanation + ) buffer = BytesIO() image.save(buffer, format="PNG") content = buffer.getvalue() if not cover_href: - cover_href = Hyperlink.generic_uri(data_source, work.presentation_edition.primary_identifier, Hyperlink.IMAGE, content=content) + cover_href = Hyperlink.generic_uri( + data_source, + work.presentation_edition.primary_identifier, + Hyperlink.IMAGE, + content=content, + ) cover_data = LinkData( - Hyperlink.IMAGE, href=cover_href, + Hyperlink.IMAGE, + href=cover_href, media_type=Representation.PNG_MEDIA_TYPE, - content=content, rights_uri=rights_uri, + content=content, + rights_uri=rights_uri, rights_explanation=cover_rights_explanation, - original=original, transformation_settings=derivation_settings, + original=original, + transformation_settings=derivation_settings, ) presentation_policy = PresentationCalculationPolicy( @@ -902,9 +991,9 @@ def change_book_cover(self, identifier_type, identifier, mirrors=None): ) metadata = Metadata(data_source, links=[cover_data]) - metadata.apply(work.presentation_edition, - collection, - replace=replacement_policy) + metadata.apply( + work.presentation_edition, collection, replace=replacement_policy + ) # metadata.apply only updates the edition, so we also need # to update the work. @@ -913,7 +1002,9 @@ def change_book_cover(self, identifier_type, identifier, mirrors=None): return Response(_("Success"), 200) def _count_complaints_for_work(self, work): - complaint_types = [complaint.type for complaint in work.complaints if not complaint.resolved] + complaint_types = [ + complaint.type for complaint in work.complaints if not complaint.resolved + ] return Counter(complaint_types) def custom_lists(self, identifier_type, identifier): @@ -958,12 +1049,27 @@ def custom_lists(self, identifier_type, identifier): if id: is_new = False - list = get_one(self._db, CustomList, id=int(id), name=name, library=library, data_source=staff_data_source) + list = get_one( + self._db, + CustomList, + id=int(id), + name=name, + library=library, + data_source=staff_data_source, + ) if not list: self._db.rollback() - return MISSING_CUSTOM_LIST.detailed(_("Could not find list \"%(list_name)s\"", list_name=name)) + return MISSING_CUSTOM_LIST.detailed( + _('Could not find list "%(list_name)s"', list_name=name) + ) else: - list, is_new = create(self._db, CustomList, name=name, data_source=staff_data_source, library=library) + list, is_new = create( + self._db, + CustomList, + name=name, + data_source=staff_data_source, + library=library, + ) list.created = utc_now() entry, was_new = list.add_entry(work, featured=True) if was_new: diff --git a/api/admin/exceptions.py b/api/admin/exceptions.py index a47bb0eea5..73274985de 100644 --- a/api/admin/exceptions.py +++ b/api/admin/exceptions.py @@ -1,6 +1,8 @@ from .problem_details import * + class AdminNotAuthorized(Exception): status_code = 403 + def as_problem_detail_document(self, debug=False): return ADMIN_NOT_AUTHORIZED diff --git a/api/admin/geographic_validator.py b/api/admin/geographic_validator.py index a133e4809c..f929fdd9e9 100644 --- a/api/admin/geographic_validator.py +++ b/api/admin/geographic_validator.py @@ -1,30 +1,33 @@ -from api.problem_details import * +import json +import os +import re +import urllib.error +import urllib.parse +import urllib.request + +import uszipcode +from flask_babel import lazy_gettext as _ +from pypostalcode import PostalCodeDatabase + from api.admin.exceptions import * from api.admin.validator import Validator +from api.problem_details import * from api.registry import RemoteRegistry -from core.model import ( - ExternalIntegration, - Representation -) +from core.model import ExternalIntegration, Representation from core.util.http import HTTP from core.util.problem_detail import ProblemDetail -from flask_babel import lazy_gettext as _ -import json -from pypostalcode import PostalCodeDatabase -import re -import urllib.request, urllib.parse, urllib.error -import uszipcode -import os -class GeographicValidator(Validator): +class GeographicValidator(Validator): @staticmethod def get_us_search(): # Use a known path for the uszipcode db_file_dir that already contains the DB that the # library would otherwise download. This is done because the host for this file can # be flaky. There is an issue for this in the underlying library here: # https://github.com/MacHu-GWU/uszipcode-project/issues/40 - db_file_path = os.path.join(os.path.dirname(__file__), "..", "..", "data", "uszipcode") + db_file_path = os.path.join( + os.path.dirname(__file__), "..", "..", "data", "uszipcode" + ) return uszipcode.SearchEngine(simple_zipcode=True, db_file_dir=db_file_path) def validate_geographic_areas(self, values, db): @@ -45,7 +48,7 @@ def validate_geographic_areas(self, values, db): "PE": "Prince Edward Island", "QC": "Quebec", "SK": "Saskatchewan", - "YT": "Yukon Territories" + "YT": "Yukon Territories", } locations = {"US": [], "CA": []} @@ -62,7 +65,12 @@ def validate_geographic_areas(self, values, db): elif len(us_search.query(state=value)): locations["US"].append(value) else: - return UNKNOWN_LOCATION.detailed(_('"%(value)s" is not a valid U.S. state or Canadian province abbreviation.', value=value)) + return UNKNOWN_LOCATION.detailed( + _( + '"%(value)s" is not a valid U.S. state or Canadian province abbreviation.', + value=value, + ) + ) elif value in list(CA_PROVINCES.values()): locations["CA"].append(value) elif self.is_zip(value, "CA"): @@ -72,20 +80,39 @@ def validate_geographic_areas(self, values, db): formatted = "%s, %s" % (info.city, info.province) # In some cases--mainly involving very small towns--even if the zip code is valid, # the registry won't recognize the name of the place to which it corresponds. - registry_response = self.find_location_through_registry(formatted, db) + registry_response = self.find_location_through_registry( + formatted, db + ) if registry_response: - locations["CA"].append(formatted); + locations["CA"].append(formatted) else: - return UNKNOWN_LOCATION.detailed(_('Unable to locate "%(value)s" (%(formatted)s). Try entering the name of a larger area.', value=value, formatted=formatted)) + return UNKNOWN_LOCATION.detailed( + _( + 'Unable to locate "%(value)s" (%(formatted)s). Try entering the name of a larger area.', + value=value, + formatted=formatted, + ) + ) except: - return UNKNOWN_LOCATION.detailed(_('"%(value)s" is not a valid Canadian zipcode.', value=value)) + return UNKNOWN_LOCATION.detailed( + _( + '"%(value)s" is not a valid Canadian zipcode.', + value=value, + ) + ) elif len(value.split(", ")) == 2: # Is it in the format "[city], [state abbreviation]" or "[county], [state abbreviation]"? city_or_county, state = value.split(", ") if us_search.by_city_and_state(city_or_county, state): - locations["US"].append(value); - elif len([x for x in us_search.query(state=state, returns=None) if x.county == city_or_county]): - locations["US"].append(value); + locations["US"].append(value) + elif len( + [ + x + for x in us_search.query(state=state, returns=None) + if x.county == city_or_county + ] + ): + locations["US"].append(value) else: # Flag this as needing to be checked with the registry flagged = True @@ -93,19 +120,25 @@ def validate_geographic_areas(self, values, db): # Is it a US zipcode? info = self.look_up_zip(value, "US") if not info: - return UNKNOWN_LOCATION.detailed(_('"%(value)s" is not a valid U.S. zipcode.', value=value)) - locations["US"].append(value); + return UNKNOWN_LOCATION.detailed( + _('"%(value)s" is not a valid U.S. zipcode.', value=value) + ) + locations["US"].append(value) else: flagged = True if flagged: registry_response = self.find_location_through_registry(value, db) - if registry_response and isinstance(registry_response, ProblemDetail): + if registry_response and isinstance( + registry_response, ProblemDetail + ): return registry_response elif registry_response: locations[registry_response].append(value) else: - return UNKNOWN_LOCATION.detailed(_('Unable to locate "%(value)s".', value=value)) + return UNKNOWN_LOCATION.detailed( + _('Unable to locate "%(value)s".', value=value) + ) return locations def is_zip(self, value, country): @@ -127,7 +160,7 @@ def look_up_zip(self, zip, country, formatted=False): def format_place(self, zip, city, state_or_province): details = "%s, %s" % (city, state_or_province) - return { zip: details } + return {zip: details} def find_location_through_registry(self, value, db): for nation in ["US", "CA"]: @@ -143,13 +176,20 @@ def ask_registry(self, service_area_object, db, do_get=HTTP.debuggable_get): # If the circulation manager doesn't know about this location, check whether the Library Registry does. result = None for registry in RemoteRegistry.for_protocol_and_goal( - db, ExternalIntegration.OPDS_REGISTRATION, ExternalIntegration.DISCOVERY_GOAL + db, + ExternalIntegration.OPDS_REGISTRATION, + ExternalIntegration.DISCOVERY_GOAL, ): base_url = registry.integration.url + "/coverage?coverage=" response = do_get(base_url + service_area_object) if not response.status_code == 200: - result = REMOTE_INTEGRATION_FAILED.detailed(_("Unable to contact the registry at %(url)s.", url=registry.integration.url)) + result = REMOTE_INTEGRATION_FAILED.detailed( + _( + "Unable to contact the registry at %(url)s.", + url=registry.integration.url, + ) + ) if hasattr(response, "content"): content = json.loads(response.content) diff --git a/api/admin/google_oauth_admin_authentication_provider.py b/api/admin/google_oauth_admin_authentication_provider.py index 32b2b00fcb..132a596573 100644 --- a/api/admin/google_oauth_admin_authentication_provider.py +++ b/api/admin/google_oauth_admin_authentication_provider.py @@ -1,11 +1,10 @@ import json from collections import defaultdict -from api.admin.template_styles import * -from .admin_authentication_provider import AdminAuthenticationProvider -from .problem_details import GOOGLE_OAUTH_FAILURE, INVALID_ADMIN_CREDENTIALS -from oauth2client import client as GoogleClient from flask_babel import lazy_gettext as _ +from oauth2client import client as GoogleClient + +from api.admin.template_styles import * from core.model import ( Admin, AdminRole, @@ -15,52 +14,71 @@ get_one, ) +from .admin_authentication_provider import AdminAuthenticationProvider +from .problem_details import GOOGLE_OAUTH_FAILURE, INVALID_ADMIN_CREDENTIALS + + class GoogleOAuthAdminAuthenticationProvider(AdminAuthenticationProvider): NAME = ExternalIntegration.GOOGLE_OAUTH DESCRIPTION = _("How to Configure a Google OAuth Integration") DOMAINS = "domains" - INSTRUCTIONS = _("

Configuring a Google OAuth integration in the Circulation Manager " + - "will allow admins to sign into the Admin interface with their Google/GMail credentials.

" + - "

Configure the Google OAuth Service:

" + - "
  1. To use this integration, visit the " + - "Google developer console. " + - "Create a project, click 'Create Credentials' in the left sidebar, and select 'OAuth client ID'. " + - "If you get a warning about the consent screen, click 'Configure consent screen' and enter your " + - "library name as the product name. Save the consent screen information.
  2. " + - "
  3. Choose 'Web Application' as the application type.
  4. " + - "
  5. Leave 'Authorized JavaScript origins' blank, but under 'Authorized redirect URIs', add the url " + - "of your circulation manager followed by '/admin/GoogleAuth/callback', e.g. " + - "'http://mycircmanager.org/admin/GoogleAuth/callback'.
  6. " - "
  7. Click create, and you'll get a popup with your new client ID and secret. " + - "Copy these values and enter them in the form below.
") - + INSTRUCTIONS = _( + "

Configuring a Google OAuth integration in the Circulation Manager " + + "will allow admins to sign into the Admin interface with their Google/GMail credentials.

" + + "

Configure the Google OAuth Service:

" + + "
  1. To use this integration, visit the " + + "Google developer console. " + + "Create a project, click 'Create Credentials' in the left sidebar, and select 'OAuth client ID'. " + + "If you get a warning about the consent screen, click 'Configure consent screen' and enter your " + + "library name as the product name. Save the consent screen information.
  2. " + + "
  3. Choose 'Web Application' as the application type.
  4. " + + "
  5. Leave 'Authorized JavaScript origins' blank, but under 'Authorized redirect URIs', add the url " + + "of your circulation manager followed by '/admin/GoogleAuth/callback', e.g. " + + "'http://mycircmanager.org/admin/GoogleAuth/callback'.
  6. " + "
  7. Click create, and you'll get a popup with your new client ID and secret. " + + "Copy these values and enter them in the form below.
" + ) SETTINGS = [ { - "key": ExternalIntegration.URL, - "label": _("Authentication URI"), - "default": "https://accounts.google.com/o/oauth2/auth", - "required": True, - "format": "url", + "key": ExternalIntegration.URL, + "label": _("Authentication URI"), + "default": "https://accounts.google.com/o/oauth2/auth", + "required": True, + "format": "url", + }, + { + "key": ExternalIntegration.USERNAME, + "label": _("Client ID"), + "required": True, + }, + { + "key": ExternalIntegration.PASSWORD, + "label": _("Client Secret"), + "required": True, }, - { "key": ExternalIntegration.USERNAME, "label": _("Client ID"), "required": True }, - { "key": ExternalIntegration.PASSWORD, "label": _("Client Secret"), "required": True }, ] LIBRARY_SETTINGS = [ - { "key": DOMAINS, - "label": _("Allowed Domains"), - "description": _("Anyone who logs in with an email address from one of these domains will automatically have librarian-level access to this library. Library manager roles must still be granted individually by other admins. If you want to set up admins individually but still allow them to log in with Google, you can create the admin authentication service without adding any libraries."), - "type": "list" }, + { + "key": DOMAINS, + "label": _("Allowed Domains"), + "description": _( + "Anyone who logs in with an email address from one of these domains will automatically have librarian-level access to this library. Library manager roles must still be granted individually by other admins. If you want to set up admins individually but still allow them to log in with Google, you can create the admin authentication service without adding any libraries." + ), + "type": "list", + }, ] SITEWIDE = True TEMPLATE = """ Sign in with Google - """.format(link_style) + """.format( + link_style + ) def __init__(self, integration, redirect_uri, test_mode=False): super(GoogleOAuthAdminAuthenticationProvider, self).__init__(integration) @@ -78,8 +96,8 @@ def client(self): config["auth_uri"] = self.integration.url config["client_id"] = self.integration.username config["client_secret"] = self.integration.password - config['redirect_uri'] = self.redirect_uri - config['scope'] = "https://www.googleapis.com/auth/userinfo.email" + config["redirect_uri"] = self.redirect_uri + config["scope"] = "https://www.googleapis.com/auth/userinfo.email" return GoogleClient.OAuth2WebServerFlow(**config) @property @@ -89,14 +107,15 @@ def domains(self): _db = Session.object_session(self.integration) for library in self.integration.libraries: setting = ConfigurationSetting.for_library_and_externalintegration( - _db, self.DOMAINS, library, self.integration) + _db, self.DOMAINS, library, self.integration + ) if setting.json_value: for domain in setting.json_value: domains[domain.lower()].append(library) return domains def sign_in_template(self, redirect_url): - return self.TEMPLATE % dict(auth_uri = self.auth_uri(redirect_url)) + return self.TEMPLATE % dict(auth_uri=self.auth_uri(redirect_url)) def auth_uri(self, redirect_url): return self.client.step1_get_authorize_url(state=redirect_url) @@ -106,29 +125,34 @@ def callback(self, _db, request={}): # The Google OAuth client sometimes hits the callback with an error. # These will be returned as a problem detail. - error = request.get('error') + error = request.get("error") if error: return self.google_error_problem_detail(error), None - auth_code = request.get('code') + auth_code = request.get("code") if auth_code: redirect_url = request.get("state") try: credentials = self.client.step2_exchange(auth_code) except GoogleClient.FlowExchangeError as e: return self.google_error_problem_detail(str(e)), None - email = credentials.id_token.get('email') + email = credentials.id_token.get("email") if not self.staff_email(_db, email): return INVALID_ADMIN_CREDENTIALS, None - domain = email[email.index('@')+1:].lower() + domain = email[email.index("@") + 1 :].lower() roles = [] for library in self.domains[domain]: - roles.append({ "role": AdminRole.LIBRARIAN, "library": library.short_name }) - return dict( - email=email, - credentials=credentials.to_json(), - type=self.NAME, - roles=roles, - ), redirect_url + roles.append( + {"role": AdminRole.LIBRARIAN, "library": library.short_name} + ) + return ( + dict( + email=email, + credentials=credentials.to_json(), + type=self.NAME, + roles=roles, + ), + redirect_url, + ) def google_error_problem_detail(self, error): error_detail = _("Error: %(error)s", error=error) @@ -146,7 +170,9 @@ def active_credentials(self, admin): """Check that existing credentials aren't expired""" if admin.credential: - oauth_credentials = GoogleClient.OAuth2Credentials.from_json(admin.credential) + oauth_credentials = GoogleClient.OAuth2Credentials.from_json( + admin.credential + ) return not oauth_credentials.access_token_expired return False @@ -159,8 +185,11 @@ def staff_email(self, _db, email): # Otherwise, their email must match one of the configured domains. staff_domains = list(self.domains.keys()) - domain = email[email.index('@')+1:] - return domain.lower() in [staff_domain.lower() for staff_domain in staff_domains] + domain = email[email.index("@") + 1 :] + return domain.lower() in [ + staff_domain.lower() for staff_domain in staff_domains + ] + class DummyGoogleClient(object): """Mock Google OAuth client for testing""" @@ -173,8 +202,8 @@ class Credentials(object): access_token_expired = False def __init__(self, email): - domain = email[email.index('@')+1:] - self.id_token = {"hd" : domain, "email" : email} + domain = email[email.index("@") + 1 :] + self.id_token = {"hd": domain, "email": email} def to_json(self): return json.dumps(dict(id_token=self.id_token)) @@ -182,7 +211,7 @@ def to_json(self): def from_json(self, credentials): return self - def __init__(self, email='example@nypl.org'): + def __init__(self, email="example@nypl.org"): self.credentials = self.Credentials(email=email) self.OAuth2Credentials = self.credentials diff --git a/api/admin/opds.py b/api/admin/opds.py index 96df4fed54..df9190b1c4 100644 --- a/api/admin/opds.py +++ b/api/admin/opds.py @@ -1,36 +1,41 @@ - from sqlalchemy import and_ + +from api.config import CannotLoadConfiguration +from api.metadata_wrangler import MetadataWranglerCollectionRegistrar from api.opds import LibraryAnnotator -from core.opds import VerboseAnnotator from core.lane import Facets, Pagination -from core.model import ( - DataSource, - LicensePool, - Measurement, - Session, -) +from core.mirror import MirrorUploader +from core.model import DataSource, LicensePool, Measurement, Session from core.model.configuration import ExternalIntegrationLink -from core.opds import AcquisitionFeed +from core.opds import AcquisitionFeed, VerboseAnnotator from core.util.opds_writer import AtomFeed -from core.mirror import MirrorUploader -from api.metadata_wrangler import MetadataWranglerCollectionRegistrar -from api.config import CannotLoadConfiguration -class AdminAnnotator(LibraryAnnotator): +class AdminAnnotator(LibraryAnnotator): def __init__(self, circulation, library, test_mode=False): - super(AdminAnnotator, self).__init__(circulation, None, library, test_mode=test_mode) + super(AdminAnnotator, self).__init__( + circulation, None, library, test_mode=test_mode + ) self.opds_cache_field = None - def annotate_work_entry(self, work, active_license_pool, edition, identifier, feed, entry): + def annotate_work_entry( + self, work, active_license_pool, edition, identifier, feed, entry + ): - super(AdminAnnotator, self).annotate_work_entry(work, active_license_pool, edition, identifier, feed, entry) + super(AdminAnnotator, self).annotate_work_entry( + work, active_license_pool, edition, identifier, feed, entry + ) VerboseAnnotator.add_ratings(work, entry) # Find staff rating and add a tag for it. for measurement in identifier.measurements: - if measurement.data_source.name == DataSource.LIBRARY_STAFF and measurement.is_most_recent: - entry.append(self.rating_tag(measurement.quantity_measured, measurement.value)) + if ( + measurement.data_source.name == DataSource.LIBRARY_STAFF + and measurement.is_most_recent + ): + entry.append( + self.rating_tag(measurement.quantity_measured, measurement.value) + ) try: MetadataWranglerCollectionRegistrar(work.license_pools[0].collection) @@ -40,7 +45,9 @@ def annotate_work_entry(self, work, active_license_pool, edition, identifier, fe href=self.url_for( "refresh", identifier_type=identifier.type, - identifier=identifier.identifier, _external=True) + identifier=identifier.identifier, + _external=True, + ), ) except CannotLoadConfiguration: # Leave out the refresh link if there's no metadata wrangler @@ -54,7 +61,9 @@ def annotate_work_entry(self, work, active_license_pool, edition, identifier, fe href=self.url_for( "unsuppress", identifier_type=identifier.type, - identifier=identifier.identifier, _external=True) + identifier=identifier.identifier, + _external=True, + ), ) else: feed.add_link_to_entry( @@ -63,7 +72,9 @@ def annotate_work_entry(self, work, active_license_pool, edition, identifier, fe href=self.url_for( "suppress", identifier_type=identifier.type, - identifier=identifier.identifier, _external=True) + identifier=identifier.identifier, + _external=True, + ), ) feed.add_link_to_entry( @@ -72,7 +83,9 @@ def annotate_work_entry(self, work, active_license_pool, edition, identifier, fe href=self.url_for( "edit", identifier_type=identifier.type, - identifier=identifier.identifier, _external=True) + identifier=identifier.identifier, + _external=True, + ), ) # If there is a storage integration for the collection, changing the cover is allowed. @@ -86,7 +99,9 @@ def annotate_work_entry(self, work, active_license_pool, edition, identifier, fe href=self.url_for( "work_change_book_cover", identifier_type=identifier.type, - identifier=identifier.identifier, _external=True) + identifier=identifier.identifier, + _external=True, + ), ) def complaints_url(self, facets, pagination): @@ -100,20 +115,14 @@ def suppressed_url(self, pagination): def annotate_feed(self, feed): # Add a 'search' link. - search_url = self.url_for( - 'lane_search', languages=None, - _external=True - ) + search_url = self.url_for("lane_search", languages=None, _external=True) search_link = dict( - rel="search", - type="application/opensearchdescription+xml", - href=search_url + rel="search", type="application/opensearchdescription+xml", href=search_url ) feed.add_link_to_feed(feed.feed, **search_link) class AdminFeed(AcquisitionFeed): - @classmethod def complaints(cls, library, title, url, annotator, pagination=None): _db = Session.object_session(library) @@ -134,21 +143,37 @@ def complaints(cls, library, title, url, annotator, pagination=None): # Render a 'start' link top_level_title = annotator.top_level_title() start_uri = annotator.groups_url(None) - AdminFeed.add_link_to_feed(feed.feed, href=start_uri, rel="start", title=top_level_title) + AdminFeed.add_link_to_feed( + feed.feed, href=start_uri, rel="start", title=top_level_title + ) # Render an 'up' link, same as the 'start' link to indicate top-level feed - AdminFeed.add_link_to_feed(feed.feed, href=start_uri, rel="up", title=top_level_title) + AdminFeed.add_link_to_feed( + feed.feed, href=start_uri, rel="up", title=top_level_title + ) if len(works) > 0: # There are works in this list. Add a 'next' link. - AdminFeed.add_link_to_feed(feed.feed, rel="next", href=annotator.complaints_url(facets, pagination.next_page)) + AdminFeed.add_link_to_feed( + feed.feed, + rel="next", + href=annotator.complaints_url(facets, pagination.next_page), + ) if pagination.offset > 0: - AdminFeed.add_link_to_feed(feed.feed, rel="first", href=annotator.complaints_url(facets, pagination.first_page)) + AdminFeed.add_link_to_feed( + feed.feed, + rel="first", + href=annotator.complaints_url(facets, pagination.first_page), + ) previous_page = pagination.previous_page if previous_page: - AdminFeed.add_link_to_feed(feed.feed, rel="previous", href=annotator.complaints_url(facets, previous_page)) + AdminFeed.add_link_to_feed( + feed.feed, + rel="previous", + href=annotator.complaints_url(facets, previous_page), + ) annotator.annotate_feed(feed) return str(feed) @@ -157,13 +182,15 @@ def complaints(cls, library, title, url, annotator, pagination=None): def suppressed(cls, _db, title, url, annotator, pagination=None): pagination = pagination or Pagination.default() - q = _db.query(LicensePool).filter( - and_( - LicensePool.suppressed == True, - LicensePool.superceded == False, + q = ( + _db.query(LicensePool) + .filter( + and_( + LicensePool.suppressed == True, + LicensePool.superceded == False, + ) ) - ).order_by( - LicensePool.id + .order_by(LicensePool.id) ) pools = pagination.modify_database_query(_db, q).all() @@ -173,22 +200,35 @@ def suppressed(cls, _db, title, url, annotator, pagination=None): # Render a 'start' link top_level_title = annotator.top_level_title() start_uri = annotator.groups_url(None) - AdminFeed.add_link_to_feed(feed.feed, href=start_uri, rel="start", title=top_level_title) + AdminFeed.add_link_to_feed( + feed.feed, href=start_uri, rel="start", title=top_level_title + ) # Render an 'up' link, same as the 'start' link to indicate top-level feed - AdminFeed.add_link_to_feed(feed.feed, href=start_uri, rel="up", title=top_level_title) + AdminFeed.add_link_to_feed( + feed.feed, href=start_uri, rel="up", title=top_level_title + ) if len(works) > 0: # There are works in this list. Add a 'next' link. - AdminFeed.add_link_to_feed(feed.feed, rel="next", href=annotator.suppressed_url(pagination.next_page)) + AdminFeed.add_link_to_feed( + feed.feed, + rel="next", + href=annotator.suppressed_url(pagination.next_page), + ) if pagination.offset > 0: - AdminFeed.add_link_to_feed(feed.feed, rel="first", href=annotator.suppressed_url(pagination.first_page)) + AdminFeed.add_link_to_feed( + feed.feed, + rel="first", + href=annotator.suppressed_url(pagination.first_page), + ) previous_page = pagination.previous_page if previous_page: - AdminFeed.add_link_to_feed(feed.feed, rel="previous", href=annotator.suppressed_url(previous_page)) + AdminFeed.add_link_to_feed( + feed.feed, rel="previous", href=annotator.suppressed_url(previous_page) + ) annotator.annotate_feed(feed) return str(feed) - diff --git a/api/admin/password_admin_authentication_provider.py b/api/admin/password_admin_authentication_provider.py index dfbc427b73..ae6b10fdb2 100644 --- a/api/admin/password_admin_authentication_provider.py +++ b/api/admin/password_admin_authentication_provider.py @@ -1,14 +1,12 @@ - from flask import url_for +from core.model import Admin, Session + from .admin_authentication_provider import AdminAuthenticationProvider -from core.model import ( - Admin, - Session, -) from .problem_details import * from .template_styles import * + class PasswordAdminAuthenticationProvider(AdminAuthenticationProvider): NAME = "Password Auth" @@ -19,11 +17,15 @@ class PasswordAdminAuthenticationProvider(AdminAuthenticationProvider): -""".format(label=label_style, input=input_style, button=button_style) +""".format( + label=label_style, input=input_style, button=button_style + ) def sign_in_template(self, redirect): password_sign_in_url = url_for("password_auth") - return self.TEMPLATE % dict(redirect=redirect, password_sign_in_url=password_sign_in_url) + return self.TEMPLATE % dict( + redirect=redirect, password_sign_in_url=password_sign_in_url + ) def sign_in(self, _db, request={}): email = request.get("email") @@ -33,10 +35,13 @@ def sign_in(self, _db, request={}): if email and password: match = Admin.authenticate(_db, email, password) if match: - return dict( - email=email, - type=self.NAME, - ), redirect_url + return ( + dict( + email=email, + type=self.NAME, + ), + redirect_url, + ) return INVALID_ADMIN_CREDENTIALS, None diff --git a/api/admin/problem_details.py b/api/admin/problem_details.py index e54e3b05e0..2c561bf120 100644 --- a/api/admin/problem_details.py +++ b/api/admin/problem_details.py @@ -1,7 +1,8 @@ -from core.util.problem_detail import ProblemDetail as pd -from api.problem_details import * from flask_babel import lazy_gettext as _ +from api.problem_details import * +from core.util.problem_detail import ProblemDetail as pd + ADMIN_AUTH_NOT_CONFIGURED = pd( "http://librarysimplified.org/terms/problem/admin-auth-not-configured", 500, @@ -13,14 +14,16 @@ "http://librarysimplified.org/terms/problem/admin-auth-mechanism-not-configured", 400, _("Admin auth mechanism not configured"), - _("This circulation manager has not been configured to authenticate admins with the mechanism you used"), + _( + "This circulation manager has not been configured to authenticate admins with the mechanism you used" + ), ) INVALID_ADMIN_CREDENTIALS = pd( - "http://librarysimplified.org/terms/problem/admin-credentials-invalid", - 401, - _("Invalid admin credentials"), - _("Valid library staff credentials are required."), + "http://librarysimplified.org/terms/problem/admin-credentials-invalid", + 401, + _("Invalid admin credentials"), + _("Valid library staff credentials are required."), ) ADMIN_NOT_AUTHORIZED = pd( @@ -31,38 +34,38 @@ ) GOOGLE_OAUTH_FAILURE = pd( - "http://librarysimplified.org/terms/problem/google-oauth-failure", - 400, - _("Google OAuth Error"), - _("There was an error connecting with Google OAuth."), + "http://librarysimplified.org/terms/problem/google-oauth-failure", + 400, + _("Google OAuth Error"), + _("There was an error connecting with Google OAuth."), ) INVALID_CSRF_TOKEN = pd( - "http://librarysimplified.org/terms/problem/invalid-csrf-token", - 400, - _("Invalid CSRF token"), - _("There was an error saving your changes."), + "http://librarysimplified.org/terms/problem/invalid-csrf-token", + 400, + _("Invalid CSRF token"), + _("There was an error saving your changes."), ) INVALID_EDIT = pd( - "http://librarysimplified.org/terms/problem/invalid-edit", - 400, - _("Invalid edit"), - _("There was a problem with the edited metadata."), + "http://librarysimplified.org/terms/problem/invalid-edit", + 400, + _("Invalid edit"), + _("There was a problem with the edited metadata."), ) METADATA_REFRESH_PENDING = pd( - "http://librarysimplified.org/terms/problem/metadata-refresh-pending", - 201, - _("Metadata refresh pending."), - _("The Metadata Wrangler is looking for new data. Check back later."), + "http://librarysimplified.org/terms/problem/metadata-refresh-pending", + 201, + _("Metadata refresh pending."), + _("The Metadata Wrangler is looking for new data. Check back later."), ) METADATA_REFRESH_FAILURE = pd( - "http://librarysimplified.org/terms/problem/metadata-refresh-failure", - 400, - _("Metadata could not be refreshed."), - _("Metadata could not be refreshed."), + "http://librarysimplified.org/terms/problem/metadata-refresh-failure", + 400, + _("Metadata could not be refreshed."), + _("Metadata could not be refreshed."), ) UNRECOGNIZED_COMPLAINT = pd( @@ -125,7 +128,7 @@ "http://librarysimplified.org/terms/problem/unknown-location", status_code=400, title=_("Unknown location."), - detail=_("The submitted geographic location cannot be found.") + detail=_("The submitted geographic location cannot be found."), ) UNKNOWN_ROLE = pd( @@ -167,7 +170,9 @@ "http://librarysimplified.org/terms/problem/library-short-name-already-in-use", status_code=400, title=_("Library short name already in use"), - detail=_("The library short name must be unique, and there's already a library with the specified short name."), + detail=_( + "The library short name must be unique, and there's already a library with the specified short name." + ), ) MISSING_COLLECTION = pd( @@ -195,21 +200,27 @@ "http://librarysimplified.org/terms/problem/collection-name-already-in-use", status_code=400, title=_("Collection name already in use"), - detail=_("The collection name must be unique, and there's already a collection with the specified name."), + detail=_( + "The collection name must be unique, and there's already a collection with the specified name." + ), ) CANNOT_DELETE_COLLECTION_WITH_CHILDREN = pd( "http://librarysimplified.org/terms/problem/cannot-delete-collection-with-children", status_code=400, title=_("Cannot delete collection with children"), - detail=_("The collection is the parent of at least one other collection, so it can't be deleted."), + detail=_( + "The collection is the parent of at least one other collection, so it can't be deleted." + ), ) NO_PROTOCOL_FOR_NEW_SERVICE = pd( "http://librarysimplified.org/terms/problem/no-protocol-for-new-service", status_code=400, title=_("No protocol for new service"), - detail=_("The specified service doesn't exist. You can create it, but you must specify a protocol."), + detail=_( + "The specified service doesn't exist. You can create it, but you must specify a protocol." + ), ) UNKNOWN_PROTOCOL = pd( @@ -230,7 +241,9 @@ "http://librarysimplified.org/terms/problem/protocol-does-not-support-parents", status_code=400, title=_("Protocol does not support parents"), - detail=_("You attempted to add a parent but the protocol does not support parents."), + detail=_( + "You attempted to add a parent but the protocol does not support parents." + ), ) MISSING_PARENT = pd( @@ -258,14 +271,16 @@ "http://librarysimplified.org/terms/problem/duplicate-integration", status_code=400, title=_("Duplicate integration"), - detail=_("A given site can only support one integration of this type.") + detail=_("A given site can only support one integration of this type."), ) INTEGRATION_NAME_ALREADY_IN_USE = pd( "http://librarysimplified.org/terms/problem/integration-name-already-in-use", status_code=400, title=_("Integration name already in use"), - detail=_("The integration name must be unique, and there's already an integration with the specified name."), + detail=_( + "The integration name must be unique, and there's already an integration with the specified name." + ), ) INTEGRATION_URL_ALREADY_IN_USE = pd( @@ -279,7 +294,9 @@ "http://librarysimplified.org/terms/problem/integration-goal-conflict", status_code=409, title=_("Incompatible use of integration"), - detail=_("You tried to use an integration in a way incompatible with the goal of that integration"), + detail=_( + "You tried to use an integration in a way incompatible with the goal of that integration" + ), ) MISSING_INTEGRATION = pd( @@ -293,7 +310,9 @@ "http://librarysimplified.org/terms/problem/missing-pgcrypto-extension", status_code=500, title=_("Missing pgcrypto database extension"), - detail=_("You tried to store a password for an individual admin, but the database does not have the pgcrypto extension installed."), + detail=_( + "You tried to store a password for an individual admin, but the database does not have the pgcrypto extension installed." + ), ) MISSING_ADMIN = pd( @@ -349,21 +368,27 @@ "http://librarysimplified.org/terms/problem/invalid-library-identifier-restriction-regular-expression", status_code=400, title=_("Invalid library identifier restriction regular expression"), - detail=_("The specified library identifier restriction regular expression does not compile."), + detail=_( + "The specified library identifier restriction regular expression does not compile." + ), ) MULTIPLE_BASIC_AUTH_SERVICES = pd( "http://librarysimplified.org/terms/problem/multiple-basic-auth-services", status_code=400, title=_("Multiple basic authentication services"), - detail=_("Each library can only have one patron authentication service using basic auth."), + detail=_( + "Each library can only have one patron authentication service using basic auth." + ), ) NO_SUCH_PATRON = pd( "http://librarysimplified.org/terms/problem/no-such-patron", status_code=404, title=_("No such patron"), - detail=_("The specified patron doesn't exist, or is associated with a different library."), + detail=_( + "The specified patron doesn't exist, or is associated with a different library." + ), ) MISSING_SITEWIDE_SETTING_KEY = pd( @@ -384,7 +409,9 @@ "http://librarysimplified.org/terms/problem/multiple-search-services", status_code=400, title=_("Multiple sitewide services"), - detail=_("You tried to create a new sitewide service, but a sitewide service of the same type is already configured."), + detail=_( + "You tried to create a new sitewide service, but a sitewide service of the same type is already configured." + ), ) MULTIPLE_SERVICES_FOR_LIBRARY = pd( @@ -419,14 +446,18 @@ "http://librarysimplified.org/terms/problem/collection-not-associated-with-library", status_code=400, title=_("Collection not associated with library"), - detail=_("You can't add a collection to a list unless it is associated with the list's library."), + detail=_( + "You can't add a collection to a list unless it is associated with the list's library." + ), ) MISSING_LANE = pd( "http://librarysimplified.org/terms/problem/missing-lane", status_code=404, title=_("Missing lane"), - detail=_("The specified lane doesn't exist, or is associated with a different library."), + detail=_( + "The specified lane doesn't exist, or is associated with a different library." + ), ) CANNOT_EDIT_DEFAULT_LANE = pd( @@ -454,7 +485,9 @@ "http://librarysimplified.org/terms/problem/lane-with-parent-and-display-name-already-exists", status_code=400, title=_("Lane with parent and display name already exists"), - detail=_("You cannot create a lane with the same parent and display name as an existing lane."), + detail=_( + "You cannot create a lane with the same parent and display name as an existing lane." + ), ) CANNOT_SHOW_LANE_WITH_HIDDEN_PARENT = pd( @@ -475,12 +508,12 @@ "http://librarysimplified.org/terms/problem/failed-to-run-self-tests", status_code=400, title=_("Failed to run self tests."), - detail=_("Failed to run self tests.") + detail=_("Failed to run self tests."), ) MISSING_IDENTIFIER = pd( "http://librarysimplified.org/terms/problem/missing-identifier", status_code=400, title=_("Missing identifier"), - detail=_("No identifier was used.") + detail=_("No identifier was used."), ) diff --git a/api/admin/routes.py b/api/admin/routes.py index ab50211c17..63059c80b8 100644 --- a/api/admin/routes.py +++ b/api/admin/routes.py @@ -2,79 +2,81 @@ from functools import wraps import flask -from flask import (make_response, redirect, Response) +from flask import Response, make_response, redirect from api.admin.config import Configuration as AdminClientConfig from api.app import app from api.config import Configuration -from api.routes import (allows_library, has_library, library_route) +from api.routes import allows_library, has_library, library_route from core.app_server import returns_problem_detail from core.local_analytics_provider import LocalAnalyticsProvider -from core.model import ( - ConfigurationSetting, -) +from core.model import ConfigurationSetting from core.util.problem_detail import ProblemDetail -from .controller import ( - setup_admin_controllers, -) -from .templates import ( - admin_sign_in_again as sign_in_again_template, -) + +from .controller import setup_admin_controllers +from .templates import admin_sign_in_again as sign_in_again_template # An admin's session will expire after this amount of time and # the admin will have to log in again. app.permanent_session_lifetime = timedelta(hours=9) + @app.before_first_request def setup_admin(_db=None): - if getattr(app, 'manager', None) is not None: + if getattr(app, "manager", None) is not None: setup_admin_controllers(app.manager) _db = _db or app._db # The secret key is used for signing cookies for admin login - app.secret_key = ConfigurationSetting.sitewide_secret( - _db, Configuration.SECRET_KEY - ) + app.secret_key = ConfigurationSetting.sitewide_secret(_db, Configuration.SECRET_KEY) # Create a default Local Analytics service if one does not # already exist. local_analytics = LocalAnalyticsProvider.initialize(_db) + def allows_admin_auth_setup(f): @wraps(f) def decorated(*args, **kwargs): - setting_up = (app.manager.admin_sign_in_controller.admin_auth_providers == []) + setting_up = app.manager.admin_sign_in_controller.admin_auth_providers == [] return f(*args, setting_up=setting_up, **kwargs) + return decorated + def requires_admin(f): @wraps(f) def decorated(*args, **kwargs): - if 'setting_up' in kwargs: + if "setting_up" in kwargs: # If the function also requires a CSRF token, # setting_up needs to stay in the arguments for # the next decorator. Otherwise, it should be # removed before the route function. if f.__dict__.get("requires_csrf_token"): - setting_up = kwargs.get('setting_up') + setting_up = kwargs.get("setting_up") else: - setting_up = kwargs.pop('setting_up') + setting_up = kwargs.pop("setting_up") else: setting_up = False if not setting_up: - admin = app.manager.admin_sign_in_controller.authenticated_admin_from_request() + admin = ( + app.manager.admin_sign_in_controller.authenticated_admin_from_request() + ) if isinstance(admin, ProblemDetail): return app.manager.admin_sign_in_controller.error_response(admin) elif isinstance(admin, Response): return admin return f(*args, **kwargs) + return decorated + def requires_csrf_token(f): f.__dict__["requires_csrf_token"] = True + @wraps(f) def decorated(*args, **kwargs): - if 'setting_up' in kwargs: - setting_up = kwargs.pop('setting_up') + if "setting_up" in kwargs: + setting_up = kwargs.pop("setting_up") else: setting_up = False if not setting_up and flask.request.method in ["POST", "PUT", "DELETE"]: @@ -82,8 +84,10 @@ def decorated(*args, **kwargs): if isinstance(token, ProblemDetail): return token return f(*args, **kwargs) + return decorated + def returns_json_or_response_or_problem_detail(f): @wraps(f) def decorated(*args, **kwargs): @@ -93,71 +97,101 @@ def decorated(*args, **kwargs): if isinstance(v, Response): return v return flask.jsonify(**v) + return decorated -@app.route('/admin/GoogleAuth/callback') + +@app.route("/admin/GoogleAuth/callback") @returns_problem_detail def google_auth_callback(): return app.manager.admin_sign_in_controller.redirect_after_google_sign_in() + @app.route("/admin/sign_in_with_password", methods=["POST"]) @returns_problem_detail def password_auth(): return app.manager.admin_sign_in_controller.password_sign_in() -@app.route('/admin/sign_in') + +@app.route("/admin/sign_in") @returns_problem_detail def admin_sign_in(): return app.manager.admin_sign_in_controller.sign_in() -@app.route('/admin/sign_out') + +@app.route("/admin/sign_out") @returns_problem_detail @requires_admin def admin_sign_out(): return app.manager.admin_sign_in_controller.sign_out() -@app.route('/admin/change_password', methods=["POST"]) + +@app.route("/admin/change_password", methods=["POST"]) @returns_problem_detail @requires_admin def admin_change_password(): return app.manager.admin_sign_in_controller.change_password() -@library_route('/admin/works//', methods=['GET']) + +@library_route("/admin/works//", methods=["GET"]) @has_library @returns_problem_detail @requires_admin def work_details(identifier_type, identifier): return app.manager.admin_work_controller.details(identifier_type, identifier) -@library_route('/admin/works///classifications', methods=['GET']) + +@library_route( + "/admin/works///classifications", methods=["GET"] +) @has_library @returns_json_or_response_or_problem_detail @requires_admin def work_classifications(identifier_type, identifier): - return app.manager.admin_work_controller.classifications(identifier_type, identifier) + return app.manager.admin_work_controller.classifications( + identifier_type, identifier + ) -@library_route('/admin/works///preview_book_cover', methods=['POST']) + +@library_route( + "/admin/works///preview_book_cover", + methods=["POST"], +) @has_library @returns_problem_detail @requires_admin def work_preview_book_cover(identifier_type, identifier): - return app.manager.admin_work_controller.preview_book_cover(identifier_type, identifier) + return app.manager.admin_work_controller.preview_book_cover( + identifier_type, identifier + ) + -@library_route('/admin/works///change_book_cover', methods=['POST']) +@library_route( + "/admin/works///change_book_cover", + methods=["POST"], +) @has_library @returns_problem_detail @requires_admin def work_change_book_cover(identifier_type, identifier): - return app.manager.admin_work_controller.change_book_cover(identifier_type, identifier) + return app.manager.admin_work_controller.change_book_cover( + identifier_type, identifier + ) + -@library_route('/admin/works///complaints', methods=['GET']) +@library_route( + "/admin/works///complaints", methods=["GET"] +) @has_library @returns_json_or_response_or_problem_detail @requires_admin def work_complaints(identifier_type, identifier): return app.manager.admin_work_controller.complaints(identifier_type, identifier) -@library_route('/admin/works///lists', methods=['GET', 'POST']) + +@library_route( + "/admin/works///lists", methods=["GET", "POST"] +) @has_library @returns_json_or_response_or_problem_detail @requires_admin @@ -165,7 +199,10 @@ def work_complaints(identifier_type, identifier): def work_custom_lists(identifier_type, identifier): return app.manager.admin_work_controller.custom_lists(identifier_type, identifier) -@library_route('/admin/works///edit', methods=['POST']) + +@library_route( + "/admin/works///edit", methods=["POST"] +) @has_library @returns_problem_detail @requires_admin @@ -173,7 +210,10 @@ def work_custom_lists(identifier_type, identifier): def edit(identifier_type, identifier): return app.manager.admin_work_controller.edit(identifier_type, identifier) -@library_route('/admin/works///suppress', methods=['POST']) + +@library_route( + "/admin/works///suppress", methods=["POST"] +) @has_library @returns_problem_detail @requires_admin @@ -181,7 +221,10 @@ def edit(identifier_type, identifier): def suppress(identifier_type, identifier): return app.manager.admin_work_controller.suppress(identifier_type, identifier) -@library_route('/admin/works///unsuppress', methods=['POST']) + +@library_route( + "/admin/works///unsuppress", methods=["POST"] +) @has_library @returns_problem_detail @requires_admin @@ -189,58 +232,79 @@ def suppress(identifier_type, identifier): def unsuppress(identifier_type, identifier): return app.manager.admin_work_controller.unsuppress(identifier_type, identifier) -@library_route('/works///refresh', methods=['POST']) + +@library_route("/works///refresh", methods=["POST"]) @has_library @returns_problem_detail @requires_admin @requires_csrf_token def refresh(identifier_type, identifier): - return app.manager.admin_work_controller.refresh_metadata(identifier_type, identifier) + return app.manager.admin_work_controller.refresh_metadata( + identifier_type, identifier + ) + -@library_route('/admin/works///resolve_complaints', methods=['POST']) +@library_route( + "/admin/works///resolve_complaints", + methods=["POST"], +) @has_library @returns_problem_detail @requires_admin @requires_csrf_token def resolve_complaints(identifier_type, identifier): - return app.manager.admin_work_controller.resolve_complaints(identifier_type, identifier) + return app.manager.admin_work_controller.resolve_complaints( + identifier_type, identifier + ) + -@library_route('/admin/works///edit_classifications', methods=['POST']) +@library_route( + "/admin/works///edit_classifications", + methods=["POST"], +) @has_library @returns_problem_detail @requires_admin @requires_csrf_token def edit_classifications(identifier_type, identifier): - return app.manager.admin_work_controller.edit_classifications(identifier_type, identifier) + return app.manager.admin_work_controller.edit_classifications( + identifier_type, identifier + ) -@app.route('/admin/roles') + +@app.route("/admin/roles") @returns_json_or_response_or_problem_detail def roles(): return app.manager.admin_work_controller.roles() -@app.route('/admin/languages') + +@app.route("/admin/languages") @returns_json_or_response_or_problem_detail def languages(): return app.manager.admin_work_controller.languages() -@app.route('/admin/media') + +@app.route("/admin/media") @returns_json_or_response_or_problem_detail def media(): return app.manager.admin_work_controller.media() -@app.route('/admin/rights_status') + +@app.route("/admin/rights_status") @returns_json_or_response_or_problem_detail def rights_status(): return app.manager.admin_work_controller.rights_status() -@library_route('/admin/complaints') + +@library_route("/admin/complaints") @has_library @returns_problem_detail @requires_admin def complaints(): return app.manager.admin_feed_controller.complaints() -@library_route('/admin/suppressed') + +@library_route("/admin/suppressed") @has_library @returns_problem_detail @requires_admin @@ -248,21 +312,28 @@ def suppressed(): """Returns a feed of suppressed works.""" return app.manager.admin_feed_controller.suppressed() -@app.route('/admin/genres') + +@app.route("/admin/genres") @returns_json_or_response_or_problem_detail @requires_admin def genres(): """Returns a JSON representation of complete genre tree.""" return app.manager.admin_feed_controller.genres() -@library_route('/admin/bulk_circulation_events') + +@library_route("/admin/bulk_circulation_events") @returns_problem_detail @allows_library @requires_admin def bulk_circulation_events(): """Returns a CSV representation of all circulation events with optional start and end times.""" - data, date, date_end, library = app.manager.admin_dashboard_controller.bulk_circulation_events() + ( + data, + date, + date_end, + library, + ) = app.manager.admin_dashboard_controller.bulk_circulation_events() if isinstance(data, ProblemDetail): return data @@ -272,11 +343,14 @@ def bulk_circulation_events(): # for convenience. The start and end dates will always be included. filename = library + "-" if library else "" filename += date + "-to-" + date_end if date_end and date != date_end else date - response.headers['Content-Disposition'] = "attachment; filename=circulation_events_" + filename + ".csv" + response.headers["Content-Disposition"] = ( + "attachment; filename=circulation_events_" + filename + ".csv" + ) response.headers["Content-type"] = "text/csv" return response -@library_route('/admin/circulation_events') + +@library_route("/admin/circulation_events") @has_library @returns_json_or_response_or_problem_detail @requires_admin @@ -284,19 +358,22 @@ def circulation_events(): """Returns a JSON representation of the most recent circulation events.""" return app.manager.admin_dashboard_controller.circulation_events() -@app.route('/admin/stats') + +@app.route("/admin/stats") @returns_json_or_response_or_problem_detail @requires_admin def stats(): return app.manager.admin_dashboard_controller.stats() -@app.route('/admin/libraries', methods=['GET', 'POST']) + +@app.route("/admin/libraries", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def libraries(): return app.manager.admin_library_settings_controller.process_libraries() + @app.route("/admin/library/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -304,41 +381,53 @@ def libraries(): def library(library_uuid): return app.manager.admin_library_settings_controller.process_delete(library_uuid) -@app.route("/admin/collections", methods=['GET', 'POST']) + +@app.route("/admin/collections", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def collections(): return app.manager.admin_collection_settings_controller.process_collections() + @app.route("/admin/collection/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def collection(collection_id): - return app.manager.admin_collection_settings_controller.process_delete(collection_id) + return app.manager.admin_collection_settings_controller.process_delete( + collection_id + ) + @app.route("/admin/collection_self_tests/", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def collection_self_tests(identifier): - return app.manager.admin_collection_self_tests_controller.process_collection_self_tests(identifier) + return app.manager.admin_collection_self_tests_controller.process_collection_self_tests( + identifier + ) + -@app.route("/admin/collection_library_registrations", methods=['GET', 'POST']) +@app.route("/admin/collection_library_registrations", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def collection_library_registrations(): - return app.manager.admin_collection_library_registrations_controller.process_collection_library_registrations() + return ( + app.manager.admin_collection_library_registrations_controller.process_collection_library_registrations() + ) + -@app.route("/admin/admin_auth_services", methods=['GET', 'POST']) +@app.route("/admin/admin_auth_services", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def admin_auth_services(): return app.manager.admin_auth_services_controller.process_admin_auth_services() + @app.route("/admin/admin_auth_service/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -346,13 +435,17 @@ def admin_auth_services(): def admin_auth_service(protocol): return app.manager.admin_auth_services_controller.process_delete(protocol) -@app.route("/admin/individual_admins", methods=['GET', 'POST']) + +@app.route("/admin/individual_admins", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @allows_admin_auth_setup @requires_admin @requires_csrf_token def individual_admins(): - return app.manager.admin_individual_admin_settings_controller.process_individual_admins() + return ( + app.manager.admin_individual_admin_settings_controller.process_individual_admins() + ) + @app.route("/admin/individual_admin/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @@ -361,12 +454,16 @@ def individual_admins(): def individual_admin(email): return app.manager.admin_individual_admin_settings_controller.process_delete(email) -@app.route("/admin/patron_auth_services", methods=['GET', 'POST']) + +@app.route("/admin/patron_auth_services", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def patron_auth_services(): - return app.manager.admin_patron_auth_services_controller.process_patron_auth_services() + return ( + app.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) + @app.route("/admin/patron_auth_service/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @@ -375,14 +472,20 @@ def patron_auth_services(): def patron_auth_service(service_id): return app.manager.admin_patron_auth_services_controller.process_delete(service_id) -@app.route("/admin/patron_auth_service_self_tests/", methods=["GET", "POST"]) + +@app.route( + "/admin/patron_auth_service_self_tests/", methods=["GET", "POST"] +) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def patron_auth_self_tests(identifier): - return app.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests(identifier) + return app.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests( + identifier + ) + -@library_route("/admin/manage_patrons", methods=['POST']) +@library_route("/admin/manage_patrons", methods=["POST"]) @has_library @returns_json_or_response_or_problem_detail @requires_admin @@ -390,7 +493,8 @@ def patron_auth_self_tests(identifier): def lookup_patron(): return app.manager.admin_patron_controller.lookup_patron() -@library_route("/admin/manage_patrons/reset_adobe_id", methods=['POST']) + +@library_route("/admin/manage_patrons/reset_adobe_id", methods=["POST"]) @has_library @returns_json_or_response_or_problem_detail @requires_admin @@ -398,13 +502,15 @@ def lookup_patron(): def reset_adobe_id(): return app.manager.admin_patron_controller.reset_adobe_id() -@app.route("/admin/metadata_services", methods=['GET', 'POST']) + +@app.route("/admin/metadata_services", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def metadata_services(): return app.manager.admin_metadata_services_controller.process_metadata_services() + @app.route("/admin/metadata_service/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -412,20 +518,25 @@ def metadata_services(): def metadata_service(service_id): return app.manager.admin_metadata_services_controller.process_delete(service_id) + @app.route("/admin/metadata_service_self_tests/", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def metadata_service_self_tests(identifier): - return app.manager.admin_metadata_service_self_tests_controller.process_metadata_service_self_tests(identifier) + return app.manager.admin_metadata_service_self_tests_controller.process_metadata_service_self_tests( + identifier + ) + -@app.route("/admin/analytics_services", methods=['GET', 'POST']) +@app.route("/admin/analytics_services", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def analytics_services(): return app.manager.admin_analytics_services_controller.process_analytics_services() + @app.route("/admin/analytics_service/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -433,6 +544,7 @@ def analytics_services(): def analytics_service(service_id): return app.manager.admin_analytics_services_controller.process_delete(service_id) + @app.route("/admin/cdn_services", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -440,6 +552,7 @@ def analytics_service(service_id): def cdn_services(): return app.manager.admin_cdn_services_controller.process_cdn_services() + @app.route("/admin/cdn_service/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -447,6 +560,7 @@ def cdn_services(): def cdn_service(service_id): return app.manager.admin_cdn_services_controller.process_delete(service_id) + @app.route("/admin/search_services", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -454,6 +568,7 @@ def cdn_service(service_id): def search_services(): return app.manager.admin_search_services_controller.process_services() + @app.route("/admin/search_service/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -461,12 +576,15 @@ def search_services(): def search_service(service_id): return app.manager.admin_search_services_controller.process_delete(service_id) + @app.route("/admin/search_service_self_tests/", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def search_service_self_tests(identifier): - return app.manager.admin_search_service_self_tests_controller.process_search_service_self_tests(identifier) + return app.manager.admin_search_service_self_tests_controller.process_search_service_self_tests( + identifier + ) @app.route("/admin/storage_services", methods=["GET", "POST"]) @@ -476,6 +594,7 @@ def search_service_self_tests(identifier): def storage_services(): return app.manager.admin_storage_services_controller.process_services() + @app.route("/admin/storage_service/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -483,13 +602,15 @@ def storage_services(): def storage_service(service_id): return app.manager.admin_storage_services_controller.process_delete(service_id) -@app.route("/admin/catalog_services", methods=['GET', 'POST']) + +@app.route("/admin/catalog_services", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def catalog_services(): return app.manager.admin_catalog_services_controller.process_catalog_services() + @app.route("/admin/catalog_service/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -497,6 +618,7 @@ def catalog_services(): def catalog_service(service_id): return app.manager.admin_catalog_services_controller.process_delete(service_id) + @app.route("/admin/discovery_services", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -504,6 +626,7 @@ def catalog_service(service_id): def discovery_services(): return app.manager.admin_discovery_services_controller.process_discovery_services() + @app.route("/admin/discovery_service/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -511,27 +634,35 @@ def discovery_services(): def discovery_service(service_id): return app.manager.admin_discovery_services_controller.process_delete(service_id) -@app.route("/admin/sitewide_settings", methods=['GET', 'POST']) + +@app.route("/admin/sitewide_settings", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def sitewide_settings(): - return app.manager.admin_sitewide_configuration_settings_controller.process_services() + return ( + app.manager.admin_sitewide_configuration_settings_controller.process_services() + ) + @app.route("/admin/sitewide_setting/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def sitewide_setting(key): - return app.manager.admin_sitewide_configuration_settings_controller.process_delete(key) + return app.manager.admin_sitewide_configuration_settings_controller.process_delete( + key + ) + -@app.route("/admin/logging_services", methods=['GET', 'POST']) +@app.route("/admin/logging_services", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def logging_services(): return app.manager.admin_logging_services_controller.process_services() + @app.route("/admin/logging_service/", methods=["DELETE"]) @returns_json_or_response_or_problem_detail @requires_admin @@ -539,12 +670,16 @@ def logging_services(): def logging_service(key): return app.manager.admin_logging_services_controller.process_delete(key) -@app.route("/admin/discovery_service_library_registrations", methods=['GET', 'POST']) + +@app.route("/admin/discovery_service_library_registrations", methods=["GET", "POST"]) @returns_json_or_response_or_problem_detail @requires_admin @requires_csrf_token def discovery_service_library_registrations(): - return app.manager.admin_discovery_service_library_registrations_controller.process_discovery_service_library_registrations() + return ( + app.manager.admin_discovery_service_library_registrations_controller.process_discovery_service_library_registrations() + ) + @library_route("/admin/custom_lists", methods=["GET", "POST"]) @has_library @@ -554,6 +689,7 @@ def discovery_service_library_registrations(): def custom_lists(): return app.manager.admin_custom_lists_controller.custom_lists() + @library_route("/admin/custom_list/", methods=["GET", "POST", "DELETE"]) @has_library @returns_json_or_response_or_problem_detail @@ -562,6 +698,7 @@ def custom_lists(): def custom_list(list_id): return app.manager.admin_custom_lists_controller.custom_list(list_id) + @library_route("/admin/lanes", methods=["GET", "POST"]) @has_library @returns_json_or_response_or_problem_detail @@ -570,6 +707,7 @@ def custom_list(list_id): def lanes(): return app.manager.admin_lanes_controller.lanes() + @library_route("/admin/lane/", methods=["DELETE"]) @has_library @returns_json_or_response_or_problem_detail @@ -578,6 +716,7 @@ def lanes(): def lane(lane_identifier): return app.manager.admin_lanes_controller.lane(lane_identifier) + @library_route("/admin/lane//show", methods=["POST"]) @has_library @returns_json_or_response_or_problem_detail @@ -586,6 +725,7 @@ def lane(lane_identifier): def lane_show(lane_identifier): return app.manager.admin_lanes_controller.show_lane(lane_identifier) + @library_route("/admin/lane//hide", methods=["POST"]) @has_library @returns_json_or_response_or_problem_detail @@ -594,6 +734,7 @@ def lane_show(lane_identifier): def lane_hide(lane_identifier): return app.manager.admin_lanes_controller.hide_lane(lane_identifier) + @library_route("/admin/lanes/reset", methods=["POST"]) @has_library @returns_json_or_response_or_problem_detail @@ -602,6 +743,7 @@ def lane_hide(lane_identifier): def reset_lanes(): return app.manager.admin_lanes_controller.reset() + @library_route("/admin/lanes/change_order", methods=["POST"]) @has_library @returns_json_or_response_or_problem_detail @@ -610,39 +752,49 @@ def reset_lanes(): def change_lane_order(): return app.manager.admin_lanes_controller.change_order() + @app.route("/admin/diagnostics") @requires_admin @returns_json_or_response_or_problem_detail def diagnostics(): return app.manager.timestamps_controller.diagnostics() -@app.route('/admin/sign_in_again') + +@app.route("/admin/sign_in_again") def admin_sign_in_again(): """Allows an admin with expired credentials to sign back in from a new browser tab so they won't lose changes. """ admin = app.manager.admin_sign_in_controller.authenticated_admin_from_request() csrf_token = app.manager.admin_sign_in_controller.get_csrf_token() - if isinstance(admin, ProblemDetail) or csrf_token is None or isinstance(csrf_token, ProblemDetail): + if ( + isinstance(admin, ProblemDetail) + or csrf_token is None + or isinstance(csrf_token, ProblemDetail) + ): redirect_url = flask.request.url - return redirect(app.manager.url_for('admin_sign_in', redirect=redirect_url)) + return redirect(app.manager.url_for("admin_sign_in", redirect=redirect_url)) return flask.render_template_string(sign_in_again_template) -@app.route('/admin/web/', strict_slashes=False) -@app.route('/admin/web/collection//book/') -@app.route('/admin/web/collection/') -@app.route('/admin/web/book/') -@app.route('/admin/web/') # catchall for single-page URLs + +@app.route("/admin/web/", strict_slashes=False) +@app.route("/admin/web/collection//book/") +@app.route("/admin/web/collection/") +@app.route("/admin/web/book/") +@app.route("/admin/web/") # catchall for single-page URLs def admin_view(collection=None, book=None, etc=None, **kwargs): return app.manager.admin_view_controller(collection, book, path=etc) -@app.route('/admin/', strict_slashes=False) + +@app.route("/admin/", strict_slashes=False) def admin_base(**kwargs): - return redirect(app.manager.url_for('admin_view')) + return redirect(app.manager.url_for("admin_view")) + # This path is used only in debug mode to serve frontend assets. -@app.route('/admin/static/') +@app.route("/admin/static/") @returns_problem_detail def admin_static_file(filename): - return app.manager.static_files.static_file(AdminClientConfig.static_files_directory(), filename) - + return app.manager.static_files.static_file( + AdminClientConfig.static_files_directory(), filename + ) diff --git a/api/admin/template_styles.py b/api/admin/template_styles.py index 7189bbab94..87f82f16ec 100644 --- a/api/admin/template_styles.py +++ b/api/admin/template_styles.py @@ -15,9 +15,12 @@ font-weight: 700; """ -error_style = body_style + """ +error_style = ( + body_style + + """ border-color: #D0343A; """ +) input_style = """ border-radius: .25em; display: block; @@ -65,10 +68,13 @@ margin: 2vh auto; """ -small_link_style = link_style + """ +small_link_style = ( + link_style + + """ width: 5vw; margin-bottom: 0; """ +) hr_style = """ width: 10vw; diff --git a/api/admin/validator.py b/api/admin/validator.py index f3d14257e4..50f4dc42cf 100644 --- a/api/admin/validator.py +++ b/api/admin/validator.py @@ -10,7 +10,6 @@ class Validator(object): - def validate(self, settings, content): validators = [ self.validate_email, @@ -25,7 +24,9 @@ def validate(self, settings, content): if error: return error - def _extract_inputs(self, settings, value, form, key="format", is_list=False, should_zip=False): + def _extract_inputs( + self, settings, value, form, key="format", is_list=False, should_zip=False + ): if not (isinstance(settings, list)): return [] @@ -62,7 +63,9 @@ def validate_email(self, settings, content): emails = [emails] for email in emails: if not self._is_email(email): - return INVALID_EMAIL.detailed(_('"%(email)s" is not a valid email address.', email=email)) + return INVALID_EMAIL.detailed( + _('"%(email)s" is not a valid email address.', email=email) + ) def _is_email(self, email): """Email addresses must be in the format 'x@y.z'.""" @@ -73,7 +76,9 @@ def validate_url(self, settings, content): """Find any URLs that the user has submitted, and make sure that they are in a valid format.""" # Find the fields that have to do with URLs and are not blank. - url_inputs = self._extract_inputs(settings, "url", content.get("form"), should_zip=True) + url_inputs = self._extract_inputs( + settings, "url", content.get("form"), should_zip=True + ) for field, urls in url_inputs: if not isinstance(urls, list): @@ -83,20 +88,26 @@ def validate_url(self, settings, content): # for example, the patron web client URL can be set to "*". allowed = field.get("allowed") or [] if not self._is_url(url, allowed): - return INVALID_URL.detailed(_('"%(url)s" is not a valid URL.', url=url)) + return INVALID_URL.detailed( + _('"%(url)s" is not a valid URL.', url=url) + ) @classmethod def _is_url(cls, url, allowed): if not url: return False - has_protocol = any([url.startswith(protocol + "://") for protocol in ("http", "https")]) + has_protocol = any( + [url.startswith(protocol + "://") for protocol in ("http", "https")] + ) return has_protocol or (url in allowed) def validate_number(self, settings, content): """Find any numbers that the user has submitted, and make sure that they are 1) actually numbers, 2) positive, and 3) lower than the specified maximum, if there is one.""" # Find the fields that should have numeric input and are not blank. - number_inputs = self._extract_inputs(settings, "number", content.get("form"), key="type", should_zip=True) + number_inputs = self._extract_inputs( + settings, "number", content.get("form"), key="type", should_zip=True + ) for field, number in number_inputs: error = self._number_error(field, number) if error: @@ -109,19 +120,37 @@ def _number_error(self, field, number): try: number = float(number) except ValueError: - return INVALID_NUMBER.detailed(_('"%(number)s" is not a number.', number=number)) + return INVALID_NUMBER.detailed( + _('"%(number)s" is not a number.', number=number) + ) if number < min: - return INVALID_NUMBER.detailed(_('%(field)s must be greater than %(min)s.', field=field.get("label"), min=min)) + return INVALID_NUMBER.detailed( + _( + "%(field)s must be greater than %(min)s.", + field=field.get("label"), + min=min, + ) + ) if max and number > max: - return INVALID_NUMBER.detailed(_('%(field)s cannot be greater than %(max)s.', field=field.get("label"), max=max)) + return INVALID_NUMBER.detailed( + _( + "%(field)s cannot be greater than %(max)s.", + field=field.get("label"), + max=max, + ) + ) def validate_language_code(self, settings, content): # Find the fields that should contain language codes and are not blank. - language_inputs = self._extract_inputs(settings, "language-code", content.get("form"), is_list=True) + language_inputs = self._extract_inputs( + settings, "language-code", content.get("form"), is_list=True + ) for language in language_inputs: if not self._is_language(language): - return UNKNOWN_LANGUAGE.detailed(_('"%(language)s" is not a valid language code.', language=language)) + return UNKNOWN_LANGUAGE.detailed( + _('"%(language)s" is not a valid language code.', language=language) + ) def _is_language(self, language): # Check that the input string is in the list of recognized language codes. @@ -131,19 +160,28 @@ def validate_image(self, settings, content): # Find the fields that contain image uploads and are not blank. files = content.get("files") if files: - image_inputs = self._extract_inputs(settings, "image", files, key="type", should_zip=True) + image_inputs = self._extract_inputs( + settings, "image", files, key="type", should_zip=True + ) for setting, image in image_inputs: invalid_format = self._image_format_error(image) if invalid_format: - return INVALID_CONFIGURATION_OPTION.detailed(_( - "Upload for %(setting)s must be in GIF, PNG, or JPG format. (Upload was %(format)s.)", - setting=setting.get("label"), - format=invalid_format)) + return INVALID_CONFIGURATION_OPTION.detailed( + _( + "Upload for %(setting)s must be in GIF, PNG, or JPG format. (Upload was %(format)s.)", + setting=setting.get("label"), + format=invalid_format, + ) + ) def _image_format_error(self, image_file): # Check that the uploaded image is in an acceptable format. - allowed_types = [Representation.JPEG_MEDIA_TYPE, Representation.PNG_MEDIA_TYPE, Representation.GIF_MEDIA_TYPE] + allowed_types = [ + Representation.JPEG_MEDIA_TYPE, + Representation.PNG_MEDIA_TYPE, + Representation.GIF_MEDIA_TYPE, + ] image_type = image_file.headers.get("Content-Type") if not image_type in allowed_types: return image_type @@ -170,8 +208,8 @@ def _value(self, field, form): class PatronAuthenticationValidatorFactory(object): """Creates Validator instances for particular authentication protocols""" - VALIDATOR_CLASS_NAME = 'Validator' - VALIDATOR_FACTORY = 'validator_factory' + VALIDATOR_CLASS_NAME = "Validator" + VALIDATOR_FACTORY = "validator_factory" def __init__(self): """Initializes a new instance of ValidatorFactory class""" @@ -221,6 +259,12 @@ def create(self, protocol): if validator: return validator except: - self._logger.warning(_('Could not load a validator defined in {0} module'.format(module_name))) + self._logger.warning( + _( + "Could not load a validator defined in {0} module".format( + module_name + ) + ) + ) return None diff --git a/api/adobe_vendor_id.py b/api/adobe_vendor_id.py index 94a33202ce..c756873b13 100644 --- a/api/adobe_vendor_id.py +++ b/api/adobe_vendor_id.py @@ -1,34 +1,22 @@ import argparse +import base64 +import datetime import json import logging -import uuid -import base64 import os -import datetime -import jwt -from jwt.algorithms import HMACAlgorithm import sys +import uuid + import flask +import jwt from flask import Response from flask_babel import lazy_gettext as _ -from .config import ( - CannotLoadConfiguration, - Configuration, -) +from jwt.algorithms import HMACAlgorithm +from sqlalchemy.orm.session import Session from api.base_controller import BaseCirculationManagerController -from .problem_details import * -from sqlalchemy.orm.session import Session -from core.util.datetime_helpers import ( - datetime_utc, - utc_now, -) -from core.util.xmlparser import XMLParser -from core.util.problem_detail import ProblemDetail from core.app_server import url_for from core.model import ( - create, - get_one, ConfigurationSetting, Credential, DataSource, @@ -36,14 +24,24 @@ ExternalIntegration, Library, Patron, + create, + get_one, ) from core.scripts import Script +from core.util.datetime_helpers import datetime_utc, utc_now +from core.util.problem_detail import ProblemDetail +from core.util.xmlparser import XMLParser + +from .config import CannotLoadConfiguration, Configuration +from .problem_details import * + class AdobeVendorIDController(object): """Flask controllers that implement the Account Service and Authorization Service portions of the Adobe Vendor ID protocol. """ + def __init__(self, _db, library, vendor_id, node_value, authenticator): self._db = _db self.library = library @@ -66,15 +64,16 @@ def signin_handler(self): """Process an incoming signInRequest document.""" __transaction = self._db.begin_nested() output = self.request_handler.handle_signin_request( - flask.request.data, self.model.standard_lookup, - self.model.authdata_lookup) + flask.request.data, self.model.standard_lookup, self.model.authdata_lookup + ) __transaction.commit() return Response(output, 200, {"Content-Type": "application/xml"}) def userinfo_handler(self): """Process an incoming userInfoRequest document.""" output = self.request_handler.handle_accountinfo_request( - flask.request.data, self.model.urn_to_label) + flask.request.data, self.model.urn_to_label + ) return Response(output, 200, {"Content-Type": "application/xml"}) def status_handler(self): @@ -86,8 +85,9 @@ class DeviceManagementProtocolController(BaseCirculationManagerController): The code that does the actual work is in DeviceManagementRequestHandler. """ + DEVICE_ID_LIST_MEDIA_TYPE = "vnd.librarysimplified/drm-device-id-list" - PLAIN_TEXT_HEADERS = {"Content-Type" : "text/plain"} + PLAIN_TEXT_HEADERS = {"Content-Type": "text/plain"} @property def link_template_header(self): @@ -95,7 +95,12 @@ def link_template_header(self): a specific DRM device ID. """ library = flask.request.library - url = url_for("adobe_drm_device", library_short_name=library.short_name, device_id="{id}", _external=True) + url = url_for( + "adobe_drm_device", + library_short_name=library.short_name, + device_id="{id}", + _external=True, + ) # The curly brackets in {id} were escaped. Un-escape them to # get a Link Template. url = url.replace("%7Bid%7D", "{id}") @@ -122,37 +127,34 @@ def device_id_list_handler(self): return handler device_ids = self.DEVICE_ID_LIST_MEDIA_TYPE - if flask.request.method=='GET': + if flask.request.method == "GET": # Serve a list of device IDs. output = handler.device_list() if isinstance(output, ProblemDetail): return output headers = self.link_template_header - headers['Content-Type'] = device_ids + headers["Content-Type"] = device_ids return Response(output, 200, headers) - elif flask.request.method=='POST': + elif flask.request.method == "POST": # Add a device ID to the list. - incoming_media_type = flask.request.headers.get('Content-Type') + incoming_media_type = flask.request.headers.get("Content-Type") if incoming_media_type != device_ids: return UNSUPPORTED_MEDIA_TYPE.detailed( - _("Expected %(media_type)s document.", - media_type=device_ids) + _("Expected %(media_type)s document.", media_type=device_ids) ) output = handler.register_device(flask.request.get_data(as_text=True)) if isinstance(output, ProblemDetail): return output return Response(output, 200, self.PLAIN_TEXT_HEADERS) - return METHOD_NOT_ALLOWED.detailed( - _("Only GET and POST are supported.") - ) + return METHOD_NOT_ALLOWED.detailed(_("Only GET and POST are supported.")) def device_id_handler(self, device_id): """Manage one of the device IDs associated with an Adobe ID.""" - handler = self._request_handler(getattr(flask.request, 'patron', None)) + handler = self._request_handler(getattr(flask.request, "patron", None)) if isinstance(handler, ProblemDetail): return handler - if flask.request.method != 'DELETE': + if flask.request.method != "DELETE": return METHOD_NOT_ALLOWED.detailed(_("Only DELETE is supported.")) # Delete the specified device ID. @@ -182,8 +184,8 @@ class AdobeVendorIDRequestHandler(object): ERROR_RESPONSE_TEMPLATE = '' - TOKEN_FAILURE = 'Incorrect token.' - AUTHENTICATION_FAILURE = 'Incorrect barcode or PIN.' + TOKEN_FAILURE = "Incorrect token." + AUTHENTICATION_FAILURE = "Incorrect barcode or PIN." URN_LOOKUP_FAILURE = "Could not identify patron from '%s'." def __init__(self, vendor_id): @@ -199,22 +201,21 @@ def handle_signin_request(self, data, standard_lookup, authdata_lookup): user_id = label = None if not data: return self.error_document( - self.AUTH_ERROR_TYPE, "Request document in wrong format.") - if not 'method' in data: - return self.error_document( - self.AUTH_ERROR_TYPE, "No method specified") - if data['method'] == parser.STANDARD: + self.AUTH_ERROR_TYPE, "Request document in wrong format." + ) + if not "method" in data: + return self.error_document(self.AUTH_ERROR_TYPE, "No method specified") + if data["method"] == parser.STANDARD: user_id, label = standard_lookup(data) failure = self.AUTHENTICATION_FAILURE - elif data['method'] == parser.AUTH_DATA: + elif data["method"] == parser.AUTH_DATA: authdata = data[parser.AUTH_DATA] user_id, label = authdata_lookup(authdata) failure = self.TOKEN_FAILURE if user_id is None: return self.error_document(self.AUTH_ERROR_TYPE, failure) else: - return self.SIGN_IN_RESPONSE_TEMPLATE % dict( - user=user_id, label=label) + return self.SIGN_IN_RESPONSE_TEMPLATE % dict(user=user_id, label=label) def handle_accountinfo_request(self, data, urn_to_label): parser = AdobeAccountInfoRequestParser() @@ -223,28 +224,28 @@ def handle_accountinfo_request(self, data, urn_to_label): data = parser.process(data) if not data: return self.error_document( - self.ACCOUNT_INFO_ERROR_TYPE, - "Request document in wrong format.") - if not 'user' in data: + self.ACCOUNT_INFO_ERROR_TYPE, "Request document in wrong format." + ) + if not "user" in data: return self.error_document( self.ACCOUNT_INFO_ERROR_TYPE, - "Could not find user identifer in request document.") - label = urn_to_label(data['user']) + "Could not find user identifer in request document.", + ) + label = urn_to_label(data["user"]) except Exception as e: - return self.error_document( - self.ACCOUNT_INFO_ERROR_TYPE, str(e)) + return self.error_document(self.ACCOUNT_INFO_ERROR_TYPE, str(e)) if label: return self.ACCOUNT_INFO_RESPONSE_TEMPLATE % dict(label=label) else: return self.error_document( - self.ACCOUNT_INFO_ERROR_TYPE, - self.URN_LOOKUP_FAILURE % data['user'] + self.ACCOUNT_INFO_ERROR_TYPE, self.URN_LOOKUP_FAILURE % data["user"] ) def error_document(self, type, message): return self.ERROR_RESPONSE_TEMPLATE % dict( - vendor_id=self.vendor_id, type=type, message=message) + vendor_id=self.vendor_id, type=type, message=message + ) class DeviceManagementRequestHandler(object): @@ -255,10 +256,7 @@ def __init__(self, credential): def device_list(self): return "\n".join( - sorted( - x.device_identifier - for x in self.credential.drm_device_identifiers - ) + sorted(x.device_identifier for x in self.credential.drm_device_identifiers) ) def register_device(self, data): @@ -270,20 +268,19 @@ def register_device(self, data): for device_id in device_ids: if device_id: self.credential.register_drm_device_identifier(device_id) - return 'Success' + return "Success" def deregister_device(self, device_id): self.credential.deregister_drm_device_identifier(device_id) - return 'Success' + return "Success" class AdobeRequestParser(XMLParser): - NAMESPACES = { "adept" : "http://ns.adobe.com/adept" } + NAMESPACES = {"adept": "http://ns.adobe.com/adept"} def process(self, data): - requests = list(self.process_all( - data, self.REQUEST_XPATH, self.NAMESPACES)) + requests = list(self.process_all(data, self.REQUEST_XPATH, self.NAMESPACES)) if not requests: return None # There should only be one request tag, but if there's more than @@ -291,7 +288,7 @@ def process(self, data): return requests[0] def _add(self, d, tag, key, namespaces, transform=None): - v = self._xpath1(tag, 'adept:' + key, namespaces) + v = self._xpath1(tag, "adept:" + key, namespaces) if v is not None: v = v.text if v is not None: @@ -302,36 +299,38 @@ def _add(self, d, tag, key, namespaces, transform=None): v = v.decode("utf-8") d[key] = v + class AdobeSignInRequestParser(AdobeRequestParser): REQUEST_XPATH = "/adept:signInRequest" - STANDARD = 'standard' - AUTH_DATA = 'authData' + STANDARD = "standard" + AUTH_DATA = "authData" def process_one(self, tag, namespaces): - method = tag.attrib.get('method') + method = tag.attrib.get("method") if not method: raise ValueError("No signin method specified") data = dict(method=method) if method == self.STANDARD: - self._add(data, tag, 'username', namespaces) - self._add(data, tag, 'password', namespaces) + self._add(data, tag, "username", namespaces) + self._add(data, tag, "password", namespaces) elif method == self.AUTH_DATA: self._add(data, tag, self.AUTH_DATA, namespaces, base64.b64decode) else: raise ValueError("Unknown signin method: %s" % method) return data + class AdobeAccountInfoRequestParser(AdobeRequestParser): REQUEST_XPATH = "/adept:accountInfoRequest" def process_one(self, tag, namespaces): - method = tag.attrib.get('method') + method = tag.attrib.get("method") data = dict(method=method) - self._add(data, tag, 'user', namespaces) + self._add(data, tag, "user", namespaces) return data @@ -344,13 +343,15 @@ class AdobeVendorIDModel(object): AUTHDATA_TOKEN_TYPE = "Authdata for Adobe Vendor ID" VENDOR_ID_UUID_TOKEN_TYPE = "Vendor ID UUID" - def __init__(self, _db, library, authenticator, node_value, - temporary_token_duration=None): + def __init__( + self, _db, library, authenticator, node_value, temporary_token_duration=None + ): self.library = library self._db = _db self.authenticator = authenticator - self.temporary_token_duration = ( - temporary_token_duration or datetime.timedelta(minutes=10)) + self.temporary_token_duration = temporary_token_duration or datetime.timedelta( + minutes=10 + ) if isinstance(node_value, (bytes, str)): node_value = int(node_value, 16) self.node_value = node_value @@ -378,8 +379,8 @@ def uuid_and_label(self, patron): # First, find or create a Credential containing the patron's # anonymized key into the DelegatedPatronIdentifier database. - adobe_account_id_patron_identifier_credential = self.get_or_create_patron_identifier_credential( - patron + adobe_account_id_patron_identifier_credential = ( + self.get_or_create_patron_identifier_credential(patron) ) # Look up a Credential containing the patron's Adobe account @@ -387,8 +388,11 @@ def uuid_and_label(self, patron): # Credential.lookup because we don't want to create a # Credential if it doesn't exist. old_style_adobe_account_id_credential = get_one( - self._db, Credential, patron=patron, data_source=self.data_source, - type=self.VENDOR_ID_UUID_TOKEN_TYPE + self._db, + Credential, + patron=patron, + data_source=self.data_source, + type=self.VENDOR_ID_UUID_TOKEN_TYPE, ) if old_style_adobe_account_id_credential: @@ -397,6 +401,7 @@ def uuid_and_label(self, patron): # we have to create one. def new_value(): return old_style_adobe_account_id_credential.credential + else: # There is no old-style credential. If we have to create a # new DelegatedPatronIdentifier we will give it a value @@ -407,14 +412,14 @@ def new_value(): # anonymized patron identifier we just looked up or created. utility = AuthdataUtility.from_config(patron.library, self._db) return self.to_delegated_patron_identifier_uuid( - utility.library_uri, adobe_account_id_patron_identifier_credential.credential, - value_generator=new_value + utility.library_uri, + adobe_account_id_patron_identifier_credential.credential, + value_generator=new_value, ) def create_authdata(self, patron): credential, is_new = Credential.persistent_token_create( - self._db, self.data_source, self.AUTHDATA_TOKEN_TYPE, - patron + self._db, self.data_source, self.AUTHDATA_TOKEN_TYPE, patron ) return credential @@ -423,14 +428,14 @@ def standard_lookup(self, authorization_data): UUID and their human-readable label, creating a Credential object to hold the UUID if necessary. """ - username = authorization_data.get('username') - password = authorization_data.get('password') + username = authorization_data.get("username") + password = authorization_data.get("password") if username and not password: # The absence of a password indicates the username might # be a persistent authdata token smuggled to get around a # broken Adobe client-side API. Try treating the # 'username' as a token. - possible_authdata_token = authorization_data['username'] + possible_authdata_token = authorization_data["username"] return self.authdata_lookup(possible_authdata_token) if username and password: @@ -443,9 +448,7 @@ def standard_lookup(self, authorization_data): # Last ditch effort: try a normal username/password lookup. # This should almost never be used. - patron = self.authenticator.authenticated_patron( - self._db, authorization_data - ) + patron = self.authenticator.authenticated_patron(self._db, authorization_data) return self.uuid_and_label(patron) def authdata_lookup(self, authdata): @@ -470,9 +473,7 @@ def authdata_lookup(self, authdata): # Hopefully this is an authdata JWT generated by another # library's circulation manager. try: - library_uri, foreign_patron_identifier = utility.decode( - authdata - ) + library_uri, foreign_patron_identifier = utility.decode(authdata) except Exception as e: # Not a problem -- we'll try the old system. pass @@ -507,7 +508,10 @@ def short_client_token_lookup(self, token, signature): # Hopefully this is a short client token generated by # another library's circulation manager. try: - library_uri, foreign_patron_identifier = utility.decode_two_part_short_client_token(token, signature) + ( + library_uri, + foreign_patron_identifier, + ) = utility.decode_two_part_short_client_token(token, signature) except Exception as e: # This didn't work--either the incoming data was wrong # or this technique wasn't the right one to use. @@ -529,7 +533,7 @@ def short_client_token_lookup(self, token, signature): return uuid_and_label def to_delegated_patron_identifier_uuid( - self, library_uri, foreign_patron_identifier, value_generator=None + self, library_uri, foreign_patron_identifier, value_generator=None ): """Create or lookup a DelegatedPatronIdentifier containing an Adobe account ID for the given library and foreign patron ID. @@ -540,20 +544,28 @@ def to_delegated_patron_identifier_uuid( return None, None value_generator = value_generator or self.uuid identifier, is_new = DelegatedPatronIdentifier.get_one_or_create( - self._db, library_uri, foreign_patron_identifier, - DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID, value_generator + self._db, + library_uri, + foreign_patron_identifier, + DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID, + value_generator, ) if identifier is None: return None, None - return (identifier.delegated_identifier, - self.urn_to_label(identifier.delegated_identifier)) + return ( + identifier.delegated_identifier, + self.urn_to_label(identifier.delegated_identifier), + ) def patron_from_authdata_lookup(self, authdata): """Look up a patron by their persistent authdata token.""" credential = Credential.lookup_by_token( - self._db, self.data_source, self.AUTHDATA_TOKEN_TYPE, - authdata, allow_persistent_token=True + self._db, + self.data_source, + self.AUTHDATA_TOKEN_TYPE, + authdata, + allow_persistent_token=True, ) if not credential: return None @@ -575,13 +587,18 @@ def uuid(self): @classmethod def get_or_create_patron_identifier_credential(cls, patron): _db = Session.object_session(patron) + def refresh(credential): credential.credential = str(uuid.uuid1()) + data_source = DataSource.lookup(_db, DataSource.INTERNAL_PROCESSING) patron_identifier_credential = Credential.lookup( - _db, data_source, + _db, + data_source, AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, - patron, refresher_method=refresh, allow_persistent_token=True + patron, + refresher_method=refresh, + allow_persistent_token=True, ) return patron_identifier_credential @@ -607,10 +624,11 @@ class AuthdataUtility(object): # consequences other than losing their currently checked-in books. ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER = "Identifier for Adobe account ID purposes" - ALGORITHM = 'HS256' + ALGORITHM = "HS256" - def __init__(self, vendor_id, library_uri, library_short_name, secret, - other_libraries={}): + def __init__( + self, vendor_id, library_uri, library_short_name, secret, other_libraries={} + ): """Basic constructor. :param vendor_id: The Adobe Vendor ID that should accompany authdata @@ -664,21 +682,17 @@ def __init__(self, vendor_id, library_uri, library_short_name, secret, if short_name in self.library_uris_by_short_name: # This can happen if the same library is in the list # twice, capitalized differently. - raise ValueError( - "Duplicate short name: %s" % short_name - ) + raise ValueError("Duplicate short name: %s" % short_name) self.library_uris_by_short_name[short_name] = uri self.secrets_by_library_uri[uri] = secret self.log = logging.getLogger("Adobe authdata utility") self.short_token_signer = HMACAlgorithm(HMACAlgorithm.SHA256) - self.short_token_signing_key = self.short_token_signer.prepare_key( - self.secret - ) + self.short_token_signing_key = self.short_token_signer.prepare_key(self.secret) - VENDOR_ID_KEY = 'vendor_id' - OTHER_LIBRARIES_KEY = 'other_libraries' + VENDOR_ID_KEY = "vendor_id" + OTHER_LIBRARIES_KEY = "other_libraries" @classmethod def from_config(cls, library, _db=None): @@ -698,26 +712,28 @@ def from_config(cls, library, _db=None): library = _db.merge(library, load=False) # Try to find an external integration with a configured Vendor ID. - integrations = _db.query( - ExternalIntegration - ).outerjoin( - ExternalIntegration.libraries - ).filter( - ExternalIntegration.protocol==ExternalIntegration.OPDS_REGISTRATION, - ExternalIntegration.goal==ExternalIntegration.DISCOVERY_GOAL, - Library.id==library.id + integrations = ( + _db.query(ExternalIntegration) + .outerjoin(ExternalIntegration.libraries) + .filter( + ExternalIntegration.protocol == ExternalIntegration.OPDS_REGISTRATION, + ExternalIntegration.goal == ExternalIntegration.DISCOVERY_GOAL, + Library.id == library.id, + ) ) integration = None for possible_integration in integrations: vendor_id = ConfigurationSetting.for_externalintegration( - cls.VENDOR_ID_KEY, possible_integration).value + cls.VENDOR_ID_KEY, possible_integration + ).value if vendor_id: integration = possible_integration break library_uri = ConfigurationSetting.for_library( - Configuration.WEBSITE_URL, library).value + Configuration.WEBSITE_URL, library + ).value if not integration: return None @@ -732,29 +748,29 @@ def from_config(cls, library, _db=None): other_libraries = None adobe_integration = ExternalIntegration.lookup( - _db, ExternalIntegration.ADOBE_VENDOR_ID, - ExternalIntegration.DRM_GOAL, library=library + _db, + ExternalIntegration.ADOBE_VENDOR_ID, + ExternalIntegration.DRM_GOAL, + library=library, ) if adobe_integration: - other_libraries = adobe_integration.setting(cls.OTHER_LIBRARIES_KEY).json_value + other_libraries = adobe_integration.setting( + cls.OTHER_LIBRARIES_KEY + ).json_value other_libraries = other_libraries or dict() - if (not vendor_id or not library_uri - or not library_short_name or not secret - ): + if not vendor_id or not library_uri or not library_short_name or not secret: raise CannotLoadConfiguration( "Short Client Token configuration is incomplete. " "vendor_id (%s), username (%s), password (%s) and " - "Library website_url (%s) must all be defined." % ( - vendor_id, library_uri, library_short_name, secret - ) + "Library website_url (%s) must all be defined." + % (vendor_id, library_uri, library_short_name, secret) ) - if '|' in library_short_name: + if "|" in library_short_name: raise CannotLoadConfiguration( "Library short name cannot contain the pipe character." ) - return cls(vendor_id, library_uri, library_short_name, secret, - other_libraries) + return cls(vendor_id, library_uri, library_short_name, secret, other_libraries) @classmethod def adobe_relevant_credentials(self, patron): @@ -768,12 +784,15 @@ def adobe_relevant_credentials(self, patron): :return: A SQLAlchemy query """ _db = Session.object_session(patron) - types = (AdobeVendorIDModel.VENDOR_ID_UUID_TOKEN_TYPE, - AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER) - return _db.query( - Credential).filter(Credential.patron==patron).filter( - Credential.type.in_(types) - ) + types = ( + AdobeVendorIDModel.VENDOR_ID_UUID_TOKEN_TYPE, + AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, + ) + return ( + _db.query(Credential) + .filter(Credential.patron == patron) + .filter(Credential.type.in_(types)) + ) def encode(self, patron_identifier): """Generate an authdata JWT suitable for putting in an OPDS feed, where @@ -786,20 +805,18 @@ def encode(self, patron_identifier): raise ValueError("No patron identifier specified") now = utc_now() expires = now + datetime.timedelta(minutes=60) - authdata = self._encode( - self.library_uri, patron_identifier, now, expires - ) + authdata = self._encode(self.library_uri, patron_identifier, now, expires) return self.vendor_id, authdata def _encode(self, iss=None, sub=None, iat=None, exp=None): """Helper method split out separately for use in tests.""" - payload = dict(iss=iss) # Issuer + payload = dict(iss=iss) # Issuer if sub: - payload['sub'] = sub # Subject + payload["sub"] = sub # Subject if iat: - payload['iat'] = self.numericdate(iat) # Issued At + payload["iat"] = self.numericdate(iat) # Issued At if exp: - payload['exp'] = self.numericdate(exp) # Expiration Time + payload["exp"] = self.numericdate(exp) # Expiration Time return base64.encodebytes( jwt.encode(payload, self.secret, algorithm=self.ALGORITHM) ) @@ -862,26 +879,23 @@ def decode(self, authdata): def _decode(self, authdata): # First, decode the authdata without checking the signature. decoded = jwt.decode( - authdata, algorithm=self.ALGORITHM, - options=dict(verify_signature=False) + authdata, algorithm=self.ALGORITHM, options=dict(verify_signature=False) ) # This lets us get the library URI, which lets us get the secret. - library_uri = decoded.get('iss') + library_uri = decoded.get("iss") if not library_uri in self.secrets_by_library_uri: # The request came in without a library specified # or with an unknown library specified. - raise jwt.exceptions.DecodeError( - "Unknown library: %s" % library_uri - ) + raise jwt.exceptions.DecodeError("Unknown library: %s" % library_uri) # We know the secret for this library, so we can re-decode the # secret and require signature valudation this time. secret = self.secrets_by_library_uri[library_uri] decoded = jwt.decode(authdata, secret, algorithm=self.ALGORITHM) - if not 'sub' in decoded: + if not "sub" in decoded: raise jwt.exceptions.DecodeError("No subject specified.") - return library_uri, decoded['sub'] + return library_uri, decoded["sub"] @classmethod def _adobe_patron_identifier(cls, patron): @@ -891,21 +905,24 @@ def _adobe_patron_identifier(cls, patron): def refresh(credential): credential.credential = str(uuid.uuid1()) + patron_identifier = Credential.lookup( - _db, internal, AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, patron, - refresher_method=refresh, allow_persistent_token=True + _db, + internal, + AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, + patron, + refresher_method=refresh, + allow_persistent_token=True, ) return patron_identifier.credential def short_client_token_for_patron(self, patron_information): """Generate short client token for patron, or for a patron's identifier - for Adobe ID purposes""" + for Adobe ID purposes""" if isinstance(patron_information, Patron): # Find the patron's identifier for Adobe ID purposes. - patron_identifier = self._adobe_patron_identifier( - patron_information - ) + patron_identifier = self._adobe_patron_identifier(patron_information) else: patron_identifier = patron_information @@ -924,7 +941,7 @@ def encode_short_client_token(self, patron_identifier, expires=None): :return: A 2-tuple (vendor ID, token) """ if expires is None: - expires = {'minutes': 60} + expires = {"minutes": 60} if not patron_identifier: raise ValueError("No patron identifier specified") expires = int(self.numericdate(self._now() + datetime.timedelta(**expires))) @@ -933,8 +950,9 @@ def encode_short_client_token(self, patron_identifier, expires=None): ) return self.vendor_id, authdata - def _encode_short_client_token(self, library_short_name, - patron_identifier, expires): + def _encode_short_client_token( + self, library_short_name, patron_identifier, expires + ): base = library_short_name + "|" + str(expires) + "|" + patron_identifier signature = self.short_token_signer.sign( base.encode("utf-8"), self.short_token_signing_key @@ -958,12 +976,12 @@ def decode_short_client_token(self, token): :raise ValueError: When the token is not valid for any reason. """ - if not '|' in token: + if not "|" in token: raise ValueError( 'Supposed client token "%s" does not contain a pipe.' % token ) - username, password = token.rsplit('|', 1) + username, password = token.rsplit("|", 1) return self.decode_two_part_short_client_token(username, password) def decode_two_part_short_client_token(self, username, password): @@ -977,7 +995,7 @@ def _decode_short_client_token(self, token, supposed_signature): """Make sure a client token is properly formatted, correctly signed, and not expired. """ - if token.count('|') < 2: + if token.count("|") < 2: raise ValueError("Invalid client token: %s" % token) library_short_name, expiration, patron_identifier = token.split("|", 2) @@ -990,19 +1008,16 @@ def _decode_short_client_token(self, token, supposed_signature): # We don't police the content of the patron identifier but there # has to be _something_ there. if not patron_identifier: - raise ValueError( - "Token %s has empty patron identifier" % token - ) + raise ValueError("Token %s has empty patron identifier" % token) if not library_short_name in self.library_uris_by_short_name: raise ValueError( - "I don't know how to handle tokens from library \"%s\"" % library_short_name + 'I don\'t know how to handle tokens from library "%s"' + % library_short_name ) library_uri = self.library_uris_by_short_name[library_short_name] if not library_uri in self.secrets_by_library_uri: - raise ValueError( - "I don't know the secret for library %s" % library_uri - ) + raise ValueError("I don't know the secret for library %s" % library_uri) secret = self.secrets_by_library_uri[library_uri] # Don't bother checking an expired token. @@ -1010,9 +1025,7 @@ def _decode_short_client_token(self, token, supposed_signature): expiration = self.EPOCH + datetime.timedelta(seconds=expiration) if expiration < now: raise ValueError( - "Token %s expired at %s (now is %s)." % ( - token, expiration, now - ) + "Token %s expired at %s (now is %s)." % (token, expiration, now) ) # Sign the token and check against the provided signature. @@ -1020,9 +1033,7 @@ def _decode_short_client_token(self, token, supposed_signature): actual_signature = self.short_token_signer.sign(token.encode("utf-8"), key) if actual_signature != supposed_signature: - raise ValueError( - "Invalid signature for %s." % token - ) + raise ValueError("Invalid signature for %s." % token) return library_uri, patron_identifier @@ -1031,7 +1042,7 @@ def _decode_short_client_token(self, token, supposed_signature): @classmethod def numericdate(cls, d): """Turn a datetime object into a NumericDate as per RFC 7519.""" - return (d-cls.EPOCH).total_seconds() + return (d - cls.EPOCH).total_seconds() def migrate_adobe_id(self, patron): """If the given patron has an Adobe ID stored as a Credential, also @@ -1045,8 +1056,10 @@ def migrate_adobe_id(self, patron): _db = Session.object_session(patron) credential = get_one( - _db, Credential, - patron=patron, type=AdobeVendorIDModel.VENDOR_ID_UUID_TOKEN_TYPE + _db, + Credential, + patron=patron, + type=AdobeVendorIDModel.VENDOR_ID_UUID_TOKEN_TYPE, ) if not credential: # This patron has no Adobe ID. Do nothing. @@ -1054,8 +1067,8 @@ def migrate_adobe_id(self, patron): adobe_id = credential.credential # Create a new Credential containing an anonymized patron ID. - patron_identifier_credential = AdobeVendorIDModel.get_or_create_patron_identifier_credential( - patron + patron_identifier_credential = ( + AdobeVendorIDModel.get_or_create_patron_identifier_credential(patron) ) # Then create a DelegatedPatronIdentifier mapping that @@ -1066,29 +1079,32 @@ def create_function(): want to store it in the DPI. """ return adobe_id + delegated_identifier, is_new = DelegatedPatronIdentifier.get_one_or_create( - _db, self.library_uri, patron_identifier_credential.credential, - DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID, create_function + _db, + self.library_uri, + patron_identifier_credential.credential, + DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID, + create_function, ) return patron_identifier_credential, delegated_identifier class VendorIDLibraryConfigurationScript(Script): - @classmethod def arg_parser(cls): parser = argparse.ArgumentParser() parser.add_argument( - '--website-url', - help="The URL to this library's patron-facing website (not their circulation manager), e.g. \"https://nypl.org/\". This is used to uniquely identify a library." + "--website-url", + help='The URL to this library\'s patron-facing website (not their circulation manager), e.g. "https://nypl.org/". This is used to uniquely identify a library.', ) parser.add_argument( - '--short-name', - help="The short name the library will use in Short Client Tokens, e.g. \"NYNYPL\"." + "--short-name", + help='The short name the library will use in Short Client Tokens, e.g. "NYNYPL".', ) parser.add_argument( - '--secret', - help="The secret the library will use to sign Short Client Tokens." + "--secret", + help="The secret the library will use to sign Short Client Tokens.", ) return parser @@ -1098,19 +1114,19 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout): default_library = Library.default(_db) adobe_integration = ExternalIntegration.lookup( - _db, ExternalIntegration.ADOBE_VENDOR_ID, - ExternalIntegration.DRM_GOAL, library=default_library + _db, + ExternalIntegration.ADOBE_VENDOR_ID, + ExternalIntegration.DRM_GOAL, + library=default_library, ) if not adobe_integration: output.write( - "Could not find an Adobe Vendor ID integration for default library %s.\n" % - default_library.short_name + "Could not find an Adobe Vendor ID integration for default library %s.\n" + % default_library.short_name ) return - setting = adobe_integration.setting( - AuthdataUtility.OTHER_LIBRARIES_KEY - ) + setting = adobe_integration.setting(AuthdataUtility.OTHER_LIBRARIES_KEY) other_libraries = setting.json_value chosen_website = args.website_url @@ -1119,12 +1135,14 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout): self.explain(output, other_libraries, website) return - if (not args.short_name and not args.secret): + if not args.short_name and not args.secret: self.explain(output, other_libraries, chosen_website) return if not args.short_name or not args.secret: - output.write("To configure a library you must provide both --short_name and --secret.\n") + output.write( + "To configure a library you must provide both --short_name and --secret.\n" + ) return # All three arguments are specified. Set or modify the library's @@ -1134,9 +1152,8 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout): else: what = "set" output.write( - "About to %s the Short Client Token configuration for %s.\n" % ( - what, chosen_website - ) + "About to %s the Short Client Token configuration for %s.\n" + % (what, chosen_website) ) if chosen_website in other_libraries: output.write("Old configuration:\n") @@ -1159,26 +1176,25 @@ def explain(self, output, libraries, website): class ShortClientTokenLibraryConfigurationScript(Script): - @classmethod def arg_parser(cls): parser = argparse.ArgumentParser() parser.add_argument( - '--website-url', - help="The URL to this library's patron-facing website (not their circulation manager), e.g. \"https://nypl.org/\". This is used to uniquely identify a library.", + "--website-url", + help='The URL to this library\'s patron-facing website (not their circulation manager), e.g. "https://nypl.org/". This is used to uniquely identify a library.', required=True, ) parser.add_argument( - '--vendor-id', + "--vendor-id", help="The name of the vendor ID the library will use. The default of 'NYPL' is probably what you want.", - default='NYPL' + default="NYPL", ) parser.add_argument( - '--short-name', - help="The short name the library will use in Short Client Tokens, e.g. \"NYBPL\".", + "--short-name", + help='The short name the library will use in Short Client Tokens, e.g. "NYBPL".', ) parser.add_argument( - '--secret', + "--secret", help="The secret the library will use to sign Short Client Tokens.", ) return parser @@ -1188,41 +1204,41 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout): args = self.parse_command_line(self._db, cmd_args=cmd_args) self.set_secret( - _db, args.website_url, args.vendor_id, args.short_name, - args.secret, output + _db, args.website_url, args.vendor_id, args.short_name, args.secret, output ) _db.commit() - def set_secret(self, _db, website_url, vendor_id, short_name, - secret, output): + def set_secret(self, _db, website_url, vendor_id, short_name, secret, output): # Look up a library by its url setting. library_setting = get_one( - _db, ConfigurationSetting, + _db, + ConfigurationSetting, key=Configuration.WEBSITE_URL, value=website_url, ) if not library_setting: - available_urls = _db.query( - ConfigurationSetting - ).filter( - ConfigurationSetting.key==Configuration.WEBSITE_URL - ).filter( - ConfigurationSetting.library!=None + available_urls = ( + _db.query(ConfigurationSetting) + .filter(ConfigurationSetting.key == Configuration.WEBSITE_URL) + .filter(ConfigurationSetting.library != None) ) raise Exception( - "Could not locate library with URL %s. Available URLs: %s" % - (website_url, ",".join(x.value for x in available_urls)) + "Could not locate library with URL %s. Available URLs: %s" + % (website_url, ",".join(x.value for x in available_urls)) ) library = library_setting.library integration = ExternalIntegration.lookup( - _db, ExternalIntegration.OPDS_REGISTRATION, - ExternalIntegration.DISCOVERY_GOAL, library=library + _db, + ExternalIntegration.OPDS_REGISTRATION, + ExternalIntegration.DISCOVERY_GOAL, + library=library, ) if not integration: integration, ignore = create( - _db, ExternalIntegration, + _db, + ExternalIntegration, protocol=ExternalIntegration.OPDS_REGISTRATION, - goal=ExternalIntegration.DISCOVERY_GOAL + goal=ExternalIntegration.DISCOVERY_GOAL, ) library.integrations.append(integration) @@ -1239,10 +1255,7 @@ def set_secret(self, _db, website_url, vendor_id, short_name, username_s.value = short_name password_s.value = secret - output.write( - "Current Short Client Token configuration for %s:\n" - % website_url - ) + output.write("Current Short Client Token configuration for %s:\n" % website_url) output.write(" Vendor ID: %s\n" % vendor_id_s.value) output.write(" Library name: %s\n" % username_s.value) output.write(" Shared secret: %s\n" % password_s.value) diff --git a/api/annotations.py b/api/annotations.py index 0e6c714f9d..0d029a43c6 100644 --- a/api/annotations.py +++ b/api/annotations.py @@ -1,43 +1,36 @@ -from pyld import jsonld import json import os -from core.model import ( - Annotation, - Identifier, - get_one_or_create, -) +from pyld import jsonld -from core.app_server import ( - url_for, -) +from core.app_server import url_for +from core.model import Annotation, Identifier, get_one_or_create from core.util.datetime_helpers import utc_now from .problem_details import * + def load_document(url): """Retrieves JSON-LD for the given URL from a local file if available, and falls back to the network. """ files = { AnnotationWriter.JSONLD_CONTEXT: "anno.jsonld", - AnnotationWriter.LDP_CONTEXT: "ldp.jsonld" + AnnotationWriter.LDP_CONTEXT: "ldp.jsonld", } if url in files: - base_path = os.path.join(os.path.split(__file__)[0], 'jsonld') + base_path = os.path.join(os.path.split(__file__)[0], "jsonld") jsonld_file = os.path.join(base_path, files[url]) data = open(jsonld_file).read() - doc = { - "contextUrl": None, - "documentUrl": url, - "document": data - } + doc = {"contextUrl": None, "documentUrl": url, "document": data} return doc else: return jsonld.load_document(url) + jsonld.set_document_loader(load_document) + class AnnotationWriter(object): CONTENT_TYPE = 'application/ld+json; profile="http://www.w3.org/ns/anno.jsonld"' @@ -47,21 +40,33 @@ class AnnotationWriter(object): @classmethod def annotations_for(cls, patron, identifier=None): - annotations = [annotation for annotation in patron.annotations if annotation.active] + annotations = [ + annotation for annotation in patron.annotations if annotation.active + ] if identifier: - annotations = [annotation for annotation in annotations if annotation.identifier == identifier] + annotations = [ + annotation + for annotation in annotations + if annotation.identifier == identifier + ] return annotations @classmethod def annotation_container_for(cls, patron, identifier=None): if identifier: - url = url_for('annotations_for_work', - identifier_type=identifier.type, - identifier=identifier.identifier, - library_short_name=patron.library.short_name, - _external=True) + url = url_for( + "annotations_for_work", + identifier_type=identifier.type, + identifier=identifier.identifier, + library_short_name=patron.library.short_name, + _external=True, + ) else: - url = url_for("annotations", library_short_name=patron.library.short_name, _external=True) + url = url_for( + "annotations", + library_short_name=patron.library.short_name, + _external=True, + ) annotations = cls.annotations_for(patron, identifier=identifier) latest_timestamp = None @@ -75,22 +80,32 @@ def annotation_container_for(cls, patron, identifier=None): container["id"] = url container["type"] = ["BasicContainer", "AnnotationCollection"] container["total"] = len(annotations) - container["first"] = cls.annotation_page_for(patron, identifier=identifier, with_context=False) + container["first"] = cls.annotation_page_for( + patron, identifier=identifier, with_context=False + ) return container, latest_timestamp - @classmethod def annotation_page_for(cls, patron, identifier=None, with_context=True): if identifier: - url = url_for('annotations_for_work', - identifier_type=identifier.type, - identifier=identifier.identifier, - library_short_name=patron.library.short_name, - _external=True) + url = url_for( + "annotations_for_work", + identifier_type=identifier.type, + identifier=identifier.identifier, + library_short_name=patron.library.short_name, + _external=True, + ) else: - url = url_for("annotations", library_short_name=patron.library.short_name, _external=True) + url = url_for( + "annotations", + library_short_name=patron.library.short_name, + _external=True, + ) annotations = cls.annotations_for(patron, identifier=identifier) - details = [cls.detail(annotation, with_context=with_context) for annotation in annotations] + details = [ + cls.detail(annotation, with_context=with_context) + for annotation in annotations + ] page = dict() if with_context: @@ -105,9 +120,12 @@ def detail(cls, annotation, with_context=True): item = dict() if with_context: item["@context"] = cls.JSONLD_CONTEXT - item["id"] = url_for("annotation_detail", annotation_id=annotation.id, - library_short_name=annotation.patron.library.short_name, - _external=True) + item["id"] = url_for( + "annotation_detail", + annotation_id=annotation.id, + library_short_name=annotation.patron.library.short_name, + _external=True, + ) item["type"] = "Annotation" item["motivation"] = annotation.motivation item["body"] = annotation.content @@ -124,8 +142,8 @@ def detail(cls, annotation, with_context=True): return item -class AnnotationParser(object): +class AnnotationParser(object): @classmethod def parse(cls, _db, data, patron): if patron.synchronize_annotations != True: @@ -133,8 +151,8 @@ def parse(cls, _db, data, patron): try: data = json.loads(data) - if 'id' in data and data['id'] is None: - del data['id'] + if "id" in data and data["id"] is None: + del data["id"] data = jsonld.expand(data) except ValueError as e: return INVALID_ANNOTATION_FORMAT @@ -152,7 +170,7 @@ def parse(cls, _db, data, patron): if not source or not len(source) == 1: return INVALID_ANNOTATION_TARGET - source = source[0].get('@id') + source = source[0].get("@id") try: identifier, ignore = Identifier.parse_urn(_db, source) @@ -162,7 +180,7 @@ def parse(cls, _db, data, patron): motivation = data.get("http://www.w3.org/ns/oa#motivatedBy") if not motivation or not len(motivation) == 1: return INVALID_ANNOTATION_MOTIVATION - motivation = motivation[0].get('@id') + motivation = motivation[0].get("@id") if motivation not in Annotation.MOTIVATIONS: return INVALID_ANNOTATION_MOTIVATION @@ -185,11 +203,14 @@ def parse(cls, _db, data, patron): elif motivation == Annotation.BOOKMARKING: # A given book can only have one 'bookmarking' annotation # per target. - extra_kwargs['target'] = target + extra_kwargs["target"] = target annotation, ignore = Annotation.get_one_or_create( - _db, patron=patron, identifier=identifier, - motivation=motivation, on_multiple='interchangeable', + _db, + patron=patron, + identifier=identifier, + motivation=motivation, + on_multiple="interchangeable", **extra_kwargs ) annotation.target = target diff --git a/api/announcements.py b/api/announcements.py index 1a72b87a14..467e8f4c9c 100644 --- a/api/announcements.py +++ b/api/announcements.py @@ -4,12 +4,14 @@ from .admin.announcement_list_validator import AnnouncementListValidator + class Announcements(object): """Data model class for a library's announcements. This entire list is stored as a single ConfigurationSetting, which is why this isn't in core/model. """ + SETTING_NAME = "announcements" @classmethod @@ -49,6 +51,7 @@ def active(self): class Announcement(object): """Data model class for a single library-wide announcement.""" + def __init__(self, **kwargs): """Instantiate an Announcement from a dictionary of data. @@ -61,16 +64,17 @@ def __init__(self, **kwargs): :param finish: The date (relative to the time zone of the server) on which the announcement should stop being published. """ - self.id = kwargs.pop('id') - self.content = kwargs.pop('content') - self.start = AnnouncementListValidator.validate_date("", kwargs.pop('start')) - self.finish = AnnouncementListValidator.validate_date("", kwargs.pop('finish')) + self.id = kwargs.pop("id") + self.content = kwargs.pop("content") + self.start = AnnouncementListValidator.validate_date("", kwargs.pop("start")) + self.finish = AnnouncementListValidator.validate_date("", kwargs.pop("finish")) @property def json_ready(self): format = AnnouncementListValidator.DATE_FORMAT return dict( - id=self.id, content=self.content, + id=self.id, + content=self.content, start=datetime.datetime.strftime(self.start, format), finish=datetime.datetime.strftime(self.finish, format), ) diff --git a/api/app.py b/api/app.py index ae1e25c841..69ec0e2813 100644 --- a/api/app.py +++ b/api/app.py @@ -1,34 +1,31 @@ -import os import logging +import os import urllib.parse import flask -from flask import ( - Flask, - Response, - redirect, -) +from flask import Flask, Response, redirect +from flask_babel import Babel from flask_sqlalchemy_session import flask_scoped_session from sqlalchemy.orm import sessionmaker -from .config import Configuration -from core.model import ( - ConfigurationSetting, - SessionManager, -) + from core.log import LogConfiguration +from core.model import ConfigurationSetting, SessionManager from core.util import LanguageCodes -from flask_babel import Babel +from .config import Configuration app = Flask(__name__) app._db = None -app.config['BABEL_DEFAULT_LOCALE'] = LanguageCodes.three_to_two[Configuration.localization_languages()[0]] -app.config['BABEL_TRANSLATION_DIRECTORIES'] = "../translations" +app.config["BABEL_DEFAULT_LOCALE"] = LanguageCodes.three_to_two[ + Configuration.localization_languages()[0] +] +app.config["BABEL_TRANSLATION_DIRECTORIES"] = "../translations" babel = Babel(app) + @app.before_first_request def initialize_database(autoinitialize=True): - testing = 'TESTING' in os.environ + testing = "TESTING" in os.environ db_url = Configuration.database_url() if autoinitialize: @@ -38,27 +35,29 @@ def initialize_database(autoinitialize=True): app._db = _db log_level = LogConfiguration.initialize(_db, testing=testing) - debug = log_level == 'DEBUG' - app.config['DEBUG'] = debug + debug = log_level == "DEBUG" + app.config["DEBUG"] = debug app.debug = debug _db.commit() logging.getLogger().info("Application debug mode==%r" % app.debug) + from . import routes from .admin import routes + def run(url=None): - base_url = url or 'http://localhost:6500/' + base_url = url or "http://localhost:6500/" scheme, netloc, path, parameters, query, fragment = urllib.parse.urlparse(base_url) - if ':' in netloc: - host, port = netloc.split(':') + if ":" in netloc: + host, port = netloc.split(":") port = int(port) else: host = netloc port = 80 # Required for subdomain support. - app.config['SERVER_NAME'] = netloc + app.config["SERVER_NAME"] = netloc debug = True @@ -66,8 +65,9 @@ def run(url=None): # running in debug mode with the global socket timeout set by isbnlib if debug: import socket + socket.setdefaulttimeout(None) logging.info("Starting app on %s:%s", host, port) - sslContext = 'adhoc' if scheme == 'https' else None + sslContext = "adhoc" if scheme == "https" else None app.run(debug=debug, host=host, port=port, threaded=True, ssl_context=sslContext) diff --git a/api/authenticator.py b/api/authenticator.py index 8264a86303..47007f6769 100644 --- a/api/authenticator.py +++ b/api/authenticator.py @@ -3,14 +3,14 @@ import json import logging import re -import urllib.request, urllib.parse, urllib.error +import urllib.error +import urllib.parse +import urllib.request from abc import ABCMeta import flask import jwt -from flask import ( - redirect, - url_for) +from flask import redirect, url_for from flask_babel import lazy_gettext as _ from money import Money from sqlalchemy.ext.hybrid import hybrid_property @@ -24,14 +24,7 @@ from api.custom_patron_catalog import CustomPatronCatalog from api.opds import LibraryAnnotator from api.saml.configuration.model import SAMLSettings -from .config import ( - Configuration, - CannotLoadConfiguration, - IntegrationException, -) from core.model import ( - get_one, - get_one_or_create, CirculationEvent, ConfigurationSetting, Credential, @@ -41,11 +34,11 @@ Patron, PatronProfileStorage, Session, + get_one, + get_one_or_create, ) from core.opds import OPDSFeed -from core.selftest import ( - HasSelfTests, -) +from core.selftest import HasSelfTests from core.user_profile import ProfileController from core.util.authentication_for_opds import ( AuthenticationForOPDSDocument, @@ -53,10 +46,10 @@ ) from core.util.datetime_helpers import utc_now from core.util.http import RemoteIntegrationException -from core.util.problem_detail import ( - ProblemDetail, - json as pd_json, -) +from core.util.problem_detail import ProblemDetail +from core.util.problem_detail import json as pd_json + +from .config import CannotLoadConfiguration, Configuration, IntegrationException from .problem_details import * from .util.patron import PatronUtility @@ -86,42 +79,44 @@ class NoValue(object): def __bool__(self): """We want this object to act like None or False.""" return False + NO_VALUE = NoValue() # Reasons why a patron might be blocked. - UNKNOWN_BLOCK = 'unknown' - CARD_REPORTED_LOST = 'card reported lost' - EXCESSIVE_FINES = 'excessive fines' - EXCESSIVE_FEES = 'excessive fees' - NO_BORROWING_PRIVILEGES = 'no borrowing privileges' - TOO_MANY_LOANS = 'too many active loans' - TOO_MANY_RENEWALS = 'too many renewals' - TOO_MANY_OVERDUE = 'too many items overdue' - TOO_MANY_LOST = 'too many items lost' + UNKNOWN_BLOCK = "unknown" + CARD_REPORTED_LOST = "card reported lost" + EXCESSIVE_FINES = "excessive fines" + EXCESSIVE_FEES = "excessive fees" + NO_BORROWING_PRIVILEGES = "no borrowing privileges" + TOO_MANY_LOANS = "too many active loans" + TOO_MANY_RENEWALS = "too many renewals" + TOO_MANY_OVERDUE = "too many items overdue" + TOO_MANY_LOST = "too many items lost" # Patron is being billed for too many items (as opposed to # excessive fines, which means patron's fines have exceeded a # certain amount). - TOO_MANY_ITEMS_BILLED = 'too many items billed' + TOO_MANY_ITEMS_BILLED = "too many items billed" # Patron was asked to return an item so someone else could borrow it, # but didn't return the item. - RECALL_OVERDUE = 'recall overdue' - - def __init__(self, - permanent_id=None, - authorization_identifier=None, - username=None, - personal_name=None, - email_address=None, - authorization_expires=None, - external_type=None, - fines=None, - block_reason=None, - library_identifier=None, - neighborhood=None, - cached_neighborhood=None, - complete=True, + RECALL_OVERDUE = "recall overdue" + + def __init__( + self, + permanent_id=None, + authorization_identifier=None, + username=None, + personal_name=None, + email_address=None, + authorization_expires=None, + external_type=None, + fines=None, + block_reason=None, + library_identifier=None, + neighborhood=None, + cached_neighborhood=None, + complete=True, ): """Store basic information about a patron. @@ -240,24 +235,25 @@ def __eq__(self, other): if not isinstance(other, PatronData): return False - return \ - self.permanent_id == other.permanent_id and \ - self.username == other.username and \ - self.authorization_expires == other.authorization_expires and \ - self.external_type == other.external_type and \ - self.fines == other.fines and \ - self.block_reason == other.block_reason and \ - self.library_identifier == other.library_identifier and \ - self.complete == other.complete and \ - self.personal_name == other.personal_name and \ - self.email_address == other.email_address and \ - self.neighborhood == other.neighborhood and \ - self.cached_neighborhood == other.cached_neighborhood + return ( + self.permanent_id == other.permanent_id + and self.username == other.username + and self.authorization_expires == other.authorization_expires + and self.external_type == other.external_type + and self.fines == other.fines + and self.block_reason == other.block_reason + and self.library_identifier == other.library_identifier + and self.complete == other.complete + and self.personal_name == other.personal_name + and self.email_address == other.email_address + and self.neighborhood == other.neighborhood + and self.cached_neighborhood == other.cached_neighborhood + ) def __repr__(self): - return "" % ( - self.permanent_id, self.authorization_identifier, - self.username + return ( + "" + % (self.permanent_id, self.authorization_identifier, self.username) ) @hybrid_property @@ -280,14 +276,13 @@ def apply(self, patron): # First, handle the easy stuff -- everything except authorization # identifier. - self.set_value(patron, 'external_identifier', self.permanent_id) - self.set_value(patron, 'username', self.username) - self.set_value(patron, 'external_type', self.external_type) - self.set_value(patron, 'authorization_expires', - self.authorization_expires) - self.set_value(patron, 'fines', self.fines) - self.set_value(patron, 'block_reason', self.block_reason) - self.set_value(patron, 'cached_neighborhood', self.cached_neighborhood) + self.set_value(patron, "external_identifier", self.permanent_id) + self.set_value(patron, "username", self.username) + self.set_value(patron, "external_type", self.external_type) + self.set_value(patron, "authorization_expires", self.authorization_expires) + self.set_value(patron, "fines", self.fines) + self.set_value(patron, "block_reason", self.block_reason) + self.set_value(patron, "cached_neighborhood", self.cached_neighborhood) # Patron neighborhood (not a database field) is set as a # convenience. @@ -298,13 +293,15 @@ def apply(self, patron): # We have a complete picture of data from the ILS, # so we can be comfortable setting the authorization # identifier if necessary. - if (patron.authorization_identifier is None or - patron.authorization_identifier not in - self.authorization_identifiers): + if ( + patron.authorization_identifier is None + or patron.authorization_identifier not in self.authorization_identifiers + ): # The patron's authorization_identifier is not set, or is # set to a value that is no longer valid. Set it again. - self.set_value(patron, 'authorization_identifier', - self.authorization_identifier) + self.set_value( + patron, "authorization_identifier", self.authorization_identifier + ) elif patron.authorization_identifier != self.authorization_identifier: # It looks like we need to change # Patron.authorization_identifier. However, we do not @@ -315,8 +312,9 @@ def apply(self, patron): # However, we can provisionally # Patron.authorization_identifier if it's not already set. if not patron.authorization_identifier: - self.set_value(patron, 'authorization_identifier', - self.authorization_identifier) + self.set_value( + patron, "authorization_identifier", self.authorization_identifier + ) if patron.username and self.authorization_identifier == patron.username: # This should be fine. It looks like the patron's @@ -385,22 +383,19 @@ def get_or_create_patron(self, _db, library_id, analytics=None): elif self.username: search_by = dict(username=self.username) elif self.authorization_identifier: - search_by = dict( - authorization_identifier=self.authorization_identifier - ) + search_by = dict(authorization_identifier=self.authorization_identifier) else: raise CannotCreateLocalPatron( "Cannot create patron without some way of identifying them uniquely." ) - search_by['library_id'] = library_id + search_by["library_id"] = library_id __transaction = _db.begin_nested() patron, is_new = get_one_or_create(_db, Patron, **search_by) if is_new and analytics: # Send out an analytics event to record the fact # that a new patron was created. - analytics.collect_event(patron.library, None, - CirculationEvent.NEW_PATRON) + analytics.collect_event(patron.library, None, CirculationEvent.NEW_PATRON) # This makes sure the Patron is brought into sync with the # other fields of this PatronData object, regardless of @@ -428,10 +423,12 @@ def to_dict(self): """Convert the information in this PatronData to a dictionary which can be converted to JSON and sent out to a client. """ + def scrub(value, default=None): if value is self.NO_VALUE: return default return value + data = dict( permanent_id=self.permanent_id, authorization_identifier=self.authorization_identifier, @@ -439,7 +436,7 @@ def scrub(value, default=None): external_type=self.external_type, block_reason=self.block_reason, personal_name=self.personal_name, - email_address = self.email_address + email_address=self.email_address, ) data = dict((k, scrub(v)) for k, v in list(data.items())) @@ -449,18 +446,16 @@ def scrub(value, default=None): expires = scrub(self.authorization_expires) if expires: expires = self.authorization_expires.strftime("%Y-%m-%d") - data['authorization_expires'] = expires + data["authorization_expires"] = expires # A Money fines = scrub(self.fines) if fines is not None: fines = str(fines) - data['fines'] = fines + data["fines"] = fines # A list - data['authorization_identifiers'] = scrub( - self.authorization_identifiers, [] - ) + data["authorization_identifiers"] = scrub(self.authorization_identifiers, []) return data def set_authorization_identifier(self, authorization_identifier): @@ -484,8 +479,10 @@ def set_authorization_identifier(self, authorization_identifier): self.authorization_identifier = authorization_identifier self.authorization_identifiers = authorization_identifiers + class CirculationPatronProfileStorage(PatronProfileStorage): """A patron profile storage that can also provide short client tokens""" + @property def profile_document(self): doc = super(CirculationPatronProfileStorage, self).profile_document @@ -497,34 +494,42 @@ def profile_document(self): if authdata: vendor_id, token = authdata.short_client_token_for_patron(self.patron) adobe_drm = {} - adobe_drm['drm:vendor'] = vendor_id - adobe_drm['drm:clientToken'] = token - adobe_drm['drm:scheme'] = "http://librarysimplified.org/terms/drm/scheme/ACS" + adobe_drm["drm:vendor"] = vendor_id + adobe_drm["drm:clientToken"] = token + adobe_drm[ + "drm:scheme" + ] = "http://librarysimplified.org/terms/drm/scheme/ACS" drm.append(adobe_drm) - device_link['rel'] = 'http://librarysimplified.org/terms/drm/rel/devices' - device_link['href'] = self.url_for( - "adobe_drm_devices", library_short_name=self.patron.library.short_name, _external=True + device_link["rel"] = "http://librarysimplified.org/terms/drm/rel/devices" + device_link["href"] = self.url_for( + "adobe_drm_devices", + library_short_name=self.patron.library.short_name, + _external=True, ) links.append(device_link) annotations_link = dict( rel="http://www.w3.org/ns/oa#annotationService", type=AnnotationWriter.CONTENT_TYPE, - href=self.url_for('annotations', library_short_name=self.patron.library.short_name, _external=True) + href=self.url_for( + "annotations", + library_short_name=self.patron.library.short_name, + _external=True, + ), ) links.append(annotations_link) - doc['links'] = links + doc["links"] = links if drm: - doc['drm'] = drm + doc["drm"] = drm return doc + class Authenticator(object): - """Route requests to the appropriate LibraryAuthenticator. - """ + """Route requests to the appropriate LibraryAuthenticator.""" def __init__(self, _db, analytics=None): self.library_authenticators = {} @@ -537,13 +542,17 @@ def current_library_short_name(self): def populate_authenticators(self, _db, analytics): for library in _db.query(Library): - self.library_authenticators[library.short_name] = LibraryAuthenticator.from_config(_db, library, analytics) + self.library_authenticators[ + library.short_name + ] = LibraryAuthenticator.from_config(_db, library, analytics) def invoke_authenticator_method(self, method_name, *args, **kwargs): short_name = self.current_library_short_name if short_name not in self.library_authenticators: return LIBRARY_NOT_FOUND - return getattr(self.library_authenticators[short_name], method_name)(*args, **kwargs) + return getattr(self.library_authenticators[short_name], method_name)( + *args, **kwargs + ) def authenticated_patron(self, _db, header): return self.invoke_authenticator_method("authenticated_patron", _db, header) @@ -558,9 +567,7 @@ def get_credential_from_header(self, header): return self.invoke_authenticator_method("get_credential_from_header", header) def create_bearer_token(self, *args, **kwargs): - return self.invoke_authenticator_method( - "create_bearer_token", *args, **kwargs - ) + return self.invoke_authenticator_method("create_bearer_token", *args, **kwargs) def oauth_provider_lookup(self, *args, **kwargs): return self.invoke_authenticator_method( @@ -568,14 +575,10 @@ def oauth_provider_lookup(self, *args, **kwargs): ) def saml_provider_lookup(self, *args, **kwargs): - return self.invoke_authenticator_method( - "saml_provider_lookup", *args, **kwargs - ) + return self.invoke_authenticator_method("saml_provider_lookup", *args, **kwargs) def decode_bearer_token(self, *args, **kwargs): - return self.invoke_authenticator_method( - "decode_bearer_token", *args, **kwargs - ) + return self.invoke_authenticator_method("decode_bearer_token", *args, **kwargs) class LibraryAuthenticator(object): @@ -584,7 +587,9 @@ class LibraryAuthenticator(object): """ @classmethod - def from_config(cls, _db, library, analytics=None, custom_catalog_source=CustomPatronCatalog): + def from_config( + cls, _db, library, analytics=None, custom_catalog_source=CustomPatronCatalog + ): """Initialize an Authenticator for the given Library based on its configured ExternalIntegrations. @@ -596,8 +601,7 @@ def from_config(cls, _db, library, analytics=None, custom_catalog_source=CustomP # Start with an empty list of authenticators. authenticator = cls( - _db=_db, library=library, - authentication_document_annotator=custom_catalog + _db=_db, library=library, authentication_document_annotator=custom_catalog ) # Find all of this library's ExternalIntegrations set up with @@ -615,29 +619,37 @@ def from_config(cls, _db, library, analytics=None, custom_catalog_source=CustomP # by misconfiguration, as opposed to bad code. logging.error( "Error registering authentication provider %r (%s)", - integration.name, integration.protocol, - exc_info=e + integration.name, + integration.protocol, + exc_info=e, ) authenticator.initialization_exceptions[integration.id] = e - if authenticator.oauth_providers_by_name or authenticator.saml_providers_by_name: + if ( + authenticator.oauth_providers_by_name + or authenticator.saml_providers_by_name + ): # NOTE: this will immediately commit the database session, # which may not be what you want during a test. To avoid # this, you can create the bearer token signing secret as # a regular site-wide ConfigurationSetting. - authenticator.bearer_token_signing_secret = BearerTokenSigner.bearer_token_signing_secret( - _db + authenticator.bearer_token_signing_secret = ( + BearerTokenSigner.bearer_token_signing_secret(_db) ) authenticator.assert_ready_for_token_signing() return authenticator - def __init__(self, _db, library, basic_auth_provider=None, - oauth_providers=None, - saml_providers=None, - bearer_token_signing_secret=None, - authentication_document_annotator=None, + def __init__( + self, + _db, + library, + basic_auth_provider=None, + oauth_providers=None, + saml_providers=None, + bearer_token_signing_secret=None, + authentication_document_annotator=None, ): """Initialize a LibraryAuthenticator from a list of AuthenticationProviders. @@ -665,7 +677,7 @@ def __init__(self, _db, library, basic_auth_provider=None, self.library_uuid = library.uuid self.library_name = library.name self.library_short_name = library.short_name - self.authentication_document_annotator=authentication_document_annotator + self.authentication_document_annotator = authentication_document_annotator self.basic_auth_provider = basic_auth_provider self.oauth_providers_by_name = dict() @@ -692,7 +704,11 @@ def __init__(self, _db, library, basic_auth_provider=None, @property def supports_patron_authentication(self): """Does this library have any way of authenticating patrons at all?""" - if self.basic_auth_provider or self.oauth_providers_by_name or self.saml_providers_by_name: + if ( + self.basic_auth_provider + or self.oauth_providers_by_name + or self.saml_providers_by_name + ): return True return False @@ -711,9 +727,7 @@ def identifies_individuals(self): if not self.supports_patron_authentication: return False matches = list(self.providers) - return matches and all( - [x.IDENTIFIES_INDIVIDUALS for x in matches] - ) + return matches and all([x.IDENTIFIES_INDIVIDUALS for x in matches]) @property def library(self): @@ -725,12 +739,16 @@ def assert_ready_for_token_signing(self): """ if self.oauth_providers_by_name and not self.bearer_token_signing_secret: raise CannotLoadConfiguration( - _("OAuth providers are configured, but secret for signing bearer tokens is not.") + _( + "OAuth providers are configured, but secret for signing bearer tokens is not." + ) ) if self.saml_providers_by_name and not self.bearer_token_signing_secret: raise CannotLoadConfiguration( - _("SAML providers are configured, but secret for signing bearer tokens is not.") + _( + "SAML providers are configured, but secret for signing bearer tokens is not." + ) ) def register_provider(self, integration, analytics=None): @@ -742,13 +760,15 @@ def register_provider(self, integration, analytics=None): """ if integration.goal != integration.PATRON_AUTH_GOAL: raise CannotLoadConfiguration( - "Was asked to register an integration with goal=%s as though it were a way of authenticating patrons." % integration.goal + "Was asked to register an integration with goal=%s as though it were a way of authenticating patrons." + % integration.goal ) library = self.library if library not in integration.libraries: raise CannotLoadConfiguration( - "Was asked to register an integration with library %s, which doesn't use it." % library.name + "Was asked to register an integration with library %s, which doesn't use it." + % library.name ) module_name = integration.protocol @@ -761,15 +781,15 @@ def register_provider(self, integration, analytics=None): provider_class = getattr(provider_module, "AuthenticationProvider", None) if not provider_class: raise CannotLoadConfiguration( - "Loaded module %s but could not find a class called AuthenticationProvider inside." % module_name + "Loaded module %s but could not find a class called AuthenticationProvider inside." + % module_name ) try: provider = provider_class(self.library, integration, analytics) except RemoteIntegrationException as e: raise CannotLoadConfiguration( - "Could not instantiate %s authentication provider for library %s, possibly due to a network connection problem." % ( - provider_class, self.library.short_name - ) + "Could not instantiate %s authentication provider for library %s, possibly due to a network connection problem." + % (provider_class, self.library.short_name) ) return if issubclass(provider_class, BasicAuthenticationProvider): @@ -782,38 +802,28 @@ def register_provider(self, integration, analytics=None): self.register_saml_provider(provider) else: raise CannotLoadConfiguration( - "Authentication provider %s is neither a BasicAuthenticationProvider nor an OAuthAuthenticationProvider. I can create it, but not sure where to put it." % provider_class + "Authentication provider %s is neither a BasicAuthenticationProvider nor an OAuthAuthenticationProvider. I can create it, but not sure where to put it." + % provider_class ) def register_basic_auth_provider(self, provider): - if (self.basic_auth_provider != None - and self.basic_auth_provider != provider): - raise CannotLoadConfiguration( - "Two basic auth providers configured" - ) + if self.basic_auth_provider != None and self.basic_auth_provider != provider: + raise CannotLoadConfiguration("Two basic auth providers configured") self.basic_auth_provider = provider def register_oauth_provider(self, provider): - already_registered = self.oauth_providers_by_name.get( - provider.NAME - ) + already_registered = self.oauth_providers_by_name.get(provider.NAME) if already_registered and already_registered != provider: raise CannotLoadConfiguration( - 'Two different OAuth providers claim the name "%s"' % ( - provider.NAME - ) + 'Two different OAuth providers claim the name "%s"' % (provider.NAME) ) self.oauth_providers_by_name[provider.NAME] = provider def register_saml_provider(self, provider): - already_registered = self.saml_providers_by_name.get( - provider.NAME - ) + already_registered = self.saml_providers_by_name.get(provider.NAME) if already_registered and already_registered != provider: raise CannotLoadConfiguration( - 'Two different SAML providers claim the name "%s"' % ( - provider.NAME - ) + 'Two different SAML providers claim the name "%s"' % (provider.NAME) ) self.saml_providers_by_name[provider.NAME] = provider @@ -839,14 +849,19 @@ def authenticated_patron(self, _db, header): credentials do not authenticate any particular patron. A ProblemDetail if an error occurs. """ - if (self.basic_auth_provider - and isinstance(header, dict) and 'username' in header): + if ( + self.basic_auth_provider + and isinstance(header, dict) + and "username" in header + ): # The patron wants to authenticate with the # BasicAuthenticationProvider. return self.basic_auth_provider.authenticated_patron(_db, header) - elif (self.oauth_providers_by_name - and isinstance(header, (bytes, str)) - and 'bearer' in header.lower()): + elif ( + self.oauth_providers_by_name + and isinstance(header, (bytes, str)) + and "bearer" in header.lower() + ): # The patron wants to use an # OAuthAuthenticationProvider. Figure out which one. @@ -865,9 +880,11 @@ def authenticated_patron(self, _db, header): # Ask the OAuthAuthenticationProvider to turn its token # into a Patron. return provider.authenticated_patron(_db, provider_token) - elif (self.saml_providers_by_name - and isinstance(header, (bytes, str)) - and 'bearer' in header.lower()): + elif ( + self.saml_providers_by_name + and isinstance(header, (bytes, str)) + and "bearer" in header.lower() + ): # The patron wants to use an # SAMLAuthenticationProvider. Figure out which one. @@ -919,14 +936,13 @@ def oauth_provider_lookup(self, provider_name): _("No OAuth providers are configured.") ) - if (not provider_name - or not provider_name in self.oauth_providers_by_name): + if not provider_name or not provider_name in self.oauth_providers_by_name: # The patron neglected to specify a provider, or specified # one we don't support. possibilities = ", ".join(list(self.oauth_providers_by_name.keys())) return UNKNOWN_OAUTH_PROVIDER.detailed( - UNKNOWN_OAUTH_PROVIDER.detail + - _(" The known providers are: %s") % possibilities + UNKNOWN_OAUTH_PROVIDER.detail + + _(" The known providers are: %s") % possibilities ) return self.oauth_providers_by_name[provider_name] @@ -940,14 +956,13 @@ def saml_provider_lookup(self, provider_name): _("No SAML providers are configured.") ) - if (not provider_name - or not provider_name in self.saml_providers_by_name): + if not provider_name or not provider_name in self.saml_providers_by_name: # The patron neglected to specify a provider, or specified # one we don't support. possibilities = ", ".join(list(self.saml_providers_by_name.keys())) return UNKNOWN_SAML_PROVIDER.detailed( - UNKNOWN_SAML_PROVIDER.detail + - _(" The known providers are: %s") % possibilities + UNKNOWN_SAML_PROVIDER.detail + + _(" The known providers are: %s") % possibilities ) return self.saml_providers_by_name[provider_name] @@ -971,22 +986,23 @@ def create_bearer_token(self, provider_name, provider_token): iss=provider_name, ) return jwt.encode( - payload, self.bearer_token_signing_secret, algorithm='HS256' + payload, self.bearer_token_signing_secret, algorithm="HS256" ).decode("utf-8") def decode_bearer_token_from_header(self, header): """Extract auth provider name and access token from an Authenticate header value. """ - simplified_token = header.split(' ')[1] + simplified_token = header.split(" ")[1] return self.decode_bearer_token(simplified_token) def decode_bearer_token(self, token): """Extract auth provider name and access token from JSON web token.""" - decoded = jwt.decode(token, self.bearer_token_signing_secret, - algorithms=['HS256']) - provider_name = decoded['iss'] - token = decoded['token'] + decoded = jwt.decode( + token, self.bearer_token_signing_secret, algorithms=["HS256"] + ) + provider_name = decoded["iss"] + token = decoded["token"] return (provider_name, token) def authentication_document_url(self, library): @@ -994,8 +1010,9 @@ def authentication_document_url(self, library): given library. """ return url_for( - "authentication_document", library_short_name=library.short_name, - _external=True + "authentication_document", + library_short_name=library.short_name, + _external=True, ) def create_authentication_document(self): @@ -1008,48 +1025,57 @@ def create_authentication_document(self): # Add the same links that we would show in an OPDS feed, plus # some extra like 'registration' that are specific to Authentication # For OPDS. - for rel in (LibraryAnnotator.CONFIGURATION_LINKS + - Configuration.AUTHENTICATION_FOR_OPDS_LINKS): + for rel in ( + LibraryAnnotator.CONFIGURATION_LINKS + + Configuration.AUTHENTICATION_FOR_OPDS_LINKS + ): value = ConfigurationSetting.for_library(rel, library).value if not value: continue link = dict(rel=rel, href=value) - if any(value.startswith(x) for x in ('http:', 'https:')): + if any(value.startswith(x) for x in ("http:", "https:")): # We assume that HTTP URLs lead to HTML, but we don't # assume anything about other URL schemes. - link['type'] = "text/html" + link["type"] = "text/html" links.append(link) # Add a rel="start" link pointing to the root OPDS feed. - index_url = url_for("index", _external=True, - library_short_name=library.short_name) - loans_url = url_for("active_loans", _external=True, - library_short_name=library.short_name) - profile_url = url_for("patron_profile", _external=True, - library_short_name=library.short_name) + index_url = url_for( + "index", _external=True, library_short_name=library.short_name + ) + loans_url = url_for( + "active_loans", _external=True, library_short_name=library.short_name + ) + profile_url = url_for( + "patron_profile", _external=True, library_short_name=library.short_name + ) links.append( - dict(rel="start", href=index_url, - type=OPDSFeed.ACQUISITION_FEED_TYPE) + dict(rel="start", href=index_url, type=OPDSFeed.ACQUISITION_FEED_TYPE) ) links.append( - dict(rel="http://opds-spec.org/shelf", href=loans_url, - type=OPDSFeed.ACQUISITION_FEED_TYPE) + dict( + rel="http://opds-spec.org/shelf", + href=loans_url, + type=OPDSFeed.ACQUISITION_FEED_TYPE, + ) ) links.append( - dict(rel=ProfileController.LINK_RELATION, href=profile_url, - type=ProfileController.MEDIA_TYPE) + dict( + rel=ProfileController.LINK_RELATION, + href=profile_url, + type=ProfileController.MEDIA_TYPE, + ) ) # If there is a Designated Agent email address, add it as a # link. - designated_agent_uri = Configuration.copyright_designated_agent_uri( - library - ) + designated_agent_uri = Configuration.copyright_designated_agent_uri(library) if designated_agent_uri: links.append( - dict(rel=Configuration.COPYRIGHT_DESIGNATED_AGENT_REL, - href=designated_agent_uri + dict( + rel=Configuration.COPYRIGHT_DESIGNATED_AGENT_REL, + href=designated_agent_uri, ) ) @@ -1060,60 +1086,69 @@ def create_authentication_document(self): # Add a link to the web page of the library itself. library_uri = ConfigurationSetting.for_library( - Configuration.WEBSITE_URL, library).value + Configuration.WEBSITE_URL, library + ).value if library_uri: - links.append( - dict(rel="alternate", type="text/html", href=library_uri) - ) + links.append(dict(rel="alternate", type="text/html", href=library_uri)) # Add the library's logo, if it has one. - logo = ConfigurationSetting.for_library( - Configuration.LOGO, library).value + logo = ConfigurationSetting.for_library(Configuration.LOGO, library).value if logo: links.append(dict(rel="logo", type="image/png", href=logo)) # Add the library's custom CSS file, if it has one. css_file = ConfigurationSetting.for_library( - Configuration.WEB_CSS_FILE, library).value + Configuration.WEB_CSS_FILE, library + ).value if css_file: links.append(dict(rel="stylesheet", type="text/css", href=css_file)) library_name = self.library_name or str(_("Library")) auth_doc_url = self.authentication_document_url(library) doc = AuthenticationForOPDSDocument( - id=auth_doc_url, title=library_name, + id=auth_doc_url, + title=library_name, authentication_flows=list(self.providers), - links=links + links=links, ).to_dict(self._db) # Add the library's mobile color scheme, if it has one. description = ConfigurationSetting.for_library( - Configuration.COLOR_SCHEME, library).value + Configuration.COLOR_SCHEME, library + ).value if description: - doc['color_scheme'] = description + doc["color_scheme"] = description # Add the library's web colors, if it has any. primary = ConfigurationSetting.for_library( - Configuration.WEB_PRIMARY_COLOR, library).value + Configuration.WEB_PRIMARY_COLOR, library + ).value secondary = ConfigurationSetting.for_library( - Configuration.WEB_SECONDARY_COLOR, library).value + Configuration.WEB_SECONDARY_COLOR, library + ).value if primary or secondary: - doc["web_color_scheme"] = dict(primary=primary, secondary=secondary, background=primary, foreground=secondary) + doc["web_color_scheme"] = dict( + primary=primary, + secondary=secondary, + background=primary, + foreground=secondary, + ) # Add the description of the library as the OPDS feed's # service_description. description = ConfigurationSetting.for_library( - Configuration.LIBRARY_DESCRIPTION, library).value + Configuration.LIBRARY_DESCRIPTION, library + ).value if description: - doc['service_description'] = description + doc["service_description"] = description # Add the library's focus area and service area, if either is # specified. focus_area, service_area = self._geographic_areas(library) if focus_area: - doc['focus_area'] = focus_area + doc["focus_area"] = focus_area if service_area: - doc['service_area'] = service_area + doc["service_area"] = service_area # Add the library's public key. doc["public_key"] = dict(type="RSA", value=self.public_key) @@ -1127,31 +1162,30 @@ def create_authentication_document(self): else: bucket = disabled bucket.append(Configuration.RESERVATIONS_FEATURE) - doc['features'] = dict(enabled=enabled, disabled=disabled) + doc["features"] = dict(enabled=enabled, disabled=disabled) # Add any active announcements for the library. announcements = [ x.for_authentication_document for x in Announcements.for_library(library).active ] - doc['announcements'] = announcements + doc["announcements"] = announcements # Finally, give the active annotator a chance to modify the document. if self.authentication_document_annotator: - doc = self.authentication_document_annotator.annotate_authentication_document( - library, doc, url_for + doc = ( + self.authentication_document_annotator.annotate_authentication_document( + library, doc, url_for + ) ) return json.dumps(doc) @property def key_pair(self): - """Look up or create a public/private key pair for use by this library. - """ - setting = ConfigurationSetting.for_library( - Configuration.KEY_PAIR, self.library - ) + """Look up or create a public/private key pair for use by this library.""" + setting = ConfigurationSetting.for_library(Configuration.KEY_PAIR, self.library) return Configuration.key_pair(setting) @classmethod @@ -1161,12 +1195,8 @@ def _geographic_areas(cls, library): :param library: A Library :return: A 2-tuple (focus_area, service_area) """ - focus_area = cls._geographic_area( - Configuration.LIBRARY_FOCUS_AREA, library - ) - service_area = cls._geographic_area( - Configuration.LIBRARY_SERVICE_AREA, library - ) + focus_area = cls._geographic_area(Configuration.LIBRARY_FOCUS_AREA, library) + service_area = cls._geographic_area(Configuration.LIBRARY_SERVICE_AREA, library) # If only one value is provided, both values are considered to # be the same. @@ -1186,7 +1216,7 @@ def _geographic_area(cls, key, library): setting = ConfigurationSetting.for_library(key, library).value if not setting: return setting - if setting == 'everywhere': + if setting == "everywhere": # This literal string may be served as is. return setting try: @@ -1204,15 +1234,24 @@ def create_authentication_headers(self): authentication document.""" library = Library.by_id(self._db, self.library_id) headers = Headers() - headers.add('Content-Type', AuthenticationForOPDSDocument.MEDIA_TYPE) - headers.add('Link', "<%s>; rel=%s" % ( - self.authentication_document_url(library), - AuthenticationForOPDSDocument.LINK_RELATION - )) + headers.add("Content-Type", AuthenticationForOPDSDocument.MEDIA_TYPE) + headers.add( + "Link", + "<%s>; rel=%s" + % ( + self.authentication_document_url(library), + AuthenticationForOPDSDocument.LINK_RELATION, + ), + ) # if requested from a web client, don't include WWW-Authenticate header, # which forces the default browser authentication prompt - if self.basic_auth_provider and not flask.request.headers.get("X-Requested-With") == "XMLHttpRequest": - headers.add('WWW-Authenticate', self.basic_auth_provider.authentication_header) + if ( + self.basic_auth_provider + and not flask.request.headers.get("X-Requested-With") == "XMLHttpRequest" + ): + headers.add( + "WWW-Authenticate", self.basic_auth_provider.authentication_header + ) # TODO: We're leaving out headers for other providers to avoid breaking iOS # clients that don't support multiple auth headers. It's not clear what @@ -1224,8 +1263,7 @@ def create_authentication_headers(self): class AuthenticationProvider(OPDSAuthenticationFlow): - """Handle a specific patron authentication scheme. - """ + """Handle a specific patron authentication scheme.""" # NOTE: Each subclass MUST define an attribute called NAME, which # is displayed in the admin interface when configuring patron auth, @@ -1273,76 +1311,105 @@ class AuthenticationProvider(OPDSAuthenticationFlow): # Each library and authentication mechanism may have a regular # expression for deriving a patron's external type from their # authentication identifier. - EXTERNAL_TYPE_REGULAR_EXPRESSION = 'external_type_regular_expression' + EXTERNAL_TYPE_REGULAR_EXPRESSION = "external_type_regular_expression" # When multiple libraries share an ILS, a person may be able to # authenticate with the ILS but not be considered a patron of # _this_ library. This setting contains the rule for determining # whether an identifier is valid for a specific library. - LIBRARY_IDENTIFIER_RESTRICTION_TYPE = 'library_identifier_restriction_type' + LIBRARY_IDENTIFIER_RESTRICTION_TYPE = "library_identifier_restriction_type" # This field lets the user choose the data source for the patron match. - LIBRARY_IDENTIFIER_FIELD = 'library_identifier_field' + LIBRARY_IDENTIFIER_FIELD = "library_identifier_field" # Usually this is a string which is compared against the # patron's identifiers using the comparison method chosen in # LIBRARY_IDENTIFIER_RESTRICTION_TYPE. - LIBRARY_IDENTIFIER_RESTRICTION = 'library_identifier_restriction' + LIBRARY_IDENTIFIER_RESTRICTION = "library_identifier_restriction" # Different types of patron restrictions. - LIBRARY_IDENTIFIER_RESTRICTION_BARCODE = 'barcode' - LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE = 'none' - LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX = 'regex' - LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX = 'prefix' - LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING = 'string' - LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST = 'list' + LIBRARY_IDENTIFIER_RESTRICTION_BARCODE = "barcode" + LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE = "none" + LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX = "regex" + LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX = "prefix" + LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING = "string" + LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST = "list" LIBRARY_SETTINGS = [ - { "key": EXTERNAL_TYPE_REGULAR_EXPRESSION, - "label": _("External Type Regular Expression"), - "description": _("Derive a patron's type from their identifier."), + { + "key": EXTERNAL_TYPE_REGULAR_EXPRESSION, + "label": _("External Type Regular Expression"), + "description": _("Derive a patron's type from their identifier."), }, - { "key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE, - "label": _("Library Identifier Restriction Type"), - "type": "select", - "description": _("When multiple libraries share an ILS, a person may be able to " + - "authenticate with the ILS but not be considered a patron of " + - "this library. This setting contains the rule for determining " + - "whether an identifier is valid for this specific library.

" + - "If this setting it set to 'No Restriction' then the values for " + - "Library Identifier Field and Library Identifier " + - "Restriction will not be used."), - "options": [ - {"key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE, "label": _("No restriction")}, - {"key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX, "label": _("Prefix Match")}, - {"key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING, "label": _("Exact Match")}, - {"key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX, "label": _("Regex Match")}, - {"key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST, "label": _("Exact Match, comma separated list")}, - ], - "default": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE + { + "key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE, + "label": _("Library Identifier Restriction Type"), + "type": "select", + "description": _( + "When multiple libraries share an ILS, a person may be able to " + + "authenticate with the ILS but not be considered a patron of " + + "this library. This setting contains the rule for determining " + + "whether an identifier is valid for this specific library.

" + + "If this setting it set to 'No Restriction' then the values for " + + "Library Identifier Field and Library Identifier " + + "Restriction will not be used." + ), + "options": [ + { + "key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE, + "label": _("No restriction"), + }, + { + "key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX, + "label": _("Prefix Match"), + }, + { + "key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING, + "label": _("Exact Match"), + }, + { + "key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX, + "label": _("Regex Match"), + }, + { + "key": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST, + "label": _("Exact Match, comma separated list"), + }, + ], + "default": LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE, }, - { "key": LIBRARY_IDENTIFIER_FIELD, - "label": _("Library Identifier Field"), - "type": "select", - "options": [ - {"key": LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, "label": _("Barcode")}, - ], - "description": _("This is the field on the patron record that the Library Identifier Restriction " + - "Type is applied to, different patron authentication methods provide different " + - "values here. This value is not used if Library Identifier Restriction Type " + - "is set to 'No restriction'."), - "default": LIBRARY_IDENTIFIER_RESTRICTION_BARCODE + { + "key": LIBRARY_IDENTIFIER_FIELD, + "label": _("Library Identifier Field"), + "type": "select", + "options": [ + {"key": LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, "label": _("Barcode")}, + ], + "description": _( + "This is the field on the patron record that the Library Identifier Restriction " + + "Type is applied to, different patron authentication methods provide different " + + "values here. This value is not used if Library Identifier Restriction Type " + + "is set to 'No restriction'." + ), + "default": LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, }, - { "key": LIBRARY_IDENTIFIER_RESTRICTION, - "label": _("Library Identifier Restriction"), - "description": _("This is the restriction applied to the Library Identifier Field " + - "using the method chosen in Library Identifier Restriction Type. " + - "This value is not used if Library Identifier Restriction Type " + - "is set to 'No restriction'."), + { + "key": LIBRARY_IDENTIFIER_RESTRICTION, + "label": _("Library Identifier Restriction"), + "description": _( + "This is the restriction applied to the Library Identifier Field " + + "using the method chosen in Library Identifier Restriction Type. " + + "This value is not used if Library Identifier Restriction Type " + + "is set to 'No restriction'." + ), + }, + { + "key": INSTITUTION_ID, + "label": _("Institution ID"), + "description": _( + "A specific identifier for the library or branch, if used in patron authentication" + ), }, - { "key": INSTITUTION_ID, "label": _("Institution ID"), - "description": _("A specific identifier for the library or branch, if used in patron authentication") - } ] def __init__(self, library, integration, analytics=None): @@ -1359,12 +1426,11 @@ def __init__(self, library, integration, analytics=None): pull normal Python objects out of it. """ if not isinstance(library, Library): - raise Exception( - "Expected library to be a Library, got %r" % library - ) + raise Exception("Expected library to be a Library, got %r" % library) if not isinstance(integration, ExternalIntegration): raise Exception( - "Expected integration to be an ExternalIntegration, got %r" % integration + "Expected integration to be an ExternalIntegration, got %r" + % integration ) self.library_id = library.id @@ -1394,22 +1460,35 @@ def __init__(self, library, integration, analytics=None): field = field.strip() self.library_identifier_field = field - self.library_identifier_restriction_type = ConfigurationSetting.for_library_and_externalintegration( - _db, self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE, library, integration - ).value + self.library_identifier_restriction_type = ( + ConfigurationSetting.for_library_and_externalintegration( + _db, self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE, library, integration + ).value + ) if not self.library_identifier_restriction_type: - self.library_identifier_restriction_type = self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE + self.library_identifier_restriction_type = ( + self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE + ) restriction = ConfigurationSetting.for_library_and_externalintegration( _db, self.LIBRARY_IDENTIFIER_RESTRICTION, library, integration ).value - if self.library_identifier_restriction_type == self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX: + if ( + self.library_identifier_restriction_type + == self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX + ): self.library_identifier_restriction = re.compile(restriction) - elif self.library_identifier_restriction_type == self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST: + elif ( + self.library_identifier_restriction_type + == self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST + ): restriction = restriction.split(",") self.library_identifier_restriction = [item.strip() for item in restriction] - elif self.library_identifier_restriction_type == self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE: + elif ( + self.library_identifier_restriction_type + == self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE + ): self.library_identifier_restriction = None else: if isinstance(restriction, (bytes, str)): @@ -1417,9 +1496,12 @@ def __init__(self, library, integration, analytics=None): else: self.library_identifier_restriction = restriction - self.institution_id = ConfigurationSetting.for_library_and_externalintegration( - _db, self.INSTITUTION_ID, library, integration - ).value or '' + self.institution_id = ( + ConfigurationSetting.for_library_and_externalintegration( + _db, self.INSTITUTION_ID, library, integration + ).value + or "" + ) @classmethod def _restriction_matches(cls, field, restriction, match_type): @@ -1454,7 +1536,11 @@ def enforce_library_identifier_restriction(self, identifier, patrondata): else: return False - if not self.library_identifier_restriction_type or self.library_identifier_restriction_type == self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE: + if ( + not self.library_identifier_restriction_type + or self.library_identifier_restriction_type + == self.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE + ): # No restriction to enforce. return patrondata @@ -1462,7 +1548,10 @@ def enforce_library_identifier_restriction(self, identifier, patrondata): # Restriction field is blank, so everything matches. return patrondata - if self.library_identifier_field.lower() == self.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE: + if ( + self.library_identifier_field.lower() + == self.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE + ): field = identifier else: if not patrondata.complete: @@ -1470,7 +1559,11 @@ def enforce_library_identifier_restriction(self, identifier, patrondata): patrondata = self.remote_patron_lookup(patrondata) field = patrondata.library_identifier - if self._restriction_matches(field, self.library_identifier_restriction, self.library_identifier_restriction_type): + if self._restriction_matches( + field, + self.library_identifier_restriction, + self.library_identifier_restriction_type, + ): return patrondata else: return False @@ -1561,7 +1654,6 @@ def get_credential_from_header(self, header): """ return None - def remote_patron_lookup(self, patron_or_patrondata): """Ask the remote for detailed information about a patron's account. @@ -1582,18 +1674,16 @@ def remote_patron_lookup(self, patron_or_patrondata): """ if not patron_or_patrondata: return None - if (isinstance(patron_or_patrondata, PatronData) - or isinstance(patron_or_patrondata, Patron)): - - + if isinstance(patron_or_patrondata, PatronData) or isinstance( + patron_or_patrondata, Patron + ): return patron_or_patrondata raise ValueError( - "Unexpected object %r passed into remote_patron_lookup." % - patron_or_patrondata + "Unexpected object %r passed into remote_patron_lookup." + % patron_or_patrondata ) - # BasicAuthenticationProvider defines remote_patron_lookup to call this # method and then do something additional; by default, we want the core # lookup mechanism to work the same way as AuthenticationProvider.remote_patron_lookup. @@ -1623,6 +1713,7 @@ class BasicAuthenticationProvider(AuthenticationProvider, HasSelfTests): """Verify a username/password, obtained through HTTP Basic Auth, with a remote source of truth. """ + # NOTE: Each subclass MUST define an attribute called NAME, which # is used to configure that subclass in the configuration file, # used to create the name of the log channel used by this @@ -1652,7 +1743,7 @@ class BasicAuthenticationProvider(AuthenticationProvider, HasSelfTests): DISPLAY_NAME = _("Library Barcode") AUTHENTICATION_REALM = _("Library card") FLOW_TYPE = "http://opds-spec.org/auth/basic" - NAME = 'Generic Basic Authentication provider' + NAME = "Generic Basic Authentication provider" # By default, patron identifiers can only contain alphanumerics and # a few other characters. By default, there are no restrictions on @@ -1667,15 +1758,15 @@ class BasicAuthenticationProvider(AuthenticationProvider, HasSelfTests): # Identifiers can be presumed invalid if they don't match # this regular expression. - IDENTIFIER_REGULAR_EXPRESSION = 'identifier_regular_expression' + IDENTIFIER_REGULAR_EXPRESSION = "identifier_regular_expression" # Passwords can be presumed invalid if they don't match this regular # expression. - PASSWORD_REGULAR_EXPRESSION = 'password_regular_expression' + PASSWORD_REGULAR_EXPRESSION = "password_regular_expression" # The client should prefer one keyboard over another. - IDENTIFIER_KEYBOARD = 'identifier_keyboard' - PASSWORD_KEYBOARD = 'password_keyboard' + IDENTIFIER_KEYBOARD = "identifier_keyboard" + PASSWORD_KEYBOARD = "password_keyboard" # Constants describing different types of keyboards. DEFAULT_KEYBOARD = "Default" @@ -1690,8 +1781,8 @@ class BasicAuthenticationProvider(AuthenticationProvider, HasSelfTests): # The client should use a certain string when asking for a patron's # "identifier" and "password" - IDENTIFIER_LABEL = 'identifier_label' - PASSWORD_LABEL = 'password_label' + IDENTIFIER_LABEL = "identifier_label" + PASSWORD_LABEL = "password_label" DEFAULT_IDENTIFIER_LABEL = "Barcode" DEFAULT_PASSWORD_LABEL = "PIN" @@ -1715,89 +1806,122 @@ class BasicAuthenticationProvider(AuthenticationProvider, HasSelfTests): } IDENTIFIER_BARCODE_FORMAT = "identifier_barcode_format" - BARCODE_FORMAT_CODABAR = "Codabar" # Constant defined in the extension + BARCODE_FORMAT_CODABAR = "Codabar" # Constant defined in the extension BARCODE_FORMAT_NONE = "" # These identifier and password are supposed to be valid # credentials. If there's a problem using them, there's a problem # with the authenticator or with the way we have it configured. - TEST_IDENTIFIER = 'test_identifier' - TEST_PASSWORD = 'test_password' + TEST_IDENTIFIER = "test_identifier" + TEST_PASSWORD = "test_password" TEST_IDENTIFIER_DESCRIPTION_FOR_REQUIRED_PASSWORD = _( "A valid identifier that can be used to test that patron authentication is working." ) - TEST_IDENTIFIER_DESCRIPTION_FOR_OPTIONAL_PASSWORD = _("{} {}".format( - TEST_IDENTIFIER_DESCRIPTION_FOR_REQUIRED_PASSWORD, - "An optional Test Password for this identifier can be set in the next section.", - )) + TEST_IDENTIFIER_DESCRIPTION_FOR_OPTIONAL_PASSWORD = _( + "{} {}".format( + TEST_IDENTIFIER_DESCRIPTION_FOR_REQUIRED_PASSWORD, + "An optional Test Password for this identifier can be set in the next section.", + ) + ) TEST_PASSWORD_DESCRIPTION_REQUIRED = _("The password for the Test Identifier.") - TEST_PASSWORD_DESCRIPTION_OPTIONAL = _("The password for the Test Identifier (above, in previous section).") + TEST_PASSWORD_DESCRIPTION_OPTIONAL = _( + "The password for the Test Identifier (above, in previous section)." + ) SETTINGS = [ - { "key": TEST_IDENTIFIER, - "label": _("Test Identifier"), - "description": TEST_IDENTIFIER_DESCRIPTION_FOR_OPTIONAL_PASSWORD, - "required": True, + { + "key": TEST_IDENTIFIER, + "label": _("Test Identifier"), + "description": TEST_IDENTIFIER_DESCRIPTION_FOR_OPTIONAL_PASSWORD, + "required": True, }, - { "key": TEST_PASSWORD, - "label": _("Test Password"), - "description": TEST_PASSWORD_DESCRIPTION_OPTIONAL, + { + "key": TEST_PASSWORD, + "label": _("Test Password"), + "description": TEST_PASSWORD_DESCRIPTION_OPTIONAL, }, - { "key" : IDENTIFIER_BARCODE_FORMAT, - "label": _("Patron identifier barcode format"), - "description": _("Many libraries render patron identifiers as barcodes on physical library cards. If you specify the barcode format, patrons will be able to scan their library cards with a camera instead of manually typing in their identifiers."), - "type": "select", - "options": [ - { "key": BARCODE_FORMAT_CODABAR, "label": _("Patron identifiers are are rendered as barcodes in Codabar format") }, - { "key": BARCODE_FORMAT_NONE, "label": _("Patron identifiers are not rendered as barcodes") }, - ], - "default": BARCODE_FORMAT_NONE, - "required": True, + { + "key": IDENTIFIER_BARCODE_FORMAT, + "label": _("Patron identifier barcode format"), + "description": _( + "Many libraries render patron identifiers as barcodes on physical library cards. If you specify the barcode format, patrons will be able to scan their library cards with a camera instead of manually typing in their identifiers." + ), + "type": "select", + "options": [ + { + "key": BARCODE_FORMAT_CODABAR, + "label": _( + "Patron identifiers are are rendered as barcodes in Codabar format" + ), + }, + { + "key": BARCODE_FORMAT_NONE, + "label": _("Patron identifiers are not rendered as barcodes"), + }, + ], + "default": BARCODE_FORMAT_NONE, + "required": True, }, - { "key": IDENTIFIER_REGULAR_EXPRESSION, - "label": _("Identifier Regular Expression"), - "description": _("A patron's identifier will be immediately rejected if it doesn't match this regular expression."), + { + "key": IDENTIFIER_REGULAR_EXPRESSION, + "label": _("Identifier Regular Expression"), + "description": _( + "A patron's identifier will be immediately rejected if it doesn't match this regular expression." + ), }, - { "key": PASSWORD_REGULAR_EXPRESSION, - "label": _("Password Regular Expression"), - "description": _("A patron's password will be immediately rejected if it doesn't match this regular expression."), + { + "key": PASSWORD_REGULAR_EXPRESSION, + "label": _("Password Regular Expression"), + "description": _( + "A patron's password will be immediately rejected if it doesn't match this regular expression." + ), }, - { "key": IDENTIFIER_KEYBOARD, - "label": _("Keyboard for identifier entry"), - "type": "select", - "options": [ - { "key": DEFAULT_KEYBOARD, "label": _("System default") }, - { "key": EMAIL_ADDRESS_KEYBOARD, - "label": _("Email address entry") }, - { "key": NUMBER_PAD, "label": _("Number pad") }, - ], - "default": DEFAULT_KEYBOARD, - "required": True, + { + "key": IDENTIFIER_KEYBOARD, + "label": _("Keyboard for identifier entry"), + "type": "select", + "options": [ + {"key": DEFAULT_KEYBOARD, "label": _("System default")}, + {"key": EMAIL_ADDRESS_KEYBOARD, "label": _("Email address entry")}, + {"key": NUMBER_PAD, "label": _("Number pad")}, + ], + "default": DEFAULT_KEYBOARD, + "required": True, }, - { "key": PASSWORD_KEYBOARD, - "label": _("Keyboard for password entry"), - "type": "select", - "options": [ - { "key": DEFAULT_KEYBOARD, "label": _("System default") }, - { "key": NUMBER_PAD, "label": _("Number pad") }, - { "key": NULL_KEYBOARD, "label": _("Patrons have no password and should not be prompted for one.") }, - ], - "default": DEFAULT_KEYBOARD + { + "key": PASSWORD_KEYBOARD, + "label": _("Keyboard for password entry"), + "type": "select", + "options": [ + {"key": DEFAULT_KEYBOARD, "label": _("System default")}, + {"key": NUMBER_PAD, "label": _("Number pad")}, + { + "key": NULL_KEYBOARD, + "label": _( + "Patrons have no password and should not be prompted for one." + ), + }, + ], + "default": DEFAULT_KEYBOARD, }, - { "key": IDENTIFIER_MAXIMUM_LENGTH, - "label": _("Maximum identifier length"), - "type": "number", + { + "key": IDENTIFIER_MAXIMUM_LENGTH, + "label": _("Maximum identifier length"), + "type": "number", }, - { "key": PASSWORD_MAXIMUM_LENGTH, - "label": _("Maximum password length"), - "type": "number", + { + "key": PASSWORD_MAXIMUM_LENGTH, + "label": _("Maximum password length"), + "type": "number", }, - { "key": IDENTIFIER_LABEL, - "label": _("Label for identifier entry"), + { + "key": IDENTIFIER_LABEL, + "label": _("Label for identifier entry"), }, - { "key": PASSWORD_LABEL, - "label": _("Label for password entry"), + { + "key": PASSWORD_LABEL, + "label": _("Label for password entry"), }, ] + AuthenticationProvider.SETTINGS @@ -1819,41 +1943,46 @@ def __init__(self, library, integration, analytics=None): object! It's associated with a scoped database session. Just pull normal Python objects out of it. """ - super(BasicAuthenticationProvider, self).__init__(library, integration, analytics) - identifier_regular_expression = integration.setting( - self.IDENTIFIER_REGULAR_EXPRESSION - ).value or self.DEFAULT_IDENTIFIER_REGULAR_EXPRESSION + super(BasicAuthenticationProvider, self).__init__( + library, integration, analytics + ) + identifier_regular_expression = ( + integration.setting(self.IDENTIFIER_REGULAR_EXPRESSION).value + or self.DEFAULT_IDENTIFIER_REGULAR_EXPRESSION + ) if identifier_regular_expression: - identifier_regular_expression = re.compile( - identifier_regular_expression - ) + identifier_regular_expression = re.compile(identifier_regular_expression) self.identifier_re = identifier_regular_expression - password_regular_expression = integration.setting( - self.PASSWORD_REGULAR_EXPRESSION - ).value or self.DEFAULT_PASSWORD_REGULAR_EXPRESSION + password_regular_expression = ( + integration.setting(self.PASSWORD_REGULAR_EXPRESSION).value + or self.DEFAULT_PASSWORD_REGULAR_EXPRESSION + ) if password_regular_expression: - password_regular_expression = re.compile( - password_regular_expression - ) + password_regular_expression = re.compile(password_regular_expression) self.password_re = password_regular_expression self.test_username = integration.setting(self.TEST_IDENTIFIER).value self.test_password = integration.setting(self.TEST_PASSWORD).value self.identifier_maximum_length = integration.setting( - self.IDENTIFIER_MAXIMUM_LENGTH).int_value + self.IDENTIFIER_MAXIMUM_LENGTH + ).int_value self.password_maximum_length = integration.setting( - self.PASSWORD_MAXIMUM_LENGTH).int_value - self.identifier_keyboard = integration.setting( - self.IDENTIFIER_KEYBOARD).value or self.DEFAULT_KEYBOARD - self.password_keyboard = integration.setting( - self.PASSWORD_KEYBOARD).value or self.DEFAULT_KEYBOARD + self.PASSWORD_MAXIMUM_LENGTH + ).int_value + self.identifier_keyboard = ( + integration.setting(self.IDENTIFIER_KEYBOARD).value or self.DEFAULT_KEYBOARD + ) + self.password_keyboard = ( + integration.setting(self.PASSWORD_KEYBOARD).value or self.DEFAULT_KEYBOARD + ) - self.identifier_barcode_format = integration.setting( - self.IDENTIFIER_BARCODE_FORMAT - ).value or self.BARCODE_FORMAT_NONE + self.identifier_barcode_format = ( + integration.setting(self.IDENTIFIER_BARCODE_FORMAT).value + or self.BARCODE_FORMAT_NONE + ) self.identifier_label = ( integration.setting(self.IDENTIFIER_LABEL).value @@ -1870,7 +1999,9 @@ def remote_patron_lookup(self, patron_or_patrondata): patron_info = self._remote_patron_lookup(patron_or_patrondata) if patron_info: - return self.enforce_library_identifier_restriction(patron_info.authorization_identifier, patron_info) + return self.enforce_library_identifier_restriction( + patron_info.authorization_identifier, patron_info + ) else: return patron_info @@ -1899,28 +2030,31 @@ def testing_patron_or_bust(self, _db): :return: A 2-tuple (Patron, password) """ if self.test_username is None: - raise CannotLoadConfiguration( - "No test patron identifier is configured." - ) + raise CannotLoadConfiguration("No test patron identifier is configured.") patron, password = self.testing_patron(_db) if isinstance(patron, Patron): return patron, password if not patron: - message = ( + message = ( "Remote declined to authenticate the test patron.", - "The patron may not exist or its password may be wrong." + "The patron may not exist or its password may be wrong.", ) elif isinstance(patron, ProblemDetail): - message = "Test patron lookup returned a problem detail - {}: {} ({})".format( + message = ( + "Test patron lookup returned a problem detail - {}: {} ({})".format( patron.title, patron.detail, patron.uri + ) ) else: - message = "Test patron lookup returned invalid value for patron: {!r}".format(patron) + message = ( + "Test patron lookup returned invalid value for patron: {!r}".format( + patron + ) + ) raise IntegrationException(message) - def _run_self_tests(self, _db): """Verify the credentials of the test patron for this integration, and update its metadata. @@ -1936,8 +2070,7 @@ def _run_self_tests(self, _db): patron, password = patron_test.result yield self.run_test( - "Syncing patron metadata", self.update_patron_metadata, - patron + "Syncing patron metadata", self.update_patron_metadata, patron ) def scrub_credential(self, value): @@ -1954,16 +2087,15 @@ def authenticate(self, _db, credentials): :return: A Patron if one can be authenticated; a ProblemDetail if an error occurs; None if the credentials are missing or wrong. """ - username = self.scrub_credential(credentials.get('username')) - password = self.scrub_credential(credentials.get('password')) - server_side_validation_result = self.server_side_validation( - username, password - ) + username = self.scrub_credential(credentials.get("username")) + password = self.scrub_credential(credentials.get("password")) + server_side_validation_result = self.server_side_validation(username, password) if not server_side_validation_result: # False => None server_side_validation_result = None - if (not server_side_validation_result - or isinstance(server_side_validation_result, ProblemDetail)): + if not server_side_validation_result or isinstance( + server_side_validation_result, ProblemDetail + ): # The credentials are prima facie invalid and do not # need to be checked with the source of truth. return server_side_validation_result @@ -2073,7 +2205,7 @@ def get_credential_from_header(self, header): """ if not isinstance(header, dict): return None - return header.get('password', None) + return header.get("password", None) def server_side_validation(self, username, password): """Do these credentials even look right? @@ -2083,20 +2215,28 @@ def server_side_validation(self, username, password): """ valid = True if self.identifier_re: - valid = valid and username is not None and ( - self.identifier_re.match(username) is not None + valid = ( + valid + and username is not None + and (self.identifier_re.match(username) is not None) ) if not self.collects_password: # The only legal password is an empty one. - valid = valid and password in (None, '') + valid = valid and password in (None, "") else: if self.password_re: - valid = valid and password is not None and ( - self.password_re.match(password) is not None + valid = ( + valid + and password is not None + and (self.password_re.match(password) is not None) ) if self.password_maximum_length: - valid = valid and password and (len(password) <= self.password_maximum_length) + valid = ( + valid + and password + and (len(password) <= self.password_maximum_length) + ) if self.identifier_maximum_length: valid = valid and (len(username) <= self.identifier_maximum_length) @@ -2144,29 +2284,22 @@ def local_patron_lookup(self, _db, username, patrondata): # Permanent ID is the most reliable way of identifying # a patron, since this is supposed to be an internal # ID that never changes. - lookups.append( - dict(external_identifier=patrondata.permanent_id) - ) + lookups.append(dict(external_identifier=patrondata.permanent_id)) if patrondata.username: # Username is fairly reliable, since the patron # generally has to decide to change it. - lookups.append( - dict(username=patrondata.username) - ) + lookups.append(dict(username=patrondata.username)) if patrondata.authorization_identifier: # Authorization identifiers change all the time so # they're not terribly reliable. lookups.append( - dict( - authorization_identifier= - patrondata.authorization_identifier - ) + dict(authorization_identifier=patrondata.authorization_identifier) ) patron = None for lookup in lookups: - lookup['library_id'] = self.library_id + lookup["library_id"] = self.library_id patron = get_one(_db, Patron, **lookup) if patron: # We found them! @@ -2183,10 +2316,15 @@ def local_patron_lookup(self, _db, username, patrondata): # undefined which Patron is returned from this query. If # this happens, it's a problem with the ILS and needs to # be resolved over there. - clause = or_(Patron.authorization_identifier==username, - Patron.username==username) - qu = _db.query(Patron).filter(clause).filter( - Patron.library_id==self.library_id).limit(1) + clause = or_( + Patron.authorization_identifier == username, Patron.username == username + ) + qu = ( + _db.query(Patron) + .filter(clause) + .filter(Patron.library_id == self.library_id) + .limit(1) + ) try: patron = qu.one() except NoResultFound: @@ -2204,35 +2342,41 @@ def _authentication_flow_document(self, _db): login_inputs = dict(keyboard=self.identifier_keyboard) if self.identifier_maximum_length: - login_inputs['maximum_length'] = self.identifier_maximum_length + login_inputs["maximum_length"] = self.identifier_maximum_length if self.identifier_barcode_format: - login_inputs['barcode_format'] = self.identifier_barcode_format + login_inputs["barcode_format"] = self.identifier_barcode_format password_inputs = dict(keyboard=self.password_keyboard) if self.password_maximum_length: - password_inputs['maximum_length'] = self.password_maximum_length + password_inputs["maximum_length"] = self.password_maximum_length # Localize the labels if possible. localized_identifier_label = self.COMMON_IDENTIFIER_LABELS.get( - self.identifier_label, - self.identifier_label + self.identifier_label, self.identifier_label ) localized_password_label = self.COMMON_PASSWORD_LABELS.get( - self.password_label, - self.password_label + self.password_label, self.password_label ) flow_doc = dict( description=str(self.DISPLAY_NAME), - labels=dict(login=str(localized_identifier_label), - password=str(localized_password_label)), - inputs = dict(login=login_inputs, - password=password_inputs) + labels=dict( + login=str(localized_identifier_label), + password=str(localized_password_label), + ), + inputs=dict(login=login_inputs, password=password_inputs), ) if self.LOGIN_BUTTON_IMAGE: # TODO: I'm not sure if logo is appropriate for this, since it's a button # with the logo on it rather than a plain logo. Perhaps we should use plain # logos instead. - flow_doc["links"] = [dict(rel="logo", href=url_for("static_image", filename=self.LOGIN_BUTTON_IMAGE, _external=True))] + flow_doc["links"] = [ + dict( + rel="logo", + href=url_for( + "static_image", filename=self.LOGIN_BUTTON_IMAGE, _external=True + ), + ) + ] return flow_doc @@ -2253,9 +2397,7 @@ def bearer_token_signing_secret(cls, db): :return: ConfigurationSetting object containing the signing secret :rtype: ConfigurationSetting """ - return ConfigurationSetting.sitewide_secret( - db, cls.BEARER_TOKEN_SIGNING_SECRET - ) + return ConfigurationSetting.sitewide_secret(db, cls.BEARER_TOKEN_SIGNING_SECRET) class OAuthAuthenticationProvider(AuthenticationProvider, BearerTokenSigner): @@ -2305,13 +2447,17 @@ class OAuthAuthenticationProvider(AuthenticationProvider, BearerTokenSigner): # After verifying the patron's OAuth credentials, we send them a # token. This configuration setting controls how long they can use # that token before we check their OAuth credentials again. - OAUTH_TOKEN_EXPIRATION_DAYS = 'token_expiration_days' + OAUTH_TOKEN_EXPIRATION_DAYS = "token_expiration_days" # This is the default value for that configuration setting. DEFAULT_TOKEN_EXPIRATION_DAYS = 42 SETTINGS = [ - { "key": OAUTH_TOKEN_EXPIRATION_DAYS, "type": "number", "label": _("Days until OAuth token expires") }, + { + "key": OAUTH_TOKEN_EXPIRATION_DAYS, + "type": "number", + "label": _("Days until OAuth token expires"), + }, ] + AuthenticationProvider.SETTINGS def __init__(self, library, integration, analytics=None): @@ -2338,9 +2484,10 @@ def __init__(self, library, integration, analytics=None): ) self.client_id = integration.username self.client_secret = integration.password - self.token_expiration_days = integration.setting( - self.OAUTH_TOKEN_EXPIRATION_DAYS - ).int_value or self.DEFAULT_TOKEN_EXPIRATION_DAYS + self.token_expiration_days = ( + integration.setting(self.OAUTH_TOKEN_EXPIRATION_DAYS).int_value + or self.DEFAULT_TOKEN_EXPIRATION_DAYS + ) def authenticated_patron(self, _db, token): """Go from an OAuth provider token to an authenticated Patron. @@ -2397,15 +2544,16 @@ def external_authenticate_url(self, state, _db): return template % arguments def external_authenticate_url_parameters(self, state, _db): - """Arguments used to fill in the template EXTERNAL_AUTHENTICATE_URL. - """ + """Arguments used to fill in the template EXTERNAL_AUTHENTICATE_URL.""" library_short_name = self.library(_db).short_name return dict( client_id=self.client_id, state=state, # When the patron finishes logging in to the OAuth provider, # we want them to send the patron to this URL. - oauth_callback_url=OAuthController.oauth_authentication_callback_url(library_short_name) + oauth_callback_url=OAuthController.oauth_authentication_callback_url( + library_short_name + ), ) def oauth_callback(self, _db, code): @@ -2488,9 +2636,12 @@ def _internal_authenticate_url(self, _db): """ library = self.library(_db) - return url_for('oauth_authenticate', _external=True, - provider=self.NAME, - library_short_name=library.short_name) + return url_for( + "oauth_authenticate", + _external=True, + provider=self.NAME, + library_short_name=library.short_name, + ) def _authentication_flow_document(self, _db): """Create a Authentication Flow object for use in an Authentication for @@ -2508,20 +2659,24 @@ def _authentication_flow_document(self, _db): """ flow_doc = dict( description=self.NAME, - links=[dict(rel="authenticate", - href=self._internal_authenticate_url(_db))] + links=[dict(rel="authenticate", href=self._internal_authenticate_url(_db))], ) if self.LOGIN_BUTTON_IMAGE: # TODO: I'm not sure if logo is appropriate for this, since it's a button # with the logo on it rather than a plain logo. Perhaps we should use plain # logos instead. - flow_doc["links"] += [dict(rel="logo", href=url_for("static_image", filename=self.LOGIN_BUTTON_IMAGE, _external=True))] + flow_doc["links"] += [ + dict( + rel="logo", + href=url_for( + "static_image", filename=self.LOGIN_BUTTON_IMAGE, _external=True + ), + ) + ] return flow_doc def token_data_source(self, _db): - return get_one_or_create( - _db, DataSource, name=self.TOKEN_DATA_SOURCE_NAME - ) + return get_one_or_create(_db, DataSource, name=self.TOKEN_DATA_SOURCE_NAME) class OAuthController(object): @@ -2542,7 +2697,9 @@ def oauth_authentication_callback_url(cls, library_short_name): provider to demonstrate that it knows which URL a patron was redirected to. """ - return url_for('oauth_callback', library_short_name=library_short_name, _external=True) + return url_for( + "oauth_callback", library_short_name=library_short_name, _external=True + ) def oauth_authentication_redirect(self, params, _db): """Redirect an unauthenticated patron to the authentication URL of the @@ -2552,14 +2709,12 @@ def oauth_authentication_redirect(self, params, _db): redirected back to the circulation manager, ending up in oauth_authentication_callback. """ - redirect_uri = params.get('redirect_uri', '') - provider_name = params.get('provider') + redirect_uri = params.get("redirect_uri", "") + provider_name = params.get("provider") provider = self.authenticator.oauth_provider_lookup(provider_name) if isinstance(provider, ProblemDetail): return self._redirect_with_error(redirect_uri, provider) - state = dict( - provider=provider.NAME, redirect_uri=redirect_uri - ) + state = dict(provider=provider.NAME, redirect_uri=redirect_uri) state = json.dumps(state) state = urllib.parse.quote(state) return redirect(provider.external_authenticate_url(state, _db)) @@ -2584,14 +2739,14 @@ def oauth_authentication_callback(self, _db, params): start using it as a bearer token, and make sense of the patron_info. """ - code = params.get('code') - state = params.get('state') + code = params.get("code") + state = params.get("state") if not code or not state: return INVALID_OAUTH_CALLBACK_PARAMETERS state = json.loads(urllib.parse.unquote(state)) - client_redirect_uri = state.get('redirect_uri') or "" - provider_name = state.get('provider') + client_redirect_uri = state.get("redirect_uri") or "" + provider_name = state.get("provider") provider = self.authenticator.oauth_provider_lookup(provider_name) if isinstance(provider, ProblemDetail): return self._redirect_with_error(client_redirect_uri, provider) @@ -2606,9 +2761,7 @@ def oauth_authentication_callback(self, _db, params): if isinstance(response, ProblemDetail): # Most likely the OAuth provider didn't like the credentials # we sent. - return self._redirect_with_error( - client_redirect_uri, response - ) + return self._redirect_with_error(client_redirect_uri, response) provider_token, patron, patrondata = response # Turn the provider token into a bearer token we can give to @@ -2618,10 +2771,7 @@ def oauth_authentication_callback(self, _db, params): ) patron_info = json.dumps(patrondata.to_response_parameters) - params = dict( - access_token=simplified_token, - patron_info=patron_info - ) + params = dict(access_token=simplified_token, patron_info=patron_info) return redirect(client_redirect_uri + "#" + urllib.parse.urlencode(params)) def _redirect_with_error(self, redirect_uri, pd): @@ -2635,28 +2785,29 @@ def _error_uri(self, redirect_uri, pd): of the given URI. """ problem_detail_json = pd_json( - pd.uri, pd.status_code, pd.title, pd.detail, pd.instance, - pd.debug_message + pd.uri, pd.status_code, pd.title, pd.detail, pd.instance, pd.debug_message ) params = dict(error=problem_detail_json) return redirect_uri + "#" + urllib.parse.urlencode(params) -class BaseSAMLAuthenticationProvider(AuthenticationProvider, BearerTokenSigner, metaclass=ABCMeta): +class BaseSAMLAuthenticationProvider( + AuthenticationProvider, BearerTokenSigner, metaclass=ABCMeta +): """ Base class for SAML authentication providers """ - NAME = 'SAML 2.0' + NAME = "SAML 2.0" - DESCRIPTION = _('SAML 2.0 authentication provider') + DESCRIPTION = _("SAML 2.0 authentication provider") DISPLAY_NAME = NAME - FLOW_TYPE = 'http://librarysimplified.org/authtype/SAML-2.0' + FLOW_TYPE = "http://librarysimplified.org/authtype/SAML-2.0" TOKEN_TYPE = "SAML 2.0 token" - TOKEN_DATA_SOURCE_NAME = 'SAML 2.0' + TOKEN_DATA_SOURCE_NAME = "SAML 2.0" SETTINGS = SAMLSettings() diff --git a/api/axis.py b/api/axis.py index f4a097a9e3..0ede3ef4d8 100644 --- a/api/axis.py +++ b/api/axis.py @@ -68,7 +68,9 @@ class Axis360APIConstants: VERIFY_SSL = "verify_certificate" -class Axis360API(Authenticator, BaseCirculationAPI, HasCollectionSelfTests, Axis360APIConstants): +class Axis360API( + Authenticator, BaseCirculationAPI, HasCollectionSelfTests, Axis360APIConstants +): NAME = ExternalIntegration.AXIS_360 @@ -78,40 +80,43 @@ class Axis360API(Authenticator, BaseCirculationAPI, HasCollectionSelfTests, Axis PRODUCTION_BASE_URL = "https://axis360api.baker-taylor.com/Services/VendorAPI/" QA_BASE_URL = "http://axis360apiqa.baker-taylor.com/Services/VendorAPI/" SERVER_NICKNAMES = { - "production" : PRODUCTION_BASE_URL, - "qa" : QA_BASE_URL, + "production": PRODUCTION_BASE_URL, + "qa": QA_BASE_URL, } DATE_FORMAT = "%m-%d-%Y %H:%M:%S" SETTINGS = [ - { "key": ExternalIntegration.USERNAME, "label": _("Username"), "required": True }, - { "key": ExternalIntegration.PASSWORD, "label": _("Password"), "required": True }, - { "key": Collection.EXTERNAL_ACCOUNT_ID_KEY, "label": _("Library ID"), "required": True }, - { "key": ExternalIntegration.URL, - "label": _("Server"), - "default": PRODUCTION_BASE_URL, - "required": True, - "format": "url", - "allowed": list(SERVER_NICKNAMES.keys()), + {"key": ExternalIntegration.USERNAME, "label": _("Username"), "required": True}, + {"key": ExternalIntegration.PASSWORD, "label": _("Password"), "required": True}, + { + "key": Collection.EXTERNAL_ACCOUNT_ID_KEY, + "label": _("Library ID"), + "required": True, + }, + { + "key": ExternalIntegration.URL, + "label": _("Server"), + "default": PRODUCTION_BASE_URL, + "required": True, + "format": "url", + "allowed": list(SERVER_NICKNAMES.keys()), }, { - "key": Axis360APIConstants.VERIFY_SSL, - "label": _("Verify SSL Certificate"), - "description": _( - "This should always be True in production, it may need to be set to False to use the" - "Axis 360 QA Environment."), - "type": "select", - "options": [ - { - "label": _("True"), - "key": "True" - }, - { - "label": _("False"), - "key": "False", - } - ], - "default": True, + "key": Axis360APIConstants.VERIFY_SSL, + "label": _("Verify SSL Certificate"), + "description": _( + "This should always be True in production, it may need to be set to False to use the" + "Axis 360 QA Environment." + ), + "type": "select", + "options": [ + {"label": _("True"), "key": "True"}, + { + "label": _("False"), + "key": "False", + }, + ], + "default": True, }, ] + BaseCirculationAPI.SETTINGS @@ -119,10 +124,10 @@ class Axis360API(Authenticator, BaseCirculationAPI, HasCollectionSelfTests, Axis BaseCirculationAPI.DEFAULT_LOAN_DURATION_SETTING ] - access_token_endpoint = 'accesstoken' - availability_endpoint = 'availability/v2' - fulfillment_endpoint = 'getfullfillmentInfo/v2' - audiobook_metadata_endpoint = 'getaudiobookmetadata/v2' + access_token_endpoint = "accesstoken" + availability_endpoint = "availability/v2" + fulfillment_endpoint = "getfullfillmentInfo/v2" + audiobook_metadata_endpoint = "getaudiobookmetadata/v2" log = logging.getLogger("Axis 360 API") @@ -140,19 +145,19 @@ class Axis360API(Authenticator, BaseCirculationAPI, HasCollectionSelfTests, Axis AXISNOW = "AxisNow" delivery_mechanism_to_internal_format = { - (epub, no_drm): 'ePub', - (epub, adobe_drm): 'ePub', - (pdf, no_drm): 'PDF', - (pdf, adobe_drm): 'PDF', - (None, findaway_drm): 'Acoustik', + (epub, no_drm): "ePub", + (epub, adobe_drm): "ePub", + (pdf, no_drm): "PDF", + (pdf, adobe_drm): "PDF", + (None, findaway_drm): "Acoustik", (None, axisnow_drm): AXISNOW, } def __init__(self, _db, collection): if collection.protocol != ExternalIntegration.AXIS_360: raise ValueError( - "Collection protocol is %s, but passed into Axis360API!" % - collection.protocol + "Collection protocol is %s, but passed into Axis360API!" + % collection.protocol ) self._db = _db self.library_id = collection.external_account_id @@ -163,26 +168,27 @@ def __init__(self, _db, collection): base_url = collection.external_integration.url or self.PRODUCTION_BASE_URL if base_url in self.SERVER_NICKNAMES: base_url = self.SERVER_NICKNAMES[base_url] - if not base_url.endswith('/'): - base_url += '/' + if not base_url.endswith("/"): + base_url += "/" self.base_url = base_url - if (not self.library_id or not self.username - or not self.password): - raise CannotLoadConfiguration( - "Axis 360 configuration is incomplete." - ) + if not self.library_id or not self.username or not self.password: + raise CannotLoadConfiguration("Axis 360 configuration is incomplete.") # Use utf8 instead of unicode encoding settings = [self.library_id, self.username, self.password] self.library_id, self.username, self.password = ( - setting.encode('utf8') for setting in settings + setting.encode("utf8") for setting in settings ) self.token = None self.collection_id = collection.id - verify_certificate = collection.external_integration.setting(self.VERIFY_SSL).bool_value - self.verify_certificate: bool = verify_certificate if verify_certificate is not None else True + verify_certificate = collection.external_integration.setting( + self.VERIFY_SSL + ).bool_value + self.verify_certificate: bool = ( + verify_certificate if verify_certificate is not None else True + ) @property def collection(self): @@ -203,9 +209,7 @@ def external_integration(self, _db): return self.collection.external_integration def _run_self_tests(self, _db): - result = self.run_test( - "Refreshing bearer token", self.refresh_bearer_token - ) + result = self.run_test("Refreshing bearer token", self.refresh_bearer_token) yield result if not result.success: # If we can't get a bearer token, there's no point running @@ -219,8 +223,7 @@ def _count_events(): return "Found %d event(s)" % count yield self.run_test( - "Asking for circulation events for the last five minutes", - _count_events + "Asking for circulation events for the last five minutes", _count_events ) for result in self.default_patrons(self.collection): @@ -228,12 +231,14 @@ def _count_events(): yield result continue library, patron, pin = result + def _count_activity(): result = self.patron_activity(patron, pin) return "Found %d loans/holds" % len(result) + yield self.run_test( "Checking activity for test patron for library %s" % library.name, - _count_activity + _count_activity, ) # Run the tests defined by HasCollectionSelfTests @@ -244,27 +249,38 @@ def refresh_bearer_token(self): url = self.base_url + self.access_token_endpoint headers = self.authorization_headers response = self._make_request( - url, 'post', headers, allowed_response_codes=[200] + url, "post", headers, allowed_response_codes=[200] ) return self.parse_token(response.content) - def request(self, url, method='get', extra_headers={}, data=None, - params=None, exception_on_401=False, **kwargs): + def request( + self, + url, + method="get", + extra_headers={}, + data=None, + params=None, + exception_on_401=False, + **kwargs + ): """Make an HTTP request, acquiring/refreshing a bearer token if necessary. """ if not self.token: self.token = self.refresh_bearer_token() headers = dict(extra_headers) - headers['Authorization'] = "Bearer " + self.token - headers['Library'] = self.library_id + headers["Authorization"] = "Bearer " + self.token + headers["Library"] = self.library_id if exception_on_401: disallowed_response_codes = ["401"] else: disallowed_response_codes = None response = self._make_request( - url=url, method=method, headers=headers, - data=data, params=params, + url=url, + method=method, + headers=headers, + data=data, + params=params, disallowed_response_codes=disallowed_response_codes, **kwargs ) @@ -275,8 +291,12 @@ def request(self, url, method='get', extra_headers={}, data=None, # The token has expired. Get a new token and try again. self.token = None return self.request( - url=url, method=method, extra_headers=extra_headers, - data=data, params=params, exception_on_401=True, + url=url, + method=method, + extra_headers=extra_headers, + data=data, + params=params, + exception_on_401=True, **kwargs ) else: @@ -287,11 +307,11 @@ def availability(self, patron_id=None, since=None, title_ids=[]): args = dict() if since: since = since.strftime(self.DATE_FORMAT) - args['updatedDate'] = since + args["updatedDate"] = since if patron_id: - args['patronId'] = patron_id + args["patronId"] = patron_id if title_ids: - args['titleIds'] = ','.join(title_ids) + args["titleIds"] = ",".join(title_ids) response = self.request(url, params=args, timeout=None) return response @@ -322,19 +342,17 @@ def checkin(self, patron, pin, licensepool): patron_id = patron.authorization_identifier response = self._checkin(title_id, patron_id) try: - return CheckinResponseParser( - licensepool.collection - ).process_all(response.content) - except etree.XMLSyntaxError as e: - raise RemoteInitiatedServerError( - response.content, self.SERVICE_NAME + return CheckinResponseParser(licensepool.collection).process_all( + response.content ) + except etree.XMLSyntaxError as e: + raise RemoteInitiatedServerError(response.content, self.SERVICE_NAME) def _checkin(self, title_id, patron_id): """Make a request to the EarlyCheckInTitle endpoint.""" url = self.base_url + "EarlyCheckInTitle/v3?itemID=%s&patronID=%s" % ( urllib.parse.quote(title_id), - urllib.parse.quote(patron_id) + urllib.parse.quote(patron_id), ) return self.request(url, method="GET", verbose=True) @@ -343,17 +361,15 @@ def checkout(self, patron, pin, licensepool, internal_format): patron_id = patron.authorization_identifier response = self._checkout(title_id, patron_id, internal_format) try: - return CheckoutResponseParser( - licensepool.collection).process_all(response.content) - except etree.XMLSyntaxError as e: - raise RemoteInitiatedServerError( - response.content, self.SERVICE_NAME + return CheckoutResponseParser(licensepool.collection).process_all( + response.content ) + except etree.XMLSyntaxError as e: + raise RemoteInitiatedServerError(response.content, self.SERVICE_NAME) def _checkout(self, title_id, patron_id, internal_format): url = self.base_url + "checkout/v2" - args = dict(titleId=title_id, patronId=patron_id, - format=internal_format) + args = dict(titleId=title_id, patronId=patron_id, format=internal_format) response = self.request(url, data=args, method="POST") return response @@ -367,12 +383,16 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): """ identifier = licensepool.identifier # This should include only one 'activity'. - activities = self.patron_activity(patron, pin, licensepool.identifier, internal_format) + activities = self.patron_activity( + patron, pin, licensepool.identifier, internal_format + ) for loan in activities: if not isinstance(loan, LoanInfo): continue - if not (loan.identifier_type == identifier.type - and loan.identifier == identifier.identifier): + if not ( + loan.identifier_type == identifier.type + and loan.identifier == identifier.identifier + ): continue # We've found the remote loan corresponding to this # license pool. @@ -394,11 +414,13 @@ def place_hold(self, patron, pin, licensepool, hold_notification_email): identifier = licensepool.identifier title_id = identifier.identifier patron_id = patron.authorization_identifier - params = dict(titleId=title_id, patronId=patron_id, - email=hold_notification_email) + params = dict( + titleId=title_id, patronId=patron_id, email=hold_notification_email + ) response = self.request(url, params=params) hold_info = HoldResponseParser(licensepool.collection).process_all( - response.content) + response.content + ) if not hold_info.identifier: # The Axis 360 API doesn't return the identifier of the # item that was placed on hold, so we have to fill it in @@ -416,7 +438,8 @@ def release_hold(self, patron, pin, licensepool): response = self.request(url, params=params) try: HoldReleaseResponseParser(licensepool.collection).process_all( - response.content) + response.content + ) except NotOnHold: # Fine, it wasn't on hold and now it's still not on hold. pass @@ -429,10 +452,13 @@ def patron_activity(self, patron, pin, identifier=None, internal_format=None): else: title_ids = None availability = self.availability( - patron_id=patron.authorization_identifier, - title_ids=title_ids) - return list(AvailabilityResponseParser(self, internal_format).process_all( - availability.content)) + patron_id=patron.authorization_identifier, title_ids=title_ids + ) + return list( + AvailabilityResponseParser(self, internal_format).process_all( + availability.content + ) + ) def update_availability(self, licensepool): """Update the availability information for a single LicensePool. @@ -452,9 +478,7 @@ def update_licensepools_for_identifiers(self, identifiers): circulation information. """ remainder = set(identifiers) - for bibliographic, availability in self._fetch_remote_availability( - identifiers - ): + for bibliographic, availability in self._fetch_remote_availability(identifiers): edition, ignore1, license_pool, ignore2 = self.update_book( bibliographic, availability ) @@ -519,7 +543,7 @@ def _reap(self, identifier): if not pool: self.log.warn( "Was about to reap %r but no local license pool in this collection.", - identifier + identifier, ) return if pool.licenses_owned == 0: @@ -536,8 +560,7 @@ def _reap(self, identifier): patrons_in_hold_queue=0, ) availability.apply( - self._db, collection, - ReplacementPolicy.from_license_source(self._db) + self._db, collection, ReplacementPolicy.from_license_source(self._db) ) def recent_activity(self, since): @@ -547,8 +570,9 @@ def recent_activity(self, since): """ availability = self.availability(since=since) content = availability.content - for bibliographic, circulation in BibliographicParser(self.collection).process_all( - content): + for bibliographic, circulation in BibliographicParser( + self.collection + ).process_all(content): yield bibliographic, circulation @classmethod @@ -566,20 +590,19 @@ def create_identifier_strings(cls, identifiers): @classmethod def parse_token(cls, token): data = json.loads(token) - return data['access_token'] + return data["access_token"] - def _make_request(self, url, method, headers, data=None, params=None, - **kwargs): + def _make_request(self, url, method, headers, data=None, params=None, **kwargs): """Actually make an HTTP request.""" return HTTP.request_with_timeout( - method, url, headers=headers, data=data, - params=params, **kwargs + method, url, headers=headers, data=data, params=params, **kwargs ) + class Axis360CirculationMonitor(CollectionMonitor, TimelineMonitor): - """Maintain LicensePools for Axis 360 titles. - """ + """Maintain LicensePools for Axis 360 titles.""" + SERVICE_NAME = "Axis 360 Circulation Monitor" INTERVAL_SECONDS = 60 DEFAULT_BATCH_SIZE = 50 @@ -598,8 +621,8 @@ def __init__(self, _db, collection, api_class=Axis360API): self.api = api_class(_db, collection) self.batch_size = self.DEFAULT_BATCH_SIZE - self.bibliographic_coverage_provider = ( - Axis360BibliographicCoverageProvider(collection, api_class=self.api) + self.bibliographic_coverage_provider = Axis360BibliographicCoverageProvider( + collection, api_class=self.api ) def catch_up_from(self, start, cutoff, progress): @@ -626,10 +649,7 @@ def process_book(self, bibliographic, circulation): # work has been done so we don't have to do it again. identifier = edition.primary_identifier self.bibliographic_coverage_provider.handle_success(identifier) - self.bibliographic_coverage_provider.add_coverage_record_for( - identifier - ) - + self.bibliographic_coverage_provider.add_coverage_record_for(identifier) return edition, license_pool @@ -640,17 +660,18 @@ def mock_collection(cls, _db, name="Test Axis 360 Collection"): """Create a mock Axis 360 collection for use in tests.""" library = DatabaseTest.make_default_library(_db) collection, ignore = get_one_or_create( - _db, Collection, + _db, + Collection, name=name, create_method_kwargs=dict( - external_account_id='c', - ) + external_account_id="c", + ), ) integration = collection.create_external_integration( protocol=ExternalIntegration.AXIS_360 ) - integration.username = 'a' - integration.password = 'b' + integration.username = "a" + integration.password = "b" integration.url = "http://axis.test/" library.collections.append(collection) return collection @@ -673,18 +694,20 @@ def __init__(self, _db, collection, with_token=True, **kwargs): def queue_response(self, status_code, headers={}, content=None): from core.testing import MockRequestsResponse - self.responses.insert( - 0, MockRequestsResponse(status_code, headers, content) - ) + + self.responses.insert(0, MockRequestsResponse(status_code, headers, content)) def _make_request(self, url, *args, **kwargs): self.requests.append([url, args, kwargs]) response = self.responses.pop() return HTTP._process_response( - url, response, kwargs.get('allowed_response_codes'), - kwargs.get('disallowed_response_codes') + url, + response, + kwargs.get("allowed_response_codes"), + kwargs.get("disallowed_response_codes"), ) + class Axis360BibliographicCoverageProvider(BibliographicCoverageProvider): """Fill in bibliographic metadata for Axis 360 records. @@ -708,9 +731,7 @@ def __init__(self, collection, api_class=Axis360API, **kwargs): :param api_class: Instantiate this class with the given Collection, rather than instantiating Axis360API. """ - super(Axis360BibliographicCoverageProvider, self).__init__( - collection, **kwargs - ) + super(Axis360BibliographicCoverageProvider, self).__init__(collection, **kwargs) if isinstance(api_class, Axis360API): # We were given a specific Axis360API instance to use. self.api = api_class @@ -759,12 +780,14 @@ def process_item(self, identifier): results = self.process_batch([identifier]) return results[0] + class AxisCollectionReaper(IdentifierSweepMonitor): """Check for books that are in the local collection but have left our Axis 360 collection. """ + SERVICE_NAME = "Axis Collection Reaper" - INTERVAL_SECONDS = 3600*12 + INTERVAL_SECONDS = 3600 * 12 PROTOCOL = ExternalIntegration.AXIS_360 def __init__(self, _db, collection, api_class=Axis360API): @@ -802,7 +825,7 @@ def _xpath1_boolean(self, e, target, ns, default=False): text = self.text_of_optional_subtag(e, target, ns) if text is None: return default - if text == 'true': + if text == "true": return True else: return False @@ -811,14 +834,15 @@ def _xpath1_date(self, e, target, ns): value = self.text_of_optional_subtag(e, target, ns) return self._pd(value) + class BibliographicParser(Axis360Parser): DELIVERY_DATA_FOR_AXIS_FORMAT = { - "Blio" : None, # Legacy format, handled the same way as AxisNow - "Acoustik" : (None, DeliveryMechanism.FINDAWAY_DRM), # Audiobooks - "AxisNow" : None, # Handled specially, for ebooks only. - "ePub" : (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), - "PDF" : (Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), + "Blio": None, # Legacy format, handled the same way as AxisNow + "Acoustik": (None, DeliveryMechanism.FINDAWAY_DRM), # Audiobooks + "AxisNow": None, # Handled specially, for ebooks only. + "ePub": (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), + "PDF": (Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), } log = logging.getLogger("Axis 360 Bibliographic Parser") @@ -838,11 +862,12 @@ def __init__(self, include_availability=True, include_bibliographic=True): def process_all(self, string): for i in super(BibliographicParser, self).process_all( - string, "//axis:title", self.NS): + string, "//axis:title", self.NS + ): yield i def extract_availability(self, circulation_data, element, ns): - identifier = self.text_of_subtag(element, 'axis:titleId', ns) + identifier = self.text_of_subtag(element, "axis:titleId", ns) primary_identifier = IdentifierData(Identifier.AXIS_360_ID, identifier) if not circulation_data: circulation_data = CirculationData( @@ -850,27 +875,25 @@ def extract_availability(self, circulation_data, element, ns): primary_identifier=primary_identifier, ) - availability = self._xpath1(element, 'axis:availability', ns) - total_copies = self.int_of_subtag(availability, 'axis:totalCopies', ns) - available_copies = self.int_of_subtag( - availability, 'axis:availableCopies', ns) - size_of_hold_queue = self.int_of_subtag( - availability, 'axis:holdsQueueSize', ns) + availability = self._xpath1(element, "axis:availability", ns) + total_copies = self.int_of_subtag(availability, "axis:totalCopies", ns) + available_copies = self.int_of_subtag(availability, "axis:availableCopies", ns) + size_of_hold_queue = self.int_of_subtag(availability, "axis:holdsQueueSize", ns) availability_updated = self.text_of_optional_subtag( - availability, 'axis:updateDate', ns) + availability, "axis:updateDate", ns + ) if availability_updated: # NOTE: We don't actually do anything with this. availability_updated = self._pd(availability_updated) - circulation_data.licenses_owned=total_copies - circulation_data.licenses_available=available_copies - circulation_data.licenses_reserved=0 - circulation_data.patrons_in_hold_queue=size_of_hold_queue + circulation_data.licenses_owned = total_copies + circulation_data.licenses_available = available_copies + circulation_data.licenses_reserved = 0 + circulation_data.patrons_in_hold_queue = size_of_hold_queue return circulation_data - # Axis authors with a special role have an abbreviation after their names, # e.g. "San Ruby (FRW)" role_abbreviation = re.compile("\(([A-Z][A-Z][A-Z])\)$") @@ -882,13 +905,12 @@ def extract_availability(self, circulation_data, element, ns): ILT=Contributor.ILLUSTRATOR_ROLE, TRN=Contributor.TRANSLATOR_ROLE, FRW=Contributor.FOREWORD_ROLE, - ADP=generic_author, # Author of adaptation - COR=generic_author, # Corporate author + ADP=generic_author, # Author of adaptation + COR=generic_author, # Corporate author ) @classmethod - def parse_contributor(cls, author, primary_author_found=False, - force_role=None): + def parse_contributor(cls, author, primary_author_found=False, force_role=None): """Parse an Axis 360 contributor string. The contributor string looks like "Butler, Octavia" or "Walt @@ -917,15 +939,14 @@ def parse_contributor(cls, author, primary_author_found=False, if match: role_type = match.groups()[0] role = cls.role_abbreviation_to_role.get( - role_type, Contributor.UNKNOWN_ROLE) + role_type, Contributor.UNKNOWN_ROLE + ) if role is cls.generic_author: role = default_author_role author = author[:-5].strip() if force_role: role = force_role - return ContributorData( - sort_name=author, roles=[role] - ) + return ContributorData(sort_name=author, roles=[role]) def extract_bibliographic(self, element, ns): """Turn bibliographic metadata into a Metadata and a CirculationData objects, @@ -938,25 +959,21 @@ def extract_bibliographic(self, element, ns): # edition # runtime - identifier = self.text_of_subtag(element, 'axis:titleId', ns) - isbn = self.text_of_optional_subtag(element, 'axis:isbn', ns) - title = self.text_of_subtag(element, 'axis:productTitle', ns) + identifier = self.text_of_subtag(element, "axis:titleId", ns) + isbn = self.text_of_optional_subtag(element, "axis:isbn", ns) + title = self.text_of_subtag(element, "axis:productTitle", ns) - contributor = self.text_of_optional_subtag( - element, 'axis:contributor', ns) + contributor = self.text_of_optional_subtag(element, "axis:contributor", ns) contributors = [] found_primary_author = False if contributor: for c in self.parse_list(contributor): - contributor = self.parse_contributor( - c, found_primary_author) + contributor = self.parse_contributor(c, found_primary_author) if Contributor.PRIMARY_AUTHOR_ROLE in contributor.roles: found_primary_author = True contributors.append(contributor) - narrator = self.text_of_optional_subtag( - element, 'axis:narrator', ns - ) + narrator = self.text_of_optional_subtag(element, "axis:narrator", ns) if narrator: for n in self.parse_list(narrator): contributor = self.parse_contributor( @@ -965,9 +982,7 @@ def extract_bibliographic(self, element, ns): contributors.append(contributor) links = [] - description = self.text_of_optional_subtag( - element, 'axis:annotation', ns - ) + description = self.text_of_optional_subtag(element, "axis:annotation", ns) if description: links.append( LinkData( @@ -977,30 +992,30 @@ def extract_bibliographic(self, element, ns): ) ) - subject = self.text_of_optional_subtag(element, 'axis:subject', ns) + subject = self.text_of_optional_subtag(element, "axis:subject", ns) subjects = [] if subject: for subject_identifier in self.parse_list(subject): subjects.append( SubjectData( - type=Subject.BISAC, identifier=None, + type=Subject.BISAC, + identifier=None, name=subject_identifier, - weight=Classification.TRUSTED_DISTRIBUTOR_WEIGHT + weight=Classification.TRUSTED_DISTRIBUTOR_WEIGHT, ) ) publication_date = self.text_of_optional_subtag( - element, 'axis:publicationDate', ns) + element, "axis:publicationDate", ns + ) if publication_date: - publication_date = strptime_utc( - publication_date, self.SHORT_DATE_FORMAT - ) + publication_date = strptime_utc(publication_date, self.SHORT_DATE_FORMAT) - series = self.text_of_optional_subtag(element, 'axis:series', ns) - publisher = self.text_of_optional_subtag(element, 'axis:publisher', ns) - imprint = self.text_of_optional_subtag(element, 'axis:imprint', ns) + series = self.text_of_optional_subtag(element, "axis:series", ns) + publisher = self.text_of_optional_subtag(element, "axis:publisher", ns) + imprint = self.text_of_optional_subtag(element, "axis:imprint", ns) - audience = self.text_of_optional_subtag(element, 'axis:audience', ns) + audience = self.text_of_optional_subtag(element, "axis:audience", ns) if audience: subjects.append( SubjectData( @@ -1010,15 +1025,13 @@ def extract_bibliographic(self, element, ns): ) ) - language = self.text_of_subtag(element, 'axis:language', ns) + language = self.text_of_subtag(element, "axis:language", ns) - thumbnail_url = self.text_of_optional_subtag( - element, 'axis:imageUrl', ns - ) + thumbnail_url = self.text_of_optional_subtag(element, "axis:imageUrl", ns) if thumbnail_url: # We presume all images from this service are JPEGs. media_type = MediaTypes.JPEG_MEDIA_TYPE - if '/Medium/' in thumbnail_url: + if "/Medium/" in thumbnail_url: # We know about a URL hack for this service that lets us # get a larger image. full_size_url = thumbnail_url.replace("/Medium/", "/Large/") @@ -1031,13 +1044,13 @@ def extract_bibliographic(self, element, ns): thumbnail = LinkData( rel=LinkRelations.THUMBNAIL_IMAGE, href=thumbnail_url, - media_type=media_type + media_type=media_type, ) image = LinkData( rel=LinkRelations.IMAGE, href=full_size_url, media_type=media_type, - thumbnail=thumbnail + thumbnail=thumbnail, ) links.append(image) @@ -1065,8 +1078,7 @@ def extract_bibliographic(self, element, ns): blio_seen = False for format_tag in self._xpath( - element, 'axis:availability/axis:availableFormats/axis:formatName', - ns + element, "axis:availability/axis:availableFormats/axis:formatName", ns ): informal_name = format_tag.text seen_formats.append(informal_name) @@ -1082,9 +1094,10 @@ def extract_bibliographic(self, element, ns): continue if informal_name not in self.DELIVERY_DATA_FOR_AXIS_FORMAT: - self.log.warn("Unrecognized Axis format name for %s: %s" % ( - identifier, informal_name - )) + self.log.warn( + "Unrecognized Axis format name for %s: %s" + % (identifier, informal_name) + ) elif self.DELIVERY_DATA_FOR_AXIS_FORMAT.get(informal_name): content_type, drm_scheme = self.DELIVERY_DATA_FOR_AXIS_FORMAT[ informal_name @@ -1097,7 +1110,7 @@ def extract_bibliographic(self, element, ns): medium = Edition.AUDIO_MEDIUM else: medium = Edition.BOOK_MEDIUM - if (blio_seen or (axisnow_seen and medium == Edition.BOOK_MEDIUM)): + if blio_seen or (axisnow_seen and medium == Edition.BOOK_MEDIUM): # This ebook is available through AxisNow. Add an # appropriate FormatData. # @@ -1109,8 +1122,10 @@ def extract_bibliographic(self, element, ns): if not formats: self.log.error( - "No supported format for %s (%s)! Saw: %s", identifier, - title, ", ".join(seen_formats) + "No supported format for %s (%s)! Saw: %s", + identifier, + title, + ", ".join(seen_formats), ) metadata = Metadata( @@ -1138,7 +1153,6 @@ def extract_bibliographic(self, element, ns): metadata.circulation = circulationdata return metadata - def process_one(self, element, ns): if self.include_bibliographic: bibliographic = self.extract_bibliographic(element, ns) @@ -1150,12 +1164,15 @@ def process_one(self, element, ns): passed_availability = bibliographic.circulation if self.include_availability: - availability = self.extract_availability(circulation_data=passed_availability, element=element, ns=ns) + availability = self.extract_availability( + circulation_data=passed_availability, element=element, ns=ns + ) else: availability = None return bibliographic, availability + class ResponseParser(Axis360Parser): id_type = Identifier.AXIS_360_ID @@ -1164,58 +1181,58 @@ class ResponseParser(Axis360Parser): # Map Axis 360 error codes to our circulation exceptions. code_to_exception = { - 315 : InvalidInputException, # Bad password - 316 : InvalidInputException, # DRM account already exists - 1000 : PatronAuthorizationFailedException, - 1001 : PatronAuthorizationFailedException, - 1002 : PatronAuthorizationFailedException, - 1003 : PatronAuthorizationFailedException, - 2000 : LibraryAuthorizationFailedException, - 2001 : LibraryAuthorizationFailedException, - 2002 : LibraryAuthorizationFailedException, - 2003 : LibraryAuthorizationFailedException, # "Encoded input parameters exceed limit", whatever that meaus - 2004 : LibraryAuthorizationFailedException, - 2005 : LibraryAuthorizationFailedException, # Invalid credentials - 2005 : LibraryAuthorizationFailedException, # Wrong library ID - 2007 : LibraryAuthorizationFailedException, # Invalid library ID - 2008 : LibraryAuthorizationFailedException, # Invalid library ID - 3100 : LibraryInvalidInputException, # Missing title ID - 3101 : LibraryInvalidInputException, # Missing patron ID - 3102 : LibraryInvalidInputException, # Missing email address (for hold notification) - 3103 : NotFoundOnRemote, # Invalid title ID - 3104 : LibraryInvalidInputException, # Invalid Email Address (for hold notification) - 3105 : PatronAuthorizationFailedException, # Invalid Account Credentials - 3106 : InvalidInputException, # Loan Period is out of bounds - 3108 : InvalidInputException, # DRM Credentials Required - 3109 : InvalidInputException, # Hold already exists or hold does not exist, depending. - 3110 : AlreadyCheckedOut, - 3111 : CurrentlyAvailable, - 3112 : CannotFulfill, - 3113 : CannotLoan, - (3113, "Title ID is not available for checkout") : NoAvailableCopies, - 3114 : PatronLoanLimitReached, - 3115 : LibraryInvalidInputException, # Missing DRM format - 3116 : LibraryInvalidInputException, # No patron session ID provided -- we don't use this - 3117 : LibraryInvalidInputException, # Invalid DRM format - 3118 : LibraryInvalidInputException, # Invalid Patron credentials - 3119 : LibraryAuthorizationFailedException, # No Blio account - 3120 : LibraryAuthorizationFailedException, # No Acoustikaccount - 3123 : PatronAuthorizationFailedException, # Patron Session ID expired - 3124 : PatronAuthorizationFailedException, # Patron SessionID is required - 3126 : LibraryInvalidInputException, # Invalid checkout format - 3127 : InvalidInputException, # First name is required - 3128 : InvalidInputException, # Last name is required - 3129 : PatronAuthorizationFailedException, # Invalid Patron Session Id - 3130 : LibraryInvalidInputException, # Invalid hold format (?) - 3131 : RemoteInitiatedServerError, # Custom error message (?) - 3132 : LibraryInvalidInputException, # Invalid delta datetime format - 3134 : LibraryInvalidInputException, # Delta datetime format must not be in the future - 3135 : NoAcceptableFormat, - 3136 : LibraryInvalidInputException, # Missing checkout format - 4058 : NoActiveLoan, # No checkout is associated with patron for the title. - 5000 : RemoteInitiatedServerError, - 5003 : LibraryInvalidInputException, # Missing TransactionID - 5004 : LibraryInvalidInputException, # Missing TransactionID + 315: InvalidInputException, # Bad password + 316: InvalidInputException, # DRM account already exists + 1000: PatronAuthorizationFailedException, + 1001: PatronAuthorizationFailedException, + 1002: PatronAuthorizationFailedException, + 1003: PatronAuthorizationFailedException, + 2000: LibraryAuthorizationFailedException, + 2001: LibraryAuthorizationFailedException, + 2002: LibraryAuthorizationFailedException, + 2003: LibraryAuthorizationFailedException, # "Encoded input parameters exceed limit", whatever that meaus + 2004: LibraryAuthorizationFailedException, + 2005: LibraryAuthorizationFailedException, # Invalid credentials + 2005: LibraryAuthorizationFailedException, # Wrong library ID + 2007: LibraryAuthorizationFailedException, # Invalid library ID + 2008: LibraryAuthorizationFailedException, # Invalid library ID + 3100: LibraryInvalidInputException, # Missing title ID + 3101: LibraryInvalidInputException, # Missing patron ID + 3102: LibraryInvalidInputException, # Missing email address (for hold notification) + 3103: NotFoundOnRemote, # Invalid title ID + 3104: LibraryInvalidInputException, # Invalid Email Address (for hold notification) + 3105: PatronAuthorizationFailedException, # Invalid Account Credentials + 3106: InvalidInputException, # Loan Period is out of bounds + 3108: InvalidInputException, # DRM Credentials Required + 3109: InvalidInputException, # Hold already exists or hold does not exist, depending. + 3110: AlreadyCheckedOut, + 3111: CurrentlyAvailable, + 3112: CannotFulfill, + 3113: CannotLoan, + (3113, "Title ID is not available for checkout"): NoAvailableCopies, + 3114: PatronLoanLimitReached, + 3115: LibraryInvalidInputException, # Missing DRM format + 3116: LibraryInvalidInputException, # No patron session ID provided -- we don't use this + 3117: LibraryInvalidInputException, # Invalid DRM format + 3118: LibraryInvalidInputException, # Invalid Patron credentials + 3119: LibraryAuthorizationFailedException, # No Blio account + 3120: LibraryAuthorizationFailedException, # No Acoustikaccount + 3123: PatronAuthorizationFailedException, # Patron Session ID expired + 3124: PatronAuthorizationFailedException, # Patron SessionID is required + 3126: LibraryInvalidInputException, # Invalid checkout format + 3127: InvalidInputException, # First name is required + 3128: InvalidInputException, # Last name is required + 3129: PatronAuthorizationFailedException, # Invalid Patron Session Id + 3130: LibraryInvalidInputException, # Invalid hold format (?) + 3131: RemoteInitiatedServerError, # Custom error message (?) + 3132: LibraryInvalidInputException, # Invalid delta datetime format + 3134: LibraryInvalidInputException, # Delta datetime format must not be in the future + 3135: NoAcceptableFormat, + 3136: LibraryInvalidInputException, # Missing checkout format + 4058: NoActiveLoan, # No checkout is associated with patron for the title. + 5000: RemoteInitiatedServerError, + 5003: LibraryInvalidInputException, # Missing TransactionID + 5004: LibraryInvalidInputException, # Missing TransactionID } def __init__(self, collection): @@ -1226,8 +1243,9 @@ def __init__(self, collection): """ self.collection = collection - def raise_exception_on_error(self, e, ns, custom_error_classes={}, - ignore_error_codes=None): + def raise_exception_on_error( + self, e, ns, custom_error_classes={}, ignore_error_codes=None + ): """Raise an error if the given lxml node represents an Axis 360 error condition. @@ -1238,8 +1256,8 @@ def raise_exception_on_error(self, e, ns, custom_error_classes={}, :param ignore_error_codes: A list of error codes to treat as success rather than as cause to raise an exception. """ - code = self._xpath1(e, '//axis:status/axis:code', ns) - message = self._xpath1(e, '//axis:status/axis:statusMessage', ns) + code = self._xpath1(e, "//axis:status/axis:code", ns) + message = self._xpath1(e, "//axis:status/axis:statusMessage", ns) if message is None: message = etree.tostring(e) else: @@ -1252,15 +1270,15 @@ def raise_exception_on_error(self, e, ns, custom_error_classes={}, ) @classmethod - def _raise_exception_on_error(cls, code, message, custom_error_classes={}, - ignore_error_codes=None): + def _raise_exception_on_error( + cls, code, message, custom_error_classes={}, ignore_error_codes=None + ): try: code = int(code) except ValueError: # Non-numeric code? Inconcievable! raise RemoteInitiatedServerError( - "Invalid response code from Axis 360: %s" % code, - cls.SERVICE_NAME + "Invalid response code from Axis 360: %s" % code, cls.SERVICE_NAME ) if ignore_error_codes and code in ignore_error_codes: @@ -1282,26 +1300,23 @@ def _raise_exception_on_error(cls, code, message, custom_error_classes={}, class CheckinResponseParser(ResponseParser): - def process_all(self, string): for i in super(CheckinResponseParser, self).process_all( - string, "//axis:EarlyCheckinRestResult", self.NS): + string, "//axis:EarlyCheckinRestResult", self.NS + ): return i def process_one(self, e, namespaces): - """Either raise an appropriate exception, or do nothing. - """ - self.raise_exception_on_error( - e, namespaces, ignore_error_codes = [4058] - ) + """Either raise an appropriate exception, or do nothing.""" + self.raise_exception_on_error(e, namespaces, ignore_error_codes=[4058]) return True class CheckoutResponseParser(ResponseParser): - def process_all(self, string): for i in super(CheckoutResponseParser, self).process_all( - string, "//axis:checkoutResult", self.NS): + string, "//axis:checkoutResult", self.NS + ): return i def process_one(self, e, namespaces): @@ -1312,8 +1327,8 @@ def process_one(self, e, namespaces): self.raise_exception_on_error(e, namespaces) # If we get to this point it's because the checkout succeeded. - expiration_date = self._xpath1(e, '//axis:expirationDate', namespaces) - fulfillment_url = self._xpath1(e, '//axis:url', namespaces) + expiration_date = self._xpath1(e, "//axis:expirationDate", namespaces) + fulfillment_url = self._xpath1(e, "//axis:url", namespaces) if fulfillment_url is not None: fulfillment_url = fulfillment_url.text @@ -1323,30 +1338,31 @@ def process_one(self, e, namespaces): loan_start = utc_now() loan = LoanInfo( - collection=self.collection, data_source_name=DataSource.AXIS_360, - identifier_type=self.id_type, identifier=None, + collection=self.collection, + data_source_name=DataSource.AXIS_360, + identifier_type=self.id_type, + identifier=None, start_date=loan_start, end_date=expiration_date, ) return loan -class HoldResponseParser(ResponseParser): +class HoldResponseParser(ResponseParser): def process_all(self, string): for i in super(HoldResponseParser, self).process_all( - string, "//axis:addtoholdResult", self.NS): + string, "//axis:addtoholdResult", self.NS + ): return i def process_one(self, e, namespaces): """Either turn the given document into a HoldInfo object, or raise an appropriate exception. """ - self.raise_exception_on_error( - e, namespaces, {3109 : AlreadyOnHold}) + self.raise_exception_on_error(e, namespaces, {3109: AlreadyOnHold}) # If we get to this point it's because the hold place succeeded. - queue_position = self._xpath1( - e, '//axis:holdsQueuePosition', namespaces) + queue_position = self._xpath1(e, "//axis:holdsQueuePosition", namespaces) if queue_position is None: queue_position = None else: @@ -1360,33 +1376,38 @@ def process_one(self, e, namespaces): # NOTE: The caller needs to fill in Collection -- we have no idea # what collection this is. hold = HoldInfo( - collection=self.collection, data_source_name=DataSource.AXIS_360, - identifier_type=self.id_type, identifier=None, - start_date=hold_start, end_date=None, hold_position=queue_position) + collection=self.collection, + data_source_name=DataSource.AXIS_360, + identifier_type=self.id_type, + identifier=None, + start_date=hold_start, + end_date=None, + hold_position=queue_position, + ) return hold -class HoldReleaseResponseParser(ResponseParser): +class HoldReleaseResponseParser(ResponseParser): def process_all(self, string): for i in super(HoldReleaseResponseParser, self).process_all( - string, "//axis:removeholdResult", self.NS): + string, "//axis:removeholdResult", self.NS + ): return i def post_process(self, i): """Unlike other ResponseParser subclasses, we don't return any type of - \*Info object, so there's no need to do any post-processing. + \*Info object, so there's no need to do any post-processing. """ return i def process_one(self, e, namespaces): # There's no data to gather here. Either there was an error # or we were successful. - self.raise_exception_on_error( - e, namespaces, {3109 : NotOnHold}) + self.raise_exception_on_error(e, namespaces, {3109: NotOnHold}) return True -class AvailabilityResponseParser(ResponseParser): +class AvailabilityResponseParser(ResponseParser): def __init__(self, api, internal_format=None): """Constructor. @@ -1404,7 +1425,8 @@ def __init__(self, api, internal_format=None): def process_all(self, string): for info in super(AvailabilityResponseParser, self).process_all( - string, "//axis:title", self.NS): + string, "//axis:title", self.NS + ): # Filter out books where nothing in particular is # happening. if info: @@ -1414,30 +1436,31 @@ def process_one(self, e, ns): # Figure out which book we're talking about. axis_identifier = self.text_of_subtag(e, "axis:titleId", ns) - availability = self._xpath1(e, 'axis:availability', ns) + availability = self._xpath1(e, "axis:availability", ns) if availability is None: return None - reserved = self._xpath1_boolean(availability, 'axis:isReserved', ns) - checked_out = self._xpath1_boolean(availability, 'axis:isCheckedout', ns) - on_hold = self._xpath1_boolean(availability, 'axis:isInHoldQueue', ns) + reserved = self._xpath1_boolean(availability, "axis:isReserved", ns) + checked_out = self._xpath1_boolean(availability, "axis:isCheckedout", ns) + on_hold = self._xpath1_boolean(availability, "axis:isInHoldQueue", ns) info = None if checked_out: - start_date = self._xpath1_date( - availability, 'axis:checkoutStartDate', ns) - end_date = self._xpath1_date( - availability, 'axis:checkoutEndDate', ns) + start_date = self._xpath1_date(availability, "axis:checkoutStartDate", ns) + end_date = self._xpath1_date(availability, "axis:checkoutEndDate", ns) download_url = self.text_of_optional_subtag( - availability, 'axis:downloadUrl', ns) - transaction_id = self.text_of_optional_subtag( - availability, 'axis:transactionID', ns) or "" + availability, "axis:downloadUrl", ns + ) + transaction_id = ( + self.text_of_optional_subtag(availability, "axis:transactionID", ns) + or "" + ) # Arguments common to FulfillmentInfo and # Axis360FulfillmentInfo. kwargs = dict( data_source_name=DataSource.AXIS_360, identifier_type=self.id_type, - identifier=axis_identifier + identifier=axis_identifier, ) if download_url and self.internal_format != self.api.AXISNOW: @@ -1477,13 +1500,13 @@ def process_one(self, e, ns): data_source_name=DataSource.AXIS_360, identifier_type=self.id_type, identifier=axis_identifier, - start_date=start_date, end_date=end_date, - fulfillment_info=fulfillment + start_date=start_date, + end_date=end_date, + fulfillment_info=fulfillment, ) elif reserved: - end_date = self._xpath1_date( - availability, 'axis:reservedEndDate', ns) + end_date = self._xpath1_date(availability, "axis:reservedEndDate", ns) info = HoldInfo( collection=self.collection, data_source_name=DataSource.AXIS_360, @@ -1491,18 +1514,21 @@ def process_one(self, e, ns): identifier=axis_identifier, start_date=None, end_date=end_date, - hold_position=0 + hold_position=0, ) elif on_hold: position = self.int_of_optional_subtag( - availability, 'axis:holdsQueuePosition', ns) + availability, "axis:holdsQueuePosition", ns + ) info = HoldInfo( collection=self.collection, data_source_name=DataSource.AXIS_360, identifier_type=self.id_type, identifier=axis_identifier, - start_date=None, end_date=None, - hold_position=position) + start_date=None, + end_date=None, + hold_position=position, + ) return info @@ -1521,10 +1547,12 @@ def _required_key(cls, key, json_obj): """ if json_obj is None or key not in json_obj: raise RemoteInitiatedServerError( - "Required key %s not present in Axis 360 fulfillment document: %s" % ( - key, json_obj, + "Required key %s not present in Axis 360 fulfillment document: %s" + % ( + key, + json_obj, ), - cls.SERVICE_NAME + cls.SERVICE_NAME, ) return json_obj[key] @@ -1534,9 +1562,9 @@ def verify_status_code(cls, parsed): response. """ k = cls._required_key - status = k('Status', parsed) - code = k('Code', status) - message = status.get('Message') + status = k("Status", parsed) + code = k("Code", status) + message = status.get("Message") # If the document describes an error condition, raise # an appropriate exception immediately. @@ -1545,7 +1573,7 @@ def verify_status_code(cls, parsed): def parse(self, data, *args, **kwargs): """Parse a JSON document.""" if isinstance(data, dict): - parsed = data # already parsed + parsed = data # already parsed else: try: parsed = json.loads(data) @@ -1553,7 +1581,7 @@ def parse(self, data, *args, **kwargs): # It's not JSON. raise RemoteInitiatedServerError( "Invalid response from Axis 360 (was expecting JSON): %s" % data, - self.SERVICE_NAME + self.SERVICE_NAME, ) # If the response indicates an error condition, don't continue -- @@ -1579,9 +1607,7 @@ def __init__(self, api): a fulfillment document triggers additional API requests. """ self.api = api - super(Axis360FulfillmentInfoResponseParser, self).__init__( - self.api.collection - ) + super(Axis360FulfillmentInfoResponseParser, self).__init__(self.api.collection) def _parse(self, parsed, license_pool): """Extract all useful information from a parsed FulfillmentInfo @@ -1596,10 +1622,10 @@ def _parse(self, parsed, license_pool): :return: A 2-tuple (manifest, expiration_date). `manifest` is either a FindawayManifest (for an audiobook) or an AxisNowManifest (for an ebook). """ - expiration_date = self._required_key('ExpirationDate', parsed) + expiration_date = self._required_key("ExpirationDate", parsed) expiration_date = self.parse_date(expiration_date) - if 'FNDTransactionID' in parsed: + if "FNDTransactionID" in parsed: manifest = self.parse_findaway(parsed, license_pool) else: manifest = self.parse_axisnow(parsed) @@ -1607,28 +1633,26 @@ def _parse(self, parsed, license_pool): return manifest, expiration_date def parse_date(self, date): - if '.' in date: + if "." in date: # Remove 7(?!) decimal places of precision and # UTC timezone, which are more trouble to parse # than they're worth. - date = date[:date.rindex('.')] + date = date[: date.rindex(".")] try: date = strptime_utc(date, "%Y-%m-%d %H:%M:%S") except ValueError: raise RemoteInitiatedServerError( - "Could not parse expiration date: %s" % date, - self.SERVICE_NAME + "Could not parse expiration date: %s" % date, self.SERVICE_NAME ) return date - def parse_findaway(self, parsed, license_pool): k = self._required_key - fulfillmentId = k('FNDContentID', parsed) - licenseId = k('FNDLicenseID', parsed) - sessionKey = k('FNDSessionKey', parsed) - checkoutId = k('FNDTransactionID', parsed) + fulfillmentId = k("FNDContentID", parsed) + licenseId = k("FNDLicenseID", parsed) + sessionKey = k("FNDSessionKey", parsed) + checkoutId = k("FNDTransactionID", parsed) # Acquire the TOC information metadata_response = self.api.get_audiobook_metadata(fulfillmentId) @@ -1636,27 +1660,30 @@ def parse_findaway(self, parsed, license_pool): accountId, spine_items = parser.parse(metadata_response.content) return FindawayManifest( - license_pool, accountId=accountId, checkoutId=checkoutId, - fulfillmentId=fulfillmentId, licenseId=licenseId, - sessionKey=sessionKey, spine_items=spine_items + license_pool, + accountId=accountId, + checkoutId=checkoutId, + fulfillmentId=fulfillmentId, + licenseId=licenseId, + sessionKey=sessionKey, + spine_items=spine_items, ) def parse_axisnow(self, parsed): k = self._required_key - isbn = k('ISBN', parsed) - book_vault_uuid = k('BookVaultUUID', parsed) + isbn = k("ISBN", parsed) + book_vault_uuid = k("BookVaultUUID", parsed) return AxisNowManifest(book_vault_uuid, isbn) class AudiobookMetadataParser(JSONResponseParser): - """Parse the results of Axis 360's audiobook metadata API call. - """ + """Parse the results of Axis 360's audiobook metadata API call.""" @classmethod def _parse(cls, parsed): spine_items = [] - accountId = parsed.get('fndaccountid', None) - for item in parsed.get('readingOrder', []): + accountId = parsed.get("fndaccountid", None) + for item in parsed.get("readingOrder", []): spine_item = cls._extract_spine_item(item) if spine_item: spine_items.append(spine_item) @@ -1665,11 +1692,11 @@ def _parse(cls, parsed): @classmethod def _extract_spine_item(cls, part): """Convert an element of the 'readingOrder' list to a SpineItem.""" - title = part.get('title') + title = part.get("title") # Incoming duration is measured in seconds. - duration = part.get('duration', 0) - part_number = int(part.get('fndpart', 0)) - sequence = int(part.get('fndsequence', 0)) + duration = part.get("duration", 0) + part_number = int(part.get("fndpart", 0)) + sequence = int(part.get("fndsequence", 0)) return SpineItem(title, duration, part_number, sequence) @@ -1703,6 +1730,7 @@ class Axis360FulfillmentInfo(APIAwareFulfillmentInfo): one or two extra HTTP requests, and there's often no need to make those requests. """ + def do_fetch(self): _db = self.api._db license_pool = self.license_pool(_db) @@ -1746,6 +1774,7 @@ class Axis360AcsFulfillmentInfo(FulfillmentInfo): to the Axis 360 API, sidestepping the problem, but taking a different code path than most of our external HTTP requests. """ + logger = logging.getLogger(__name__) def __init__(self, verify: bool, **kwargs): @@ -1758,7 +1787,7 @@ def problem_detail_document(self, error_details: str) -> ProblemDetail: return INTEGRATION_ERROR.detailed( _(RequestNetworkException.detail, service=service_name), title=RequestNetworkException.title, - debug_message=error_details + debug_message=error_details, ) @property @@ -1775,7 +1804,9 @@ def as_response(self) -> Union[Response, ProblemDetail]: # Default context does no ssl verification ssl_context = ssl.SSLContext() req = urllib.request.Request(self.content_link) - with urllib.request.urlopen(req, timeout=20, context=ssl_context) as response: + with urllib.request.urlopen( + req, timeout=20, context=ssl_context + ) as response: content = response.read() status = response.status headers = response.headers @@ -1784,12 +1815,18 @@ def as_response(self) -> Union[Response, ProblemDetail]: # wrap the exceptions thrown by urllib and ssl returning a ProblemDetail document. except urllib.error.HTTPError as e: return self.problem_detail_document( - "The server received a bad status code ({}) while contacting {}".format(e.code, service_name) + "The server received a bad status code ({}) while contacting {}".format( + e.code, service_name + ) ) except socket.timeout: - return self.problem_detail_document("Error connecting to {}. Timeout occurred.".format(service_name)) + return self.problem_detail_document( + "Error connecting to {}. Timeout occurred.".format(service_name) + ) except (urllib.error.URLError, ssl.SSLError) as e: reason = getattr(e, "reason", e.__class__.__name__) - return self.problem_detail_document("Error connecting to {}. {}.".format(service_name, reason)) + return self.problem_detail_document( + "Error connecting to {}. {}.".format(service_name, reason) + ) return Response(response=content, status=status, headers=headers) diff --git a/api/base_controller.py b/api/base_controller.py index 8f6b08bf37..3776a4e4ea 100644 --- a/api/base_controller.py +++ b/api/base_controller.py @@ -4,14 +4,10 @@ from flask import Response from flask_babel import lazy_gettext as _ -from .circulation_exceptions import * -from core.model import ( - Library, - Loan, - Patron, - get_one, -) +from core.model import Library, Loan, Patron, get_one from core.util.problem_detail import ProblemDetail + +from .circulation_exceptions import * from .problem_details import * @@ -35,8 +31,8 @@ def authorization_header(self): # If we're using a token instead, flask doesn't extract it for us. if not header: - if 'Authorization' in flask.request.headers: - header = flask.request.headers['Authorization'] + if "Authorization" in flask.request.headers: + header = flask.request.headers["Authorization"] return header @@ -53,7 +49,7 @@ def request_patron(self): :return: A Patron, if one could be authenticated; None otherwise. """ - if not hasattr(flask.request, 'patron'): + if not hasattr(flask.request, "patron"): # Call authenticated_patron_from_request for its side effect # of setting flask.request.patron self.authenticated_patron_from_request() @@ -104,9 +100,7 @@ def authenticated_patron(self, authorization_header): If there's no problem, return a Patron object. """ - patron = self.manager.auth.authenticated_patron( - self._db, authorization_header - ) + patron = self.manager.auth.authenticated_patron(self._db, authorization_header) if not patron: return INVALID_CREDENTIALS diff --git a/api/bibliotheca.py b/api/bibliotheca.py index 5ebb45f0df..7c009108d8 100644 --- a/api/bibliotheca.py +++ b/api/bibliotheca.py @@ -8,10 +8,7 @@ import time import urllib.parse from datetime import datetime, timedelta -from io import ( - BytesIO, - StringIO, -) +from io import BytesIO, StringIO import dateutil.parser from flask_babel import lazy_gettext as _ @@ -19,20 +16,16 @@ from pymarc import parse_xml_to_array from core.analytics import Analytics -from core.config import ( - CannotLoadConfiguration, -) -from core.coverage import ( - BibliographicCoverageProvider -) +from core.config import CannotLoadConfiguration +from core.coverage import BibliographicCoverageProvider from core.metadata_layer import ( - ContributorData, CirculationData, - Metadata, - LinkData, - IdentifierData, + ContributorData, FormatData, + IdentifierData, + LinkData, MeasurementData, + Metadata, ReplacementPolicy, SubjectData, ) @@ -45,8 +38,6 @@ DeliveryMechanism, Edition, ExternalIntegration, - get_one, - get_one_or_create, Hyperlink, Identifier, LicensePool, @@ -55,40 +46,21 @@ Session, Subject, Timestamp, + get_one, + get_one_or_create, ) -from core.monitor import ( - CollectionMonitor, - IdentifierSweepMonitor, - TimelineMonitor, -) +from core.monitor import CollectionMonitor, IdentifierSweepMonitor, TimelineMonitor from core.scripts import RunCollectionMonitorScript from core.testing import DatabaseTest -from core.util.datetime_helpers import ( - datetime_utc, - strptime_utc, - to_utc, - utc_now, -) -from core.util.http import ( - HTTP -) +from core.util.datetime_helpers import datetime_utc, strptime_utc, to_utc, utc_now +from core.util.http import HTTP from core.util.string_helpers import base64 from core.util.xmlparser import XMLParser -from .circulation import ( - FulfillmentInfo, - HoldInfo, - LoanInfo, - BaseCirculationAPI, -) + +from .circulation import BaseCirculationAPI, FulfillmentInfo, HoldInfo, LoanInfo from .circulation_exceptions import * -from .selftest import ( - HasSelfTests, - SelfTestResult, -) -from .web_publication_manifest import ( - FindawayManifest, - SpineItem, -) +from .selftest import HasSelfTests, SelfTestResult +from .web_publication_manifest import FindawayManifest, SpineItem class BibliothecaAPI(BaseCirculationAPI, HasSelfTests): @@ -108,9 +80,21 @@ class BibliothecaAPI(BaseCirculationAPI, HasSelfTests): DEFAULT_BASE_URL = "https://partner.yourcloudlibrary.com/" SETTINGS = [ - { "key": ExternalIntegration.USERNAME, "label": _("Account ID"), "required": True }, - { "key": ExternalIntegration.PASSWORD, "label": _("Account Key"), "required": True }, - { "key": Collection.EXTERNAL_ACCOUNT_ID_KEY, "label": _("Library ID"), "required": True }, + { + "key": ExternalIntegration.USERNAME, + "label": _("Account ID"), + "required": True, + }, + { + "key": ExternalIntegration.PASSWORD, + "label": _("Account Key"), + "required": True, + }, + { + "key": Collection.EXTERNAL_ACCOUNT_ID_KEY, + "label": _("Library ID"), + "required": True, + }, ] + BaseCirculationAPI.SETTINGS LIBRARY_SETTINGS = BaseCirculationAPI.LIBRARY_SETTINGS + [ @@ -128,25 +112,26 @@ class BibliothecaAPI(BaseCirculationAPI, HasSelfTests): adobe_drm = DeliveryMechanism.ADOBE_DRM findaway_drm = DeliveryMechanism.FINDAWAY_DRM delivery_mechanism_to_internal_format = { - (Representation.EPUB_MEDIA_TYPE, adobe_drm): 'ePub', - (Representation.PDF_MEDIA_TYPE, adobe_drm): 'PDF', - (None, findaway_drm) : 'MP3' + (Representation.EPUB_MEDIA_TYPE, adobe_drm): "ePub", + (Representation.PDF_MEDIA_TYPE, adobe_drm): "PDF", + (None, findaway_drm): "MP3", } internal_format_to_delivery_mechanism = dict( - [v,k] for k, v in list(delivery_mechanism_to_internal_format.items()) + [v, k] for k, v in list(delivery_mechanism_to_internal_format.items()) ) def __init__(self, _db, collection): if collection.protocol != ExternalIntegration.BIBLIOTHECA: raise ValueError( - "Collection protocol is %s, but passed into BibliothecaAPI!" % - collection.protocol + "Collection protocol is %s, but passed into BibliothecaAPI!" + % collection.protocol ) self._db = _db self.version = ( - collection.external_integration.setting('version').value or self.DEFAULT_VERSION + collection.external_integration.setting("version").value + or self.DEFAULT_VERSION ) self.account_id = collection.external_integration.username self.account_key = collection.external_integration.password @@ -154,9 +139,7 @@ def __init__(self, _db, collection): self.base_url = collection.external_integration.url or self.DEFAULT_BASE_URL if not self.account_id or not self.account_key or not self.library_id: - raise CannotLoadConfiguration( - "Bibliotheca configuration is incomplete." - ) + raise CannotLoadConfiguration("Bibliotheca configuration is incomplete.") self.item_list_parser = ItemListParser() self.collection_id = collection.id @@ -190,9 +173,10 @@ def signature(self, method, path): signature_components = [now, method, path] signature_string = "\n".join(signature_components) digest = hmac.new( - self.account_key.encode("utf-8"), - msg=signature_string.encode("utf-8"), - digestmod=hashlib.sha256).digest() + self.account_key.encode("utf-8"), + msg=signature_string.encode("utf-8"), + digestmod=hashlib.sha256, + ).digest() signature = base64.standard_b64encode(digest) return signature, now @@ -215,30 +199,36 @@ def replacement_policy(cls, _db, analytics=None): policy.analytics = analytics return policy - def request(self, path, body=None, method="GET", identifier=None, - max_age=None): + def request(self, path, body=None, method="GET", identifier=None, max_age=None): path = self.full_path(path) url = self.full_url(path) - if method == 'GET': - headers = {"Accept" : "application/xml"} + if method == "GET": + headers = {"Accept": "application/xml"} else: - headers = {"Content-Type" : "application/xml"} + headers = {"Content-Type": "application/xml"} self.sign(method, headers, path) # print headers # self.log.debug("3M request: %s %s", method, url) - if max_age and method=='GET': + if max_age and method == "GET": representation, cached = Representation.get( - self._db, url, extra_request_headers=headers, - do_get=self._simple_http_get, max_age=max_age, + self._db, + url, + extra_request_headers=headers, + do_get=self._simple_http_get, + max_age=max_age, exception_handler=Representation.reraise_exception, - timeout=60 + timeout=60, ) content = representation.content return content else: return self._request_with_timeout( - method, url, data=body, headers=headers, - allow_redirects=False, timeout=60 + method, + url, + data=body, + headers=headers, + allow_redirects=False, + timeout=60, ) def get_bibliographic_info_for(self, editions, max_age=None): @@ -265,7 +255,10 @@ def marc_request(self, start, end, offset=1, limit=50): start = start.strftime(self.ARGUMENT_TIME_FORMAT) end = end.strftime(self.ARGUMENT_TIME_FORMAT) url = "data/marc?startdate=%s&enddate=%s&offset=%d&limit=%d" % ( - start, end, offset, limit + start, + end, + offset, + limit, ) response = self.request(url) if response.status_code != 200: @@ -322,8 +315,7 @@ def _count_events(): return "Found %d event(s)" % count yield self.run_test( - "Asking for circulation events for the last five minutes", - _count_events + "Asking for circulation events for the last five minutes", _count_events ) for result in self.default_patrons(self.collection): @@ -331,12 +323,14 @@ def _count_events(): yield result continue library, patron, pin = result + def _count_activity(): result = self.patron_activity(patron, pin) return "Found %d loans/holds" % len(result) + yield self.run_test( "Checking activity for test patron for library %s" % library.name, - _count_activity + _count_activity, ) def get_events_between(self, start, end, cache_result=False, no_events_error=False): @@ -355,8 +349,9 @@ def get_events_between(self, start, end, cache_result=False, no_events_error=Fal events = EventParser().process_all(response.content, no_events_error) except Exception as e: self.log.error( - "Error parsing Bibliotheca response content: %s", response.content, - exc_info=e + "Error parsing Bibliotheca response content: %s", + response.content, + exc_info=e, ) raise e return events @@ -380,10 +375,7 @@ def patron_activity(self, patron, pin): TEMPLATE = "<%(request_type)s>%(item_id)s%(patron_id)s" - def checkout( - self, patron_obj, patron_password, licensepool, - delivery_mechanism - ): + def checkout(self, patron_obj, patron_password, licensepool, delivery_mechanism): """Check out a book on behalf of a patron. @@ -399,10 +391,13 @@ def checkout( """ bibliotheca_id = licensepool.identifier.identifier patron_identifier = patron_obj.authorization_identifier - args = dict(request_type='CheckoutRequest', - item_id=bibliotheca_id, patron_id=patron_identifier) + args = dict( + request_type="CheckoutRequest", + item_id=bibliotheca_id, + patron_id=patron_identifier, + ) body = self.TEMPLATE % args - response = self.request('checkout', body, method="PUT") + response = self.request("checkout", body, method="PUT") if response.status_code == 201: # New loan start_date = utc_now() @@ -421,7 +416,8 @@ def checkout( # At this point we know we have a loan. loan_expires = CheckoutResponseParser().process_all(response.content) loan = LoanInfo( - licensepool.collection, DataSource.BIBLIOTHECA, + licensepool.collection, + DataSource.BIBLIOTHECA, licensepool.identifier.type, licensepool.identifier.identifier, start_date=None, @@ -453,56 +449,59 @@ def fulfill(self, patron, password, pool, internal_format, **kwargs): content_type = None if content_transformation: try: - content_type, content = ( - content_transformation(pool, content) - ) + content_type, content = content_transformation(pool, content) except Exception as e: self.log.error( "Error transforming fulfillment document: %s", - response.content, exc_info=e + response.content, + exc_info=e, ) return FulfillmentInfo( - pool.collection, DataSource.BIBLIOTHECA, + pool.collection, + DataSource.BIBLIOTHECA, pool.identifier.type, pool.identifier.identifier, content_link=None, - content_type=content_type or response.headers.get('Content-Type'), + content_type=content_type or response.headers.get("Content-Type"), content=content, content_expires=None, ) def get_fulfillment_file(self, patron_id, bibliotheca_id): - args = dict(request_type='ACSMRequest', - item_id=bibliotheca_id, patron_id=patron_id) + args = dict( + request_type="ACSMRequest", item_id=bibliotheca_id, patron_id=patron_id + ) body = self.TEMPLATE % args - return self.request('GetItemACSM', body, method="PUT") + return self.request("GetItemACSM", body, method="PUT") def get_audio_fulfillment_file(self, patron_id, bibliotheca_id): - args = dict(request_type='AudioFulfillmentRequest', - item_id=bibliotheca_id, patron_id=patron_id) + args = dict( + request_type="AudioFulfillmentRequest", + item_id=bibliotheca_id, + patron_id=patron_id, + ) body = self.TEMPLATE % args - return self.request('GetItemAudioFulfillment', body, method="POST") + return self.request("GetItemAudioFulfillment", body, method="POST") def checkin(self, patron, pin, licensepool): patron_id = patron.authorization_identifier item_id = licensepool.identifier.identifier - args = dict(request_type='CheckinRequest', - item_id=item_id, patron_id=patron_id) + args = dict(request_type="CheckinRequest", item_id=item_id, patron_id=patron_id) body = self.TEMPLATE % args - return self.request('checkin', body, method="PUT") + return self.request("checkin", body, method="PUT") - def place_hold(self, patron, pin, licensepool, - hold_notification_email=None): + def place_hold(self, patron, pin, licensepool, hold_notification_email=None): """Place a hold. :return: a HoldInfo object. """ patron_id = patron.authorization_identifier item_id = licensepool.identifier.identifier - args = dict(request_type='PlaceHoldRequest', - item_id=item_id, patron_id=patron_id) + args = dict( + request_type="PlaceHoldRequest", item_id=item_id, patron_id=patron_id + ) body = self.TEMPLATE % args - response = self.request('placehold', body, method="PUT") + response = self.request("placehold", body, method="PUT") # The response comes in as a byte string that we must # convert into a string. response_content = None @@ -512,12 +511,13 @@ def place_hold(self, patron, pin, licensepool, start_date = utc_now() end_date = HoldResponseParser().process_all(response_content) return HoldInfo( - licensepool.collection, DataSource.BIBLIOTHECA, + licensepool.collection, + DataSource.BIBLIOTHECA, licensepool.identifier.type, licensepool.identifier.identifier, start_date=start_date, end_date=end_date, - hold_position=None + hold_position=None, ) else: if not response_content: @@ -531,19 +531,18 @@ def place_hold(self, patron, pin, licensepool, def release_hold(self, patron, pin, licensepool): patron_id = patron.authorization_identifier item_id = licensepool.identifier.identifier - args = dict(request_type='CancelHoldRequest', - item_id=item_id, patron_id=patron_id) + args = dict( + request_type="CancelHoldRequest", item_id=item_id, patron_id=patron_id + ) body = self.TEMPLATE % args - response = self.request('cancelhold', body, method="PUT") + response = self.request("cancelhold", body, method="PUT") if response.status_code in (200, 404): return True else: raise CannotReleaseHold() @classmethod - def findaway_license_to_webpub_manifest( - cls, license_pool, findaway_license - ): + def findaway_license_to_webpub_manifest(cls, license_pool, findaway_license): """Convert a Bibliotheca license document to a FindawayManifest suitable for serving to a mobile client. @@ -559,39 +558,39 @@ def findaway_license_to_webpub_manifest( kwargs = {} for findaway_extension in [ - 'accountId', 'checkoutId', 'fulfillmentId', 'licenseId', - 'sessionKey' + "accountId", + "checkoutId", + "fulfillmentId", + "licenseId", + "sessionKey", ]: value = findaway_license.get(findaway_extension, None) kwargs[findaway_extension] = value # Create the SpineItem objects. - audio_format = findaway_license.get('format') - if audio_format == 'MP3': + audio_format = findaway_license.get("format") + if audio_format == "MP3": part_media_type = Representation.MP3_MEDIA_TYPE else: - logging.error("Unknown Findaway audio format encountered: %s", - audio_format) + logging.error("Unknown Findaway audio format encountered: %s", audio_format) part_media_type = None spine_items = [] - for part in findaway_license.get('items'): - title = part.get('title') + for part in findaway_license.get("items"): + title = part.get("title") # TODO: Incoming duration appears to be measured in # milliseconds. This assumption makes our example # audiobook take about 7.9 hours, and no other reasonable # assumption is in the right order of magnitude. But this # needs to be explicitly verified. - duration = part.get('duration', 0) / 1000.0 + duration = part.get("duration", 0) / 1000.0 - part_number = int(part.get('part', 0)) + part_number = int(part.get("part", 0)) - sequence = int(part.get('sequence', 0)) + sequence = int(part.get("sequence", 0)) - spine_items.append( - SpineItem(title, duration, part_number, sequence) - ) + spine_items.append(SpineItem(title, duration, part_number, sequence)) # Create a FindawayManifest object and then convert it # to a string. @@ -603,29 +602,30 @@ def findaway_license_to_webpub_manifest( class DummyBibliothecaAPIResponse(object): - def __init__(self, response_code, headers, content): self.status_code = response_code self.headers = headers self.content = content -class MockBibliothecaAPI(BibliothecaAPI): +class MockBibliothecaAPI(BibliothecaAPI): @classmethod def mock_collection(self, _db, name="Test Bibliotheca Collection"): """Create a mock Bibliotheca collection for use in tests.""" library = DatabaseTest.make_default_library(_db) collection, ignore = get_one_or_create( - _db, Collection, - name=name, create_method_kwargs=dict( - external_account_id='c', - ) + _db, + Collection, + name=name, + create_method_kwargs=dict( + external_account_id="c", + ), ) integration = collection.create_external_integration( protocol=ExternalIntegration.BIBLIOTHECA ) - integration.username = 'a' - integration.password = 'b' + integration.username = "a" + integration.password = "b" integration.url = "http://bibliotheca.test" library.collections.append(collection) return collection @@ -633,36 +633,34 @@ def mock_collection(self, _db, name="Test Bibliotheca Collection"): def __init__(self, _db, collection, *args, **kwargs): self.responses = [] self.requests = [] - super(MockBibliothecaAPI, self).__init__( - _db, collection, *args, **kwargs - ) + super(MockBibliothecaAPI, self).__init__(_db, collection, *args, **kwargs) def now(self): """Return an unvarying time in the format Bibliotheca expects.""" - return datetime.strftime( - datetime(2016, 1, 1), self.AUTH_TIME_FORMAT - ) + return datetime.strftime(datetime(2016, 1, 1), self.AUTH_TIME_FORMAT) def queue_response(self, status_code, headers={}, content=None): from core.testing import MockRequestsResponse - self.responses.insert( - 0, MockRequestsResponse(status_code, headers, content) - ) + + self.responses.insert(0, MockRequestsResponse(status_code, headers, content)) def _request_with_timeout(self, method, url, *args, **kwargs): """Simulate HTTP.request_with_timeout.""" self.requests.append([method, url, args, kwargs]) response = self.responses.pop() return HTTP._process_response( - url, response, kwargs.get('allowed_response_codes'), - kwargs.get('disallowed_response_codes') + url, + response, + kwargs.get("allowed_response_codes"), + kwargs.get("disallowed_response_codes"), ) def _simple_http_get(self, url, headers, *args, **kwargs): """Simulate Representation.simple_http_get.""" - response = self._request_with_timeout('GET', url, *args, **kwargs) + response = self._request_with_timeout("GET", url, *args, **kwargs) return response.status_code, response.headers, response.content + class ItemListParser(XMLParser): DATE_FORMAT = "%Y-%m-%d" @@ -678,20 +676,11 @@ def parse(self, xml): parenthetical = re.compile(" \([^)]+\)$") - format_data_for_bibliotheca_format = { - "EPUB" : ( - Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM - ), - "EPUB3" : ( - Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM - ), - "PDF" : ( - Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM - ), - "MP3" : ( - None, DeliveryMechanism.FINDAWAY_DRM - ), + "EPUB": (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), + "EPUB3": (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), + "PDF": (Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), + "MP3": (None, DeliveryMechanism.FINDAWAY_DRM), } @classmethod @@ -705,13 +694,10 @@ def contributors_from_string(cls, string, role=Contributor.AUTHOR_ROLE): # We handle the potential need for a second unescaping here. string = cls.unescape_entity_references(string) - for sort_name in string.split(';'): + for sort_name in string.split(";"): sort_name = cls.parenthetical.sub("", sort_name.strip()) contributors.append( - ContributorData( - sort_name=sort_name.strip(), - roles=[role] - ) + ContributorData(sort_name=sort_name.strip(), roles=[role]) ) return contributors @@ -724,11 +710,21 @@ def parse_genre_string(self, s): i = i.strip() if not i: continue - i = i.replace("&amp;", "&").replace("&", "&").replace("'", "'") - genres.append(SubjectData(Subject.BISAC, None, i, weight=Classification.TRUSTED_DISTRIBUTOR_WEIGHT)) + i = ( + i.replace("&amp;", "&") + .replace("&", "&") + .replace("'", "'") + ) + genres.append( + SubjectData( + Subject.BISAC, + None, + i, + weight=Classification.TRUSTED_DISTRIBUTOR_WEIGHT, + ) + ) return genres - def process_one(self, tag, namespaces): """Turn an tag into a Metadata and an encompassed CirculationData objects, and return the Metadata.""" @@ -740,17 +736,13 @@ def value(bibliotheca_key): identifiers = dict() subjects = [] - primary_identifier = IdentifierData( - Identifier.BIBLIOTHECA_ID, value("ItemId") - ) + primary_identifier = IdentifierData(Identifier.BIBLIOTHECA_ID, value("ItemId")) identifiers = [] - for key in ('ISBN13', 'PhysicalISBN'): + for key in ("ISBN13", "PhysicalISBN"): v = value(key) if v: - identifiers.append( - IdentifierData(Identifier.ISBN, v) - ) + identifiers.append(IdentifierData(Identifier.ISBN, v)) subjects = self.parse_genre_string(value("Genre")) @@ -759,11 +751,9 @@ def value(bibliotheca_key): publisher = value("Publisher") language = value("Language") - authors = list(self.contributors_from_string(value('Authors'))) + authors = list(self.contributors_from_string(value("Authors"))) narrators = list( - self.contributors_from_string( - value('Narrator'), Contributor.NARRATOR_ROLE - ) + self.contributors_from_string(value("Narrator"), Contributor.NARRATOR_ROLE) ) published_date = None @@ -783,16 +773,13 @@ def value(bibliotheca_key): links = [] description = value("Description") if description: - links.append( - LinkData(rel=Hyperlink.DESCRIPTION, content=description) - ) + links.append(LinkData(rel=Hyperlink.DESCRIPTION, content=description)) # Presume all images from Bibliotheca are JPEG. media_type = Representation.JPEG_MEDIA_TYPE cover_url = value("CoverLinkURL").replace("&", "&") cover_link = LinkData( - rel=Hyperlink.IMAGE, href=cover_url, - media_type=media_type + rel=Hyperlink.IMAGE, href=cover_url, media_type=media_type ) # Unless the URL format has drastically changed, we should be @@ -801,26 +788,23 @@ def value(bibliotheca_key): # # NOTE: this is an undocumented feature of the Bibliotheca API # which was discovered by investigating the BookLinkURL. - if '/delivery/img' in cover_url: + if "/delivery/img" in cover_url: thumbnail_url = cover_url + "&size=NORMAL" thumbnail = LinkData( - rel=Hyperlink.THUMBNAIL_IMAGE, - href=thumbnail_url, - media_type=media_type + rel=Hyperlink.THUMBNAIL_IMAGE, href=thumbnail_url, media_type=media_type ) cover_link.thumbnail = thumbnail links.append(cover_link) alternate_url = value("BookLinkURL").replace("&", "&") - links.append(LinkData(rel='alternate', href=alternate_url)) + links.append(LinkData(rel="alternate", href=alternate_url)) measurements = [] pages = value("NumberOfPages") if pages: pages = int(pages) measurements.append( - MeasurementData(quantity_measured=Measurement.PAGE_COUNT, - value=pages) + MeasurementData(quantity_measured=Measurement.PAGE_COUNT, value=pages) ) circulation, medium = self._make_circulation_data( @@ -838,7 +822,7 @@ def value(bibliotheca_key): primary_identifier=primary_identifier, identifiers=identifiers, subjects=subjects, - contributors=authors+narrators, + contributors=authors + narrators, measurements=measurements, links=links, circulation=circulation, @@ -865,7 +849,7 @@ def intvalue(key): except IndexError: logging.warn( "No information on available copies for %s", - primary_identifier.identifier + primary_identifier.identifier, ) licenses_available = 0 @@ -894,17 +878,16 @@ def internal_formats(cls, book_format): logging.error("Unrecognized BookFormat: %s", book_format) return medium, [] - content_type, drm_scheme = cls.format_data_for_bibliotheca_format[ - book_format - ] + content_type, drm_scheme = cls.format_data_for_bibliotheca_format[book_format] format = FormatData(content_type=content_type, drm_scheme=drm_scheme) - if book_format == 'MP3': + if book_format == "MP3": medium = Edition.AUDIO_MEDIUM else: medium = Edition.BOOK_MEDIUM return medium, [format] + class BibliothecaParser(XMLParser): INPUT_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S" @@ -921,8 +904,7 @@ def parse_date(self, value): value = strptime_utc(value, self.INPUT_TIME_FORMAT) except ValueError as e: logging.error( - 'Unable to parse Bibliotheca date: "%s"', value, - exc_info=e + 'Unable to parse Bibliotheca date: "%s"', value, exc_info=e ) value = None return to_utc(value) @@ -946,55 +928,54 @@ def __init__(self, actual_status, statuses_that_would_work): def __str__(self): return "Book status is %s, must be: %s" % ( - self.actual_status, ", ".join(self.statuses_that_would_work)) + self.actual_status, + ", ".join(self.statuses_that_would_work), + ) + class ErrorParser(BibliothecaParser): """Turn an error document from the Bibliotheca web service into a CheckoutException""" wrong_status = re.compile( - "the patron document status was ([^ ]+) and not one of ([^ ]+)") - - loan_limit_reached = re.compile( - "Patron cannot loan more than [0-9]+ document" + "the patron document status was ([^ ]+) and not one of ([^ ]+)" ) - hold_limit_reached = re.compile( - "Patron cannot have more than [0-9]+ hold" - ) + loan_limit_reached = re.compile("Patron cannot loan more than [0-9]+ document") + + hold_limit_reached = re.compile("Patron cannot have more than [0-9]+ hold") error_mapping = { - "The patron does not have the book on hold" : NotOnHold, - "The patron has no eBooks checked out" : NotCheckedOut, + "The patron does not have the book on hold": NotOnHold, + "The patron has no eBooks checked out": NotCheckedOut, } def process_all(self, string): try: - for i in super(ErrorParser, self).process_all( - string, "//Error"): + for i in super(ErrorParser, self).process_all(string, "//Error"): return i except Exception as e: # The server sent us an error with an incorrect or # nonstandard syntax. - return RemoteInitiatedServerError( - string, BibliothecaAPI.SERVICE_NAME - ) + return RemoteInitiatedServerError(string, BibliothecaAPI.SERVICE_NAME) # We were not able to interpret the result as an error. # The most likely cause is that the Bibliotheca app server is down. return RemoteInitiatedServerError( - "Unknown error", BibliothecaAPI.SERVICE_NAME, + "Unknown error", + BibliothecaAPI.SERVICE_NAME, ) def process_one(self, error_tag, namespaces): message = self.text_of_optional_subtag(error_tag, "Message") if not message: return RemoteInitiatedServerError( - "Unknown error", BibliothecaAPI.SERVICE_NAME, + "Unknown error", + BibliothecaAPI.SERVICE_NAME, ) if message in self.error_mapping: return self.error_mapping[message](message) - if message in ('Authentication failed', 'Unknown error'): + if message in ("Authentication failed", "Unknown error"): # 'Unknown error' is an unknown error on the Bibliotheca side. # # 'Authentication failed' could _in theory_ be an error on @@ -1004,9 +985,7 @@ def process_one(self, error_tag, namespaces): # transient error on the Bibliotheca side. Possibly some # authentication internal to Bibliotheca has failed? Anyway, it # happens relatively frequently. - return RemoteInitiatedServerError( - message, BibliothecaAPI.SERVICE_NAME - ) + return RemoteInitiatedServerError(message, BibliothecaAPI.SERVICE_NAME) m = self.loan_limit_reached.search(message) if m: @@ -1022,32 +1001,33 @@ def process_one(self, error_tag, namespaces): actual, expected = m.groups() expected = expected.split(",") - if actual == 'CAN_WISH': + if actual == "CAN_WISH": return NoLicenses(message) - if 'CAN_LOAN' in expected and actual == 'CAN_HOLD': + if "CAN_LOAN" in expected and actual == "CAN_HOLD": return NoAvailableCopies(message) - if 'CAN_LOAN' in expected and actual == 'HOLD': + if "CAN_LOAN" in expected and actual == "HOLD": return AlreadyOnHold(message) - if 'CAN_LOAN' in expected and actual == 'LOAN': + if "CAN_LOAN" in expected and actual == "LOAN": return AlreadyCheckedOut(message) - if 'CAN_HOLD' in expected and actual == 'CAN_LOAN': + if "CAN_HOLD" in expected and actual == "CAN_LOAN": return CurrentlyAvailable(message) - if 'CAN_HOLD' in expected and actual == 'HOLD': + if "CAN_HOLD" in expected and actual == "HOLD": return AlreadyOnHold(message) - if 'CAN_HOLD' in expected: + if "CAN_HOLD" in expected: return CannotHold(message) - if 'CAN_LOAN' in expected: + if "CAN_LOAN" in expected: return CannotLoan(message) return BibliothecaException(message) + class PatronCirculationParser(BibliothecaParser): """Parse Bibliotheca's patron circulation status document into a list of @@ -1068,12 +1048,11 @@ def process_all(self, string): string = string.decode("utf-8") root = etree.parse(StringIO(string), parser) sup = super(PatronCirculationParser, self) - loans = sup.process_all( - root, "//Checkouts/Item", handler=self.process_one_loan) - holds = sup.process_all( - root, "//Holds/Item", handler=self.process_one_hold) + loans = sup.process_all(root, "//Checkouts/Item", handler=self.process_one_loan) + holds = sup.process_all(root, "//Holds/Item", handler=self.process_one_hold) reserves = sup.process_all( - root, "//Reserves/Item", handler=self.process_one_reserve) + root, "//Reserves/Item", handler=self.process_one_reserve + ) everything = itertools.chain(loans, holds, reserves) return [x for x in everything if x] @@ -1097,15 +1076,19 @@ def process_one(self, tag, namespaces, source_class): def datevalue(key): value = self.text_of_subtag(tag, key) - return strptime_utc( - value, BibliothecaAPI.ARGUMENT_TIME_FORMAT - ) + return strptime_utc(value, BibliothecaAPI.ARGUMENT_TIME_FORMAT) identifier = self.text_of_subtag(tag, "ItemId") start_date = datevalue("EventStartDateInUTC") end_date = datevalue("EventEndDateInUTC") - a = [self.collection, DataSource.BIBLIOTHECA, self.id_type, identifier, - start_date, end_date] + a = [ + self.collection, + DataSource.BIBLIOTHECA, + self.id_type, + identifier, + start_date, + end_date, + ] if source_class is HoldInfo: hold_position = self.int_of_subtag(tag, "Position") a.append(hold_position) @@ -1114,8 +1097,10 @@ def datevalue(key): a.append(None) return source_class(*a) + class DateResponseParser(BibliothecaParser): """Extract a date from a response.""" + RESULT_TAG_NAME = None DATE_TAG_NAME = None @@ -1138,6 +1123,7 @@ def process_all(self, string): class CheckoutResponseParser(DateResponseParser): """Extract due date from a checkout response.""" + RESULT_TAG_NAME = "CheckoutResult" DATE_TAG_NAME = "DueDateInUTC" @@ -1145,6 +1131,7 @@ class CheckoutResponseParser(DateResponseParser): class HoldResponseParser(DateResponseParser): """Extract availability date from a hold response.""" + RESULT_TAG_NAME = "PlaceHoldResult" DATE_TAG_NAME = "AvailabilityDateInUTC" @@ -1159,18 +1146,17 @@ class EventParser(BibliothecaParser): # Map Bibliotheca's event names to our names. EVENT_NAMES = { - "CHECKOUT" : CirculationEvent.DISTRIBUTOR_CHECKOUT, - "CHECKIN" : CirculationEvent.DISTRIBUTOR_CHECKIN, - "HOLD" : CirculationEvent.DISTRIBUTOR_HOLD_PLACE, - "RESERVED" : CirculationEvent.DISTRIBUTOR_AVAILABILITY_NOTIFY, - "PURCHASE" : CirculationEvent.DISTRIBUTOR_LICENSE_ADD, - "REMOVED" : CirculationEvent.DISTRIBUTOR_LICENSE_REMOVE, + "CHECKOUT": CirculationEvent.DISTRIBUTOR_CHECKOUT, + "CHECKIN": CirculationEvent.DISTRIBUTOR_CHECKIN, + "HOLD": CirculationEvent.DISTRIBUTOR_HOLD_PLACE, + "RESERVED": CirculationEvent.DISTRIBUTOR_AVAILABILITY_NOTIFY, + "PURCHASE": CirculationEvent.DISTRIBUTOR_LICENSE_ADD, + "REMOVED": CirculationEvent.DISTRIBUTOR_LICENSE_REMOVE, } def process_all(self, string, no_events_error=False): has_events = False - for i in super(EventParser, self).process_all( - string, "//CloudLibraryEvent"): + for i in super(EventParser, self).process_all(string, "//CloudLibraryEvent"): yield i has_events = True @@ -1186,7 +1172,7 @@ def process_all(self, string, no_events_error=False): # will resume as soon as something happens. raise RemoteInitiatedServerError( "No events returned from server. This may not be an error, but treating it as one to be safe.", - BibliothecaAPI.SERVICE_NAME + BibliothecaAPI.SERVICE_NAME, ) def process_one(self, tag, namespaces): @@ -1195,15 +1181,19 @@ def process_one(self, tag, namespaces): patron_id = self.text_of_optional_subtag(tag, "PatronId") start_time = self.date_from_subtag(tag, "EventStartDateTimeInUTC") - end_time = self.date_from_subtag( - tag, "EventEndDateTimeInUTC", required=False - ) + end_time = self.date_from_subtag(tag, "EventEndDateTimeInUTC", required=False) bibliotheca_event_type = self.text_of_subtag(tag, "EventType") internal_event_type = self.EVENT_NAMES[bibliotheca_event_type] - return (bibliotheca_id, isbn, patron_id, start_time, end_time, - internal_event_type) + return ( + bibliotheca_id, + isbn, + patron_id, + start_time, + end_time, + internal_event_type, + ) class BibliothecaCirculationSweep(IdentifierSweepMonitor): @@ -1222,15 +1212,14 @@ class BibliothecaCirculationSweep(IdentifierSweepMonitor): whatever reason, we'll find out about it here, because Bibliotheca will act like they never heard of it. """ + SERVICE_NAME = "Bibliotheca Circulation Sweep" DEFAULT_BATCH_SIZE = 25 PROTOCOL = ExternalIntegration.BIBLIOTHECA def __init__(self, _db, collection, api_class=BibliothecaAPI, **kwargs): _db = Session.object_session(collection) - super(BibliothecaCirculationSweep, self).__init__( - _db, collection, **kwargs - ) + super(BibliothecaCirculationSweep, self).__init__(_db, collection, **kwargs) if isinstance(api_class, BibliothecaAPI): self.api = api_class else: @@ -1249,7 +1238,8 @@ def process_items(self, identifiers): now = utc_now() for metadata in self.api.bibliographic_lookup(bibliotheca_ids): self._process_metadata( - metadata, identifiers_by_bibliotheca_id, + metadata, + identifiers_by_bibliotheca_id, identifiers_not_mentioned_by_bibliotheca, ) @@ -1258,23 +1248,25 @@ def process_items(self, identifiers): # indication that we no longer own any licenses to the # book. for identifier in identifiers_not_mentioned_by_bibliotheca: - pools = [lp for lp in identifier.licensed_through - if lp.data_source.name==DataSource.BIBLIOTHECA - and lp.collection == self.collection] + pools = [ + lp + for lp in identifier.licensed_through + if lp.data_source.name == DataSource.BIBLIOTHECA + and lp.collection == self.collection + ] if pools: [pool] = pools else: continue if pool.licenses_owned > 0: - self.log.warn( - "Removing %s from circulation.", - identifier.identifier - ) + self.log.warn("Removing %s from circulation.", identifier.identifier) pool.update_availability(0, 0, 0, 0, self.analytics, as_of=now) def _process_metadata( - self, metadata, identifiers_by_bibliotheca_id, - identifiers_not_mentioned_by_bibliotheca + self, + metadata, + identifiers_by_bibliotheca_id, + identifiers_not_mentioned_by_bibliotheca, ): """Process a single Metadata object (containing CirculationData) retrieved from Bibliotheca. @@ -1294,20 +1286,20 @@ def _process_metadata( # identifier?--but now we do. for library in self.collection.libraries: self.analytics.collect_event( - library, pool, CirculationEvent.DISTRIBUTOR_TITLE_ADD, - utc_now() + library, pool, CirculationEvent.DISTRIBUTOR_TITLE_ADD, utc_now() ) - edition, ignore = metadata.apply(edition, collection=self.collection, - replace=self.replacement_policy) + edition, ignore = metadata.apply( + edition, collection=self.collection, replace=self.replacement_policy + ) class BibliothecaTimelineMonitor(CollectionMonitor, TimelineMonitor): """Common superclass for our two TimelineMonitors.""" + PROTOCOL = ExternalIntegration.BIBLIOTHECA - LOG_DATE_FORMAT = '%Y-%m-%dT%H:%M:%S' + LOG_DATE_FORMAT = "%Y-%m-%dT%H:%M:%S" - def __init__(self, _db, collection, api_class=BibliothecaAPI, - analytics=None): + def __init__(self, _db, collection, api_class=BibliothecaAPI, analytics=None): """Initializer. :param _db: Database session object. @@ -1327,9 +1319,7 @@ def __init__(self, _db, collection, api_class=BibliothecaAPI, self.api = api_class else: self.api = api_class(_db, collection) - self.replacement_policy = BibliothecaAPI.replacement_policy( - _db, self.analytics - ) + self.replacement_policy = BibliothecaAPI.replacement_policy(_db, self.analytics) self.bibliographic_coverage_provider = BibliothecaBibliographicCoverageProvider( collection, self.api, replacement_policy=self.replacement_policy ) @@ -1349,12 +1339,18 @@ class BibliothecaPurchaseMonitor(BibliothecaTimelineMonitor): and special capabilities for customizing the start_time to go back even further. """ + SERVICE_NAME = "Bibliotheca Purchase Monitor" DEFAULT_START_TIME = datetime_utc(2014, 1, 1) - def __init__(self, _db, collection, api_class=BibliothecaAPI, - default_start=None, override_timestamp=False, - analytics=None + def __init__( + self, + _db, + collection, + api_class=BibliothecaAPI, + default_start=None, + override_timestamp=False, + analytics=None, ): """Initializer. @@ -1380,8 +1376,7 @@ def __init__(self, _db, collection, api_class=BibliothecaAPI, :type override_timestamp: bool """ super(BibliothecaPurchaseMonitor, self).__init__( - _db=_db, collection=collection, api_class=api_class, - analytics=analytics + _db=_db, collection=collection, api_class=api_class, analytics=analytics ) # We should only force the use of `default_start` as the actual @@ -1389,8 +1384,9 @@ def __init__(self, _db, collection, api_class=BibliothecaAPI, self.override_timestamp = override_timestamp if default_start else False # A specified `default_start` takes precedence over the # monitor's intrinsic default start date/time. - self.default_start_time = (self._optional_iso_date(default_start) or - self._intrinsic_start_time(_db)) + self.default_start_time = self._optional_iso_date( + default_start + ) or self._intrinsic_start_time(_db) def _optional_iso_date(self, date): """Return the date in `datetime` format. @@ -1409,7 +1405,8 @@ def _optional_iso_date(self, date): except ValueError as e: self.log.warn( '%r. Date argument "%s" was not in a valid format. Use an ISO 8601 string or a datetime.', - e, date, + e, + date, ) raise return dt_date @@ -1431,13 +1428,17 @@ def _intrinsic_start_time(self, _db): # or not it exists. default_start_time = self.DEFAULT_START_TIME initialized = get_one( - _db, Timestamp, service=self.service_name, - service_type=Timestamp.MONITOR_TYPE, collection=self.collection + _db, + Timestamp, + service=self.service_name, + service_type=Timestamp.MONITOR_TYPE, + collection=self.collection, ) if not initialized: self.log.info( - "Initializing %s from date: %s.", self.service_name, - default_start_time.strftime(self.LOG_DATE_FORMAT) + "Initializing %s from date: %s.", + self.service_name, + default_start_time.strftime(self.LOG_DATE_FORMAT), ) return default_start_time @@ -1452,8 +1453,10 @@ def timestamp(self): """ timestamp = super(BibliothecaPurchaseMonitor, self).timestamp() if self.override_timestamp: - self.log.info('Overriding timestamp and starting at %s.', - datetime.strftime(self.default_start_time, self.LOG_DATE_FORMAT)) + self.log.info( + "Overriding timestamp and starting at %s.", + datetime.strftime(self.default_start_time, self.LOG_DATE_FORMAT), + ) timestamp.finish = None return timestamp @@ -1492,8 +1495,7 @@ def catch_up_from(self, start, cutoff, progress): # We're playing it safe by using slice_start instead # of slice_end here -- slice_end should be fine. self._checkpoint( - progress, start, slice_start, - achievement_template % num_records + progress, start, slice_start, achievement_template % num_records ) # We're all caught up. The superclass will take care of # finalizing the dates, so there's no need to explicitly @@ -1537,13 +1539,11 @@ def purchases(self, start, end): :yield: A sequence of pymarc Record objects """ - offset = 1 # Smallest allowed offset - page_size = 50 # Maximum supported size. + offset = 1 # Smallest allowed offset + page_size = 50 # Maximum supported size. records = None while records is None or len(records) >= page_size: - records = [ - x for x in self.api.marc_request(start, end, offset, page_size) - ] + records = [x for x in self.api.marc_request(start, end, offset, page_size)] for record in records: yield record offset += page_size @@ -1554,7 +1554,7 @@ def process_record(self, record, purchase_time): :param record: Bibliographic information about the new title. :type record: pymarc.Record - :param purchase_time: Put down this time as the time the + :param purchase_time: Put down this time as the time the purchase happened. :type start_time: datetime.datetime @@ -1563,7 +1563,7 @@ def process_record(self, record, purchase_time): """ # The control number associated with the MARC record is what # we call the Bibliotheca ID. - control_numbers = [x for x in record.fields if x.tag == '001'] + control_numbers = [x for x in record.fields if x.tag == "001"] # These errors should not happen in real usage. error = None if not control_numbers: @@ -1580,8 +1580,11 @@ def process_record(self, record, purchase_time): # Find or lookup a LicensePool from the control number. license_pool, is_new = LicensePool.for_foreign_id( - self._db, self.api.source, Identifier.BIBLIOTHECA_ID, - bibliotheca_id, collection=self.collection + self._db, + self.api.source, + Identifier.BIBLIOTHECA_ID, + bibliotheca_id, + collection=self.collection, ) if is_new: @@ -1607,8 +1610,7 @@ def process_record(self, record, purchase_time): # potentially a long time ago -- since `start_time` is # provided. license_pool.collect_analytics_event( - self.analytics, CirculationEvent.DISTRIBUTOR_TITLE_ADD, - purchase_time, 0, 1 + self.analytics, CirculationEvent.DISTRIBUTOR_TITLE_ADD, purchase_time, 0, 1 ) return license_pool @@ -1649,7 +1651,7 @@ def catch_up_from(self, start, cutoff, progress): self.log.info( "Requesting events between %s and %s", start.strftime(self.LOG_DATE_FORMAT), - cutoff.strftime(self.LOG_DATE_FORMAT) + cutoff.strftime(self.LOG_DATE_FORMAT), ) events_handled = 0 @@ -1666,12 +1668,22 @@ def catch_up_from(self, start, cutoff, progress): self._db.commit() progress.achievements = "Events handled: %d." % events_handled - def handle_event(self, bibliotheca_id, isbn, foreign_patron_id, - start_time, end_time, internal_event_type): + def handle_event( + self, + bibliotheca_id, + isbn, + foreign_patron_id, + start_time, + end_time, + internal_event_type, + ): # Find or lookup the LicensePool for this event. license_pool, is_new = LicensePool.for_foreign_id( - self._db, self.api.source, Identifier.BIBLIOTHECA_ID, - bibliotheca_id, collection=self.collection + self._db, + self.api.source, + Identifier.BIBLIOTHECA_ID, + bibliotheca_id, + collection=self.collection, ) if is_new: @@ -1689,11 +1701,11 @@ def handle_event(self, bibliotheca_id, isbn, foreign_patron_id, ) bibliotheca_identifier = license_pool.identifier - isbn, ignore = Identifier.for_foreign_id( - self._db, Identifier.ISBN, isbn) + isbn, ignore = Identifier.for_foreign_id(self._db, Identifier.ISBN, isbn) edition, ignore = Edition.for_foreign_id( - self._db, self.api.source, Identifier.BIBLIOTHECA_ID, bibliotheca_id) + self._db, self.api.source, Identifier.BIBLIOTHECA_ID, bibliotheca_id + ) # The ISBN and the Bibliotheca identifier are exactly equivalent. bibliotheca_identifier.equivalent_to(self.api.source, isbn, strength=1) @@ -1711,8 +1723,12 @@ def handle_event(self, bibliotheca_id, isbn, foreign_patron_id, ) title = edition.title or "[no title]" - self.log.info("%s %s: %s", start_time.strftime(self.LOG_DATE_FORMAT), - title, internal_event_type) + self.log.info( + "%s %s: %s", + start_time.strftime(self.LOG_DATE_FORMAT), + title, + internal_event_type, + ) return start_time @@ -1726,15 +1742,21 @@ class RunBibliothecaPurchaseMonitorScript(RunCollectionMonitorScript): @classmethod def arg_parser(cls): parser = super(RunBibliothecaPurchaseMonitorScript, cls).arg_parser() - parser.add_argument('--default-start', metavar='DATETIME', - default=None, type=dateutil.parser.isoparse, - help='Default start date/time to be used for uninitialized (no timestamp) monitors.' - ' Use ISO 8601 format (e.g., "yyyy-mm-dd", "yyyy-mm-ddThh:mm:ss").' - ' Do not specify a time zone or offset.', - ) - parser.add_argument('--override-timestamp', action='store_true', - help='Use the specified `--default-start` as the actual' - ' start date, even if a monitor is already initialized.') + parser.add_argument( + "--default-start", + metavar="DATETIME", + default=None, + type=dateutil.parser.isoparse, + help="Default start date/time to be used for uninitialized (no timestamp) monitors." + ' Use ISO 8601 format (e.g., "yyyy-mm-dd", "yyyy-mm-ddThh:mm:ss").' + " Do not specify a time zone or offset.", + ) + parser.add_argument( + "--override-timestamp", + action="store_true", + help="Use the specified `--default-start` as the actual" + " start date, even if a monitor is already initialized.", + ) return parser @classmethod @@ -1757,6 +1779,7 @@ class BibliothecaBibliographicCoverageProvider(BibliographicCoverageProvider): single Collection, but we rely on Monitors to keep availability information up to date for all Collections. """ + SERVICE_NAME = "Bibliotheca Bibliographic Coverage Provider" DATA_SOURCE_NAME = DataSource.BIBLIOTHECA PROTOCOL = ExternalIntegration.BIBLIOTHECA @@ -1791,8 +1814,6 @@ def __init__(self, collection, api_class=BibliothecaAPI, **kwargs): def process_item(self, identifier): metadata = self.api.bibliographic_lookup(identifier) if not metadata: - return self.failure( - identifier, "Bibliotheca bibliographic lookup failed." - ) + return self.failure(identifier, "Bibliotheca bibliographic lookup failed.") [metadata] = metadata return self.set_metadata(identifier, metadata) diff --git a/api/circulation.py b/api/circulation.py index 055e718728..ab6be35d78 100644 --- a/api/circulation.py +++ b/api/circulation.py @@ -7,35 +7,35 @@ import flask from flask_babel import lazy_gettext as _ -from .circulation_exceptions import * -from .config import Configuration from core.cdn import cdnify from core.config import CannotLoadConfiguration from core.mirror import MirrorUploader from core.model import ( - get_one, CirculationEvent, Collection, ConfigurationSetting, DeliveryMechanism, ExternalIntegration, + ExternalIntegrationLink, + Hold, Library, - LicensePoolDeliveryMechanism, LicensePool, + LicensePoolDeliveryMechanism, Loan, - Hold, Patron, RightsStatus, Session, - ExternalIntegrationLink) + get_one, +) from core.util.datetime_helpers import utc_now + +from .circulation_exceptions import * +from .config import Configuration from .util.patron import PatronUtility class CirculationInfo(object): - - def __init__(self, collection, data_source_name, identifier_type, - identifier): + def __init__(self, collection, data_source_name, identifier_type, identifier): """A loan, hold, or whatever. :param collection: The Collection that gives us the right to @@ -68,8 +68,11 @@ def license_pool(self, _db): """Find the LicensePool model object corresponding to this object.""" collection = self.collection(_db) pool, is_new = LicensePool.for_foreign_id( - _db, self.data_source_name, self.identifier_type, self.identifier, - collection=collection + _db, + self.data_source_name, + self.identifier_type, + self.identifier, + collection=collection, ) return pool @@ -95,8 +98,14 @@ class DeliveryMechanismInfo(CirculationInfo): information needs to be stored in a `CirculationData` and applied to the LicensePool separately. """ - def __init__(self, content_type, drm_scheme, - rights_uri=RightsStatus.IN_COPYRIGHT, resource=None): + + def __init__( + self, + content_type, + drm_scheme, + rights_uri=RightsStatus.IN_COPYRIGHT, + resource=None, + ): """Constructor. :param content_type: Once the loan is fulfilled, the resulting document @@ -128,12 +137,13 @@ def apply(self, loan, autocommit=True): # Create or update the DeliveryMechanism. delivery_mechanism, is_new = DeliveryMechanism.lookup( - _db, self.content_type, - self.drm_scheme + _db, self.content_type, self.drm_scheme ) - if (loan.fulfillment - and loan.fulfillment.delivery_mechanism == delivery_mechanism): + if ( + loan.fulfillment + and loan.fulfillment.delivery_mechanism == delivery_mechanism + ): # The work has already been done. Do nothing. return @@ -151,9 +161,13 @@ def apply(self, loan, autocommit=True): # We set autocommit=False because we're probably in the middle # of a nested transaction. lpdm = LicensePoolDeliveryMechanism.set( - pool.data_source, pool.identifier, self.content_type, - self.drm_scheme, self.rights_uri, self.resource, - autocommit=autocommit + pool.data_source, + pool.identifier, + self.content_type, + self.drm_scheme, + self.rights_uri, + self.resource, + autocommit=autocommit, ) loan.fulfillment = lpdm return lpdm @@ -164,9 +178,17 @@ class FulfillmentInfo(CirculationInfo): a loan. """ - def __init__(self, collection, data_source_name, identifier_type, - identifier, content_link, content_type, content, - content_expires): + def __init__( + self, + collection, + data_source_name, + identifier_type, + identifier, + content_link, + content_type, + content, + content_expires, + ): """Constructor. One and only one of `content_link` and `content` should be @@ -207,8 +229,11 @@ def __repr__(self): else: blength = 0 return "" % ( - self.content_link, self.content_type, blength, - self.fd(self.content_expires)) + self.content_link, + self.content_type, + blength, + self.fd(self.content_expires), + ) @property def as_response(self): @@ -230,6 +255,7 @@ class APIAwareFulfillmentInfo(FulfillmentInfo): cost when the patron wants to fulfill this title and is not just looking at their loans. """ + def __init__(self, api, data_source_name, identifier_type, identifier, key): """Constructor. @@ -291,14 +317,21 @@ def content_expires(self): return self._content_expires - class LoanInfo(CirculationInfo): """A record of a loan.""" - def __init__(self, collection, data_source_name, identifier_type, - identifier, start_date, end_date, - fulfillment_info=None, external_identifier=None, - locked_to=None): + def __init__( + self, + collection, + data_source_name, + identifier_type, + identifier, + start_date, + end_date, + fulfillment_info=None, + external_identifier=None, + locked_to=None, + ): """Constructor. :param start_date: A datetime reflecting when the patron borrowed the book. @@ -324,9 +357,11 @@ def __repr__(self): fulfillment = "" f = "%Y/%m/%d" return "%s" % ( - self.identifier_type, self.identifier, - self.fd(self.start_date), self.fd(self.end_date), - fulfillment + self.identifier_type, + self.identifier, + self.fd(self.start_date), + self.fd(self.end_date), + fulfillment, ) @@ -342,9 +377,17 @@ class HoldInfo(CirculationInfo): default to be passed is None, which is equivalent to "first in line". """ - def __init__(self, collection, data_source_name, identifier_type, - identifier, start_date, end_date, hold_position, - external_identifier=None): + def __init__( + self, + collection, + data_source_name, + identifier_type, + identifier, + start_date, + end_date, + hold_position, + external_identifier=None, + ): super(HoldInfo, self).__init__( collection, data_source_name, identifier_type, identifier ) @@ -355,9 +398,11 @@ def __init__(self, collection, data_source_name, identifier_type, def __repr__(self): return "" % ( - self.identifier_type, self.identifier, - self.fd(self.start_date), self.fd(self.end_date), - self.hold_position + self.identifier_type, + self.identifier, + self.fd(self.start_date), + self.fd(self.end_date), + self.hold_position, ) @@ -413,7 +458,8 @@ def __init__(self, _db, library, analytics=None, api_map=None): except CannotLoadConfiguration as exception: self.log.exception( "Error loading configuration for {0}: {1}".format( - collection.name, str(exception)) + collection.name, str(exception) + ) ) self.initialization_exceptions[collection.id] = exception if api: @@ -429,29 +475,30 @@ def default_api_map(self): """When you see a Collection that implements protocol X, instantiate API class Y to handle that collection. """ - from .overdrive import OverdriveAPI - from .odilo import OdiloAPI - from .bibliotheca import BibliothecaAPI + from api.lcp.collection import LCPAPI + from api.proquest.importer import ProQuestOPDS2Importer + from .axis import Axis360API + from .bibliotheca import BibliothecaAPI from .enki import EnkiAPI - from .opds_for_distributors import OPDSForDistributorsAPI + from .odilo import OdiloAPI from .odl import ODLAPI, SharedODLAPI from .odl2 import ODL2API - from api.lcp.collection import LCPAPI - from api.proquest.importer import ProQuestOPDS2Importer + from .opds_for_distributors import OPDSForDistributorsAPI + from .overdrive import OverdriveAPI return { - ExternalIntegration.OVERDRIVE : OverdriveAPI, - ExternalIntegration.ODILO : OdiloAPI, - ExternalIntegration.BIBLIOTHECA : BibliothecaAPI, - ExternalIntegration.AXIS_360 : Axis360API, - EnkiAPI.ENKI_EXTERNAL : EnkiAPI, + ExternalIntegration.OVERDRIVE: OverdriveAPI, + ExternalIntegration.ODILO: OdiloAPI, + ExternalIntegration.BIBLIOTHECA: BibliothecaAPI, + ExternalIntegration.AXIS_360: Axis360API, + EnkiAPI.ENKI_EXTERNAL: EnkiAPI, OPDSForDistributorsAPI.NAME: OPDSForDistributorsAPI, ODLAPI.NAME: ODLAPI, ODL2API.NAME: ODL2API, SharedODLAPI.NAME: SharedODLAPI, LCPAPI.NAME: LCPAPI, - ProQuestOPDS2Importer.NAME: ProQuestOPDS2Importer + ProQuestOPDS2Importer.NAME: ProQuestOPDS2Importer, } def api_for_license_pool(self, licensepool): @@ -483,25 +530,29 @@ def _try_to_sign_fulfillment_link(self, licensepool, fulfillment): :rtype: FulfillmentInfo """ mirror_types = [ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS] - mirror = next(iter([ - MirrorUploader.for_collection(licensepool.collection, mirror_type) - for mirror_type in mirror_types - ])) + mirror = next( + iter( + [ + MirrorUploader.for_collection(licensepool.collection, mirror_type) + for mirror_type in mirror_types + ] + ) + ) if mirror: signed_url = mirror.sign_url(fulfillment.content_link) self.log.info( - 'Fulfilment link {0} has been signed and translated into {1}'.format( - fulfillment.content_link, signed_url) + "Fulfilment link {0} has been signed and translated into {1}".format( + fulfillment.content_link, signed_url + ) ) fulfillment.content_link = signed_url return fulfillment - def _collect_event(self, patron, licensepool, name, - include_neighborhood=False): + def _collect_event(self, patron, licensepool, name, include_neighborhood=False): """Collect an analytics event. :param patron: The Patron associated with the event. If this @@ -523,7 +574,7 @@ def _collect_event(self, patron, licensepool, name, # this event -- this will help us get a library and # potentially a neighborhood. if flask.request: - request_patron = getattr(flask.request, 'patron', None) + request_patron = getattr(flask.request, "patron", None) else: request_patron = None patron = patron or request_patron @@ -541,9 +592,13 @@ def _collect_event(self, patron, licensepool, name, library = self.library neighborhood = None - if (include_neighborhood and flask.request - and request_patron and request_patron == patron): - neighborhood = getattr(request_patron, 'neighborhood', None) + if ( + include_neighborhood + and flask.request + and request_patron + and request_patron == patron + ): + neighborhood = getattr(request_patron, "neighborhood", None) return self.analytics.collect_event( library, licensepool, name, neighborhood=neighborhood ) @@ -555,12 +610,12 @@ def _collect_checkout_event(self, patron, licensepool): licensed books and one when 'loaning' open-access books. """ return self._collect_event( - patron, licensepool, CirculationEvent.CM_CHECKOUT, - include_neighborhood=True + patron, licensepool, CirculationEvent.CM_CHECKOUT, include_neighborhood=True ) - def borrow(self, patron, pin, licensepool, delivery_mechanism, - hold_notification_email=None): + def borrow( + self, patron, pin, licensepool, delivery_mechanism, hold_notification_email=None + ): """Either borrow a book or put it on hold. Don't worry about fulfilling the loan yet. @@ -575,8 +630,11 @@ def borrow(self, patron, pin, licensepool, delivery_mechanism, now = utc_now() api = self.api_for_license_pool(licensepool) - if licensepool.open_access or licensepool.self_hosted or \ - (not api and licensepool.unlimited_access): + if ( + licensepool.open_access + or licensepool.self_hosted + or (not api and licensepool.unlimited_access) + ): # We can 'loan' open-access content ourselves just by # putting a row in the database. __transaction = self._db.begin_nested() @@ -594,7 +652,8 @@ def borrow(self, patron, pin, licensepool, delivery_mechanism, raise NoLicenses() must_set_delivery_mechanism = ( - api.SET_DELIVERY_MECHANISM_AT == BaseCirculationAPI.BORROW_STEP) + api.SET_DELIVERY_MECHANISM_AT == BaseCirculationAPI.BORROW_STEP + ) if must_set_delivery_mechanism and not delivery_mechanism: raise DeliveryMechanismMissing() @@ -605,8 +664,11 @@ def borrow(self, patron, pin, licensepool, delivery_mechanism, # Do we (think we) already have this book out on loan? existing_loan = get_one( - self._db, Loan, patron=patron, license_pool=licensepool, - on_multiple='interchangeable' + self._db, + Loan, + patron=patron, + license_pool=licensepool, + on_multiple="interchangeable", ) loan_info = None @@ -623,8 +685,11 @@ def borrow(self, patron, pin, licensepool, delivery_mechanism, # single API that needs to be synced. self.sync_bookshelf(patron, pin, force=True) existing_loan = get_one( - self._db, Loan, patron=patron, license_pool=licensepool, - on_multiple='interchangeable' + self._db, + Loan, + patron=patron, + license_pool=licensepool, + on_multiple="interchangeable", ) new_loan = False @@ -658,9 +723,7 @@ def borrow(self, patron, pin, licensepool, delivery_mechanism, # available -- someone else may have checked it in since we # last looked. try: - loan_info = api.checkout( - patron, pin, licensepool, internal_format - ) + loan_info = api.checkout(patron, pin, licensepool, internal_format) if isinstance(loan_info, HoldInfo): # If the API couldn't give us a loan, it may have given us @@ -687,16 +750,20 @@ def borrow(self, patron, pin, licensepool, delivery_mechanism, identifier.type, identifier.identifier, start_date=None, - end_date=now + datetime.timedelta(hours=1) + end_date=now + datetime.timedelta(hours=1), ) if existing_loan: - loan_info.external_identifier=existing_loan.external_identifier + loan_info.external_identifier = existing_loan.external_identifier except AlreadyOnHold: # We're trying to check out a book that we already have on hold. hold_info = HoldInfo( - licensepool.collection, licensepool.data_source, - licensepool.identifier.type, licensepool.identifier.identifier, - None, None, None + licensepool.collection, + licensepool.data_source, + licensepool.identifier.type, + licensepool.identifier.identifier, + None, + None, + None, ) except NoAvailableCopies: if existing_loan: @@ -733,15 +800,20 @@ def borrow(self, patron, pin, licensepool, delivery_mechanism, # database. __transaction = self._db.begin_nested() loan, new_loan_record = licensepool.loan_to( - patron, start=loan_info.start_date or now, + patron, + start=loan_info.start_date or now, end=loan_info.end_date, - external_identifier=loan_info.external_identifier) + external_identifier=loan_info.external_identifier, + ) if must_set_delivery_mechanism: loan.fulfillment = delivery_mechanism existing_hold = get_one( - self._db, Hold, patron=patron, license_pool=licensepool, - on_multiple='interchangeable' + self._db, + Hold, + patron=patron, + license_pool=licensepool, + on_multiple="interchangeable", ) if existing_hold: # The book was on hold, and now we have a loan. @@ -764,14 +836,17 @@ def borrow(self, patron, pin, licensepool, delivery_mechanism, if not hold_info: try: hold_info = api.place_hold( - patron, pin, licensepool, - hold_notification_email + patron, pin, licensepool, hold_notification_email ) except AlreadyOnHold as e: hold_info = HoldInfo( - licensepool.collection, licensepool.data_source, - licensepool.identifier.type, licensepool.identifier.identifier, - None, None, None + licensepool.collection, + licensepool.data_source, + licensepool.identifier.type, + licensepool.identifier.identifier, + None, + None, + None, ) except CurrentlyAvailable: if loan_exception: @@ -804,9 +879,7 @@ def borrow(self, patron, pin, licensepool, delivery_mechanism, # Send out an analytics event to record the fact that # a hold was initiated through the circulation # manager. - self._collect_event( - patron, licensepool, CirculationEvent.CM_HOLD_PLACE - ) + self._collect_event(patron, licensepool, CirculationEvent.CM_HOLD_PLACE) if existing_loan: self._db.delete(existing_loan) @@ -848,7 +921,7 @@ def enforce_limits(self, patron, pool): currently_available = pool.licenses_available > 0 if currently_available and at_loan_limit: - raise PatronLoanLimitReached(library=patron.library) + raise PatronLoanLimitReached(library=patron.library) if not currently_available and at_hold_limit: raise PatronHoldLimitReached(library=patron.library) @@ -867,7 +940,8 @@ def patron_at_loan_limit(self, patron): # Open-access loans, and loans of indefinite duration, don't count towards the loan limit # because they don't block anyone else. non_open_access_loans_with_end_date = [ - loan for loan in patron.loans + loan + for loan in patron.loans if loan.license_pool and loan.license_pool.open_access == False and loan.end ] return loan_limit and len(non_open_access_loans_with_end_date) >= loan_limit @@ -909,7 +983,16 @@ def can_fulfill_without_loan(self, patron, pool, lpdm): return False return api.can_fulfill_without_loan(patron, pool, lpdm) - def fulfill(self, patron, pin, licensepool, delivery_mechanism, part=None, fulfill_part_url=None, sync_on_failure=True): + def fulfill( + self, + patron, + pin, + licensepool, + delivery_mechanism, + part=None, + fulfill_part_url=None, + sync_on_failure=True, + ): """Fulfil a book that a patron has previously checked out. :param delivery_mechanism: A LicensePoolDeliveryMechanism @@ -931,8 +1014,11 @@ def fulfill(self, patron, pin, licensepool, delivery_mechanism, part=None, fulfi """ fulfillment = None loan = get_one( - self._db, Loan, patron=patron, license_pool=licensepool, - on_multiple='interchangeable' + self._db, + Loan, + patron=patron, + license_pool=licensepool, + on_multiple="interchangeable", ) if not loan and not self.can_fulfill_without_loan( patron, licensepool, delivery_mechanism @@ -943,33 +1029,48 @@ def fulfill(self, patron, pin, licensepool, delivery_mechanism, part=None, fulfi # that needs to be synced. self.sync_bookshelf(patron, pin, force=True) return self.fulfill( - patron, pin, licensepool=licensepool, + patron, + pin, + licensepool=licensepool, delivery_mechanism=delivery_mechanism, - part=part, fulfill_part_url=fulfill_part_url, - sync_on_failure=False + part=part, + fulfill_part_url=fulfill_part_url, + sync_on_failure=False, ) else: raise NoActiveLoan(_("Cannot find your active loan for this work.")) - if loan and loan.fulfillment is not None and not loan.fulfillment.compatible_with(delivery_mechanism): + if ( + loan + and loan.fulfillment is not None + and not loan.fulfillment.compatible_with(delivery_mechanism) + ): raise DeliveryMechanismConflict( - _("You already fulfilled this loan as %(loan_delivery_mechanism)s, you can't also do it as %(requested_delivery_mechanism)s", - loan_delivery_mechanism=loan.fulfillment.delivery_mechanism.name, - requested_delivery_mechanism=delivery_mechanism.delivery_mechanism.name) + _( + "You already fulfilled this loan as %(loan_delivery_mechanism)s, you can't also do it as %(requested_delivery_mechanism)s", + loan_delivery_mechanism=loan.fulfillment.delivery_mechanism.name, + requested_delivery_mechanism=delivery_mechanism.delivery_mechanism.name, + ) ) api = self.api_for_license_pool(licensepool) - if licensepool.open_access or licensepool.self_hosted or \ - (not api and licensepool.unlimited_access): + if ( + licensepool.open_access + or licensepool.self_hosted + or (not api and licensepool.unlimited_access) + ): # We ignore the vendor-specific arguments when doing # open-access fulfillment, because we just don't support # partial fulfillment of open-access content. fulfillment = self.fulfill_open_access( - licensepool, delivery_mechanism.delivery_mechanism, + licensepool, + delivery_mechanism.delivery_mechanism, ) if licensepool.self_hosted: - fulfillment = self._try_to_sign_fulfillment_link(licensepool, fulfillment) + fulfillment = self._try_to_sign_fulfillment_link( + licensepool, fulfillment + ) else: if not api: raise CannotFulfill() @@ -981,25 +1082,30 @@ def fulfill(self, patron, pin, licensepool, delivery_mechanism, part=None, fulfi # impact on implementation signatures. Most vendor APIs # will ignore one or more of these arguments. fulfillment = api.fulfill( - patron, pin, licensepool, internal_format=internal_format, - part=part, fulfill_part_url=fulfill_part_url + patron, + pin, + licensepool, + internal_format=internal_format, + part=part, + fulfill_part_url=fulfill_part_url, ) - if not fulfillment or not ( - fulfillment.content_link or fulfillment.content - ): + if not fulfillment or not (fulfillment.content_link or fulfillment.content): raise NoAcceptableFormat() # Send out an analytics event to record the fact that # a fulfillment was initiated through the circulation # manager. self._collect_event( - patron, licensepool, CirculationEvent.CM_FULFILL, - include_neighborhood=True + patron, licensepool, CirculationEvent.CM_FULFILL, include_neighborhood=True ) # Make sure the delivery mechanism we just used is associated # with the loan, if any. - if loan and loan.fulfillment is None and not delivery_mechanism.delivery_mechanism.is_streaming: + if ( + loan + and loan.fulfillment is None + and not delivery_mechanism.delivery_mechanism.is_streaming + ): __transaction = self._db.begin_nested() loan.fulfillment = delivery_mechanism __transaction.commit() @@ -1014,12 +1120,17 @@ def fulfill_open_access(self, licensepool, delivery_mechanism): :param delivery_mechanism: A DeliveryMechanism. """ if isinstance(delivery_mechanism, LicensePoolDeliveryMechanism): - self.log.warn("LicensePoolDeliveryMechanism passed into fulfill_open_access, should be DeliveryMechanism.") + self.log.warn( + "LicensePoolDeliveryMechanism passed into fulfill_open_access, should be DeliveryMechanism." + ) delivery_mechanism = delivery_mechanism.delivery_mechanism fulfillment = None for lpdm in licensepool.delivery_mechanisms: - if not (lpdm.resource and lpdm.resource.representation - and lpdm.resource.representation.url): + if not ( + lpdm.resource + and lpdm.resource.representation + and lpdm.resource.representation.url + ): # This LicensePoolDeliveryMechanism can't actually # be used for fulfillment. continue @@ -1041,18 +1152,24 @@ def fulfill_open_access(self, licensepool, delivery_mechanism): content_link = cdnify(fulfillment.resource.url) media_type = rep.media_type return FulfillmentInfo( - licensepool.collection, licensepool.data_source, + licensepool.collection, + licensepool.data_source, identifier_type=licensepool.identifier.type, identifier=licensepool.identifier.identifier, - content_link=content_link, content_type=media_type, content=None, + content_link=content_link, + content_type=media_type, + content=None, content_expires=None, ) def revoke_loan(self, patron, pin, licensepool): """Revoke a patron's loan for a book.""" loan = get_one( - self._db, Loan, patron=patron, license_pool=licensepool, - on_multiple='interchangeable' + self._db, + Loan, + patron=patron, + license_pool=licensepool, + on_multiple="interchangeable", ) if loan: api = self.api_for_license_pool(licensepool) @@ -1073,9 +1190,7 @@ def revoke_loan(self, patron, pin, licensepool): # Send out an analytics event to record the fact that # a loan was revoked through the circulation # manager. - self._collect_event( - patron, licensepool, CirculationEvent.CM_CHECKIN - ) + self._collect_event(patron, licensepool, CirculationEvent.CM_CHECKIN) # Any other CannotReturn exception will be propagated upwards # at this point. @@ -1084,8 +1199,11 @@ def revoke_loan(self, patron, pin, licensepool): def release_hold(self, patron, pin, licensepool): """Remove a patron's hold on a book.""" hold = get_one( - self._db, Hold, patron=patron, license_pool=licensepool, - on_multiple='interchangeable' + self._db, + Hold, + patron=patron, + license_pool=licensepool, + on_multiple="interchangeable", ) if not licensepool.open_access and not licensepool.self_hosted: api = self.api_for_license_pool(licensepool) @@ -1107,7 +1225,9 @@ def release_hold(self, patron, pin, licensepool): # a hold was revoked through the circulation # manager. self._collect_event( - patron, licensepool, CirculationEvent.CM_HOLD_RELEASE, + patron, + licensepool, + CirculationEvent.CM_HOLD_RELEASE, ) return True @@ -1122,6 +1242,7 @@ def patron_activity(self, patron, pin): `LoanInfo` objects. """ log = self.log + class PatronActivityThread(Thread): def __init__(self, api, patron, pin): self.api = api @@ -1135,15 +1256,13 @@ def __init__(self, api, patron, pin): def run(self): before = time.time() try: - self.activity = self.api.patron_activity( - self.patron, self.pin) + self.activity = self.api.patron_activity(self.patron, self.pin) except Exception as e: self.exception = e self.trace = sys.exc_info() after = time.time() log.debug( - "Synced %s in %.2f sec", self.api.__class__.__name__, - after-before + "Synced %s in %.2f sec", self.api.__class__.__name__, after - before ) threads = [] @@ -1164,9 +1283,10 @@ def run(self): # picture of the patron's loans. complete = False self.log.error( - "%s errored out: %s", thread.api.__class__.__name__, + "%s errored out: %s", + thread.api.__class__.__name__, thread.exception, - exc_info=thread.trace + exc_info=thread.trace, ) if thread.activity: for i in thread.activity: @@ -1178,26 +1298,28 @@ def run(self): else: self.log.warn( "value %r from patron_activity is neither a loan nor a hold.", - i + i, ) if l is not None: l.append(i) after = time.time() - self.log.debug("Full sync took %.2f sec", after-before) + self.log.debug("Full sync took %.2f sec", after - before) return loans, holds, complete def local_loans(self, patron): - return self._db.query(Loan).join(Loan.license_pool).filter( - LicensePool.collection_id.in_(self.collection_ids_for_sync) - ).filter( - Loan.patron==patron + return ( + self._db.query(Loan) + .join(Loan.license_pool) + .filter(LicensePool.collection_id.in_(self.collection_ids_for_sync)) + .filter(Loan.patron == patron) ) def local_holds(self, patron): - return self._db.query(Hold).join(Hold.license_pool).filter( - LicensePool.collection_id.in_(self.collection_ids_for_sync) - ).filter( - Hold.patron==patron + return ( + self._db.query(Hold) + .join(Hold.license_pool) + .filter(LicensePool.collection_id.in_(self.collection_ids_for_sync)) + .filter(Hold.patron == patron) ) def sync_bookshelf(self, patron, pin, force=False): @@ -1247,7 +1369,7 @@ def sync_bookshelf(self, patron, pin, force=False): if not i: self.log.error( "Active loan on license pool %s, which has no identifier!", - l.license_pool + l.license_pool, ) continue key = (i.type, i.identifier) @@ -1260,7 +1382,7 @@ def sync_bookshelf(self, patron, pin, force=False): if not i: self.log.error( "Active hold on license pool %r, which has no identifier!", - h.license_pool + h.license_pool, ) continue key = (i.type, i.identifier) @@ -1344,10 +1466,20 @@ def sync_bookshelf(self, patron, pin, force=False): if loan.license_pool.collection_id in self.collection_ids_for_sync: one_minute_ago = utc_now() - datetime.timedelta(minutes=1) if loan.start < one_minute_ago: - logging.info("In sync_bookshelf for patron %s, deleting loan %d (patron %s)" % (patron.authorization_identifier, loan.id, loan.patron.authorization_identifier)) + logging.info( + "In sync_bookshelf for patron %s, deleting loan %d (patron %s)" + % ( + patron.authorization_identifier, + loan.id, + loan.patron.authorization_identifier, + ) + ) self._db.delete(loan) else: - logging.info("In sync_bookshelf for patron %s, found local loan %d created in the past minute that wasn't in remote loans" % (patron.authorization_identifier, loan.id)) + logging.info( + "In sync_bookshelf for patron %s, found local loan %d created in the past minute that wasn't in remote loans" + % (patron.authorization_identifier, loan.id) + ) # Every hold remaining in holds_by_identifier is a hold that # the provider doesn't know about, which means it's expired @@ -1372,22 +1504,26 @@ class BaseCirculationAPI(object): # distributor which includes ebooks and allows clients to specify # their own loan lengths. EBOOK_LOAN_DURATION_SETTING = { - "key" : Collection.EBOOK_LOAN_DURATION_KEY, + "key": Collection.EBOOK_LOAN_DURATION_KEY, "label": _("Ebook Loan Duration (in Days)"), "default": Collection.STANDARD_DEFAULT_LOAN_PERIOD, "type": "number", - "description": _("When a patron uses SimplyE to borrow an ebook from this collection, SimplyE will ask for a loan that lasts this number of days. This must be equal to or less than the maximum loan duration negotiated with the distributor.") + "description": _( + "When a patron uses SimplyE to borrow an ebook from this collection, SimplyE will ask for a loan that lasts this number of days. This must be equal to or less than the maximum loan duration negotiated with the distributor." + ), } # Add to LIBRARY_SETTINGS if your circulation API is for a # distributor which includes audiobooks and allows clients to # specify their own loan lengths. AUDIOBOOK_LOAN_DURATION_SETTING = { - "key" : Collection.AUDIOBOOK_LOAN_DURATION_KEY, + "key": Collection.AUDIOBOOK_LOAN_DURATION_KEY, "label": _("Audiobook Loan Duration (in Days)"), "default": Collection.STANDARD_DEFAULT_LOAN_PERIOD, "type": "number", - "description": _("When a patron uses SimplyE to borrow an audiobook from this collection, SimplyE will ask for a loan that lasts this number of days. This must be equal to or less than the maximum loan duration negotiated with the distributor.") + "description": _( + "When a patron uses SimplyE to borrow an audiobook from this collection, SimplyE will ask for a loan that lasts this number of days. This must be equal to or less than the maximum loan duration negotiated with the distributor." + ), } # Add to LIBRARY_SETTINGS if your circulation API is for a @@ -1399,7 +1535,9 @@ class BaseCirculationAPI(object): "label": _("Default Loan Period (in Days)"), "default": Collection.STANDARD_DEFAULT_LOAN_PERIOD, "type": "number", - "description": _("Until it hears otherwise from the distributor, this server will assume that any given loan for this library from this collection will last this number of days. This number is usually a negotiated value between the library and the distributor. This only affects estimates—it cannot affect the actual length of loans.") + "description": _( + "Until it hears otherwise from the distributor, this server will assume that any given loan for this library from this collection will last this number of days. This number is usually a negotiated value between the library and the distributor. This only affects estimates—it cannot affect the actual length of loans." + ), } # These collection-specific settings should be inherited by all @@ -1410,8 +1548,8 @@ class BaseCirculationAPI(object): # inherited by all distributors. LIBRARY_SETTINGS = [] - BORROW_STEP = 'borrow' - FULFILL_STEP = 'fulfill' + BORROW_STEP = "borrow" + FULFILL_STEP = "fulfill" # In 3M only, when a book is in the 'reserved' state the patron # cannot revoke their hold on the book. @@ -1446,7 +1584,10 @@ def internal_format(self, delivery_mechanism): internal_format = self.delivery_mechanism_to_internal_format.get(key) if not internal_format: raise DeliveryMechanismError( - _("Could not map Simplified delivery mechanism %(mechanism_name)s to internal delivery mechanism!", mechanism_name=d.name) + _( + "Could not map Simplified delivery mechanism %(mechanism_name)s to internal delivery mechanism!", + mechanism_name=d.name, + ) ) return internal_format @@ -1460,14 +1601,14 @@ def default_notification_email_address(self, library_or_patron, pin): if isinstance(library_or_patron, Patron): library_or_patron = library_or_patron.library return ConfigurationSetting.for_library( - Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, - library_or_patron + Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, library_or_patron ).value @classmethod def _library_authenticator(self, library): """Create a LibraryAuthenticator for the given library.""" from .authenticator import LibraryAuthenticator + _db = Session.object_session(library) return LibraryAuthenticator.from_config(_db, library) @@ -1505,7 +1646,7 @@ def patron_email_address(self, patron, library_authenticator=None): return email_address def checkin(self, patron, pin, licensepool): - """ Return a book early. + """Return a book early. :param patron: a Patron object for the patron who wants to check out the book. :param pin: The patron's alleged password. @@ -1513,7 +1654,6 @@ def checkin(self, patron, pin, licensepool): """ pass - def checkout(self, patron, pin, licensepool, internal_format): """Check out a book on behalf of a patron. @@ -1530,8 +1670,15 @@ def can_fulfill_without_loan(self, patron, pool, lpdm): """In general, you can't fulfill a book without a loan.""" return False - def fulfill(self, patron, pin, licensepool, internal_format=None, - part=None, fulfill_part_url=None): + def fulfill( + self, + patron, + pin, + licensepool, + internal_format=None, + part=None, + fulfill_part_url=None, + ): """Get the actual resource file to the patron. Implementations are encouraged to define ``**kwargs`` as a container @@ -1553,13 +1700,10 @@ def fulfill(self, patron, pin, licensepool, internal_format=None, """ raise NotImplementedError() - def patron_activity(self, patron, pin): - """ Return a patron's current checkouts and holds. - """ + """Return a patron's current checkouts and holds.""" raise NotImplementedError() - def place_hold(self, patron, pin, licensepool, notification_email_address): """Place a book on hold. @@ -1567,7 +1711,6 @@ def place_hold(self, patron, pin, licensepool, notification_email_address): """ raise NotImplementedError() - def release_hold(self, patron, pin, licensepool): """Release a patron's hold on a book. @@ -1578,6 +1721,5 @@ def release_hold(self, patron, pin, licensepool): raise NotImplementedError() def update_availability(self, licensepool): - """Update availability information for a book. - """ + """Update availability information for a book.""" pass diff --git a/api/circulation_exceptions.py b/api/circulation_exceptions.py index 5ebdec5aaf..00cc70c6c5 100644 --- a/api/circulation_exceptions.py +++ b/api/circulation_exceptions.py @@ -2,10 +2,8 @@ from api.config import Configuration from core.config import IntegrationException -from core.problem_details import ( - INTEGRATION_ERROR, - INTERNAL_SERVER_ERROR, -) +from core.problem_details import INTEGRATION_ERROR, INTERNAL_SERVER_ERROR + from .problem_details import * @@ -14,6 +12,7 @@ class CirculationException(IntegrationException): `status_code` is the status code that should be returned to the patron. """ + status_code = 400 def __init__(self, message=None, debug_info=None): @@ -28,8 +27,10 @@ def as_problem_detail_document(self, debug=False): """Return a suitable problem detail document.""" return INTERNAL_SERVER_ERROR + class RemoteInitiatedServerError(InternalServerError): """One of the servers we communicate with had an internal error.""" + status_code = 502 def __init__(self, message, service_name): @@ -38,73 +39,97 @@ def __init__(self, message, service_name): def as_problem_detail_document(self, debug=False): """Return a suitable problem detail document.""" - msg = _("Integration error communicating with %(service_name)s", service_name=self.service_name) + msg = _( + "Integration error communicating with %(service_name)s", + service_name=self.service_name, + ) return INTEGRATION_ERROR.detailed(msg) + class NoOpenAccessDownload(CirculationException): """We expected a book to have an open-access download, but it didn't.""" + status_code = 500 + class AuthorizationFailedException(CirculationException): status_code = 401 + class PatronAuthorizationFailedException(AuthorizationFailedException): status_code = 400 + class RemotePatronCreationFailedException(CirculationException): status_code = 500 + class LibraryAuthorizationFailedException(CirculationException): status_code = 500 + class InvalidInputException(CirculationException): """The patron gave invalid input to the library.""" + status_code = 400 + class LibraryInvalidInputException(InvalidInputException): """The library gave invalid input to the book provider.""" + status_code = 500 + class DeliveryMechanismError(InvalidInputException): status_code = 400 """The patron broke the rules about delivery mechanisms.""" + class DeliveryMechanismMissing(DeliveryMechanismError): """The patron needed to specify a delivery mechanism and didn't.""" + class DeliveryMechanismConflict(DeliveryMechanismError): """The patron specified a delivery mechanism that conflicted with one already set in stone. """ + class CannotLoan(CirculationException): status_code = 500 + class OutstandingFines(CannotLoan): """The patron has outstanding fines above the limit in the library's policy.""" + status_code = 403 + class AuthorizationExpired(CannotLoan): """The patron's authorization has expired.""" + status_code = 403 def as_problem_detail_document(self, debug=False): """Return a suitable problem detail document.""" return EXPIRED_CREDENTIALS + class AuthorizationBlocked(CannotLoan): """The patron's authorization is blocked for some reason other than fines or an expired card. For instance, the patron has been banned from the library. """ + status_code = 403 def as_problem_detail_document(self, debug=False): """Return a suitable problem detail document.""" return BLOCKED_CREDENTIALS + class LimitReached(CirculationException): """The patron cannot carry out an operation because it would push them above some limit set by library policy. @@ -117,6 +142,7 @@ class LimitReached(CirculationException): * `MESSAGE_WITH_LIMIT` A string containing the interpolation value "%(limit)s", which offers a more specific explanation of the limit exceeded. """ + status_code = 403 BASE_DOC = None SETTING_NAME = None @@ -137,40 +163,52 @@ def as_problem_detail_document(self, debug=False): detail = self.MESSAGE_WITH_LIMIT % dict(limit=self.limit) return doc.detailed(detail=detail) + class PatronLoanLimitReached(CannotLoan, LimitReached): BASE_DOC = LOAN_LIMIT_REACHED MESSAGE_WITH_LIMIT = SPECIFIC_LOAN_LIMIT_MESSAGE SETTING_NAME = Configuration.LOAN_LIMIT + class CannotReturn(CirculationException): status_code = 500 + class CannotHold(CirculationException): status_code = 500 + class PatronHoldLimitReached(CannotHold, LimitReached): BASE_DOC = HOLD_LIMIT_REACHED MESSAGE_WITH_LIMIT = SPECIFIC_HOLD_LIMIT_MESSAGE SETTING_NAME = Configuration.HOLD_LIMIT + class CannotReleaseHold(CirculationException): status_code = 500 + class CannotFulfill(CirculationException): status_code = 500 + class CannotPartiallyFulfill(CannotFulfill): status_code = 400 + class FormatNotAvailable(CannotFulfill): """Our format information for this book was outdated, and it's no longer available in the requested format.""" + status_code = 502 + class NotFoundOnRemote(CirculationException): """We know about this book but the remote site doesn't seem to.""" + status_code = 404 + class NoLicenses(NotFoundOnRemote): """The library no longer has licenses for this book.""" @@ -178,72 +216,92 @@ def as_problem_detail_document(self, debug=False): """Return a suitable problem detail document.""" return NO_LICENSES + class CannotRenew(CirculationException): """The patron can't renew their loan on this book. Probably because it's not available for renewal. """ + status_code = 400 + class NoAvailableCopies(CannotLoan): """The patron can't check this book out because all available copies are already checked out. """ + status_code = 400 + class AlreadyCheckedOut(CannotLoan): """The patron can't put check this book out because they already have it checked out. """ + status_code = 400 + class AlreadyOnHold(CannotHold): """The patron can't put this book on hold because they already have it on hold. """ + status_code = 400 + class NotCheckedOut(CannotReturn): """The patron can't return this book because they don't have it checked out in the first place. """ + status_code = 400 + class RemoteRefusedReturn(CannotReturn): - """The remote refused to count this book as returned. - """ + """The remote refused to count this book as returned.""" + status_code = 500 + class NotOnHold(CannotReleaseHold): """The patron can't release a hold for this book because they don't have it on hold in the first place. """ + status_code = 400 + class CurrentlyAvailable(CannotHold): """The patron can't put this book on hold because it's available now.""" + status_code = 400 + class NoAcceptableFormat(CannotFulfill): """We can't fulfill the patron's loan because the book is not available in an acceptable format. """ + status_code = 400 + class FulfilledOnIncompatiblePlatform(CannotFulfill): """We can't fulfill the patron's loan because the loan was already fulfilled on an incompatible platform (i.e. Kindle) in a way that's exclusive to that platform. """ + status_code = 451 + class NoActiveLoan(CannotFulfill): """We can't fulfill the patron's loan because they don't have an active loan. """ + status_code = 400 + class PatronNotFoundOnRemote(NotFoundOnRemote): status_code = 404 - - diff --git a/api/clever/__init__.py b/api/clever/__init__.py index 167666820b..bd56d8faff 100644 --- a/api/clever/__init__.py +++ b/api/clever/__init__.py @@ -3,23 +3,20 @@ from flask_babel import lazy_gettext as lgt -from api.authenticator import ( - OAuthAuthenticationProvider, - OAuthController, - PatronData, -) +from api.authenticator import OAuthAuthenticationProvider, OAuthController, PatronData +from api.problem_details import INVALID_CREDENTIALS from core.model import ExternalIntegration from core.util.http import HTTP from core.util.problem_detail import ProblemDetail from core.util.string_helpers import base64 -from api.problem_details import INVALID_CREDENTIALS - UNSUPPORTED_CLEVER_USER_TYPE = ProblemDetail( "http://librarysimplified.org/terms/problem/unsupported-clever-user-type", 401, lgt("Your Clever user type is not supported."), - lgt("Your Clever user type is not supported. You can request a code from First Book instead"), + lgt( + "Your Clever user type is not supported. You can request a code from First Book instead" + ), ) CLEVER_NOT_ELIGIBLE = ProblemDetail( @@ -32,20 +29,24 @@ CLEVER_UNKNOWN_SCHOOL = ProblemDetail( "http://librarysimplified.org/terms/problem/clever-unknown-school", 401, - lgt("Clever did not provide the necessary information about your school to verify eligibility."), - lgt("Clever did not provide the necessary information about your school to verify eligibility."), + lgt( + "Clever did not provide the necessary information about your school to verify eligibility." + ), + lgt( + "Clever did not provide the necessary information about your school to verify eligibility." + ), ) # Load Title I NCES ID data from json. TITLE_I_NCES_IDS = None clever_dir = os.path.split(__file__)[0] -with open('%s/title_i.json' % clever_dir) as f: +with open("%s/title_i.json" % clever_dir) as f: json_data = f.read() TITLE_I_NCES_IDS = json.loads(json_data) CLEVER_GRADE_TO_EXTERNAL_TYPE_MAP = { - "InfantToddler": "E", # Early + "InfantToddler": "E", # Early "Preschool": "E", "PreKindergarten": "E", "TransitionalKindergarten": "E", @@ -53,18 +54,18 @@ "1": "E", "2": "E", "3": "E", - "4": "M", # Middle + "4": "M", # Middle "5": "M", "6": "M", "7": "M", "8": "M", - "9": "H", # High + "9": "H", # High "10": "H", "11": "H", "12": "H", "13": "H", "PostGraduate": "H", - "Other": None, # Indeterminate + "Other": None, # Indeterminate "Ungraded": None, } @@ -78,17 +79,27 @@ class CleverAuthenticationAPI(OAuthAuthenticationProvider): URI = "http://librarysimplified.org/terms/auth/clever" - NAME = 'Clever' + NAME = "Clever" - DESCRIPTION = lgt(""" + DESCRIPTION = lgt( + """ An authentication service for Open eBooks that uses Clever as an - OAuth provider.""") + OAuth provider.""" + ) LOGIN_BUTTON_IMAGE = "CleverLoginButton280.png" SETTINGS = [ - {"key": ExternalIntegration.USERNAME, "label": lgt("Client ID"), "required": True}, - {"key": ExternalIntegration.PASSWORD, "label": lgt("Client Secret"), "required": True}, + { + "key": ExternalIntegration.USERNAME, + "label": lgt("Client ID"), + "required": True, + }, + { + "key": ExternalIntegration.PASSWORD, + "label": lgt("Client Secret"), + "required": True, + }, ] + OAuthAuthenticationProvider.SETTINGS # Unlike other authentication providers, external type regular expression @@ -97,7 +108,7 @@ class CleverAuthenticationAPI(OAuthAuthenticationProvider): LIBRARY_SETTINGS = [] TOKEN_TYPE = "Clever token" - TOKEN_DATA_SOURCE_NAME = 'Clever' + TOKEN_DATA_SOURCE_NAME = "Clever" EXTERNAL_AUTHENTICATE_URL = ( "https://clever.com/oauth/authorize" @@ -114,7 +125,7 @@ class CleverAuthenticationAPI(OAuthAuthenticationProvider): # To check Title I status we need state, which is associated with # a school in Clever's API. Any users at the district-level will # need to get a code from First Book instead. - SUPPORTED_USER_TYPES = ['student', 'teacher'] + SUPPORTED_USER_TYPES = ["student", "teacher"] # Begin implementations of OAuthAuthenticationProvider abstract # methods. @@ -176,15 +187,15 @@ def remote_exchange_code_for_bearer_token(self, _db, code): payload = self._remote_exchange_payload(_db, code) authorization = base64.b64encode(self.client_id + ":" + self.client_secret) headers = { - 'Authorization': 'Basic %s' % authorization, - 'Content-Type': 'application/json', + "Authorization": "Basic %s" % authorization, + "Content-Type": "application/json", } response = self._get_token(payload, headers) invalid = INVALID_CREDENTIALS.detailed(lgt("A valid Clever login is required.")) if not response: return invalid - token = response.get('access_token', None) + token = response.get("access_token", None) if not token: return invalid @@ -195,10 +206,10 @@ def _remote_exchange_payload(self, _db, code): library = self.library(_db) return dict( code=code, - grant_type='authorization_code', + grant_type="authorization_code", redirect_uri=OAuthController.oauth_authentication_callback_url( library.short_name - ) + ), ) def remote_patron_lookup(self, token): @@ -250,28 +261,32 @@ def remote_patron_lookup(self, token): with the data listed above. """ - bearer_headers = {'Authorization': 'Bearer %s' % token} - result = self._get(self.CLEVER_API_BASE_URL + '/me', bearer_headers) - data = result.get('data', {}) or {} + bearer_headers = {"Authorization": "Bearer %s" % token} + result = self._get(self.CLEVER_API_BASE_URL + "/me", bearer_headers) + data = result.get("data", {}) or {} - identifier = data.get('id', None) + identifier = data.get("id", None) if not identifier: - return INVALID_CREDENTIALS.detailed(lgt("A valid Clever login is required.")) + return INVALID_CREDENTIALS.detailed( + lgt("A valid Clever login is required.") + ) - if result.get('type') not in self.SUPPORTED_USER_TYPES: + if result.get("type") not in self.SUPPORTED_USER_TYPES: return UNSUPPORTED_CLEVER_USER_TYPE - links = result['links'] + links = result["links"] - user_link = [link for link in links if link['rel'] == 'canonical'][0]['uri'] + user_link = [link for link in links if link["rel"] == "canonical"][0]["uri"] # The canonical link includes the API version, so we use the base URL. user = self._get(self.CLEVER_API_BASE_URL + user_link, bearer_headers) - user_data = user['data'] - school_id = user_data['school'] - school = self._get(f"{self.CLEVER_API_BASE_URL}/v1.1/schools/{school_id}", bearer_headers) - school_nces_id = school['data'].get('nces_id') + user_data = user["data"] + school_id = user_data["school"] + school = self._get( + f"{self.CLEVER_API_BASE_URL}/v1.1/schools/{school_id}", bearer_headers + ) + school_nces_id = school["data"].get("nces_id") # TODO: check student free and reduced lunch status as well @@ -285,28 +300,30 @@ def remote_patron_lookup(self, token): external_type = None - if result['type'] == 'student': + if result["type"] == "student": # We need to be able to assign an external_type to students, so that they # get the correct content level. To do so we rely on the grade field in the # user data we get back from Clever. Their API doesn't guarantee that the # grade field is present, so we supply a default. - student_grade = user_data.get('grade', None) + student_grade = user_data.get("grade", None) - if not student_grade: # If no grade was supplied, log the school/student - msg = (f"CLEVER_UNKNOWN_PATRON_GRADE: School with NCES ID {school_nces_id} " - f"did not supply grade for student {user_data.get('id')}") + if not student_grade: # If no grade was supplied, log the school/student + msg = ( + f"CLEVER_UNKNOWN_PATRON_GRADE: School with NCES ID {school_nces_id} " + f"did not supply grade for student {user_data.get('id')}" + ) self.log.info(msg) # If we can't determine a type from the grade level, set to "A" external_type = external_type_from_clever_grade(student_grade) else: - external_type = "A" # Non-students get content level "A" + external_type = "A" # Non-students get content level "A" patrondata = PatronData( permanent_id=identifier, authorization_identifier=identifier, external_type=external_type, - complete=True + complete=True, ) return patrondata diff --git a/api/config.py b/api/config.py index 7a0f96fcef..d848d40e13 100644 --- a/api/config.py +++ b/api/config.py @@ -1,25 +1,23 @@ +import contextlib import json import re -import contextlib from copy import deepcopy -from Crypto.PublicKey import RSA from Crypto.Cipher import PKCS1_OAEP - +from Crypto.PublicKey import RSA from flask_babel import lazy_gettext as _ -from .announcements import Announcements - -from core.config import ( - Configuration as CoreConfiguration, - CannotLoadConfiguration, - IntegrationException, - empty_config as core_empty_config, - temp_config as core_temp_config, -) -from core.util import MoneyUtility +from core.config import CannotLoadConfiguration +from core.config import Configuration as CoreConfiguration +from core.config import IntegrationException +from core.config import empty_config as core_empty_config +from core.config import temp_config as core_temp_config from core.lane import Facets from core.model import ConfigurationSetting +from core.util import MoneyUtility + +from .announcements import Announcements + class Configuration(CoreConfiguration): @@ -56,11 +54,13 @@ class Configuration(CoreConfiguration): # Custom text for the link defined in CUSTOM_TOS_LINK. CUSTOM_TOS_TEXT = "tos_text" - DEFAULT_TOS_TEXT = "Terms of Service for presenting content through the Palace client applications" + DEFAULT_TOS_TEXT = ( + "Terms of Service for presenting content through the Palace client applications" + ) # A short description of the library, used in its Authentication # for OPDS document. - LIBRARY_DESCRIPTION = 'library_description' + LIBRARY_DESCRIPTION = "library_description" # The name of the per-library setting that sets the maximum amount # of fines a patron can have before losing lending privileges. @@ -85,7 +85,9 @@ class Configuration(CoreConfiguration): COPYRIGHT_DESIGNATED_AGENT_EMAIL = "copyright_designated_agent_email_address" # This is the link relation used to indicate - COPYRIGHT_DESIGNATED_AGENT_REL = "http://librarysimplified.org/rel/designated-agent/copyright" + COPYRIGHT_DESIGNATED_AGENT_REL = ( + "http://librarysimplified.org/rel/designated-agent/copyright" + ) # The name of the per-library setting that sets the contact address # for problems with the library configuration itself. @@ -101,7 +103,9 @@ class Configuration(CoreConfiguration): SMALL_COLLECTION_LANGUAGES = "small_collections" TINY_COLLECTION_LANGUAGES = "tiny_collections" - LANGUAGE_DESCRIPTION = _('Each value can be either the full name of a language or an ISO-639-2 language code.') + LANGUAGE_DESCRIPTION = _( + 'Each value can be either the full name of a language or an ISO-639-2 language code.' + ) HIDDEN_CONTENT_TYPES = "hidden_content_types" @@ -144,12 +148,12 @@ class Configuration(CoreConfiguration): ) # Names of the library-wide link settings. - TERMS_OF_SERVICE = 'terms-of-service' - PRIVACY_POLICY = 'privacy-policy' - COPYRIGHT = 'copyright' - ABOUT = 'about' - LICENSE = 'license' - REGISTER = 'register' + TERMS_OF_SERVICE = "terms-of-service" + PRIVACY_POLICY = "privacy-policy" + COPYRIGHT = "copyright" + ABOUT = "about" + LICENSE = "license" + REGISTER = "register" # A library with this many titles in a given language will be given # a large, detailed lane configuration for that language. @@ -162,13 +166,13 @@ class Configuration(CoreConfiguration): # These are link relations that are valid in Authentication for # OPDS documents but are not registered with IANA. - AUTHENTICATION_FOR_OPDS_LINKS = ['register'] + AUTHENTICATION_FOR_OPDS_LINKS = ["register"] # We support three different ways of integrating help processes. # All three of these will be sent out as links with rel='help' - HELP_EMAIL = 'help-email' - HELP_WEB = 'help-web' - HELP_URI = 'help-uri' + HELP_EMAIL = "help-email" + HELP_WEB = "help-web" + HELP_URI = "help-uri" HELP_LINKS = [HELP_EMAIL, HELP_WEB, HELP_URI] # Features of an OPDS client which a library may want to enable or @@ -196,11 +200,15 @@ class Configuration(CoreConfiguration): "key": PATRON_WEB_HOSTNAMES, "label": _("Hostnames for web application access"), "required": True, - "description": _("Only web applications from these hosts can access this circulation manager. This can be a single hostname (http://catalog.library.org) or a pipe-separated list of hostnames (http://catalog.library.org|https://beta.library.org). You must include the scheme part of the URI (http:// or https://). You can also set this to '*' to allow access from any host, but you must not do this in a production environment -- only during development.") + "description": _( + "Only web applications from these hosts can access this circulation manager. This can be a single hostname (http://catalog.library.org) or a pipe-separated list of hostnames (http://catalog.library.org|https://beta.library.org). You must include the scheme part of the URI (http:// or https://). You can also set this to '*' to allow access from any host, but you must not do this in a production environment -- only during development." + ), }, { "key": STATIC_FILE_CACHE_TIME, - "label": _("Cache time for static images and JS and CSS files (in seconds)"), + "label": _( + "Cache time for static images and JS and CSS files (in seconds)" + ), "required": True, "type": "number", }, @@ -216,15 +224,19 @@ class Configuration(CoreConfiguration): "label": _("Custom Terms of Service link"), "required": False, "default": DEFAULT_TOS_HREF, - "description": _("If your inclusion in the SimplyE mobile app is governed by terms other than the default, put the URL to those terms in this link so that librarians will have access to them. This URL will be used for all libraries on this circulation manager.") + "description": _( + "If your inclusion in the SimplyE mobile app is governed by terms other than the default, put the URL to those terms in this link so that librarians will have access to them. This URL will be used for all libraries on this circulation manager." + ), }, { "key": CUSTOM_TOS_TEXT, "label": _("Custom Terms of Service link text"), "required": False, "default": DEFAULT_TOS_TEXT, - "description": _("Custom text for the Terms of Service link in the footer of these administrative interface pages. This is primarily useful if you're not connecting this circulation manager to the SimplyE mobile app. This text will be used for all libraries on this circulation manager.") - } + "description": _( + "Custom text for the Terms of Service link in the footer of these administrative interface pages. This is primarily useful if you're not connecting this circulation manager to the SimplyE mobile app. This text will be used for all libraries on this circulation manager." + ), + }, ] # The "level" property determines which admins will be able to modify the setting. Level 1 settings can be modified by anyone. @@ -235,25 +247,31 @@ class Configuration(CoreConfiguration): { "key": LIBRARY_DESCRIPTION, "label": _("A short description of this library"), - "description": _("This will be shown to people who aren't sure they've chosen the right library."), + "description": _( + "This will be shown to people who aren't sure they've chosen the right library." + ), "category": "Basic Information", - "level": CoreConfiguration.SYS_ADMIN_ONLY + "level": CoreConfiguration.SYS_ADMIN_ONLY, }, { "key": Announcements.SETTING_NAME, "label": _("Scheduled announcements"), - "description": _("Announcements will be displayed to authenticated patrons."), + "description": _( + "Announcements will be displayed to authenticated patrons." + ), "category": "Announcements", "type": "announcements", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": HELP_EMAIL, "label": _("Patron support email address"), - "description": _("An email address a patron can use if they need help, e.g. 'simplyehelp@yourlibrary.org'."), + "description": _( + "An email address a patron can use if they need help, e.g. 'simplyehelp@yourlibrary.org'." + ), "required": True, "format": "email", - "level": CoreConfiguration.SYS_ADMIN_ONLY + "level": CoreConfiguration.SYS_ADMIN_ONLY, }, { "key": HELP_WEB, @@ -261,44 +279,56 @@ class Configuration(CoreConfiguration): "description": _("A URL for patrons to get help."), "format": "url", "category": "Patron Support", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": HELP_URI, "label": _("Patron support custom integration URI"), - "description": _("A custom help integration like Helpstack, e.g. 'helpstack:nypl.desk.com'."), + "description": _( + "A custom help integration like Helpstack, e.g. 'helpstack:nypl.desk.com'." + ), "category": "Patron Support", - "level": CoreConfiguration.SYS_ADMIN_ONLY + "level": CoreConfiguration.SYS_ADMIN_ONLY, }, { "key": COPYRIGHT_DESIGNATED_AGENT_EMAIL, "label": _("Copyright designated agent email"), - "description": _("Patrons of this library should use this email address to send a DMCA notification (or other copyright complaint) to the library.
If no value is specified here, the general patron support address will be used."), + "description": _( + "Patrons of this library should use this email address to send a DMCA notification (or other copyright complaint) to the library.
If no value is specified here, the general patron support address will be used." + ), "format": "email", "category": "Patron Support", - "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER + "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER, }, { "key": CONFIGURATION_CONTACT_EMAIL, - "label": _("A point of contact for the organization reponsible for configuring this library"), - "description": _("This email address will be shared as part of integrations that you set up through this interface. It will not be shared with the general public. This gives the administrator of the remote integration a way to contact you about problems with this library's use of that integration.
If no value is specified here, the general patron support address will be used."), + "label": _( + "A point of contact for the organization reponsible for configuring this library" + ), + "description": _( + "This email address will be shared as part of integrations that you set up through this interface. It will not be shared with the general public. This gives the administrator of the remote integration a way to contact you about problems with this library's use of that integration.
If no value is specified here, the general patron support address will be used." + ), "format": "email", "category": "Patron Support", - "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER + "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER, }, { "key": DEFAULT_NOTIFICATION_EMAIL_ADDRESS, "label": _("Write-only email address for vendor hold notifications"), - "description": _('This address must trash all email sent to it. Vendor hold notifications contain sensitive patron information, but cannot be forwarded to patrons because they contain vendor-specific instructions.
The default address will work, but for greater security, set up your own address that trashes all incoming email.'), + "description": _( + 'This address must trash all email sent to it. Vendor hold notifications contain sensitive patron information, but cannot be forwarded to patrons because they contain vendor-specific instructions.
The default address will work, but for greater security, set up your own address that trashes all incoming email.' + ), "default": STANDARD_NOREPLY_EMAIL_ADDRESS, "required": True, "format": "email", - "level": CoreConfiguration.SYS_ADMIN_ONLY + "level": CoreConfiguration.SYS_ADMIN_ONLY, }, { "key": COLOR_SCHEME, "label": _("Mobile color scheme"), - "description": _("This tells mobile applications what color scheme to use when rendering this library's OPDS feed."), + "description": _( + "This tells mobile applications what color scheme to use when rendering this library's OPDS feed." + ), "options": [ dict(key="amber", label=_("Amber")), dict(key="black", label=_("Black")), @@ -321,41 +351,49 @@ class Configuration(CoreConfiguration): "type": "select", "default": DEFAULT_COLOR_SCHEME, "category": "Client Interface Customization", - "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER + "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER, }, { "key": WEB_PRIMARY_COLOR, "label": _("Web primary color"), - "description": _("This is the brand primary color for the web application. Must have sufficient contrast with white."), + "description": _( + "This is the brand primary color for the web application. Must have sufficient contrast with white." + ), "type": "color-picker", "default": DEFAULT_WEB_PRIMARY_COLOR, "category": "Client Interface Customization", - "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER + "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER, }, { "key": WEB_SECONDARY_COLOR, "label": _("Web secondary color"), - "description": _("This is the brand secondary color for the web application. Must have sufficient contrast with white."), + "description": _( + "This is the brand secondary color for the web application. Must have sufficient contrast with white." + ), "type": "color-picker", "default": DEFAULT_WEB_SECONDARY_COLOR, "category": "Client Interface Customization", - "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER + "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER, }, { "key": WEB_CSS_FILE, "label": _("Custom CSS file for web"), - "description": _("Give web applications a CSS file to customize the catalog display."), + "description": _( + "Give web applications a CSS file to customize the catalog display." + ), "format": "url", "category": "Client Interface Customization", - "level": CoreConfiguration.SYS_ADMIN_ONLY + "level": CoreConfiguration.SYS_ADMIN_ONLY, }, { "key": WEB_HEADER_LINKS, "label": _("Web header links"), - "description": _("This gives web applications a list of links to display in the header. Specify labels for each link in the same order under 'Web header labels'."), + "description": _( + "This gives web applications a list of links to display in the header. Specify labels for each link in the same order under 'Web header labels'." + ), "type": "list", "category": "Client Interface Customization", - "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER + "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER, }, { "key": WEB_HEADER_LABELS, @@ -363,7 +401,7 @@ class Configuration(CoreConfiguration): "description": _("Labels for each link under 'Web header links'."), "type": "list", "category": "Client Interface Customization", - "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER + "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER, }, { "key": LOGO, @@ -377,124 +415,142 @@ class Configuration(CoreConfiguration): f"the longest dimension does not excede {LOGO_MAX_DIMENSION} pixels." ), "category": "Client Interface Customization", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": HIDDEN_CONTENT_TYPES, "label": _("Hidden content types"), "type": "text", - "description": _('A list of content types to hide from all clients, e.g. ["application/pdf"]. This can be left blank except to solve specific problems.'), + "description": _( + 'A list of content types to hide from all clients, e.g. ["application/pdf"]. This can be left blank except to solve specific problems.' + ), "category": "Client Interface Customization", - "level": CoreConfiguration.SYS_ADMIN_ONLY + "level": CoreConfiguration.SYS_ADMIN_ONLY, }, { "key": LIBRARY_FOCUS_AREA, "label": _("Focus area"), "type": "list", - "description": _("The library focuses on serving patrons in this geographic area. In most cases this will be a city name like Springfield, OR."), + "description": _( + "The library focuses on serving patrons in this geographic area. In most cases this will be a city name like Springfield, OR." + ), "category": "Geographic Areas", "format": "geographic", "instructions": AREA_INPUT_INSTRUCTIONS, "capitalize": True, - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": LIBRARY_SERVICE_AREA, "label": _("Service area"), "type": "list", - "description": _("The full geographic area served by this library. In most cases this is the same as the focus area and can be left blank, but it may be a larger area such as a US state (which should be indicated by its abbreviation, like OR)."), + "description": _( + "The full geographic area served by this library. In most cases this is the same as the focus area and can be left blank, but it may be a larger area such as a US state (which should be indicated by its abbreviation, like OR)." + ), "category": "Geographic Areas", "format": "geographic", "instructions": AREA_INPUT_INSTRUCTIONS, "capitalize": True, - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": MAX_OUTSTANDING_FINES, - "label": _("Maximum amount in fines a patron can have before losing lending privileges"), + "label": _( + "Maximum amount in fines a patron can have before losing lending privileges" + ), "type": "number", "category": "Loans, Holds, & Fines", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": LOAN_LIMIT, "label": _("Maximum number of books a patron can have on loan at once"), - "description": _("(Note: depending on distributor settings, a patron may be able to exceed the limit by checking out books directly from a distributor's app. They may also get a limit exceeded error before they reach these limits if a distributor has a smaller limit.)"), + "description": _( + "(Note: depending on distributor settings, a patron may be able to exceed the limit by checking out books directly from a distributor's app. They may also get a limit exceeded error before they reach these limits if a distributor has a smaller limit.)" + ), "type": "number", "category": "Loans, Holds, & Fines", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": HOLD_LIMIT, "label": _("Maximum number of books a patron can have on hold at once"), - "description": _("(Note: depending on distributor settings, a patron may be able to exceed the limit by checking out books directly from a distributor's app. They may also get a limit exceeded error before they reach these limits if a distributor has a smaller limit.)"), + "description": _( + "(Note: depending on distributor settings, a patron may be able to exceed the limit by checking out books directly from a distributor's app. They may also get a limit exceeded error before they reach these limits if a distributor has a smaller limit.)" + ), "type": "number", "category": "Loans, Holds, & Fines", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": TERMS_OF_SERVICE, "label": _("Terms of Service URL"), "format": "url", "category": "Links", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": PRIVACY_POLICY, "label": _("Privacy Policy URL"), "format": "url", "category": "Links", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": COPYRIGHT, "label": _("Copyright URL"), "format": "url", "category": "Links", - "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER + "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER, }, { "key": ABOUT, "label": _("About URL"), "format": "url", "category": "Links", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": LICENSE, "label": _("License URL"), "format": "url", "category": "Links", - "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER + "level": CoreConfiguration.SYS_ADMIN_OR_MANAGER, }, { "key": REGISTER, "label": _("Patron registration URL"), - "description": _("A URL where someone who doesn't have a library card yet can sign up for one."), + "description": _( + "A URL where someone who doesn't have a library card yet can sign up for one." + ), "format": "url", "category": "Patron Support", "allowed": ["nypl.card-creator:https://patrons.librarysimplified.org/"], - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": LARGE_COLLECTION_LANGUAGES, - "label": _("The primary languages represented in this library's collection"), + "label": _( + "The primary languages represented in this library's collection" + ), "type": "list", "format": "language-code", "description": LANGUAGE_DESCRIPTION, "optional": True, "category": "Languages", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": SMALL_COLLECTION_LANGUAGES, - "label": _("Other major languages represented in this library's collection"), + "label": _( + "Other major languages represented in this library's collection" + ), "type": "list", "format": "language-code", "description": LANGUAGE_DESCRIPTION, "optional": True, "category": "Languages", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, { "key": TINY_COLLECTION_LANGUAGES, @@ -504,7 +560,7 @@ class Configuration(CoreConfiguration): "description": LANGUAGE_DESCRIPTION, "optional": True, "category": "Languages", - "level": CoreConfiguration.ALL_ACCESS + "level": CoreConfiguration.ALL_ACCESS, }, ] @@ -536,27 +592,19 @@ def _collection_languages(cls, library, key): @classmethod def large_collection_languages(cls, library): - return cls._collection_languages( - library, cls.LARGE_COLLECTION_LANGUAGES - ) + return cls._collection_languages(library, cls.LARGE_COLLECTION_LANGUAGES) @classmethod def small_collection_languages(cls, library): - return cls._collection_languages( - library, cls.SMALL_COLLECTION_LANGUAGES - ) + return cls._collection_languages(library, cls.SMALL_COLLECTION_LANGUAGES) @classmethod def tiny_collection_languages(cls, library): - return cls._collection_languages( - library, cls.TINY_COLLECTION_LANGUAGES - ) + return cls._collection_languages(library, cls.TINY_COLLECTION_LANGUAGES) @classmethod def max_outstanding_fines(cls, library): - max_fines = ConfigurationSetting.for_library( - cls.MAX_OUTSTANDING_FINES, library - ) + max_fines = ConfigurationSetting.for_library(cls.MAX_OUTSTANDING_FINES, library) if max_fines.value is None: return None return MoneyUtility.parse(max_fines.value) @@ -577,12 +625,11 @@ def estimate_language_collections_for_library(cls, library): holdings = library.estimated_holdings_by_language() large, small, tiny = cls.classify_holdings(holdings) for setting, value in ( - (cls.LARGE_COLLECTION_LANGUAGES, large), - (cls.SMALL_COLLECTION_LANGUAGES, small), - (cls.TINY_COLLECTION_LANGUAGES, tiny), + (cls.LARGE_COLLECTION_LANGUAGES, large), + (cls.SMALL_COLLECTION_LANGUAGES, small), + (cls.TINY_COLLECTION_LANGUAGES, tiny), ): - ConfigurationSetting.for_library( - setting, library).value = json.dumps(value) + ConfigurationSetting.for_library(setting, library).value = json.dumps(value) @classmethod def classify_holdings(cls, works_by_language): @@ -604,7 +651,7 @@ def classify_holdings(cls, works_by_language): if not works_by_language: # In the absence of any information, assume we have an # English collection and nothing else. - large.append('eng') + large.append("eng") return result # The single most common language always gets a large @@ -650,7 +697,7 @@ def help_uris(cls, library): if name == cls.HELP_EMAIL: value = cls._as_mailto(value) if name == cls.HELP_WEB: - type = 'text/html' + type = "text/html" yield type, value @classmethod @@ -722,6 +769,7 @@ def cipher(cls, key): """ return PKCS1_OAEP.new(RSA.import_key(key)) + # We changed Configuration.DEFAULT_OPDS_FORMAT, but the Configuration # class from core still has the old value. Change that one to match, # so that core code that checks this constant will get the right @@ -732,11 +780,13 @@ def cipher(cls, key): # appropriate one in any situation. This is a source of subtle bugs. CoreConfiguration.DEFAULT_OPDS_FORMAT = Configuration.DEFAULT_OPDS_FORMAT + @contextlib.contextmanager def empty_config(): with core_empty_config({}, [CoreConfiguration, Configuration]) as i: yield i + @contextlib.contextmanager def temp_config(new_config=None, replacement_classes=None): all_replacement_classes = [CoreConfiguration, Configuration] diff --git a/api/controller.py b/api/controller.py index a3792ca225..b1ba8b9010 100644 --- a/api/controller.py +++ b/api/controller.py @@ -3,7 +3,6 @@ import json import logging import os -import pytz import sys import urllib.parse from collections import defaultdict @@ -11,47 +10,22 @@ from wsgiref.handlers import format_date_time import flask +import pytz from expiringdict import ExpiringDict -from flask import ( - make_response, - Response, - redirect, -) +from flask import Response, make_response, redirect from flask_babel import lazy_gettext as _ from lxml import etree from sqlalchemy.orm import eagerload -from .adobe_vendor_id import ( - AdobeVendorIDController, - DeviceManagementProtocolController, - AuthdataUtility, -) -from .annotations import ( - AnnotationWriter, - AnnotationParser, -) from api.saml.controller import SAMLController -from .authenticator import ( - Authenticator, - CirculationPatronProfileStorage, - OAuthController, -) -from .base_controller import BaseCirculationManagerController -from .circulation import CirculationAPI, FulfillmentInfo -from .circulation_exceptions import * -from .config import ( - Configuration, - CannotLoadConfiguration, -) from core.analytics import Analytics +from core.app_server import ComplaintController, HeartbeatController +from core.app_server import URNLookupController as CoreURNLookupController from core.app_server import ( cdn_url_for, - url_for, load_facets_from_request, load_pagination_from_request, - ComplaintController, - HeartbeatController, - URNLookupController as CoreURNLookupController, + url_for, ) from core.entrypoint import EverythingEntryPoint from core.external_search import ( @@ -62,8 +36,8 @@ from core.lane import ( BaseFacets, FeaturedFacets, - Pagination, Lane, + Pagination, SearchFacets, WorkList, ) @@ -71,7 +45,6 @@ from core.marc import MARCExporter from core.metadata_layer import ContributorData from core.model import ( - get_one, Admin, Annotation, CachedFeed, @@ -88,38 +61,45 @@ IntegrationClient, Library, LicensePool, - Loan, LicensePoolDeliveryMechanism, + Loan, Patron, Representation, Session, + get_one, ) -from core.opds import ( - AcquisitionFeed, - NavigationFacets, - NavigationFeed, -) +from core.opds import AcquisitionFeed, NavigationFacets, NavigationFeed from core.opensearch import OpenSearchDocument from core.user_profile import ProfileController as CoreProfileController from core.util.authentication_for_opds import AuthenticationForOPDSDocument -from core.util.datetime_helpers import ( - from_timestamp, - utc_now, -) -from core.util.http import ( - HTTP, - RemoteIntegrationException, -) -from core.util.opds_writer import ( - OPDSFeed, -) +from core.util.datetime_helpers import from_timestamp, utc_now +from core.util.http import HTTP, RemoteIntegrationException +from core.util.opds_writer import OPDSFeed from core.util.problem_detail import ProblemDetail from core.util.string_helpers import base64 + +from .adobe_vendor_id import ( + AdobeVendorIDController, + AuthdataUtility, + DeviceManagementProtocolController, +) +from .annotations import AnnotationParser, AnnotationWriter +from .authenticator import ( + Authenticator, + CirculationPatronProfileStorage, + OAuthController, +) +from .base_controller import BaseCirculationManagerController +from .circulation import CirculationAPI, FulfillmentInfo +from .circulation_exceptions import * +from .config import CannotLoadConfiguration, Configuration from .custom_index import CustomIndexView from .lanes import ( - load_lanes, ContributorFacets, ContributorLane, + CrawlableCollectionBasedLane, + CrawlableCustomListBasedLane, + CrawlableFacets, HasSeriesFacets, JackpotFacets, JackpotWorkList, @@ -127,24 +107,22 @@ RelatedBooksLane, SeriesFacets, SeriesLane, - CrawlableCollectionBasedLane, - CrawlableCustomListBasedLane, - CrawlableFacets, + load_lanes, ) from .odl import ODLAPI from .opds import ( CirculationManagerAnnotator, LibraryAnnotator, - SharedCollectionAnnotator, LibraryLoanAndHoldAnnotator, + SharedCollectionAnnotator, SharedCollectionLoanAndHoldAnnotator, ) from .problem_details import * from .shared_collection import SharedCollectionAPI from .testing import MockCirculationAPI, MockSharedCollectionAPI -class CirculationManager(object): +class CirculationManager(object): def __init__(self, _db, testing=False): self.log = logging.getLogger("Circulation manager web app") @@ -154,7 +132,9 @@ def __init__(self, _db, testing=False): try: self.config = Configuration.load(_db) except CannotLoadConfiguration as exception: - self.log.exception("Could not load configuration file: {0}".format(exception)) + self.log.exception( + "Could not load configuration file: {0}".format(exception) + ) sys.exit() self.testing = testing @@ -174,18 +154,22 @@ def load_facets_from_request(self, *args, **kwargs): facets = load_facets_from_request(*args, **kwargs) - worklist = kwargs.get('worklist') + worklist = kwargs.get("worklist") if worklist is not None: # Try to get the index controller. If it's not initialized # for any reason, don't run this check -- we have bigger # problems. - index_controller = getattr(self, 'index_controller', None) - if (index_controller and not - worklist.accessible_to(index_controller.request_patron)): + index_controller = getattr(self, "index_controller", None) + if index_controller and not worklist.accessible_to( + index_controller.request_patron + ): return NO_SUCH_LANE.detailed(_("Lane does not exist")) - if isinstance(facets, BaseFacets) and getattr(facets, 'max_cache_age', None) is not None: + if ( + isinstance(facets, BaseFacets) + and getattr(facets, "max_cache_age", None) is not None + ): # A faceting object was loaded, and it tried to do something nonstandard # with caching. @@ -197,7 +181,7 @@ def load_facets_from_request(self, *args, **kwargs): # reason, we'll default to assuming the user is not an # authenticated admin. authenticated = False - controller = getattr(self, 'admin_sign_in_controller', None) + controller = getattr(self, "admin_sign_in_controller", None) if controller: admin = controller.authenticated_admin_from_request() # If authenticated_admin_from_request returns anything other than an admin (probably @@ -250,9 +234,7 @@ def load_settings(self): new_top_level_lanes[library.id] = lanes - new_custom_index_views[library.id] = CustomIndexView.for_library( - library - ) + new_custom_index_views[library.id] = CustomIndexView.for_library(library) new_circulation_apis[library.id] = self.setup_circulation( library, self.analytics @@ -278,24 +260,28 @@ def get_domain(url): url = url.strip() if url == "*": return url - scheme, netloc, path, parameters, query, fragment = urllib.parse.urlparse(url) + scheme, netloc, path, parameters, query, fragment = urllib.parse.urlparse( + url + ) if scheme and netloc: return scheme + "://" + netloc else: return None sitewide_patron_web_client_urls = ConfigurationSetting.sitewide( - self._db, Configuration.PATRON_WEB_HOSTNAMES).value + self._db, Configuration.PATRON_WEB_HOSTNAMES + ).value if sitewide_patron_web_client_urls: - for url in sitewide_patron_web_client_urls.split('|'): + for url in sitewide_patron_web_client_urls.split("|"): domain = get_domain(url) if domain: patron_web_domains.add(domain) from .registry import Registration - for setting in self._db.query( - ConfigurationSetting).filter( - ConfigurationSetting.key==Registration.LIBRARY_REGISTRATION_WEB_CLIENT): + + for setting in self._db.query(ConfigurationSetting).filter( + ConfigurationSetting.key == Registration.LIBRARY_REGISTRATION_WEB_CLIENT + ): if setting.value: patron_web_domains.add(get_domain(setting.value)) @@ -309,9 +295,12 @@ def get_domain(url): self.authentication_for_opds_documents = ExpiringDict( max_len=1000, max_age_seconds=authentication_document_cache_time ) - self.wsgi_debug = ConfigurationSetting.sitewide( - self._db, Configuration.WSGI_DEBUG_KEY - ).bool_value or False + self.wsgi_debug = ( + ConfigurationSetting.sitewide( + self._db, Configuration.WSGI_DEBUG_KEY + ).bool_value + or False + ) @property def external_search(self): @@ -330,9 +319,7 @@ def setup_external_search(self): self._external_search = self.setup_search() self.external_search_initialization_exception = None except Exception as e: - self.log.error( - "Exception initializing search engine: %s", e - ) + self.log.error("Exception initializing search engine: %s", e) self._external_search = None self.external_search_initialization_exception = e return self._external_search @@ -350,7 +337,7 @@ def cdn_url_for(self, view, *args, **kwargs): :param kwargs: Keyword arguments to the view function. """ url_for = self._cdn_url_for - facets = kwargs.pop('_facets', None) + facets = kwargs.pop("_facets", None) if facets and facets.max_cache_age is CachedFeed.IGNORE_CACHE: # The faceting object in play has disabled cache # checking. A CDN is also a cache, so we should disable @@ -368,9 +355,8 @@ def _cdn_url_for(self, *args, **kwargs): return cdn_url_for(*args, **kwargs) def url_for(self, view, *args, **kwargs): - """Call the url_for function, ensuring that Flask generates an absolute URL. - """ - kwargs['_external'] = True + """Call the url_for function, ensuring that Flask generates an absolute URL.""" + kwargs["_external"] = True return url_for(view, *args, **kwargs) def log_lanes(self, lanelist=None, level=0): @@ -379,7 +365,7 @@ def log_lanes(self, lanelist=None, level=0): for lane in lanelist: self.log.debug("%s%r", "-" * level, lane) if lane.sublanes: - self.log_lanes(lane.sublanes, level+1) + self.log_lanes(lane.sublanes, level + 1) def setup_search(self): """Set up a search client.""" @@ -428,6 +414,7 @@ def setup_one_time_controllers(self): self.static_files = StaticFileController(self) from api.lcp.controller import LCPController + self.lcp_controller = LCPController(self) def setup_configuration_dependent_controllers(self): @@ -448,13 +435,15 @@ def setup_adobe_vendor_id(self, _db, library): """ short_client_token_initialization_exceptions = dict() adobe = ExternalIntegration.lookup( - _db, ExternalIntegration.ADOBE_VENDOR_ID, - ExternalIntegration.DRM_GOAL, library=library + _db, + ExternalIntegration.ADOBE_VENDOR_ID, + ExternalIntegration.DRM_GOAL, + library=library, ) warning = ( - 'Adobe Vendor ID controller is disabled due to missing or' - ' incomplete configuration. This is probably nothing to' - ' worry about.' + "Adobe Vendor ID controller is disabled due to missing or" + " incomplete configuration. This is probably nothing to" + " worry about." ) new_adobe_vendor_id = None @@ -468,14 +457,12 @@ def setup_adobe_vendor_id(self, _db, library): "Multiple libraries define an Adobe Vendor ID integration. This is not supported and the last library seen will take precedence." ) new_adobe_vendor_id = AdobeVendorIDController( - _db, - library, - vendor_id, - node_value, - self.auth + _db, library, vendor_id, node_value, self.auth ) else: - self.log.warn("Adobe Vendor ID controller is disabled due to missing or incomplete configuration. This is probably nothing to worry about.") + self.log.warn( + "Adobe Vendor ID controller is disabled due to missing or incomplete configuration. This is probably nothing to worry about." + ) if new_adobe_vendor_id: self.adobe_vendor_id = new_adobe_vendor_id @@ -484,8 +471,10 @@ def setup_adobe_vendor_id(self, _db, library): # information for the calling code to have so it knows # whether or not we should support the Device Management Protocol. registry = ExternalIntegration.lookup( - _db, ExternalIntegration.OPDS_REGISTRATION, - ExternalIntegration.DISCOVERY_GOAL, library=library + _db, + ExternalIntegration.OPDS_REGISTRATION, + ExternalIntegration.DISCOVERY_GOAL, + library=library, ) authdata = None if registry: @@ -495,9 +484,12 @@ def setup_adobe_vendor_id(self, _db, library): short_client_token_initialization_exceptions[library.id] = e self.log.error( "Short Client Token configuration for %s is present but not working. This may be cause for concern. Original error: %s", - library.name, str(e) + library.name, + str(e), ) - self.short_client_token_initialization_exceptions = short_client_token_initialization_exceptions + self.short_client_token_initialization_exceptions = ( + short_client_token_initialization_exceptions + ) return authdata def annotator(self, lane, facets=None, *args, **kwargs): @@ -513,7 +505,7 @@ def annotator(self, lane, facets=None, *args, **kwargs): library = lane.library elif lane and isinstance(lane, WorkList): library = lane.get_library(self._db) - if not library and hasattr(flask.request, 'library'): + if not library and hasattr(flask.request, "library"): library = flask.request.library # If no library is provided, the best we can do is a generic @@ -530,12 +522,16 @@ def annotator(self, lane, facets=None, *args, **kwargs): library_identifies_patrons = ( authenticator is not None and authenticator.identifies_individuals ) - annotator_class = kwargs.pop('annotator_class', LibraryAnnotator) + annotator_class = kwargs.pop("annotator_class", LibraryAnnotator) return annotator_class( - self.circulation_apis[library.id], lane, - library, top_level_title='All Books', + self.circulation_apis[library.id], + lane, + library, + top_level_title="All Books", library_identifies_patrons=library_identifies_patrons, - facets=facets, *args, **kwargs + facets=facets, + *args, + **kwargs ) @property @@ -562,15 +558,13 @@ def authentication_for_opds_document(self): value = self.auth.create_authentication_document() self.authentication_for_opds_documents[name] = value - if self.wsgi_debug and 'debug' in flask.request.args: + if self.wsgi_debug and "debug" in flask.request.args: # Annotate with debugging information about the WSGI # environment and the authentication document cache # itself. value = json.loads(value) - value['_debug'] = dict( - url=self.url_for( - 'authentication_document', library_short_name=name - ), + value["_debug"] = dict( + url=self.url_for("authentication_document", library_short_name=name), environ=str(dict(flask.request.environ)), cache=str(self.authentication_for_opds_documents), ) @@ -580,33 +574,37 @@ def authentication_for_opds_document(self): @property def sitewide_key_pair(self): """Look up or create the sitewide public/private key pair.""" - setting = ConfigurationSetting.sitewide( - self._db, Configuration.KEY_PAIR - ) + setting = ConfigurationSetting.sitewide(self._db, Configuration.KEY_PAIR) return Configuration.key_pair(setting) @property def public_key_integration_document(self): """Serve a document with the sitewide public key.""" - site_id = ConfigurationSetting.sitewide(self._db, Configuration.BASE_URL_KEY).value + site_id = ConfigurationSetting.sitewide( + self._db, Configuration.BASE_URL_KEY + ).value document = dict(id=site_id) public, private = self.sitewide_key_pair - document['public_key'] = dict(type='RSA', value=public) + document["public_key"] = dict(type="RSA", value=public) return json.dumps(document) class CirculationManagerController(BaseCirculationManagerController): - def get_patron_circ_objects(self, object_class, patron, license_pools): if not patron: return [] pool_ids = [pool.id for pool in license_pools] - return self._db.query(object_class).filter( - object_class.patron_id==patron.id, - object_class.license_pool_id.in_(pool_ids) - ).options(eagerload(object_class.license_pool)).all() + return ( + self._db.query(object_class) + .filter( + object_class.patron_id == patron.id, + object_class.license_pool_id.in_(pool_ids), + ) + .options(eagerload(object_class.license_pool)) + .all() + ) def get_patron_loan(self, patron, license_pools): loans = self.get_patron_circ_objects(Loan, patron, license_pools) @@ -663,7 +661,7 @@ def handle_conditional_request(self, last_modified=None): if last_modified.microsecond: last_modified = last_modified.replace(microsecond=0) - if_modified_since = flask.request.headers.get('If-Modified-Since') + if_modified_since = flask.request.headers.get("If-Modified-Since") if not if_modified_since: return None @@ -717,8 +715,10 @@ def load_lane(self, lane_identifier): if not lane: return NO_SUCH_LANE.detailed( - _("Lane %(lane_identifier)s does not exist or is not associated with library %(library_id)s", - lane_identifier=lane_identifier, library_id=library_id + _( + "Lane %(lane_identifier)s does not exist or is not associated with library %(library_id)s", + lane_identifier=lane_identifier, + library_id=library_id, ) ) @@ -749,19 +749,20 @@ def load_licensepools(self, library, identifier_type, identifier): to look up an Identifier. """ _db = Session.object_session(library) - pools = _db.query(LicensePool).join(LicensePool.collection).join( - LicensePool.identifier).join(Collection.libraries).filter( - Identifier.type==identifier_type - ).filter( - Identifier.identifier==identifier - ).filter( - Library.id==library.id - ).all() + pools = ( + _db.query(LicensePool) + .join(LicensePool.collection) + .join(LicensePool.identifier) + .join(Collection.libraries) + .filter(Identifier.type == identifier_type) + .filter(Identifier.identifier == identifier) + .filter(Library.id == library.id) + .all() + ) if not pools: return NO_LICENSES.detailed( - _("The item you're asking about (%s/%s) isn't in this collection.") % ( - identifier_type, identifier - ) + _("The item you're asking about (%s/%s) isn't in this collection.") + % (identifier_type, identifier) ) return pools @@ -778,9 +779,12 @@ def load_licensepool(self, license_pool_id): def load_licensepooldelivery(self, pool, mechanism_id): """Turn user input into a LicensePoolDeliveryMechanism object.""" mechanism = get_one( - self._db, LicensePoolDeliveryMechanism, - data_source=pool.data_source, identifier=pool.identifier, - delivery_mechanism_id=mechanism_id, on_multiple='interchangeable' + self._db, + LicensePoolDeliveryMechanism, + data_source=pool.data_source, + identifier=pool.identifier, + delivery_mechanism_id=mechanism_id, + on_multiple="interchangeable", ) return mechanism or BAD_DELIVERY_MECHANISM @@ -809,15 +813,15 @@ def apply_borrowing_policy(self, patron, license_pool): if work is not None and not work.age_appropriate_for_patron(patron): return NOT_AGE_APPROPRIATE - if (not patron.library.allow_holds and - license_pool.licenses_available == 0 and - not license_pool.open_access and - not license_pool.unlimited_access and - not license_pool.self_hosted + if ( + not patron.library.allow_holds + and license_pool.licenses_available == 0 + and not license_pool.open_access + and not license_pool.unlimited_access + and not license_pool.self_hosted ): return FORBIDDEN_BY_POLICY.detailed( - _("Library policy prohibits the placement of holds."), - status_code=403 + _("Library policy prohibits the placement of holds."), status_code=403 ) return None @@ -836,7 +840,11 @@ def __call__(self): # The simple case: the app is equally open to all clients. library_short_name = flask.request.library.short_name if not self.has_root_lanes(): - return redirect(self.cdn_url_for('acquisition_groups', library_short_name=library_short_name)) + return redirect( + self.cdn_url_for( + "acquisition_groups", library_short_name=library_short_name + ) + ) # The more complex case. We must authorize the patron, check # their type, and redirect them to an appropriate feed. @@ -847,9 +855,7 @@ def authentication_document(self): return Response( self.manager.authentication_for_opds_document, 200, - { - "Content-Type" : AuthenticationForOPDSDocument.MEDIA_TYPE - } + {"Content-Type": AuthenticationForOPDSDocument.MEDIA_TYPE}, ) def has_root_lanes(self): @@ -878,14 +884,14 @@ def appropriate_index_for_patron_type(self): if root_lane is None: return redirect( self.cdn_url_for( - 'acquisition_groups', + "acquisition_groups", library_short_name=library_short_name, ) ) return redirect( self.cdn_url_for( - 'acquisition_groups', + "acquisition_groups", library_short_name=library_short_name, lane_identifier=root_lane.id, ) @@ -895,11 +901,12 @@ def public_key_document(self): """Serves a sitewide public key document""" return Response( self.manager.public_key_integration_document, - 200, { 'Content-Type' : 'application/opds+json' } + 200, + {"Content-Type": "application/opds+json"}, ) -class OPDSFeedController(CirculationManagerController): +class OPDSFeedController(CirculationManagerController): def groups(self, lane_identifier, feed_class=AcquisitionFeed): """Build or retrieve a grouped acquisition feed. @@ -919,10 +926,10 @@ def groups(self, lane_identifier, feed_class=AcquisitionFeed): if patron is not None and patron.root_lane: return redirect( self.cdn_url_for( - 'acquisition_groups', + "acquisition_groups", library_short_name=library.short_name, lane_identifier=patron.root_lane.id, - _external=True + _external=True, ) ) @@ -941,8 +948,9 @@ def groups(self, lane_identifier, feed_class=AcquisitionFeed): minimum_featured_quality=library.minimum_featured_quality, ) facets = self.manager.load_facets_from_request( - worklist=lane, base_class=FeaturedFacets, - base_class_constructor_kwargs=facet_class_kwargs + worklist=lane, + base_class=FeaturedFacets, + base_class_constructor_kwargs=facet_class_kwargs, ) if isinstance(facets, ProblemDetail): return facets @@ -952,14 +960,21 @@ def groups(self, lane_identifier, feed_class=AcquisitionFeed): return search_engine url = self.cdn_url_for( - "acquisition_groups", lane_identifier=lane_identifier, - library_short_name=library.short_name, _facets=facets + "acquisition_groups", + lane_identifier=lane_identifier, + library_short_name=library.short_name, + _facets=facets, ) annotator = self.manager.annotator(lane, facets) return feed_class.groups( - _db=self._db, title=lane.display_name, url=url, worklist=lane, - annotator=annotator, facets=facets, search_engine=search_engine + _db=self._db, + title=lane.display_name, + url=url, + worklist=lane, + annotator=annotator, + facets=facets, + search_engine=search_engine, ) def feed(self, lane_identifier, feed_class=AcquisitionFeed): @@ -985,16 +1000,22 @@ def feed(self, lane_identifier, feed_class=AcquisitionFeed): library_short_name = flask.request.library.short_name url = self.cdn_url_for( - "feed", lane_identifier=lane_identifier, - library_short_name=library_short_name, _facets=facets + "feed", + lane_identifier=lane_identifier, + library_short_name=library_short_name, + _facets=facets, ) annotator = self.manager.annotator(lane, facets=facets) return feed_class.page( - _db=self._db, title=lane.display_name, - url=url, worklist=lane, annotator=annotator, - facets=facets, pagination=pagination, - search_engine=search_engine + _db=self._db, + title=lane.display_name, + url=url, + worklist=lane, + annotator=annotator, + facets=facets, + pagination=pagination, + search_engine=search_engine, ) def navigation(self, lane_identifier): @@ -1006,7 +1027,9 @@ def navigation(self, lane_identifier): library = flask.request.library library_short_name = library.short_name url = self.cdn_url_for( - "navigation_feed", lane_identifier=lane_identifier, library_short_name=library_short_name, + "navigation_feed", + lane_identifier=lane_identifier, + library_short_name=library_short_name, ) title = lane.display_name @@ -1014,8 +1037,9 @@ def navigation(self, lane_identifier): minimum_featured_quality=library.minimum_featured_quality, ) facets = self.manager.load_facets_from_request( - worklist=lane, base_class=NavigationFacets, - base_class_constructor_kwargs=facet_class_kwargs + worklist=lane, + base_class=NavigationFacets, + base_class_constructor_kwargs=facet_class_kwargs, ) annotator = self.manager.annotator(lane, facets) return NavigationFeed.navigation( @@ -1045,8 +1069,7 @@ def crawlable_collection_feed(self, collection_name): return NO_SUCH_COLLECTION title = collection.name url = self.cdn_url_for( - "crawlable_collection_feed", - collection_name=collection.name + "crawlable_collection_feed", collection_name=collection.name ) lane = CrawlableCollectionBasedLane() lane.initialize([collection]) @@ -1073,15 +1096,17 @@ def crawlable_list_feed(self, list_name): library_short_name = library.short_name title = list.name url = self.cdn_url_for( - "crawlable_list_feed", list_name=list.name, + "crawlable_list_feed", + list_name=list.name, library_short_name=library_short_name, ) lane = CrawlableCustomListBasedLane() lane.initialize(library, list) return self._crawlable_feed(title=title, url=url, worklist=lane) - def _crawlable_feed(self, title, url, worklist, annotator=None, - feed_class=AcquisitionFeed): + def _crawlable_feed( + self, title, url, worklist, annotator=None, feed_class=AcquisitionFeed + ): """Helper method to create a crawlable feed. :param title: The title to use for the feed. @@ -1109,10 +1134,14 @@ def _crawlable_feed(self, title, url, worklist, annotator=None, facets = CrawlableFacets.default(None) return feed_class.page( - _db=self._db, title=title, url=url, worklist=worklist, + _db=self._db, + title=title, + url=url, + worklist=worklist, annotator=annotator, - facets=facets, pagination=pagination, - search_engine=search_engine + facets=facets, + pagination=pagination, + search_engine=search_engine, ) def _load_search_facets(self, lane): @@ -1126,7 +1155,8 @@ def _load_search_facets(self, lane): # and no need for a special default. default_entrypoint = None return self.manager.load_facets_from_request( - worklist=lane, base_class=SearchFacets, + worklist=lane, + base_class=SearchFacets, default_entrypoint=default_entrypoint, ) @@ -1155,7 +1185,7 @@ def search(self, lane_identifier, feed_class=AcquisitionFeed): # Check whether there is a query string -- if not, we want to # send an OpenSearch document explaining how to search. - query = flask.request.args.get('q') + query = flask.request.args.get("q") library_short_name = flask.request.library.short_name # Create a function that, when called, generates a URL to the @@ -1166,32 +1196,39 @@ def search(self, lane_identifier, feed_class=AcquisitionFeed): # string. make_url_kwargs = dict(list(facets.items())) make_url = lambda: self.url_for( - 'lane_search', lane_identifier=lane_identifier, + "lane_search", + lane_identifier=lane_identifier, library_short_name=library_short_name, **make_url_kwargs ) if not query: # Send the search form open_search_doc = OpenSearchDocument.for_lane(lane, make_url()) - headers = { "Content-Type" : "application/opensearchdescription+xml" } + headers = {"Content-Type": "application/opensearchdescription+xml"} return Response(open_search_doc, 200, headers) # We have a query -- add it to the keyword arguments used when # generating a URL. - make_url_kwargs['q'] = query.encode("utf8") + make_url_kwargs["q"] = query.encode("utf8") # Run a search. annotator = self.manager.annotator(lane, facets) info = OpenSearchDocument.search_info(lane) return feed_class.search( - _db=self._db, title=info['name'], - url=make_url(), lane=lane, search_engine=search_engine, - query=query, annotator=annotator, pagination=pagination, - facets=facets + _db=self._db, + title=info["name"], + url=make_url(), + lane=lane, + search_engine=search_engine, + query=query, + annotator=annotator, + pagination=pagination, + facets=facets, ) - def _qa_feed(self, feed_factory, feed_title, controller_name, facet_class, - worklist_factory): + def _qa_feed( + self, feed_factory, feed_title, controller_name, facet_class, worklist_factory + ): """Create some kind of OPDS feed designed for consumption by an automated QA process. @@ -1232,9 +1269,15 @@ def _qa_feed(self, feed_factory, feed_title, controller_name, facet_class, # reason to put more than a single item in each group. pagination = Pagination(size=1) return feed_factory( - _db=self._db, title=feed_title, url=url, pagination=pagination, - worklist=worklist, annotator=annotator, search_engine=search_engine, - facets=facets, max_age=CachedFeed.IGNORE_CACHE + _db=self._db, + title=feed_title, + url=url, + pagination=pagination, + worklist=worklist, + annotator=annotator, + search_engine=search_engine, + facets=facets, + max_age=CachedFeed.IGNORE_CACHE, ) def qa_feed(self, feed_class=AcquisitionFeed): @@ -1245,6 +1288,7 @@ def qa_feed(self, feed_class=AcquisitionFeed): :param feed_class: Class to substitute for AcquisitionFeed during tests. """ + def factory(library, facets): return JackpotWorkList(library, facets) @@ -1253,7 +1297,7 @@ def factory(library, facets): feed_title="QA test feed", controller_name="qa_feed", facet_class=JackpotFacets, - worklist_factory=factory + worklist_factory=factory, ) def qa_series_feed(self, feed_class=AcquisitionFeed): @@ -1263,6 +1307,7 @@ def qa_series_feed(self, feed_class=AcquisitionFeed): :param feed_class: Class to substitute for AcquisitionFeed during tests. """ + def factory(library, facets): wl = WorkList() wl.initialize(library) @@ -1273,7 +1318,7 @@ def factory(library, facets): feed_title="QA series test feed", controller_name="qa_series_feed", facet_class=HasSeriesFacets, - worklist_factory=factory + worklist_factory=factory, ) @@ -1297,7 +1342,11 @@ def download_page(self): try: exporter = MARCExporter.from_config(library) except CannotLoadConfiguration as e: - body += "

" + _("No MARC exporter is currently configured for this library.") + "

" + body += ( + "

" + + _("No MARC exporter is currently configured for this library.") + + "

" + ) if len(library.cachedmarcfiles) < 1 and exporter: body += "

" + _("MARC files aren't ready to download yet.") + "

" @@ -1313,7 +1362,9 @@ def download_page(self): # TODO: By default the MARC script only caches one level of lanes, # so sorting by priority is good enough. - lanes = sorted(list(files_by_lane.keys()), key=lambda x: x.priority if x else -1) + lanes = sorted( + list(files_by_lane.keys()), key=lambda x: x.priority if x else -1 + ) for lane in lanes: files = files_by_lane[lane] @@ -1322,8 +1373,14 @@ def download_page(self): if files.get("full"): file = files.get("full") full_url = file.representation.mirror_url - full_label = _("Full file - last updated %(update_time)s", update_time=file.end_time.strftime(time_format)) - body += '%s' % (files.get("full").representation.mirror_url, full_label) + full_label = _( + "Full file - last updated %(update_time)s", + update_time=file.end_time.strftime(time_format), + ) + body += '%s' % ( + files.get("full").representation.mirror_url, + full_label, + ) if files.get("updates"): body += "

%s

" % _("Update-only files") @@ -1331,10 +1388,15 @@ def download_page(self): files.get("updates").sort(key=lambda x: x.end_time) for update in files.get("updates"): update_url = update.representation.mirror_url - update_label = _("Updates from %(start_time)s to %(end_time)s", - start_time=update.start_time.strftime(time_format), - end_time=update.end_time.strftime(time_format)) - body += '
  • %s
  • ' % (update_url, update_label) + update_label = _( + "Updates from %(start_time)s to %(end_time)s", + start_time=update.start_time.strftime(time_format), + end_time=update.end_time.strftime(time_format), + ) + body += '
  • %s
  • ' % ( + update_url, + update_label, + ) body += "" body += "" @@ -1342,13 +1404,11 @@ def download_page(self): html = self.DOWNLOAD_TEMPLATE % dict(body=body) headers = dict() - headers['Content-Type'] = "text/html" + headers["Content-Type"] = "text/html" return Response(html, 200, headers) -class LoanController(CirculationManagerController): - - +class LoanController(CirculationManagerController): def sync(self): """Sync the authenticated patron's loans and holds with all third-party providers. @@ -1359,9 +1419,7 @@ def sync(self): # Save some time if we don't believe the patron's loans or holds have # changed since the last time the client requested this feed. - response = self.handle_conditional_request( - patron.last_loan_activity_sync - ) + response = self.handle_conditional_request(patron.last_loan_activity_sync) if isinstance(response, Response): return response @@ -1369,7 +1427,7 @@ def sync(self): # as a quick way of checking authentication. Does this still happen? # It shouldn't -- the patron profile feed should be used instead. # If it's not used, we can take this out. - if flask.request.method=='HEAD': + if flask.request.method == "HEAD": return Response() # First synchronize our local list of loans and holds with all @@ -1387,9 +1445,7 @@ def sync(self): ) # Then make the feed. - return LibraryLoanAndHoldAnnotator.active_loans_for( - self.circulation, patron - ) + return LibraryLoanAndHoldAnnotator.active_loans_for(self.circulation, patron) def borrow(self, identifier_type, identifier, mechanism_id=None): """Create a new loan or hold for a book. @@ -1410,9 +1466,7 @@ def borrow(self, identifier_type, identifier, mechanism_id=None): if not result: # No LicensePools were found and no ProblemDetail # was returned. Send a generic ProblemDetail. - return NO_LICENSES.detailed( - _("I've never heard of this work.") - ) + return NO_LICENSES.detailed(_("I've never heard of this work.")) if isinstance(result, ProblemDetail): # There was a problem determining the appropriate # LicensePool to use. @@ -1436,9 +1490,9 @@ def borrow(self, identifier_type, identifier, mechanism_id=None): # serve a feed that talks about the hold. response_kwargs = {} if is_new: - response_kwargs['status'] = 201 + response_kwargs["status"] = 201 else: - response_kwargs['status'] = 200 + response_kwargs["status"] = 200 return LibraryLoanAndHoldAnnotator.single_item_feed( self.circulation, loan_or_hold, **response_kwargs ) @@ -1468,7 +1522,7 @@ def _borrow(self, patron, credential, pool, mechanism): except NoOpenAccessDownload as e: result = NO_LICENSES.detailed( _("Couldn't find an open-access download link for this book."), - status_code=404 + status_code=404, ) except PatronAuthorizationFailedException as e: result = INVALID_CREDENTIALS @@ -1480,7 +1534,10 @@ def _borrow(self, patron, credential, pool, mechanism): ) except OutstandingFines as e: result = OUTSTANDING_FINES.detailed( - _("You must pay your $%(fine_amount).2f outstanding fines before you can borrow more books.", fine_amount=patron.fines) + _( + "You must pay your $%(fine_amount).2f outstanding fines before you can borrow more books.", + fine_amount=patron.fines, + ) ) except AuthorizationExpired as e: result = e.as_problem_detail_document(debug=False) @@ -1505,7 +1562,9 @@ def _borrow(self, patron, credential, pool, mechanism): result = HOLD_FAILED return result, is_new - def best_lendable_pool(self, library, patron, identifier_type, identifier, mechanism_id): + def best_lendable_pool( + self, library, patron, identifier_type, identifier, mechanism_id + ): """ Of the available LicensePools for the given Identifier, return the one that's the best candidate for loaning out right now. @@ -1513,9 +1572,7 @@ def best_lendable_pool(self, library, patron, identifier_type, identifier, mecha :return: A Loan if this patron already has an active loan, otherwise a LicensePool. """ # Turn source + identifier into a set of LicensePools - pools = self.load_licensepools( - library, identifier_type, identifier - ) + pools = self.load_licensepools(library, identifier_type, identifier) if isinstance(pools, ProblemDetail): # Something went wrong. return pools @@ -1524,10 +1581,13 @@ def best_lendable_pool(self, library, patron, identifier_type, identifier, mecha mechanism = None problem_doc = None - existing_loans = self._db.query(Loan).filter( - Loan.license_pool_id.in_([lp.id for lp in pools]), - Loan.patron==patron - ).all() + existing_loans = ( + self._db.query(Loan) + .filter( + Loan.license_pool_id.in_([lp.id for lp in pools]), Loan.patron == patron + ) + .all() + ) if existing_loans: # The patron already has at least one loan on this book already. # To make the "borrow" operation idempotent, return one of @@ -1561,9 +1621,11 @@ def best_lendable_pool(self, library, patron, identifier_type, identifier, mecha # But there might be many such LicensePools, and we want # to pick the one that will get the book to the patron # with the shortest wait. - if (not best + if ( + not best or pool.licenses_available > best.licenses_available - or pool.patrons_in_hold_queue < best.patrons_in_hold_queue): + or pool.patrons_in_hold_queue < best.patrons_in_hold_queue + ): best = pool if not best: @@ -1652,7 +1714,7 @@ def fulfill(self, license_pool_id, mechanism_id=None, part=None, do_get=None): if not mechanism: # See if the loan already has a mechanism set. We can use that. if loan and loan.fulfillment: - mechanism = loan.fulfillment + mechanism = loan.fulfillment else: return BAD_DELIVERY_MECHANISM.detailed( _("You must specify a delivery mechanism to fulfill this loan.") @@ -1662,36 +1724,36 @@ def fulfill(self, license_pool_id, mechanism_id=None, part=None, do_get=None): # an appropriate link to this controller. def fulfill_part_url(part): return url_for( - "fulfill", license_pool_id=requested_license_pool.id, + "fulfill", + license_pool_id=requested_license_pool.id, mechanism_id=mechanism.delivery_mechanism.id, library_short_name=library.short_name, - part=str(part), _external=True + part=str(part), + _external=True, ) try: fulfillment = self.circulation.fulfill( - patron, credential, requested_license_pool, mechanism, - part=part, fulfill_part_url=fulfill_part_url + patron, + credential, + requested_license_pool, + mechanism, + part=part, + fulfill_part_url=fulfill_part_url, ) except DeliveryMechanismConflict as e: return DELIVERY_CONFLICT.detailed(str(e)) except NoActiveLoan as e: return NO_ACTIVE_LOAN.detailed( - _('Can\'t fulfill loan because you have no active loan for this book.'), - status_code=e.status_code + _("Can't fulfill loan because you have no active loan for this book."), + status_code=e.status_code, ) except CannotFulfill as e: - return CANNOT_FULFILL.with_debug( - str(e), status_code=e.status_code - ) + return CANNOT_FULFILL.with_debug(str(e), status_code=e.status_code) except FormatNotAvailable as e: - return NO_ACCEPTABLE_FORMAT.with_debug( - str(e), status_code=e.status_code - ) + return NO_ACCEPTABLE_FORMAT.with_debug(str(e), status_code=e.status_code) except DeliveryMechanismError as e: - return BAD_DELIVERY_MECHANISM.with_debug( - str(e), status_code=e.status_code - ) + return BAD_DELIVERY_MECHANISM.with_debug(str(e), status_code=e.status_code) # A subclass of FulfillmentInfo may want to bypass the whole # response creation process. @@ -1701,9 +1763,12 @@ def fulfill_part_url(part): headers = dict() encoding_header = dict() - if (fulfillment.data_source_name == DataSource.ENKI - and mechanism.delivery_mechanism.drm_scheme_media_type == DeliveryMechanism.NO_DRM): - encoding_header["Accept-Encoding"] = "deflate" + if ( + fulfillment.data_source_name == DataSource.ENKI + and mechanism.delivery_mechanism.drm_scheme_media_type + == DeliveryMechanism.NO_DRM + ): + encoding_header["Accept-Encoding"] = "deflate" if mechanism.delivery_mechanism.is_streaming: # If this is a streaming delivery mechanism, create an OPDS entry @@ -1734,14 +1799,16 @@ def fulfill_part_url(part): # of redirecting to it, since it may be downloaded through an # indirect acquisition link. try: - status_code, headers, content = do_get(fulfillment.content_link, headers=encoding_header) + status_code, headers, content = do_get( + fulfillment.content_link, headers=encoding_header + ) headers = dict(headers) except RemoteIntegrationException as e: return e.as_problem_detail_document(debug=False) else: status_code = 200 if fulfillment.content_type: - headers['Content-Type'] = fulfillment.content_type + headers["Content-Type"] = fulfillment.content_type return Response(response=content, status=status_code, headers=headers) @@ -1788,12 +1855,15 @@ def revoke(self, license_pool_id): if not loan and not hold: if not pool.work: - title = 'this book' + title = "this book" else: title = '"%s"' % pool.work.title return NO_ACTIVE_LOAN_OR_HOLD.detailed( - _('Can\'t revoke because you have no active loan or hold for "%(title)s".', title=title), - status_code=404 + _( + 'Can\'t revoke because you have no active loan or hold for "%(title)s".', + title=title, + ), + status_code=404, ) header = self.authorization_header() @@ -1802,11 +1872,15 @@ def revoke(self, license_pool_id): try: self.circulation.revoke_loan(patron, credential, pool) except RemoteRefusedReturn as e: - title = _("Loan deleted locally but remote refused. Loan is likely to show up again on next sync.") + title = _( + "Loan deleted locally but remote refused. Loan is likely to show up again on next sync." + ) return COULD_NOT_MIRROR_TO_REMOTE.detailed(title, status_code=503) except CannotReturn as e: title = _("Loan deleted locally but remote failed.") - return COULD_NOT_MIRROR_TO_REMOTE.detailed(title, 503).with_debug(str(e)) + return COULD_NOT_MIRROR_TO_REMOTE.detailed(title, 503).with_debug( + str(e) + ) elif hold: if not self.circulation.can_revoke_hold(pool, hold): title = _("Cannot release a hold once it enters reserved state.") @@ -1822,7 +1896,7 @@ def revoke(self, license_pool_id): return AcquisitionFeed.single_entry(self._db, work, annotator) def detail(self, identifier_type, identifier): - if flask.request.method=='DELETE': + if flask.request.method == "DELETE": return self.revoke_loan_or_hold(identifier_type, identifier) patron = flask.request.patron @@ -1839,45 +1913,52 @@ def detail(self, identifier_type, identifier): if not loan and not hold: return NO_ACTIVE_LOAN_OR_HOLD.detailed( - _('You have no active loan or hold for "%(title)s".', title=pool.work.title), - status_code=404 + _( + 'You have no active loan or hold for "%(title)s".', + title=pool.work.title, + ), + status_code=404, ) - if flask.request.method == 'GET': + if flask.request.method == "GET": if loan: item = loan else: item = hold - return LibraryLoanAndHoldAnnotator.single_item_feed( - self.circulation, item - ) + return LibraryLoanAndHoldAnnotator.single_item_feed(self.circulation, item) -class AnnotationController(CirculationManagerController): +class AnnotationController(CirculationManagerController): def container(self, identifier=None, accept_post=True): headers = dict() if accept_post: - headers['Allow'] = 'GET,HEAD,OPTIONS,POST' - headers['Accept-Post'] = AnnotationWriter.CONTENT_TYPE + headers["Allow"] = "GET,HEAD,OPTIONS,POST" + headers["Accept-Post"] = AnnotationWriter.CONTENT_TYPE else: - headers['Allow'] = 'GET,HEAD,OPTIONS' + headers["Allow"] = "GET,HEAD,OPTIONS" - if flask.request.method=='HEAD': + if flask.request.method == "HEAD": return Response(status=200, headers=headers) patron = flask.request.patron - if flask.request.method == 'GET': - headers['Link'] = ['; rel="type"', - '; rel="http://www.w3.org/ns/ldp#constrainedBy"'] - headers['Content-Type'] = AnnotationWriter.CONTENT_TYPE + if flask.request.method == "GET": + headers["Link"] = [ + '; rel="type"', + '; rel="http://www.w3.org/ns/ldp#constrainedBy"', + ] + headers["Content-Type"] = AnnotationWriter.CONTENT_TYPE - container, timestamp = AnnotationWriter.annotation_container_for(patron, identifier=identifier) + container, timestamp = AnnotationWriter.annotation_container_for( + patron, identifier=identifier + ) etag = 'W/""' if timestamp: etag = 'W/"%s"' % timestamp - headers['Last-Modified'] = format_date_time(mktime(timestamp.timetuple())) - headers['ETag'] = etag + headers["Last-Modified"] = format_date_time( + mktime(timestamp.timetuple()) + ) + headers["ETag"] = etag content = json.dumps(container) return Response(content, status=200, headers=headers) @@ -1890,57 +1971,54 @@ def container(self, identifier=None, accept_post=True): content = json.dumps(AnnotationWriter.detail(annotation)) status_code = 200 - headers['Link'] = '; rel="type"' - headers['Content-Type'] = AnnotationWriter.CONTENT_TYPE + headers["Link"] = '; rel="type"' + headers["Content-Type"] = AnnotationWriter.CONTENT_TYPE return Response(content, status_code, headers) def container_for_work(self, identifier_type, identifier): id_obj, ignore = Identifier.for_foreign_id( - self._db, identifier_type, identifier) + self._db, identifier_type, identifier + ) return self.container(identifier=id_obj, accept_post=False) def detail(self, annotation_id): headers = dict() - headers['Allow'] = 'GET,HEAD,OPTIONS,DELETE' + headers["Allow"] = "GET,HEAD,OPTIONS,DELETE" - if flask.request.method=='HEAD': + if flask.request.method == "HEAD": return Response(status=200, headers=headers) patron = flask.request.patron annotation = get_one( - self._db, Annotation, - patron=patron, - id=annotation_id, - active=True) + self._db, Annotation, patron=patron, id=annotation_id, active=True + ) if not annotation: return NO_ANNOTATION - if flask.request.method == 'DELETE': + if flask.request.method == "DELETE": annotation.set_inactive() return Response() content = json.dumps(AnnotationWriter.detail(annotation)) status_code = 200 - headers['Link'] = '; rel="type"' - headers['Content-Type'] = AnnotationWriter.CONTENT_TYPE + headers["Link"] = '; rel="type"' + headers["Content-Type"] = AnnotationWriter.CONTENT_TYPE return Response(content, status_code, headers) class WorkController(CirculationManagerController): - def _lane_details(self, languages, audiences): if languages: - languages = languages.split(',') + languages = languages.split(",") if audiences: - audiences = [urllib.parse.unquote_plus(a) for a in audiences.split(',')] + audiences = [urllib.parse.unquote_plus(a) for a in audiences.split(",")] return languages, audiences def contributor( - self, contributor_name, languages, audiences, - feed_class=AcquisitionFeed + self, contributor_name, languages, audiences, feed_class=AcquisitionFeed ): """Serve a feed of books written by a particular author""" library = flask.request.library @@ -1986,9 +2064,14 @@ def contributor( ) return feed_class.page( - _db=self._db, title=lane.display_name, url=url, worklist=lane, - facets=facets, pagination=pagination, - annotator=annotator, search_engine=search_engine + _db=self._db, + title=lane.display_name, + url=url, + worklist=lane, + facets=facets, + pagination=pagination, + annotator=annotator, + search_engine=search_engine, ) def permalink(self, identifier_type, identifier): @@ -2029,12 +2112,12 @@ def permalink(self, identifier_type, identifier): annotator = self.manager.annotator(lane=None) return AcquisitionFeed.single_entry( - self._db, work, annotator, - max_age=OPDSFeed.DEFAULT_MAX_AGE + self._db, work, annotator, max_age=OPDSFeed.DEFAULT_MAX_AGE ) - def related(self, identifier_type, identifier, novelist_api=None, - feed_class=AcquisitionFeed): + def related( + self, identifier_type, identifier, novelist_api=None, feed_class=AcquisitionFeed + ): """Serve a groups feed of books related to a given book.""" library = flask.request.library @@ -2047,21 +2130,18 @@ def related(self, identifier_type, identifier, novelist_api=None, return search_engine try: - lane_name = "Books Related to %s by %s" % ( - work.title, work.author - ) - lane = RelatedBooksLane( - library, work, lane_name, novelist_api=novelist_api - ) + lane_name = "Books Related to %s by %s" % (work.title, work.author) + lane = RelatedBooksLane(library, work, lane_name, novelist_api=novelist_api) except ValueError as e: # No related books were found. return NO_SUCH_LANE.detailed(str(e)) facets = self.manager.load_facets_from_request( - worklist=lane, base_class=FeaturedFacets, + worklist=lane, + base_class=FeaturedFacets, base_class_constructor_kwargs=dict( minimum_featured_quality=library.minimum_featured_quality - ) + ), ) if isinstance(facets, ProblemDetail): return facets @@ -2073,13 +2153,18 @@ def related(self, identifier_type, identifier, novelist_api=None, ) return feed_class.groups( - _db=self._db, title=lane.DISPLAY_NAME, - url=url, worklist=lane, annotator=annotator, - facets=facets, search_engine=search_engine + _db=self._db, + title=lane.DISPLAY_NAME, + url=url, + worklist=lane, + annotator=annotator, + facets=facets, + search_engine=search_engine, ) - def recommendations(self, identifier_type, identifier, novelist_api=None, - feed_class=AcquisitionFeed): + def recommendations( + self, identifier_type, identifier, novelist_api=None, feed_class=AcquisitionFeed + ): """Serve a feed of recommendations related to a given book.""" library = flask.request.library @@ -2094,8 +2179,10 @@ def recommendations(self, identifier_type, identifier, novelist_api=None, lane_name = "Recommendations for %s by %s" % (work.title, work.author) try: lane = RecommendationLane( - library=library, work=work, display_name=lane_name, - novelist_api=novelist_api + library=library, + work=work, + display_name=lane_name, + novelist_api=novelist_api, ) except CannotLoadConfiguration as e: # NoveList isn't configured. @@ -2120,9 +2207,14 @@ def recommendations(self, identifier_type, identifier, novelist_api=None, ) return feed_class.page( - _db=self._db, title=lane.DISPLAY_NAME, url=url, worklist=lane, - facets=facets, pagination=pagination, - annotator=annotator, search_engine=search_engine + _db=self._db, + title=lane.DISPLAY_NAME, + url=url, + worklist=lane, + facets=facets, + pagination=pagination, + annotator=annotator, + search_engine=search_engine, ) def report(self, identifier_type, identifier): @@ -2139,11 +2231,11 @@ def report(self, identifier_type, identifier): # Something went wrong. return pools - if flask.request.method == 'GET': + if flask.request.method == "GET": # Return a list of valid URIs to use as the type of a problem detail # document. data = "\n".join(Complaint.VALID_TYPES) - return Response(data, 200, {"Content-Type" : "text/uri-list"}) + return Response(data, 200, {"Content-Type": "text/uri-list"}) data = flask.request.data controller = ComplaintController() @@ -2161,8 +2253,7 @@ def series(self, series_name, languages, audiences, feed_class=AcquisitionFeed): languages, audiences = self._lane_details(languages, audiences) lane = SeriesLane( - library, series_name=series_name, languages=languages, - audiences=audiences + library, series_name=series_name, languages=languages, audiences=audiences ) facets = self.manager.load_facets_from_request( @@ -2179,9 +2270,14 @@ def series(self, series_name, languages, audiences, feed_class=AcquisitionFeed): url = annotator.feed_url(lane, facets=facets, pagination=pagination) return feed_class.page( - _db=self._db, title=lane.display_name, url=url, worklist=lane, - facets=facets, pagination=pagination, - annotator=annotator, search_engine=search_engine + _db=self._db, + title=lane.display_name, + url=url, + worklist=lane, + facets=facets, + pagination=pagination, + annotator=annotator, + search_engine=search_engine, ) @@ -2190,8 +2286,7 @@ class ProfileController(CirculationManagerController): @property def _controller(self): - """Instantiate a CoreProfileController that actually does the work. - """ + """Instantiate a CoreProfileController that actually does the work.""" # TODO: Probably better to use request_patron and check for # None here. patron = self.authenticated_patron_from_request() @@ -2201,7 +2296,7 @@ def _controller(self): def protocol(self): """Handle a UPMP request.""" controller = self._controller - if flask.request.method == 'GET': + if flask.request.method == "GET": result = controller.get() else: result = controller.put(flask.request.headers, flask.request.data) @@ -2211,7 +2306,6 @@ def protocol(self): class URNLookupController(CoreURNLookupController): - def __init__(self, manager): self.manager = manager super(URNLookupController, self).__init__(manager._db) @@ -2224,13 +2318,10 @@ def work_lookup(self, route_name): library = flask.request.library top_level_worklist = self.manager.top_level_lanes[library.id] annotator = CirculationManagerAnnotator(top_level_worklist) - return super(URNLookupController, self).work_lookup( - annotator, route_name - ) + return super(URNLookupController, self).work_lookup(annotator, route_name) class AnalyticsController(CirculationManagerController): - def track_event(self, identifier_type, identifier, event_type): # TODO: It usually doesn't matter, but there should be # a way to distinguish between different LicensePools for the @@ -2239,16 +2330,15 @@ def track_event(self, identifier_type, identifier, event_type): library = flask.request.library # Authentication on the AnalyticsController is optional, # so flask.request.patron may or may not be set. - patron = getattr(flask.request, 'patron', None) + patron = getattr(flask.request, "patron", None) neighborhood = None if patron: - neighborhood = getattr(patron, 'neighborhood', None) + neighborhood = getattr(patron, "neighborhood", None) pools = self.load_licensepools(library, identifier_type, identifier) if isinstance(pools, ProblemDetail): return pools self.manager.analytics.collect_event( - library, pools[0], event_type, utc_now(), - neighborhood=neighborhood + library, pools[0], event_type, utc_now(), neighborhood=neighborhood ) return Response({}, 200) else: @@ -2272,26 +2362,33 @@ def notify(self, loan_id): if collection.protocol != ODLAPI.NAME: return INVALID_LOAN_FOR_ODL_NOTIFICATION - api = self.manager.circulation_apis[library.id].api_for_license_pool(loan.license_pool) + api = self.manager.circulation_apis[library.id].api_for_license_pool( + loan.license_pool + ) api.update_loan(loan, json.loads(status_doc)) - return Response(_('Success'), 200) + return Response(_("Success"), 200) + class SharedCollectionController(CirculationManagerController): """Enable this circulation manager to share its collections with libraries on other circulation managers, for collection types that support it.""" + def info(self, collection_name): """Return an OPDS2 catalog-like document with a link to register.""" collection = get_one(self._db, Collection, name=collection_name) if not collection: return NO_SUCH_COLLECTION - register_url = self.url_for('shared_collection_register', - collection_name=collection_name) - register_link = dict(href=register_url, rel='register') + register_url = self.url_for( + "shared_collection_register", collection_name=collection_name + ) + register_link = dict(href=register_url, rel="register") content = json.dumps(dict(links=[register_link])) headers = dict() - headers["Content-Type"] = "application/opds+json;profile=https://librarysimplified.org/rel/profile/directory" + headers[ + "Content-Type" + ] = "application/opds+json;profile=https://librarysimplified.org/rel/profile/directory" return Response(content, 200, headers) def load_collection(self, collection_name): @@ -2317,9 +2414,9 @@ def register(self, collection_name): return Response(json.dumps(response), 200) def authenticated_client_from_request(self): - header = flask.request.headers.get('Authorization') - if header and 'bearer' in header.lower(): - shared_secret = base64.b64decode(header.split(' ')[1]) + header = flask.request.headers.get("Authorization") + if header and "bearer" in header.lower(): + shared_secret = base64.b64decode(header.split(" ")[1]) client = IntegrationClient.authenticate(self._db, shared_secret) if client: return client @@ -2336,9 +2433,7 @@ def loan_info(self, collection_name, loan_id): if not loan or loan.license_pool.collection != collection: return LOAN_NOT_FOUND - return SharedCollectionLoanAndHoldAnnotator.single_item_feed( - collection, loan - ) + return SharedCollectionLoanAndHoldAnnotator.single_item_feed(collection, loan) def borrow(self, collection_name, identifier_type, identifier, hold_id): collection = self.load_collection(collection_name) @@ -2348,19 +2443,18 @@ def borrow(self, collection_name, identifier_type, identifier, hold_id): if isinstance(client, ProblemDetail): return client if identifier_type and identifier: - pools = self._db.query(LicensePool).join( - LicensePool.identifier).filter( - Identifier.type==identifier_type - ).filter( - Identifier.identifier==identifier - ).filter( - LicensePool.collection_id==collection.id - ).all() + pools = ( + self._db.query(LicensePool) + .join(LicensePool.identifier) + .filter(Identifier.type == identifier_type) + .filter(Identifier.identifier == identifier) + .filter(LicensePool.collection_id == collection.id) + .all() + ) if not pools: return NO_LICENSES.detailed( - _("The item you're asking about (%s/%s) isn't in this collection.") % ( - identifier_type, identifier - ) + _("The item you're asking about (%s/%s) isn't in this collection.") + % (identifier_type, identifier) ) pool = pools[0] hold = None @@ -2404,7 +2498,9 @@ def revoke_loan(self, collection_name, loan_id): return COULD_NOT_MIRROR_TO_REMOTE.detailed(str(e)) return Response(_("Success"), 200) - def fulfill(self, collection_name, loan_id, mechanism_id, do_get=HTTP.get_with_timeout): + def fulfill( + self, collection_name, loan_id, mechanism_id, do_get=HTTP.get_with_timeout + ): collection = self.load_collection(collection_name) if isinstance(collection, ProblemDetail): return collection @@ -2417,9 +2513,7 @@ def fulfill(self, collection_name, loan_id, mechanism_id, do_get=HTTP.get_with_t mechanism = None if mechanism_id: - mechanism = self.load_licensepooldelivery( - loan.license_pool, mechanism_id - ) + mechanism = self.load_licensepooldelivery(loan.license_pool, mechanism_id) if isinstance(mechanism, ProblemDetail): return mechanism @@ -2433,7 +2527,9 @@ def fulfill(self, collection_name, loan_id, mechanism_id, do_get=HTTP.get_with_t ) try: - fulfillment = self.shared_collection.fulfill(collection, client, loan, mechanism) + fulfillment = self.shared_collection.fulfill( + collection, client, loan, mechanism + ) except AuthorizationFailedException as e: return INVALID_CREDENTIALS.detailed(str(e)) except CannotFulfill as e: @@ -2456,7 +2552,7 @@ def fulfill(self, collection_name, loan_id, mechanism_id, do_get=HTTP.get_with_t else: status_code = 200 if fulfillment.content_type: - headers['Content-Type'] = fulfillment.content_type + headers["Content-Type"] = fulfillment.content_type return Response(content, status_code, headers) @@ -2471,9 +2567,7 @@ def hold_info(self, collection_name, hold_id): if not hold or not hold.license_pool.collection == collection: return HOLD_NOT_FOUND - return SharedCollectionLoanAndHoldAnnotator.single_item_feed( - collection, hold - ) + return SharedCollectionLoanAndHoldAnnotator.single_item_feed(collection, hold) def revoke_hold(self, collection_name, hold_id): collection = self.load_collection(collection_name) @@ -2496,13 +2590,18 @@ def revoke_hold(self, collection_name, hold_id): return CANNOT_RELEASE_HOLD.detailed(str(e)) return Response(_("Success"), 200) + class StaticFileController(CirculationManagerController): def static_file(self, directory, filename): cache_timeout = ConfigurationSetting.sitewide( self._db, Configuration.STATIC_FILE_CACHE_TIME ).int_value - return flask.send_from_directory(directory, filename, cache_timeout=cache_timeout) + return flask.send_from_directory( + directory, filename, cache_timeout=cache_timeout + ) def image(self, filename): - directory = os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", "resources", "images") + directory = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "..", "resources", "images" + ) return self.static_file(directory, filename) diff --git a/api/coverage.py b/api/coverage.py index e940d2617e..1b40fa9e6f 100644 --- a/api/coverage.py +++ b/api/coverage.py @@ -5,11 +5,13 @@ so on. """ import logging -from lxml import etree from io import StringIO + +from lxml import etree + from core.coverage import ( - CoverageFailure, CollectionCoverageProvider, + CoverageFailure, WorkCoverageProvider, ) from core.model import ( @@ -23,9 +25,6 @@ LicensePool, WorkCoverageRecord, ) -from core.util.opds_writer import ( - OPDSFeed -) from core.opds_import import ( AccessNotAuthenticated, MetadataWranglerOPDSLookup, @@ -33,15 +32,15 @@ OPDSXMLParser, SimplifiedOPDSLookup, ) -from core.util.http import ( - RemoteIntegrationException, -) +from core.util.http import RemoteIntegrationException +from core.util.opds_writer import OPDSFeed class RegistrarImporter(OPDSImporter): """We are successful whenever the metadata wrangler puts an identifier into the catalog, even if no metadata is immediately available. """ + SUCCESS_STATUS_CODES = [200, 201, 202] @@ -50,6 +49,7 @@ class ReaperImporter(OPDSImporter): identifier has been removed, and also if the identifier wasn't in the catalog in the first place. """ + SUCCESS_STATUS_CODES = [200, 404] @@ -57,6 +57,7 @@ class OPDSImportCoverageProvider(CollectionCoverageProvider): """Provide coverage for identifiers by looking them up, in batches, using the Simplified lookup protocol. """ + DEFAULT_BATCH_SIZE = 25 OPDS_IMPORTER_CLASS = OPDSImporter @@ -70,8 +71,12 @@ def __init__(self, collection, lookup_client, **kwargs): def process_batch(self, batch): """Perform a Simplified lookup and import the resulting OPDS feed.""" - (imported_editions, pools, works, - error_messages_by_id) = self.lookup_and_import_batch(batch) + ( + imported_editions, + pools, + works, + error_messages_by_id, + ) = self.lookup_and_import_batch(batch) results = [] imported_identifiers = set() @@ -89,9 +94,7 @@ def process_batch(self, batch): self.finalize_license_pool(pool) else: msg = "OPDS import operation imported LicensePool, but no Edition." - results.append( - self.failure(identifier, msg, transient=True) - ) + results.append(self.failure(identifier, msg, transient=True)) # Anything left over is either a CoverageFailure, or an # Identifier that used to be a CoverageFailure, indicating @@ -121,8 +124,7 @@ def finalize_license_pool(self, pool): @property def api_method(self): - """The method to call to fetch an OPDS feed from the remote server. - """ + """The method to call to fetch an OPDS feed from the remote server.""" return self.lookup_client.lookup def lookup_and_import_batch(self, batch): @@ -158,9 +160,10 @@ def import_feed_response(self, response, id_mapping): """ self.lookup_client.check_content_type(response) importer = self.OPDS_IMPORTER_CLASS( - self._db, self.collection, + self._db, + self.collection, identifier_mapping=id_mapping, - data_source_name=self.data_source.name + data_source_name=self.data_source.name, ) return importer.import_from_feed(response.text) @@ -183,9 +186,7 @@ def queue_import_results(self, editions, pools, works, messages_by_id): def finalize_license_pool(self, license_pool): self.finalized.append(license_pool) - super(MockOPDSImportCoverageProvider, self).finalize_license_pool( - license_pool - ) + super(MockOPDSImportCoverageProvider, self).finalize_license_pool(license_pool) def lookup_and_import_batch(self, batch): self.batches.append(batch) diff --git a/api/custom_index.py b/api/custom_index.py index 97c2602311..5902cffd52 100644 --- a/api/custom_index.py +++ b/api/custom_index.py @@ -8,22 +8,17 @@ from flask import Response from flask_babel import lazy_gettext as _ - from sqlalchemy.orm.session import Session -from .config import CannotLoadConfiguration from core.app_server import cdn_url_for -from core.model import ( - get_one, -) from core.lane import Lane -from core.model import ( - ConfigurationSetting, - ExternalIntegration, -) +from core.model import ConfigurationSetting, ExternalIntegration, get_one from core.util.datetime_helpers import utc_now from core.util.opds_writer import OPDSFeed +from .config import CannotLoadConfiguration + + class CustomIndexView(object): """A custom view that replaces the default OPDS view for a library. @@ -36,6 +31,7 @@ class CustomIndexView(object): should not store any objects obtained from the database without disconnecting them from their session. """ + BY_PROTOCOL = {} GOAL = "custom_index" @@ -98,11 +94,13 @@ class COPPAGate(CustomIndexView): REQUIREMENT_NOT_MET_LANE = "requirement_not_met_lane" SETTINGS = [ - { "key": REQUIREMENT_MET_LANE, - "label": _("ID of lane for patrons who are 13 or older"), + { + "key": REQUIREMENT_MET_LANE, + "label": _("ID of lane for patrons who are 13 or older"), }, - { "key": REQUIREMENT_NOT_MET_LANE, - "label": _("ID of lane for patrons who are under 13"), + { + "key": REQUIREMENT_NOT_MET_LANE, + "label": _("ID of lane for patrons who are under 13"), }, ] @@ -130,8 +128,8 @@ def _load_lane(self, library, lane_id): raise CannotLoadConfiguration("No lane with ID: %s" % lane_id) if lane.library != library: raise CannotLoadConfiguration( - "Lane %d is for the wrong library (%s, I need %s)" % - (lane.id, lane.library.name, library.name) + "Lane %d is for the wrong library (%s, I need %s)" + % (lane.id, lane.library.name, library.name) ) return lane @@ -139,40 +137,34 @@ def __call__(self, library, annotator, url_for=None): """Render an OPDS navigation feed that lets the patron choose a root lane on their own, without providing any credentials. """ - if not hasattr(self, 'navigation_feed'): - self.navigation_feed = self._navigation_feed( - library, annotator, url_for - ) - headers = { "Content-Type": OPDSFeed.NAVIGATION_FEED_TYPE } + if not hasattr(self, "navigation_feed"): + self.navigation_feed = self._navigation_feed(library, annotator, url_for) + headers = {"Content-Type": OPDSFeed.NAVIGATION_FEED_TYPE} return Response(str(self.navigation_feed), 200, headers) def _navigation_feed(self, library, annotator, url_for=None): """Generate an OPDS feed for navigating the COPPA age gate.""" url_for = url_for or cdn_url_for - base_url = url_for('index', library_short_name=library.short_name) + base_url = url_for("index", library_short_name=library.short_name) # An entry for grown-ups. feed = OPDSFeed(title=library.name, url=base_url) opds = feed.feed yes_url = url_for( - 'acquisition_groups', + "acquisition_groups", library_short_name=library.short_name, - lane_identifier=self.yes_lane_id - ) - opds.append( - self.navigation_entry(yes_url, self.YES_TITLE, self.YES_CONTENT) + lane_identifier=self.yes_lane_id, ) + opds.append(self.navigation_entry(yes_url, self.YES_TITLE, self.YES_CONTENT)) # An entry for children. no_url = url_for( - 'acquisition_groups', + "acquisition_groups", library_short_name=library.short_name, - lane_identifier=self.no_lane_id - ) - opds.append( - self.navigation_entry(no_url, self.NO_TITLE, self.NO_CONTENT) + lane_identifier=self.no_lane_id, ) + opds.append(self.navigation_entry(no_url, self.NO_TITLE, self.NO_CONTENT)) # The gate tag is the thing that the SimplyE client actually uses. opds.append(self.gate_tag(self.URI, yes_url, no_url)) @@ -197,11 +189,10 @@ def navigation_entry(cls, href, title, content): E.id(href), E.title(str(title)), content_tag, - E.updated(OPDSFeed._strftime(now)) + E.updated(OPDSFeed._strftime(now)), ) OPDSFeed.add_link_to_entry( - entry, href=href, rel="subsection", - type=OPDSFeed.ACQUISITION_FEED_TYPE + entry, href=href, rel="subsection", type=OPDSFeed.ACQUISITION_FEED_TYPE ) return entry @@ -211,9 +202,10 @@ def gate_tag(cls, restriction, met_url, not_met_url): the client is faced with. """ tag = OPDSFeed.SIMPLIFIED.gate() - tag.attrib['restriction-met'] = met_url - tag.attrib['restriction-not-met'] = not_met_url - tag.attrib['restriction'] = restriction + tag.attrib["restriction-met"] = met_url + tag.attrib["restriction-not-met"] = not_met_url + tag.attrib["restriction"] = restriction return tag + CustomIndexView.register(COPPAGate) diff --git a/api/custom_patron_catalog.py b/api/custom_patron_catalog.py index ff270b00f4..fc851dd32f 100644 --- a/api/custom_patron_catalog.py +++ b/api/custom_patron_catalog.py @@ -4,20 +4,15 @@ from flask import Response from flask_babel import lazy_gettext as _ - from sqlalchemy.orm.session import Session -from .config import CannotLoadConfiguration -from core.model import ( - get_one, -) from core.lane import Lane -from core.model import ( - ConfigurationSetting, - ExternalIntegration, -) +from core.model import ConfigurationSetting, ExternalIntegration, get_one from core.util.opds_writer import OPDSFeed +from .config import CannotLoadConfiguration + + class CustomPatronCatalog(object): """An annotator for a library's authentication document. @@ -30,6 +25,7 @@ class CustomPatronCatalog(object): any objects obtained from the database without disconnecting them from their session. """ + BY_PROTOCOL = {} GOAL = "custom_patron_catalog" @@ -93,8 +89,8 @@ def _load_lane(cls, library, lane_id): raise CannotLoadConfiguration("No lane with ID: %s" % lane_id) if lane.library != library: raise CannotLoadConfiguration( - "Lane %d is for the wrong library (%s, I need %s)" % - (lane.id, lane.library.name, library.name) + "Lane %d is for the wrong library (%s, I need %s)" + % (lane.id, lane.library.name, library.name) ) return lane @@ -108,21 +104,23 @@ def replace_link(cls, doc, rel, **kwargs): :param kwargs: Add a new link with these attributes. :return: A modified authentication document. """ - links = [x for x in doc['links'] if x['rel'] != rel] + links = [x for x in doc["links"] if x["rel"] != rel] links.append(dict(rel=rel, **kwargs)) - doc['links'] = links + doc["links"] = links return doc class CustomRootLane(CustomPatronCatalog): """Send library patrons to a lane other than the root lane.""" + PROTOCOL = "Custom Root Lane" LANE = "lane" SETTINGS = [ - { "key": LANE, - "label": _("Send patrons to the lane with this ID."), + { + "key": LANE, + "label": _("Send patrons to the lane with this ID."), }, ] @@ -140,13 +138,17 @@ def __init__(self, library, integration): def annotate_authentication_document(self, library, doc, url_for): """Replace the 'start' link with a link to the configured Lane.""" root_url = url_for( - "acquisition_groups", library_short_name=library.short_name, - lane_identifier=self.lane_id, _external=True + "acquisition_groups", + library_short_name=library.short_name, + lane_identifier=self.lane_id, + _external=True, ) self.replace_link( - doc, 'start', href=root_url, type=OPDSFeed.ACQUISITION_FEED_TYPE + doc, "start", href=root_url, type=OPDSFeed.ACQUISITION_FEED_TYPE ) return doc + + CustomPatronCatalog.register(CustomRootLane) @@ -155,18 +157,24 @@ class COPPAGate(CustomPatronCatalog): PROTOCOL = "COPPA Age Gate" AUTHENTICATION_TYPE = "http://librarysimplified.org/terms/authentication/gate/coppa" - AUTHENTICATION_YES_REL = "http://librarysimplified.org/terms/rel/authentication/restriction-met" - AUTHENTICATION_NO_REL = "http://librarysimplified.org/terms/rel/authentication/restriction-not-met" + AUTHENTICATION_YES_REL = ( + "http://librarysimplified.org/terms/rel/authentication/restriction-met" + ) + AUTHENTICATION_NO_REL = ( + "http://librarysimplified.org/terms/rel/authentication/restriction-not-met" + ) REQUIREMENT_MET_LANE = "requirement_met_lane" REQUIREMENT_NOT_MET_LANE = "requirement_not_met_lane" SETTINGS = [ - { "key": REQUIREMENT_MET_LANE, - "label": _("ID of lane for patrons who are 13 or older"), + { + "key": REQUIREMENT_MET_LANE, + "label": _("ID of lane for patrons who are 13 or older"), }, - { "key": REQUIREMENT_NOT_MET_LANE, - "label": _("ID of lane for patrons who are under 13"), + { + "key": REQUIREMENT_NOT_MET_LANE, + "label": _("ID of lane for patrons who are under 13"), }, ] @@ -191,40 +199,40 @@ def annotate_authentication_document(self, library, doc, url_for): # A lane for grown-ups. yes_url = url_for( - 'acquisition_groups', library_short_name=library.short_name, - lane_identifier=self.yes_lane_id, _external=True + "acquisition_groups", + library_short_name=library.short_name, + lane_identifier=self.yes_lane_id, + _external=True, ) # A lane for children. no_url = url_for( - 'acquisition_groups', library_short_name=library.short_name, - lane_identifier=self.no_lane_id, _external=True + "acquisition_groups", + library_short_name=library.short_name, + lane_identifier=self.no_lane_id, + _external=True, ) # Replace the 'start' link with the childrens link. Any client # that doesn't understand the extensions will be safe from # grown-up content. feed = OPDSFeed.ACQUISITION_FEED_TYPE - self.replace_link(doc, 'start', href=no_url, type=feed) + self.replace_link(doc, "start", href=no_url, type=feed) # Add a custom authentication technique that # explains the COPPA gate. links = [ - dict(rel=self.AUTHENTICATION_YES_REL, href=yes_url, - type=feed), - dict(rel=self.AUTHENTICATION_NO_REL, href=no_url, - type=feed), + dict(rel=self.AUTHENTICATION_YES_REL, href=yes_url, type=feed), + dict(rel=self.AUTHENTICATION_NO_REL, href=no_url, type=feed), ] - authentication = dict( - type=self.AUTHENTICATION_TYPE, - links=links - ) + authentication = dict(type=self.AUTHENTICATION_TYPE, links=links) # It's an academic question whether this is replacing the existing # auth mechanisms or just adding another one, but for the moment # let's go with "adding another one". - doc.setdefault('authentication', []).append(authentication) + doc.setdefault("authentication", []).append(authentication) return doc -CustomPatronCatalog.register(COPPAGate) + +CustomPatronCatalog.register(COPPAGate) diff --git a/api/enki.py b/api/enki.py index 89bebb700d..3a8d47dbb8 100644 --- a/api/enki.py +++ b/api/enki.py @@ -1,33 +1,22 @@ -import time import datetime -import os import json import logging -from flask_babel import lazy_gettext as _ - -from .config import ( - CannotLoadConfiguration, -) - -from .circulation import ( - LoanInfo, - FulfillmentInfo, - BaseCirculationAPI -) - -from .circulation_exceptions import * +import os +import time -from .selftest import ( - HasSelfTests, - SelfTestResult, -) +from flask_babel import lazy_gettext as _ -from core.util.http import ( - HTTP, - RemoteIntegrationException, - RequestTimedOut, +from core.analytics import Analytics +from core.metadata_layer import ( + CirculationData, + ContributorData, + FormatData, + IdentifierData, + LinkData, + Metadata, + ReplacementPolicy, + SubjectData, ) - from core.model import ( CirculationEvent, Classification, @@ -44,45 +33,40 @@ Session, Subject, ) - -from core.metadata_layer import ( - CirculationData, - ContributorData, - FormatData, - IdentifierData, - LinkData, - Metadata, - ReplacementPolicy, - SubjectData, -) - from core.monitor import ( - Monitor, - IdentifierSweepMonitor, CollectionMonitor, + IdentifierSweepMonitor, + Monitor, TimelineMonitor, ) - -from core.analytics import Analytics from core.testing import DatabaseTest -from core.util.datetime_helpers import ( - from_timestamp, - strptime_utc, - utc_now, -) +from core.util.datetime_helpers import from_timestamp, strptime_utc, utc_now +from core.util.http import HTTP, RemoteIntegrationException, RequestTimedOut + +from .circulation import BaseCirculationAPI, FulfillmentInfo, LoanInfo +from .circulation_exceptions import * +from .config import CannotLoadConfiguration +from .selftest import HasSelfTests, SelfTestResult + class EnkiAPI(BaseCirculationAPI, HasSelfTests): PRODUCTION_BASE_URL = "https://enkilibrary.org/API/" - ENKI_LIBRARY_ID_KEY = 'enki_library_id' + ENKI_LIBRARY_ID_KEY = "enki_library_id" DESCRIPTION = _("Integrate an Enki collection.") SETTINGS = [ - { "key": ExternalIntegration.URL, "label": _("URL"), "default": PRODUCTION_BASE_URL, "required": True, "format": "url" }, + { + "key": ExternalIntegration.URL, + "label": _("URL"), + "default": PRODUCTION_BASE_URL, + "required": True, + "format": "url", + }, ] + BaseCirculationAPI.SETTINGS LIBRARY_SETTINGS = [ - { "key": ENKI_LIBRARY_ID_KEY, "label": _("Library ID"), "required": True }, + {"key": ENKI_LIBRARY_ID_KEY, "label": _("Library ID"), "required": True}, ] list_endpoint = "ListAPI" @@ -101,15 +85,15 @@ class EnkiAPI(BaseCirculationAPI, HasSelfTests): no_drm = DeliveryMechanism.NO_DRM delivery_mechanism_to_internal_format = { - (epub, no_drm): 'free', - (epub, adobe_drm): 'acs', + (epub, no_drm): "free", + (epub, adobe_drm): "acs", } # Enki API serves all responses with a 200 error code and a # text/html Content-Type. However, there's a string that # reliably shows up in error pages which is unlikely to show up # in normal API operation. - ERROR_INDICATOR = '

    Oops, an error occurred

    ' + ERROR_INDICATOR = "

    Oops, an error occurred

    " SET_DELIVERY_MECHANISM_AT = BaseCirculationAPI.FULFILL_STEP SERVICE_NAME = "Enki" @@ -119,8 +103,8 @@ def __init__(self, _db, collection): self._db = _db if collection.protocol != self.ENKI: raise ValueError( - "Collection protocol is %s, but passed into EnkiAPI!" % - collection.protocol + "Collection protocol is %s, but passed into EnkiAPI!" + % collection.protocol ) self.collection_id = collection.id @@ -133,8 +117,7 @@ def enki_library_id(self, library): """Find the Enki library ID for the given library.""" _db = Session.object_session(library) return ConfigurationSetting.for_library_and_externalintegration( - _db, self.ENKI_LIBRARY_ID_KEY, library, - self.external_integration(_db) + _db, self.ENKI_LIBRARY_ID_KEY, library, self.external_integration(_db) ).value @property @@ -146,15 +129,13 @@ def _run_self_tests(self, _db): now = utc_now() def count_loans_and_holds(): - """Count recent circulation events that affected loans or holds. - """ + """Count recent circulation events that affected loans or holds.""" one_hour_ago = now - datetime.timedelta(hours=1) count = len(list(self.recent_activity(one_hour_ago, now))) return "%s circulation events in the last hour" % count yield self.run_test( - "Counting recent circulation changes.", - count_loans_and_holds + "Counting recent circulation changes.", count_loans_and_holds ) def count_title_changes(): @@ -176,40 +157,53 @@ def count_title_changes(): yield result continue library, patron, pin = result - task = "Checking patron activity, using test patron for library %s" % library.name + task = ( + "Checking patron activity, using test patron for library %s" + % library.name + ) + def count_loans_and_holds(patron, pin): activity = list(self.patron_activity(patron, pin)) return "Total loans and holds: %s" % len(activity) - yield self.run_test( - task, count_loans_and_holds, patron, pin - ) - def request(self, url, method='get', extra_headers={}, data=None, - params=None, retry_on_timeout=True, **kwargs): + yield self.run_test(task, count_loans_and_holds, patron, pin) + + def request( + self, + url, + method="get", + extra_headers={}, + data=None, + params=None, + retry_on_timeout=True, + **kwargs + ): """Make an HTTP request to the Enki API.""" headers = dict(extra_headers) response = None try: response = self._request( - method, url, headers=headers, data=data, - params=params, - **kwargs + method, url, headers=headers, data=data, params=params, **kwargs ) except RequestTimedOut as e: if not retry_on_timeout: raise e - self.log.info( - "Request to %s timed out once. Trying a second time.", url - ) + self.log.info("Request to %s timed out once. Trying a second time.", url) return self.request( - url, method, extra_headers, - data, params, retry_on_timeout=False, + url, + method, + extra_headers, + data, + params, + retry_on_timeout=False, **kwargs ) # Look for the error indicator and raise # RemoteIntegrationException if it appears. - if response.content and self.ERROR_INDICATOR in response.content.decode("utf-8"): + if response.content and self.ERROR_INDICATOR in response.content.decode( + "utf-8" + ): raise RemoteIntegrationException(url, "An unknown error occured") return response @@ -219,8 +213,13 @@ def _request(self, method, url, headers, data, params, **kwargs): MockEnkiAPI overrides this method. """ return HTTP.request_with_timeout( - method, url, headers=headers, data=data, - params=params, timeout=90, disallowed_response_codes=None, + method, + url, + headers=headers, + data=data, + params=params, + timeout=90, + disallowed_response_codes=None, **kwargs ) @@ -246,18 +245,16 @@ def recent_activity(self, start, end): end = int((end - epoch).total_seconds()) url = self.base_url + self.item_endpoint - args = dict( - method='getRecentActivityTime', - stime=str(start), - etime=str(end) - ) + args = dict(method="getRecentActivityTime", stime=str(start), etime=str(end)) response = self.request(url, params=args) data = json.loads(response.content) parser = BibliographicParser() - for element in data['result']['recentactivity']: - identifier = IdentifierData(Identifier.ENKI_ID, element['id']) + for element in data["result"]["recentactivity"]: + identifier = IdentifierData(Identifier.ENKI_ID, element["id"]) yield parser.extract_circulation( - identifier, element['availability'], None # The recent activity API does not include format info + identifier, + element["availability"], + None, # The recent activity API does not include format info ) def updated_titles(self, since): @@ -273,12 +270,12 @@ def updated_titles(self, since): minutes = self._minutes_since(since) url = self.base_url + self.list_endpoint args = dict( - method='getUpdateTitles', + method="getUpdateTitles", minutes=minutes, - id='secontent', - lib='0', # This is a stand-in value -- it doesn't matter - # which library we ask about since they all have - # the same collection. + id="secontent", + lib="0", # This is a stand-in value -- it doesn't matter + # which library we ask about since they all have + # the same collection. ) response = self.request(url, params=args) for metadata in BibliographicParser().process_all(response.content): @@ -297,9 +294,9 @@ def get_item(self, enki_id): method="getItem", recordid=enki_id, size="large", - lib='0', # This is a stand-in value -- it doesn't matter - # which library we ask about since they all have - # the same collection. + lib="0", # This is a stand-in value -- it doesn't matter + # which library we ask about since they all have + # the same collection. ) response = self.request(url, params=args) try: @@ -308,7 +305,7 @@ def get_item(self, enki_id): # This is most likely a 'not found' error. return None - book = data.get('result', {}) + book = data.get("result", {}) if book: return BibliographicParser().extract_bibliographic(book) return None @@ -322,13 +319,15 @@ def get_all_titles(self, strt=0, qty=10): :yield: A sequence of Metadata objects, each with a CirculationData attached. """ - self.log.debug ("requesting : "+ str(qty) + " books starting at econtentRecord" + str(strt)) + self.log.debug( + "requesting : " + str(qty) + " books starting at econtentRecord" + str(strt) + ) url = str(self.base_url) + str(self.list_endpoint) args = dict() - args['method'] = "getAllTitles" - args['id'] = "secontent" - args['strt'] = strt - args['qty'] = qty + args["method"] = "getAllTitles" + args["id"] = "secontent" + args["strt"] = strt + args["qty"] = qty response = self.request(url, params=args) for metadata in BibliographicParser().process_all(response.content): yield metadata @@ -339,8 +338,7 @@ def _epoch_to_struct(cls, epoch_string): # struct that the Circulation Manager can make use of. time_format = "%Y-%m-%dT%H:%M:%S" return strptime_utc( - time.strftime(time_format, time.gmtime(float(epoch_string))), - time_format + time.strftime(time_format, time.gmtime(float(epoch_string))), time_format ) def checkout(self, patron, pin, licensepool, internal_format): @@ -348,22 +346,23 @@ def checkout(self, patron, pin, licensepool, internal_format): enki_id = identifier.identifier enki_library_id = self.enki_library_id(patron.library) response = self.loan_request( - patron.authorization_identifier, pin, enki_id, - enki_library_id + patron.authorization_identifier, pin, enki_id, enki_library_id ) if response.status_code != 200: raise CannotLoan(response.status_code) - result = json.loads(response.content)['result'] - if not result['success']: - message = result['message'] + result = json.loads(response.content)["result"] + if not result["success"]: + message = result["message"] if "There are no available copies" in message: self.log.error("There are no copies of book %s available." % enki_id) raise NoAvailableCopies() elif "Login unsuccessful" in message: - self.log.error("User validation against Enki server with %s / %s was unsuccessful." - % (patron.authorization_identifier, pin)) + self.log.error( + "User validation against Enki server with %s / %s was unsuccessful." + % (patron.authorization_identifier, pin) + ) raise AuthorizationFailedException() - due_date = result['checkedOutItems'][0]['duedate'] + due_date = result["checkedOutItems"][0]["duedate"] expires = self._epoch_to_struct(due_date) # Create the loan info. @@ -379,16 +378,16 @@ def checkout(self, patron, pin, licensepool, internal_format): return loan def loan_request(self, barcode, pin, book_id, enki_library_id): - self.log.debug ("Sending checkout request for %s" % book_id) + self.log.debug("Sending checkout request for %s" % book_id) url = str(self.base_url) + str(self.user_endpoint) args = dict() - args['method'] = "getSELink" - args['username'] = barcode - args['password'] = pin - args['lib'] = enki_library_id - args['id'] = book_id + args["method"] = "getSELink" + args["username"] = barcode + args["password"] = pin + args["lib"] = enki_library_id + args["id"] = book_id - response = self.request(url, method='get', params=args) + response = self.request(url, method="get", params=args) return response def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): @@ -406,15 +405,17 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): ) if response.status_code != 200: raise CannotFulfill(response.status_code) - result = json.loads(response.content)['result'] - if not result['success']: - message = result['message'] + result = json.loads(response.content)["result"] + if not result["success"]: + message = result["message"] if "There are no available copies" in message: self.log.error("There are no copies of book %s available." % book_id) raise NoAvailableCopies() elif "Login unsuccessful" in message: - self.log.error("User validation against Enki server with %s / %s was unsuccessful." - % (patron.authorization_identifier, pin)) + self.log.error( + "User validation against Enki server with %s / %s was unsuccessful." + % (patron.authorization_identifier, pin) + ) raise AuthorizationFailedException() url, item_type, expires = self.parse_fulfill_result(result) @@ -437,14 +438,14 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): content_link=url, content_type=drm_type, content=None, - content_expires=expires + content_expires=expires, ) def parse_fulfill_result(self, result): - links = result['checkedOutItems'][0]['links'][0] - url = links['url'] - item_type = links['item_type'] - due_date = result['checkedOutItems'][0]['duedate'] + links = result["checkedOutItems"][0]["links"][0] + url = links["url"] + item_type = links["item_type"] + due_date = result["checkedOutItems"][0]["duedate"] expires = self._epoch_to_struct(due_date) return (url, item_type, expires) @@ -455,39 +456,38 @@ def patron_activity(self, patron, pin): ) if response.status_code != 200: raise PatronNotFoundOnRemote(response.status_code) - result = json.loads(response.content).get('result', {}) - if not result.get('success'): - message = result.get('message', '') + result = json.loads(response.content).get("result", {}) + if not result.get("success"): + message = result.get("message", "") if "Login unsuccessful" in message: raise AuthorizationFailedException() else: self.log.error( - "Unexpected error in patron_activity: %r", - response.content + "Unexpected error in patron_activity: %r", response.content ) raise CirculationException(response.content) - for loan in result['checkedOutItems']: + for loan in result["checkedOutItems"]: yield self.parse_patron_loans(loan) - for type, holds in list(result['holds'].items()): + for type, holds in list(result["holds"].items()): for hold in holds: yield self.parse_patron_holds(hold) def patron_request(self, patron, pin, enki_library_id): - self.log.debug ("Querying Enki for information on patron %s" % patron) + self.log.debug("Querying Enki for information on patron %s" % patron) url = str(self.base_url) + str(self.user_endpoint) args = dict() - args['method'] = "getSEPatronData" - args['username'] = patron - args['password'] = pin - args['lib'] = enki_library_id + args["method"] = "getSEPatronData" + args["username"] = patron + args["password"] = pin + args["lib"] = enki_library_id - return self.request(url, method='get', params=args) + return self.request(url, method="get", params=args) def parse_patron_loans(self, checkout_data): # We should receive a list of JSON objects - enki_id = checkout_data['id'] - start_date = self._epoch_to_struct(checkout_data['checkoutdate']) - end_date = self._epoch_to_struct(checkout_data['duedate']) + enki_id = checkout_data["id"] + start_date = self._epoch_to_struct(checkout_data["checkoutdate"]) + end_date = self._epoch_to_struct(checkout_data["duedate"]) return LoanInfo( self.collection, DataSource.ENKI, @@ -495,7 +495,7 @@ def parse_patron_loans(self, checkout_data): enki_id, start_date=start_date, end_date=end_date, - fulfillment_info=None + fulfillment_info=None, ) def parse_patron_holds(self, hold_data): @@ -518,26 +518,22 @@ def __init__(self, _db, collection=None, *args, **kwargs): collection, ignore = Collection.by_name_and_protocol( _db, name="Test Enki Collection", protocol=EnkiAPI.ENKI ) - collection.protocol=EnkiAPI.ENKI + collection.protocol = EnkiAPI.ENKI if collection not in library.collections: library.collections.append(collection) # Set the "Enki library ID" variable between the default library # and this Enki collection. ConfigurationSetting.for_library_and_externalintegration( - _db, self.ENKI_LIBRARY_ID_KEY, library, - collection.external_integration - ).value = 'c' + _db, self.ENKI_LIBRARY_ID_KEY, library, collection.external_integration + ).value = "c" - super(MockEnkiAPI, self).__init__( - _db, collection, *args, **kwargs - ) + super(MockEnkiAPI, self).__init__(_db, collection, *args, **kwargs) def queue_response(self, status_code, headers={}, content=None): from core.testing import MockRequestsResponse - self.responses.insert( - 0, MockRequestsResponse(status_code, headers, content) - ) + + self.responses.insert(0, MockRequestsResponse(status_code, headers, content)) def _request(self, method, url, headers, data, params, **kwargs): """Override EnkiAPI._request to pull responses from a @@ -546,8 +542,10 @@ def _request(self, method, url, headers, data, params, **kwargs): self.requests.append([method, url, headers, data, params, kwargs]) response = self.responses.pop() return HTTP._process_response( - url, response, kwargs.get('allowed_response_codes'), - kwargs.get('disallowed_response_codes'), + url, + response, + kwargs.get("allowed_response_codes"), + kwargs.get("disallowed_response_codes"), ) @@ -562,7 +560,7 @@ class BibliographicParser(object): # the codes we use internally. LANGUAGE_CODES = { "English": "eng", - "French" : "fre", + "French": "fre", "Spanish": "spa", } @@ -596,11 +594,14 @@ def extract_bibliographic(self, element): contributors.append(ContributorData(sort_name=sort_name)) links = [] - description = element.get('description') + description = element.get("description") if description: links.append( - LinkData(rel=Hyperlink.DESCRIPTION, content=description, - media_type="text/html") + LinkData( + rel=Hyperlink.DESCRIPTION, + content=description, + media_type="text/html", + ) ) # NOTE: When this method is called by, e.g. updated_titles(), @@ -612,16 +613,14 @@ def extract_bibliographic(self, element): full_image = None thumbnail_image = None for key, rel in ( - ('cover', Hyperlink.IMAGE), - ('small_image', Hyperlink.THUMBNAIL_IMAGE), - ('large_image', Hyperlink.IMAGE) + ("cover", Hyperlink.IMAGE), + ("small_image", Hyperlink.THUMBNAIL_IMAGE), + ("large_image", Hyperlink.IMAGE), ): url = element.get(key) if not url: continue - link = LinkData( - rel=rel, href=url, media_type=Representation.PNG_MEDIA_TYPE - ) + link = LinkData(rel=rel, href=url, media_type=Representation.PNG_MEDIA_TYPE) if rel == Hyperlink.THUMBNAIL_IMAGE: # Don't add a thumbnail to the list of links -- wait # until the end and then make it a thumbnail of the @@ -646,14 +645,15 @@ def extract_bibliographic(self, element): # presented in a form that can be parsed as BISAC. subjects = [] seen_topics = set() - for key in ('subject', 'topic', 'genre'): + for key in ("subject", "topic", "genre"): for topic in element.get(key, []): if not topic or topic in seen_topics: continue subjects.append( SubjectData( - Subject.TAG, topic, - weight=Classification.TRUSTED_DISTRIBUTOR_WEIGHT + Subject.TAG, + topic, + weight=Classification.TRUSTED_DISTRIBUTOR_WEIGHT, ) ) seen_topics.add(topic) @@ -674,7 +674,9 @@ def extract_bibliographic(self, element): subjects=subjects, ) circulationdata = self.extract_circulation( - primary_identifier, element.get('availability', {}), element.get('formattype', None) + primary_identifier, + element.get("availability", {}), + element.get("formattype", None), ) metadata.circulation = circulationdata return metadata @@ -685,26 +687,21 @@ def extract_circulation(self, primary_identifier, availability, formattype): """ if not availability: return None - licenses_owned=availability.get("totalCopies", 0) - licenses_available=availability.get("availableCopies", 0) - hold=availability.get("onHold", 0) + licenses_owned = availability.get("totalCopies", 0) + licenses_available = availability.get("availableCopies", 0) + hold = availability.get("onHold", 0) drm_type = EnkiAPI.no_drm - if availability.get('accessType') == 'acs': + if availability.get("accessType") == "acs": drm_type = EnkiAPI.adobe_drm formats = [] content_type = None - if formattype == 'PDF': + if formattype == "PDF": content_type = Representation.PDF_MEDIA_TYPE - elif formattype == 'EPUB': - content_type=Representation.EPUB_MEDIA_TYPE + elif formattype == "EPUB": + content_type = Representation.EPUB_MEDIA_TYPE if content_type != None: - formats.append( - FormatData( - content_type, - drm_scheme=drm_type - ) - ) + formats.append(FormatData(content_type, drm_scheme=drm_type)) else: self.log.error("Unrecognized formattype: %s", formattype) @@ -712,10 +709,10 @@ def extract_circulation(self, primary_identifier, availability, formattype): data_source=DataSource.ENKI, primary_identifier=primary_identifier, formats=formats, - licenses_owned = int(licenses_owned), - licenses_available = int(licenses_available), - licenses_reserved = 0, - patrons_in_hold_queue = int(hold) + licenses_owned=int(licenses_owned), + licenses_available=int(licenses_available), + licenses_reserved=0, + patrons_in_hold_queue=int(hold), ) return circulationdata @@ -724,6 +721,7 @@ class EnkiImport(CollectionMonitor, TimelineMonitor): """Make sure our local collection is up-to-date with the remote Enki collection. """ + SERVICE_NAME = "Enki Circulation Monitor" INTERVAL_SECONDS = 500 PROTOCOL = EnkiAPI.ENKI_EXTERNAL @@ -767,9 +765,8 @@ def catch_up_from(self, start, cutoff, progress): new_titles, circulation_updates = self.incremental_import(start) progress.achievements = ( - "New or modified titles: %d. Titles with circulation changes: %d." % ( - new_titles, circulation_updates - ) + "New or modified titles: %d. Titles with circulation changes: %d." + % (new_titles, circulation_updates) ) def full_import(self): @@ -779,9 +776,7 @@ def full_import(self): total_items = 0 while True: items_this_page = 0 - for bibliographic in self.api.get_all_titles( - strt=id_start, qty=batch_size - ): + for bibliographic in self.api.get_all_titles(strt=id_start, qty=batch_size): self.process_book(bibliographic) items_this_page += 1 total_items += 1 @@ -832,9 +827,7 @@ def _update_circulation(self, start, end): circulation_changes = 0 for circulation in self.api.recent_activity(start, end): circulation_changes += 1 - license_pool, is_new = circulation.license_pool( - self._db, self.collection - ) + license_pool, is_new = circulation.license_pool(self._db, self.collection) if not license_pool.work: # Either this is the first time we've heard about this # title, or we never made a Work for this @@ -870,13 +863,13 @@ def process_book(self, bibliographic): formats=True, ) bibliographic.apply(edition, self.collection, replace=policy) - license_pool, ignore = availability.license_pool( - self._db, self.collection - ) + license_pool, ignore = availability.license_pool(self._db, self.collection) if new_edition: for library in self.collection.libraries: - self.analytics.collect_event(library, license_pool, CirculationEvent.DISTRIBUTOR_TITLE_ADD, now) + self.analytics.collect_event( + library, license_pool, CirculationEvent.DISTRIBUTOR_TITLE_ADD, now + ) return edition, license_pool @@ -885,7 +878,7 @@ class EnkiCollectionReaper(IdentifierSweepMonitor): """Check for books that are in the local collection but have left the Enki collection.""" SERVICE_NAME = "Enki Collection Reaper" - INTERVAL_SECONDS = 3600*4 + INTERVAL_SECONDS = 3600 * 4 PROTOCOL = "Enki" def __init__(self, _db, collection, api_class=EnkiAPI): @@ -898,9 +891,7 @@ def __init__(self, _db, collection, api_class=EnkiAPI): self.api = api def process_item(self, identifier): - self.log.debug( - "Seeing if %s needs reaping", identifier.identifier - ) + self.log.debug("Seeing if %s needs reaping", identifier.identifier) metadata = self.api.get_item(identifier.identifier) if metadata: # This title is still in the collection. Do nothing. @@ -915,31 +906,25 @@ def process_item(self, identifier): return if pool.presentation_edition: - self.log.warn( - "Removing %r from circulation", - pool.presentation_edition - ) + self.log.warn("Removing %r from circulation", pool.presentation_edition) else: self.log.warn( - "Removing unknown title %s from circulation.", - identifier.identifier + "Removing unknown title %s from circulation.", identifier.identifier ) now = utc_now() circulationdata = CirculationData( data_source=DataSource.ENKI, - primary_identifier= IdentifierData( - identifier.type, identifier.identifier - ), - licenses_owned = 0, - licenses_available = 0, - patrons_in_hold_queue = 0, - last_checked = now + primary_identifier=IdentifierData(identifier.type, identifier.identifier), + licenses_owned=0, + licenses_available=0, + patrons_in_hold_queue=0, + last_checked=now, ) circulationdata.apply( self._db, self.collection, - replace=ReplacementPolicy.from_license_source(self._db) + replace=ReplacementPolicy.from_license_source(self._db), ) return circulationdata diff --git a/api/feedbooks.py b/api/feedbooks.py index c6fe646e9e..44d11cdad9 100644 --- a/api/feedbooks.py +++ b/api/feedbooks.py @@ -1,33 +1,30 @@ import datetime -import feedparser +import os from io import BytesIO from zipfile import ZipFile -from lxml import etree -import os + +import feedparser from flask_babel import lazy_gettext as _ +from lxml import etree -from core.opds import OPDSFeed -from core.opds_import import ( - OPDSImporter, - OPDSImportMonitor, - OPDSXMLParser, -) from core.model import ( Collection, DataSource, ExternalIntegration, Hyperlink, - Resource, Representation, + Resource, RightsStatus, ) +from core.opds import OPDSFeed +from core.opds_import import OPDSImporter, OPDSImportMonitor, OPDSXMLParser from core.util.epub import EpubAccessor class FeedbooksOPDSImporter(OPDSImporter): - REALLY_IMPORT_KEY = 'really_import' - REPLACEMENT_CSS_KEY = 'replacement_css' + REALLY_IMPORT_KEY = "really_import" + REPLACEMENT_CSS_KEY = "replacement_css" NAME = ExternalIntegration.FEEDBOOKS DESCRIPTION = _("Import open-access books from FeedBooks.") @@ -36,37 +33,50 @@ class FeedbooksOPDSImporter(OPDSImporter): "key": REALLY_IMPORT_KEY, "type": "select", "label": _("Really?"), - "description": _("Most libraries are better off importing free Feedbooks titles via an OPDS Import integration from NYPL's open-access content server or DPLA's Open Bookshelf. This setting makes sure you didn't create this collection by accident and really want to import directly from Feedbooks."), + "description": _( + "Most libraries are better off importing free Feedbooks titles via an OPDS Import integration from NYPL's open-access content server or DPLA's Open Bookshelf. This setting makes sure you didn't create this collection by accident and really want to import directly from Feedbooks." + ), "options": [ - { "key": "false", "label": _("Don't actually import directly from Feedbooks.") }, - { "key": "true", "label": _("I know what I'm doing; import directly from Feedbooks.") }, - ], - "default": "false" + { + "key": "false", + "label": _("Don't actually import directly from Feedbooks."), + }, + { + "key": "true", + "label": _( + "I know what I'm doing; import directly from Feedbooks." + ), + }, + ], + "default": "false", }, { "key": Collection.EXTERNAL_ACCOUNT_ID_KEY, "label": _("Import books in this language"), - "description": _("Feedbooks offers separate feeds for different languages. Each one can be made into a separate collection."), + "description": _( + "Feedbooks offers separate feeds for different languages. Each one can be made into a separate collection." + ), "type": "select", "options": [ - { "key": "en", "label": _("English") }, - { "key": "es", "label": _("Spanish") }, - { "key": "fr", "label": _("French") }, - { "key": "it", "label": _("Italian") }, - { "key": "de", "label": _("German") }, + {"key": "en", "label": _("English")}, + {"key": "es", "label": _("Spanish")}, + {"key": "fr", "label": _("French")}, + {"key": "it", "label": _("Italian")}, + {"key": "de", "label": _("German")}, ], "default": "en", }, { - "key" : REPLACEMENT_CSS_KEY, + "key": REPLACEMENT_CSS_KEY, "label": _("Replacement stylesheet"), - "description": _("If you are mirroring the Feedbooks titles, you may replace the Feedbooks stylesheet with an alternate stylesheet in the mirrored copies. The default value is an accessibility-focused stylesheet produced by the DAISY consortium. If you mirror Feedbooks titles but leave this empty, the Feedbooks titles will be mirrored as-is."), + "description": _( + "If you are mirroring the Feedbooks titles, you may replace the Feedbooks stylesheet with an alternate stylesheet in the mirrored copies. The default value is an accessibility-focused stylesheet produced by the DAISY consortium. If you mirror Feedbooks titles but leave this empty, the Feedbooks titles will be mirrored as-is." + ), "default": "http://www.daisy.org/z3986/2005/dtbook.2005.basic.css", }, - ] - BASE_OPDS_URL = 'http://www.feedbooks.com/books/recent.atom?lang=%(language)s' + BASE_OPDS_URL = "http://www.feedbooks.com/books/recent.atom?lang=%(language)s" THIRTY_DAYS = datetime.timedelta(days=30) @@ -75,12 +85,14 @@ def __init__(self, _db, collection, *args, **kwargs): new_css_url = integration.setting(self.REPLACEMENT_CSS_KEY).value if new_css_url: # We may need to modify incoming content to replace CSS. - kwargs['content_modifier'] = self.replace_css - kwargs['data_source_name'] = DataSource.FEEDBOOKS + kwargs["content_modifier"] = self.replace_css + kwargs["data_source_name"] = DataSource.FEEDBOOKS really_import = integration.setting(self.REALLY_IMPORT_KEY).bool_value if not really_import: - raise Exception("Refusing to instantiate a Feedbooks importer because it's configured to not actually do an import.") + raise Exception( + "Refusing to instantiate a Feedbooks importer because it's configured to not actually do an import." + ) self.language = collection.external_account_id @@ -91,16 +103,16 @@ def __init__(self, _db, collection, *args, **kwargs): status_code, headers, content = self.http_get(new_css_url, {}) if status_code != 200: raise IOError( - "Replacement stylesheet URL returned %r response code." % status_code + "Replacement stylesheet URL returned %r response code." + % status_code ) - content_type = headers.get('content-type', '') - if not content_type.startswith('text/css'): + content_type = headers.get("content-type", "") + if not content_type.startswith("text/css"): raise IOError( "Replacement stylesheet is %r, not a CSS document." % content_type ) self.new_css = content - def extract_feed_data(self, feed, feed_url=None): metadata, failures = super(FeedbooksOPDSImporter, self).extract_feed_data( feed, feed_url @@ -124,19 +136,21 @@ def rights_uri_from_entry_tag(cls, entry): """Determine the URI that best encapsulates the rights status of the downloads associated with this book. """ - rights = OPDSXMLParser._xpath1(entry, 'atom:rights') + rights = OPDSXMLParser._xpath1(entry, "atom:rights") if rights is not None: rights = rights.text - source = OPDSXMLParser._xpath1(entry, 'dcterms:source') + source = OPDSXMLParser._xpath1(entry, "dcterms:source") if source is not None: source = source.text - publication_year = OPDSXMLParser._xpath1(entry, 'dcterms:issued') + publication_year = OPDSXMLParser._xpath1(entry, "dcterms:issued") if publication_year is not None: publication_year = publication_year.text return RehostingPolicy.rights_uri(rights, source, publication_year) @classmethod - def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get=None): + def _detail_for_elementtree_entry( + cls, parser, entry_tag, feed_url=None, do_get=None + ): """Determine a more accurate value for this entry's default rights URI. @@ -149,13 +163,14 @@ def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get= parser, entry_tag, feed_url, do_get=do_get ) rights_uri = cls.rights_uri_from_entry_tag(entry_tag) - circulation = detail.setdefault('circulation', {}) - circulation['default_rights_uri'] =rights_uri + circulation = detail.setdefault("circulation", {}) + circulation["default_rights_uri"] = rights_uri return detail @classmethod - def make_link_data(cls, rel, href=None, media_type=None, rights_uri=None, - content=None): + def make_link_data( + cls, rel, href=None, media_type=None, rights_uri=None, content=None + ): """Turn basic link information into a LinkData object. FeedBooks puts open-access content behind generic @@ -163,9 +178,8 @@ def make_link_data(cls, rel, href=None, media_type=None, rights_uri=None, links and (at the request of FeedBooks) ignore the other formats. """ - if rel==Hyperlink.GENERIC_OPDS_ACQUISITION: - if (media_type - and media_type.startswith(Representation.EPUB_MEDIA_TYPE)): + if rel == Hyperlink.GENERIC_OPDS_ACQUISITION: + if media_type and media_type.startswith(Representation.EPUB_MEDIA_TYPE): # Treat this generic acquisition link as an # open-access link. rel = Hyperlink.OPEN_ACCESS_DOWNLOAD @@ -190,8 +204,11 @@ def improve_description(self, id, metadata): existing_descriptions = [] everything_except_descriptions = [] for x in metadata.links: - if (x.rel == Hyperlink.ALTERNATE and x.href - and x.media_type == OPDSFeed.ENTRY_TYPE): + if ( + x.rel == Hyperlink.ALTERNATE + and x.href + and x.media_type == OPDSFeed.ENTRY_TYPE + ): alternate_links.append(x) if x.rel == Hyperlink.DESCRIPTION: existing_descriptions.append((x.media_type, x.content)) @@ -205,8 +222,10 @@ def improve_description(self, id, metadata): # Fetch the alternate entry. representation, is_new = Representation.get( - self._db, alternate_link.href, max_age=self.THIRTY_DAYS, - do_get=self.http_get + self._db, + alternate_link.href, + max_age=self.THIRTY_DAYS, + do_get=self.http_get, ) if representation.status_code != 200: @@ -215,10 +234,10 @@ def improve_description(self, id, metadata): # Parse the alternate entry with feedparser and run it through # data_detail_for_feedparser_entry(). parsed = feedparser.parse(representation.content) - if len(parsed['entries']) != 1: + if len(parsed["entries"]) != 1: # This is supposed to be a single entry, and it's not. continue - [entry] = parsed['entries'] + [entry] = parsed["entries"] data_source = self.data_source detail_id, new_detail, failure = self.data_detail_for_feedparser_entry( entry, data_source @@ -237,16 +256,15 @@ def improve_description(self, id, metadata): # Find any descriptions present in the alternate view which # are not present in the original. new_descriptions = [ - x for x in new_detail['links'] + x + for x in new_detail["links"] if x.rel == Hyperlink.DESCRIPTION and (x.media_type, x.content) not in existing_descriptions ] if new_descriptions: # Replace old descriptions with new descriptions. - metadata.links = ( - everything_except_descriptions + new_descriptions - ) + metadata.links = everything_except_descriptions + new_descriptions break return metadata @@ -255,7 +273,10 @@ def replace_css(self, representation): """This function will replace the content of every CSS file listed in an epub's manifest with the value in self.new_css. The rest of the file is not changed. """ - if not (representation.media_type == Representation.EPUB_MEDIA_TYPE and representation.content): + if not ( + representation.media_type == Representation.EPUB_MEDIA_TYPE + and representation.content + ): return if not self.new_css: @@ -263,10 +284,12 @@ def replace_css(self, representation): return new_zip_content = BytesIO() - with EpubAccessor.open_epub(representation.url, content=representation.content) as (zip_file, package_path): + with EpubAccessor.open_epub( + representation.url, content=representation.content + ) as (zip_file, package_path): try: manifest_element = EpubAccessor.get_element_from_package( - zip_file, package_path, 'manifest' + zip_file, package_path, "manifest" ) except ValueError as e: # Invalid EPUB @@ -276,8 +299,10 @@ def replace_css(self, representation): css_paths = [] for child in manifest_element: if child.tag == ("{%s}item" % EpubAccessor.IDPF_NAMESPACE): - if child.get('media-type') == "text/css": - href = package_path.replace(os.path.basename(package_path), child.get("href")) + if child.get("media-type") == "text/css": + href = package_path.replace( + os.path.basename(package_path), child.get("href") + ) css_paths.append(href) with ZipFile(new_zip_content, "w") as new_zip: @@ -316,40 +341,43 @@ class RehostingPolicy(object): # These are the licenses that need to be preserved. RIGHTS_DICT = { - "Attribution Share Alike (cc by-sa)" : RightsStatus.CC_BY_SA, - "Attribution Non-Commercial No Derivatives (cc by-nc-nd)" : RightsStatus.CC_BY_NC_ND, - "Attribution Non-Commercial Share Alike (cc by-nc-sa)" : RightsStatus.CC_BY_NC_SA, + "Attribution Share Alike (cc by-sa)": RightsStatus.CC_BY_SA, + "Attribution Non-Commercial No Derivatives (cc by-nc-nd)": RightsStatus.CC_BY_NC_ND, + "Attribution Non-Commercial Share Alike (cc by-nc-sa)": RightsStatus.CC_BY_NC_SA, } # Feedbooks rights statuses indicating books that can be rehosted # in the US. - CAN_REHOST_IN_US = set([ - "This work was published before 1923 and is in the public domain in the USA only.", - "This work is available for countries where copyright is Life+70 and in the USA.", - 'This work is available for countries where copyright is Life+50 or in the USA (published before 1923).', - "Attribution (cc by)", - "Attribution Non-Commercial (cc by-nc)", - - "Attribution Share Alike (cc by-sa)", - "Attribution Non-Commercial No Derivatives (cc by-nc-nd)", - "Attribution Non-Commercial Share Alike (cc by-nc-sa)", - ]) + CAN_REHOST_IN_US = set( + [ + "This work was published before 1923 and is in the public domain in the USA only.", + "This work is available for countries where copyright is Life+70 and in the USA.", + "This work is available for countries where copyright is Life+50 or in the USA (published before 1923).", + "Attribution (cc by)", + "Attribution Non-Commercial (cc by-nc)", + "Attribution Share Alike (cc by-sa)", + "Attribution Non-Commercial No Derivatives (cc by-nc-nd)", + "Attribution Non-Commercial Share Alike (cc by-nc-sa)", + ] + ) RIGHTS_UNKNOWN = "Please read the legal notice included in this e-book and/or check the copyright status in your country." # These websites are hosted in the US and specialize in # open-access content. We will accept all FeedBooks titles taken # from these sites, even post-1923 titles. - US_SITES = set([ - "archive.org", - "craphound.com", - "en.wikipedia.org", - "en.wikisource.org", - "futurismic.com", - "gutenberg.org", - "project gutenberg", - "shakespeare.mit.edu", - ]) + US_SITES = set( + [ + "archive.org", + "craphound.com", + "en.wikipedia.org", + "en.wikisource.org", + "futurismic.com", + "gutenberg.org", + "project gutenberg", + "shakespeare.mit.edu", + ] + ) @classmethod def rights_uri(cls, rights, source, publication_year): @@ -412,7 +440,7 @@ def can_rehost_us(cls, rights, source, publication_year): # to rehost it. return True - if source in ('wikisource', 'gutenberg'): + if source in ("wikisource", "gutenberg"): # Presumably en.wikisource and Project Gutenberg US. We # special case these to avoid confusing the US versions of # these sites with other countries'. @@ -420,7 +448,7 @@ def can_rehost_us(cls, rights, source, publication_year): # And we special-case this one to avoid confusing Australian # Project Gutenberg with US Project Gutenberg. - if ('gutenberg.net' in source and not 'gutenberg.net.au' in source): + if "gutenberg.net" in source and not "gutenberg.net.au" in source: return True # Unless one of the above conditions is met, we must assume @@ -435,6 +463,7 @@ def can_rehost_us(cls, rights, source, publication_year): # Life+70) and it's not a pre-1923 book. return False + class FeedbooksImportMonitor(OPDSImportMonitor): """The same as OPDSImportMonitor, but uses FeedbooksOPDSImporter instead. @@ -451,5 +480,5 @@ def opds_url(self, collection): This is the base URL plus the language setting. """ - language = collection.external_account_id or 'en' + language = collection.external_account_id or "en" return FeedbooksOPDSImporter.BASE_OPDS_URL % dict(language=language) diff --git a/api/firstbook.py b/api/firstbook.py index d3ae75eaec..50955c7949 100644 --- a/api/firstbook.py +++ b/api/firstbook.py @@ -1,29 +1,25 @@ -from flask_babel import lazy_gettext as _ -import requests import logging -from .authenticator import ( - BasicAuthenticationProvider, - PatronData, -) -from .config import ( - Configuration, - CannotLoadConfiguration, -) -from .circulation_exceptions import RemoteInitiatedServerError import urllib.parse -from core.model import ( - get_one_or_create, - ExternalIntegration, - Patron, -) + +import requests +from flask_babel import lazy_gettext as _ + +from core.model import ExternalIntegration, Patron, get_one_or_create + +from .authenticator import BasicAuthenticationProvider, PatronData +from .circulation_exceptions import RemoteInitiatedServerError +from .config import CannotLoadConfiguration, Configuration + class FirstBookAuthenticationAPI(BasicAuthenticationProvider): - NAME = 'First Book (deprecated)' + NAME = "First Book (deprecated)" - DESCRIPTION = _(""" + DESCRIPTION = _( + """ An authentication service for Open eBooks that authenticates - using access codes and PINs. (This is the old version.)""") + using access codes and PINs. (This is the old version.)""" + ) DISPLAY_NAME = NAME DEFAULT_IDENTIFIER_LABEL = _("Access Code") @@ -31,35 +27,40 @@ class FirstBookAuthenticationAPI(BasicAuthenticationProvider): # If FirstBook sends this message it means they accepted the # patron's credentials. - SUCCESS_MESSAGE = 'Valid Code Pin Pair' + SUCCESS_MESSAGE = "Valid Code Pin Pair" # Server-side validation happens before the identifier # is converted to uppercase, which means lowercase characters # are valid. - DEFAULT_IDENTIFIER_REGULAR_EXPRESSION = '^[A-Za-z0-9@]+$' - DEFAULT_PASSWORD_REGULAR_EXPRESSION = '^[0-9]+$' + DEFAULT_IDENTIFIER_REGULAR_EXPRESSION = "^[A-Za-z0-9@]+$" + DEFAULT_PASSWORD_REGULAR_EXPRESSION = "^[0-9]+$" SETTINGS = [ - { "key": ExternalIntegration.URL, "format": "url", "label": _("URL"), "required": True }, - { "key": ExternalIntegration.PASSWORD, "label": _("Key"), "required": True }, + { + "key": ExternalIntegration.URL, + "format": "url", + "label": _("URL"), + "required": True, + }, + {"key": ExternalIntegration.PASSWORD, "label": _("Key"), "required": True}, ] + BasicAuthenticationProvider.SETTINGS log = logging.getLogger("First Book authentication API") def __init__(self, library_id, integration, analytics=None, root=None): - super(FirstBookAuthenticationAPI, self).__init__(library_id, integration, analytics) + super(FirstBookAuthenticationAPI, self).__init__( + library_id, integration, analytics + ) if not root: url = integration.url key = integration.password if not (url and key): - raise CannotLoadConfiguration( - "First Book server not configured." - ) - if '?' in url: - url += '&' + raise CannotLoadConfiguration("First Book server not configured.") + if "?" in url: + url += "&" else: - url += '?' - root = url + 'key=' + key + url += "?" + root = url + "key=" + key self.root = root # Begin implementation of BasicAuthenticationProvider abstract @@ -84,20 +85,18 @@ def remote_authenticate(self, username, password): # End implementation of BasicAuthenticationProvider abstract methods. def remote_pin_test(self, barcode, pin): - url = self.root + "&accesscode=%s&pin=%s" % tuple(map( - urllib.parse.quote, (barcode, pin) - )) + url = self.root + "&accesscode=%s&pin=%s" % tuple( + map(urllib.parse.quote, (barcode, pin)) + ) try: response = self.request(url) except requests.exceptions.ConnectionError as e: - raise RemoteInitiatedServerError( - str(e), - self.NAME - ) + raise RemoteInitiatedServerError(str(e), self.NAME) content = response.content.decode("utf8") if response.status_code != 200: msg = "Got unexpected response code %d. Content: %s" % ( - response.status_code, content + response.status_code, + content, ) raise RemoteInitiatedServerError(msg, self.NAME) if self.SUCCESS_MESSAGE in content: @@ -113,7 +112,6 @@ def request(self, url): class MockFirstBookResponse(object): - def __init__(self, status_code, content): self.status_code = status_code # Guarantee that the response content is always a bytestring, @@ -122,13 +120,20 @@ def __init__(self, status_code, content): content = content.encode("utf8") self.content = content + class MockFirstBookAuthenticationAPI(FirstBookAuthenticationAPI): SUCCESS = '"Valid Code Pin Pair"' FAILURE = '{"code":404,"message":"Access Code Pin Pair not found"}' - def __init__(self, library, integration, valid={}, bad_connection=False, - failure_status_code=None): + def __init__( + self, + library, + integration, + valid={}, + bad_connection=False, + failure_status_code=None, + ): super(MockFirstBookAuthenticationAPI, self).__init__( library, integration, root="http://example.com/" ) @@ -148,9 +153,9 @@ def request(self, url): self.failure_status_code, "Error %s" % self.failure_status_code ) qa = urllib.parse.parse_qs(url) - if 'accesscode' in qa and 'pin' in qa: - [code] = qa['accesscode'] - [pin] = qa['pin'] + if "accesscode" in qa and "pin" in qa: + [code] = qa["accesscode"] + [pin] = qa["pin"] if code in self.valid and self.valid[code] == pin: return MockFirstBookResponse(200, self.SUCCESS) else: diff --git a/api/firstbook2.py b/api/firstbook2.py index b1982fded5..46abc36864 100644 --- a/api/firstbook2.py +++ b/api/firstbook2.py @@ -1,71 +1,67 @@ -from flask_babel import lazy_gettext as _ -import jwt -from jwt.algorithms import HMACAlgorithm -import requests import logging import time +import urllib.parse + +import jwt +import requests +from flask_babel import lazy_gettext as _ +from jwt.algorithms import HMACAlgorithm + +from core.model import ExternalIntegration, Patron, get_one_or_create -from .authenticator import ( - BasicAuthenticationProvider, - PatronData, -) -from .config import ( - Configuration, - CannotLoadConfiguration, -) +from .authenticator import BasicAuthenticationProvider, PatronData from .circulation_exceptions import RemoteInitiatedServerError -import urllib.parse -from core.model import ( - get_one_or_create, - ExternalIntegration, - Patron, -) +from .config import CannotLoadConfiguration, Configuration + class FirstBookAuthenticationAPI(BasicAuthenticationProvider): - NAME = 'First Book' + NAME = "First Book" - DESCRIPTION = _(""" + DESCRIPTION = _( + """ An authentication service for Open eBooks that authenticates - using access codes and PINs. (This is the new version.)""") + using access codes and PINs. (This is the new version.)""" + ) DISPLAY_NAME = NAME DEFAULT_IDENTIFIER_LABEL = _("Access Code") LOGIN_BUTTON_IMAGE = "FirstBookLoginButton280.png" # The algorithm used to sign JWTs. - ALGORITHM = 'HS256' + ALGORITHM = "HS256" # If FirstBook sends this message it means they accepted the # patron's credentials. - SUCCESS_MESSAGE = 'Valid Code Pin Pair' + SUCCESS_MESSAGE = "Valid Code Pin Pair" # Server-side validation happens before the identifier # is converted to uppercase, which means lowercase characters # are valid. - DEFAULT_IDENTIFIER_REGULAR_EXPRESSION = '^[A-Za-z0-9@]+$' - DEFAULT_PASSWORD_REGULAR_EXPRESSION = '^[0-9]+$' + DEFAULT_IDENTIFIER_REGULAR_EXPRESSION = "^[A-Za-z0-9@]+$" + DEFAULT_PASSWORD_REGULAR_EXPRESSION = "^[0-9]+$" SETTINGS = [ { - "key": ExternalIntegration.URL, "format": "url", "label": _("URL"), + "key": ExternalIntegration.URL, + "format": "url", + "label": _("URL"), "default": "https://ebooksprod.firstbook.org/api/", - "required": True + "required": True, }, - { "key": ExternalIntegration.PASSWORD, "label": _("Key"), "required": True }, + {"key": ExternalIntegration.PASSWORD, "label": _("Key"), "required": True}, ] + BasicAuthenticationProvider.SETTINGS log = logging.getLogger("First Book JWT authentication API") - def __init__(self, library_id, integration, analytics=None, root=None, - secret=None): - super(FirstBookAuthenticationAPI, self).__init__(library_id, integration, analytics) + def __init__(self, library_id, integration, analytics=None, root=None, secret=None): + super(FirstBookAuthenticationAPI, self).__init__( + library_id, integration, analytics + ) root = root or integration.url secret = secret or integration.password if not (root and secret): - raise CannotLoadConfiguration( - "First Book server not configured." - ) + raise CannotLoadConfiguration("First Book server not configured.") self.root = root self.secret = secret @@ -96,14 +92,12 @@ def remote_pin_test(self, barcode, pin): try: response = self.request(url) except requests.exceptions.ConnectionError as e: - raise RemoteInitiatedServerError( - str(e), - self.NAME - ) + raise RemoteInitiatedServerError(str(e), self.NAME) content = response.content.decode("utf8") if response.status_code != 200: msg = "Got unexpected response code %d. Content: %s" % ( - response.status_code, content + response.status_code, + content, ) raise RemoteInitiatedServerError(msg, self.NAME) if self.SUCCESS_MESSAGE in content: @@ -120,7 +114,9 @@ def jwt(self, barcode, pin): pin=pin, iat=now, ) - return jwt.encode(payload, self.secret, algorithm=self.ALGORITHM).decode("utf-8") + return jwt.encode(payload, self.secret, algorithm=self.ALGORITHM).decode( + "utf-8" + ) def request(self, url): """Make an HTTP request. @@ -131,7 +127,6 @@ def request(self, url): class MockFirstBookResponse(object): - def __init__(self, status_code, content): self.status_code = status_code # Guarantee that the response content is always a bytestring, @@ -140,16 +135,22 @@ def __init__(self, status_code, content): content = content.encode("utf8") self.content = content + class MockFirstBookAuthenticationAPI(FirstBookAuthenticationAPI): SUCCESS = '"Valid Code Pin Pair"' FAILURE = '{"code":404,"message":"Access Code Pin Pair not found"}' - def __init__(self, library, integration, valid={}, bad_connection=False, - failure_status_code=None): + def __init__( + self, + library, + integration, + valid={}, + bad_connection=False, + failure_status_code=None, + ): super(MockFirstBookAuthenticationAPI, self).__init__( - library, integration, root="http://example.com/", - secret="secret" + library, integration, root="http://example.com/", secret="secret" ) self.identifier_re = None self.password_re = None @@ -188,10 +189,9 @@ def _decode(self, token): payload = jwt.decode(token, self.secret, algorithms=self.ALGORITHM) # The 'iat' field in the payload must be a recent timestamp. - assert (time.time()-int(payload['iat'])) < 2 - - return payload['barcode'], payload['pin'] + assert (time.time() - int(payload["iat"])) < 2 + return payload["barcode"], payload["pin"] # Specify which of the classes defined in this module is the diff --git a/api/google_analytics_provider.py b/api/google_analytics_provider.py index 5568f1a99b..420743e34b 100644 --- a/api/google_analytics_provider.py +++ b/api/google_analytics_provider.py @@ -1,73 +1,88 @@ -from .config import CannotLoadConfiguration -import uuid +import re import unicodedata import urllib.parse -import re +import uuid + from flask_babel import lazy_gettext as _ + +from core.model import ConfigurationSetting, ExternalIntegration, Session, get_one from core.util.http import HTTP -from core.model import ( - ConfigurationSetting, - ExternalIntegration, - Session, - get_one, -) + +from .config import CannotLoadConfiguration + class GoogleAnalyticsProvider(object): NAME = _("Google Analytics") DESCRIPTION = _("How to Configure a Google Analytics Integration") - INSTRUCTIONS = _("

    In order to track usage statistics, you can configure the Circulation Manager " + - "to connect to Google Analytics.

    " + - "

    Create a Google Analytics account, " + - "or sign into your existing one.

    " + - "

    To capture data from the Library Simplified Circulation Manager in your Google Analytics account, " + - "you must set up a property in Google Analytics for Library Simplified. In your Google Analytics " + - "account, on the administration page for the property, go to Custom Definitions > Custom Dimensions, " + - "and add the following dimensions, in this order:

      " + - "
    1. time
    2. " + - "
    3. identifier
    4. " + - "
    5. identifier_type
    6. " + - "
    7. title
    8. " + - "
    9. author
    10. " + - "
    11. fiction
    12. " + - "
    13. audience
    14. " + - "
    15. target_age
    16. " + - "
    17. publisher
    18. " + - "
    19. language
    20. " + - "
    21. genre
    22. " + - "
    23. open_access
    24. " + - "
    25. distributor
    26. " + - "
    27. medium
    28. " + - "
    29. library
    30. " + - "

    " + - "

    Each dimension should have the scope set to 'Hit' and the 'Active' box checked.

    " + - "

    Then go to Tracking Info and get the tracking id for the property. Select your " + - "library from the dropdown below, and enter the tracking id into the form.

    ") + INSTRUCTIONS = _( + "

    In order to track usage statistics, you can configure the Circulation Manager " + + "to connect to Google Analytics.

    " + + "

    Create a Google Analytics account, " + + "or sign into your existing one.

    " + + "

    To capture data from the Library Simplified Circulation Manager in your Google Analytics account, " + + "you must set up a property in Google Analytics for Library Simplified. In your Google Analytics " + + "account, on the administration page for the property, go to Custom Definitions > Custom Dimensions, " + + "and add the following dimensions, in this order:

      " + + "
    1. time
    2. " + + "
    3. identifier
    4. " + + "
    5. identifier_type
    6. " + + "
    7. title
    8. " + + "
    9. author
    10. " + + "
    11. fiction
    12. " + + "
    13. audience
    14. " + + "
    15. target_age
    16. " + + "
    17. publisher
    18. " + + "
    19. language
    20. " + + "
    21. genre
    22. " + + "
    23. open_access
    24. " + + "
    25. distributor
    26. " + + "
    27. medium
    28. " + + "
    29. library
    30. " + + "

    " + + "

    Each dimension should have the scope set to 'Hit' and the 'Active' box checked.

    " + + "

    Then go to Tracking Info and get the tracking id for the property. Select your " + + "library from the dropdown below, and enter the tracking id into the form.

    " + ) TRACKING_ID = "tracking_id" DEFAULT_URL = "http://www.google-analytics.com/collect" SETTINGS = [ - { "key": ExternalIntegration.URL, "label": _("URL"), "default": DEFAULT_URL, "required": True, "format": "url" }, + { + "key": ExternalIntegration.URL, + "label": _("URL"), + "default": DEFAULT_URL, + "required": True, + "format": "url", + }, ] LIBRARY_SETTINGS = [ - { "key": TRACKING_ID, "label": _("Tracking ID"), "required": True }, + {"key": TRACKING_ID, "label": _("Tracking ID"), "required": True}, ] def __init__(self, integration, library=None): _db = Session.object_session(integration) if not library: - raise CannotLoadConfiguration("Google Analytics can't be configured without a library.") - url_setting = ConfigurationSetting.for_externalintegration(ExternalIntegration.URL, integration) + raise CannotLoadConfiguration( + "Google Analytics can't be configured without a library." + ) + url_setting = ConfigurationSetting.for_externalintegration( + ExternalIntegration.URL, integration + ) self.url = url_setting.value or self.DEFAULT_URL self.tracking_id = ConfigurationSetting.for_library_and_externalintegration( - _db, self.TRACKING_ID, library, integration, + _db, + self.TRACKING_ID, + library, + integration, ).value if not self.tracking_id: - raise CannotLoadConfiguration("Missing tracking id for library %s" % library.short_name) - + raise CannotLoadConfiguration( + "Missing tracking id for library %s" % library.short_name + ) def collect_event(self, library, license_pool, event_type, time, **kwargs): @@ -77,50 +92,57 @@ def collect_event(self, library, license_pool, event_type, time, **kwargs): client_id = uuid.uuid4() fields = { - 'v': 1, - 'tid': self.tracking_id, - 'cid': client_id, - 'aip': 1, # anonymize IP - 'ds': "Circulation Manager", - 't': 'event', - 'ec': 'circulation', - 'ea': event_type, - 'cd1': time, + "v": 1, + "tid": self.tracking_id, + "cid": client_id, + "aip": 1, # anonymize IP + "ds": "Circulation Manager", + "t": "event", + "ec": "circulation", + "ea": event_type, + "cd1": time, } if license_pool: - fields.update({ - 'cd2': license_pool.identifier.identifier, - 'cd3': license_pool.identifier.type - }) + fields.update( + { + "cd2": license_pool.identifier.identifier, + "cd3": license_pool.identifier.type, + } + ) work = license_pool.work edition = license_pool.presentation_edition if work and edition: - fields.update({ - 'cd4': edition.title, - 'cd5': edition.author, - 'cd6': "fiction" if work.fiction else "nonfiction", - 'cd7': work.audience, - 'cd8': work.target_age_string, - 'cd9': edition.publisher, - 'cd10': edition.language, - 'cd11': work.top_genre(), - 'cd12': "true" if license_pool.open_access else "false", - }) + fields.update( + { + "cd4": edition.title, + "cd5": edition.author, + "cd6": "fiction" if work.fiction else "nonfiction", + "cd7": work.audience, + "cd8": work.target_age_string, + "cd9": edition.publisher, + "cd10": edition.language, + "cd11": work.top_genre(), + "cd12": "true" if license_pool.open_access else "false", + } + ) # Backwards compatibility requires that new dimensions be # added to the end of the list. For the sake of # consistency, this code that sets values for those new # dimensions runs after the original implementation. - fields.update({'cd13' : license_pool.data_source.name}) + fields.update({"cd13": license_pool.data_source.name}) if work and edition: - fields.update({'cd14' : edition.medium}) + fields.update({"cd14": edition.medium}) if library: - fields.update({'cd15' : library.short_name}) + fields.update({"cd15": library.short_name}) # urlencode doesn't like unicode strings so we convert them to utf8 - fields = {k: unicodedata.normalize("NFKD", str(v)).encode("utf8") for k, v in list(fields.items())} + fields = { + k: unicodedata.normalize("NFKD", str(v)).encode("utf8") + for k, v in list(fields.items()) + } params = re.sub(r"=None(&?)", r"=\1", urllib.parse.urlencode(fields)) self.post(self.url, params) diff --git a/api/kansas_patron.py b/api/kansas_patron.py index df35784192..e1638125c5 100644 --- a/api/kansas_patron.py +++ b/api/kansas_patron.py @@ -1,22 +1,24 @@ -from flask_babel import lazy_gettext as _ import logging -from .authenticator import ( - BasicAuthenticationProvider, - PatronData, -) -from .config import CannotLoadConfiguration -from core.model import ExternalIntegration + +from flask_babel import lazy_gettext as _ from lxml import etree + +from core.model import ExternalIntegration from core.util.http import HTTP +from .authenticator import BasicAuthenticationProvider, PatronData +from .config import CannotLoadConfiguration + class KansasAuthenticationAPI(BasicAuthenticationProvider): - NAME = 'Kansas' + NAME = "Kansas" - DESCRIPTION = _(""" + DESCRIPTION = _( + """ An authentication service for the Kansas State Library. - """) + """ + ) DISPLAY_NAME = NAME @@ -26,20 +28,20 @@ class KansasAuthenticationAPI(BasicAuthenticationProvider): "format": "url", "label": _("URL"), "default": "https://ks-kansaslibrary3m.civicplus.com/api/UserDetails", - "required": True + "required": True, }, ] + BasicAuthenticationProvider.SETTINGS log = logging.getLogger("Kansas authentication API") def __init__(self, library_id, integration, analytics=None, base_url=None): - super(KansasAuthenticationAPI, self).__init__(library_id, integration, analytics) + super(KansasAuthenticationAPI, self).__init__( + library_id, integration, analytics + ) if base_url is None: base_url = integration.url if not base_url: - raise CannotLoadConfiguration( - "Kansas server url not configured." - ) + raise CannotLoadConfiguration("Kansas server url not configured.") self.base_url = base_url # Begin implementation of BasicAuthenticationProvider abstract @@ -51,7 +53,9 @@ def remote_authenticate(self, username, password): # Post request to the server response = self.post_request(authorization_request) # Parse response from server - authorized, patron_name, library_identifier = self.parse_authorize_response(response.content) + authorized, patron_name, library_identifier = self.parse_authorize_response( + response.content + ) if not authorized: return False # Kansas auth gives very little data about the patron. Only name and a library identifier. @@ -60,7 +64,7 @@ def remote_authenticate(self, username, password): authorization_identifier=username, personal_name=patron_name, library_identifier=library_identifier, - complete=True + complete=True, ) # End implementation of BasicAuthenticationProvider abstract methods. @@ -75,25 +79,30 @@ def create_authorize_request(barcode, pin): password.text = pin authorize_request.append(user_id) authorize_request.append(password) - return etree.tostring(authorize_request, encoding='utf8') + return etree.tostring(authorize_request, encoding="utf8") def parse_authorize_response(self, response): try: authorize_response = etree.fromstring(response) except etree.XMLSyntaxError: - self.log.error("Unable to parse response from API. Deny Access. Response: \n%s", response) + self.log.error( + "Unable to parse response from API. Deny Access. Response: \n%s", + response, + ) return False, None, None patron_names = [] for tag in ["FirstName", "LastName"]: element = authorize_response.find(tag) if element is not None and element.text is not None: patron_names.append(element.text) - patron_name = ' '.join(patron_names) if len(patron_names) != 0 else None + patron_name = " ".join(patron_names) if len(patron_names) != 0 else None element = authorize_response.find("LibraryID") library_identifier = element.text if element is not None else None - element = authorize_response.find('Status') + element = authorize_response.find("Status") if element is None: - self.log.info("Status element not found in response from server. Deny Access.") + self.log.info( + "Status element not found in response from server. Deny Access." + ) authorized = True if element is not None and element.text == "1" else False return authorized, patron_name, library_identifier @@ -106,7 +115,7 @@ def post_request(self, data): self.base_url, data, headers={"Content-Type": "application/xml"}, - allowed_response_codes=['2xx'], + allowed_response_codes=["2xx"], ) diff --git a/api/lanes.py b/api/lanes.py index 69a3b5a16a..1baa4ff189 100644 --- a/api/lanes.py +++ b/api/lanes.py @@ -1,39 +1,24 @@ -from sqlalchemy import ( - and_, - func, - or_, -) -from sqlalchemy.orm import aliased -from flask_babel import lazy_gettext as _ -import elasticsearch import logging +import elasticsearch +from flask_babel import lazy_gettext as _ +from sqlalchemy import and_, func, or_ +from sqlalchemy.orm import aliased + import core.classifier as genres -from .config import ( - CannotLoadConfiguration, - Configuration, -) -from core.classifier import ( - Classifier, - fiction_genres, - nonfiction_genres, - GenreData, -) from core import classifier - +from core.classifier import Classifier, GenreData, fiction_genres, nonfiction_genres from core.lane import ( BaseFacets, DatabaseBackedWorkList, DefaultSortOrderFacets, Facets, FacetsWithEntryPoint, - Pagination, Lane, + Pagination, WorkList, ) from core.model import ( - get_one, - create, CachedFeed, Contribution, Contributor, @@ -44,11 +29,15 @@ LicensePool, Session, Work, + create, + get_one, ) - from core.util import LanguageCodes + +from .config import CannotLoadConfiguration, Configuration from .novelist import NoveListAPI + def load_lanes(_db, library): """Return a WorkList that reflects the current lane structure of the Library. @@ -95,7 +84,7 @@ def _lane_configuration_from_collection_sizes(estimates): if not estimates: # There are no holdings. Assume we have a large English # collection and nothing else. - return ['eng'], [], [] + return ["eng"], [], [] large = [] small = [] @@ -137,7 +126,7 @@ def create_default_lanes(_db, library): """ # Delete existing lanes. - for lane in _db.query(Lane).filter(Lane.library_id==library.id): + for lane in _db.query(Lane).filter(Lane.library_id == library.id): _db.delete(lane) top_level_lanes = [] @@ -155,23 +144,34 @@ def create_default_lanes(_db, library): large, small, tiny = _lane_configuration_from_collection_sizes(estimates) priority = 0 for language in large: - priority = create_lanes_for_large_collection(_db, library, language, priority=priority) + priority = create_lanes_for_large_collection( + _db, library, language, priority=priority + ) create_world_languages_lane(_db, library, small, tiny, priority) -def lane_from_genres(_db, library, genres, display_name=None, - exclude_genres=None, priority=0, audiences=None, **extra_args): + +def lane_from_genres( + _db, + library, + genres, + display_name=None, + exclude_genres=None, + priority=0, + audiences=None, + **extra_args +): """Turn genre info into a Lane object.""" genre_lane_instructions = { "Dystopian SF": dict(display_name="Dystopian"), "Erotica": dict(audiences=[Classifier.AUDIENCE_ADULTS_ONLY]), - "Humorous Fiction" : dict(display_name="Humor"), - "Media Tie-in SF" : dict(display_name="Movie and TV Novelizations"), - "Suspense/Thriller" : dict(display_name="Thriller"), - "Humorous Nonfiction" : dict(display_name="Humor"), - "Political Science" : dict(display_name="Politics & Current Events"), - "Periodicals" : dict(visible=False) + "Humorous Fiction": dict(display_name="Humor"), + "Media Tie-in SF": dict(display_name="Movie and TV Novelizations"), + "Suspense/Thriller": dict(display_name="Thriller"), + "Humorous Nonfiction": dict(display_name="Humor"), + "Political Science": dict(display_name="Politics & Current Events"), + "Periodicals": dict(visible=False), } # Create sublanes first. @@ -180,21 +180,35 @@ def lane_from_genres(_db, library, genres, display_name=None, if isinstance(genre, dict): sublane_priority = 0 for subgenre in genre.get("subgenres", []): - sublanes.append(lane_from_genres( - _db, library, [subgenre], - priority=sublane_priority, **extra_args)) + sublanes.append( + lane_from_genres( + _db, + library, + [subgenre], + priority=sublane_priority, + **extra_args + ) + ) sublane_priority += 1 # Now that we have sublanes we don't care about subgenres anymore. - genres = [genre.get("name") if isinstance(genre, dict) - else genre.name if isinstance(genre, GenreData) - else genre - for genre in genres] - - exclude_genres = [genre.get("name") if isinstance(genre, dict) - else genre.name if isinstance(genre, GenreData) - else genre - for genre in exclude_genres or []] + genres = [ + genre.get("name") + if isinstance(genre, dict) + else genre.name + if isinstance(genre, GenreData) + else genre + for genre in genres + ] + + exclude_genres = [ + genre.get("name") + if isinstance(genre, dict) + else genre.name + if isinstance(genre, GenreData) + else genre + for genre in exclude_genres or [] + ] fiction = None visible = True @@ -208,7 +222,7 @@ def lane_from_genres(_db, library, genres, display_name=None, if genres[0] in list(genre_lane_instructions.keys()): instructions = genre_lane_instructions[genres[0]] if not display_name and "display_name" in instructions: - display_name = instructions.get('display_name') + display_name = instructions.get("display_name") if "audiences" in instructions: audiences = instructions.get("audiences") if "visible" in instructions: @@ -217,11 +231,17 @@ def lane_from_genres(_db, library, genres, display_name=None, if not display_name: display_name = ", ".join(sorted(genres)) - lane, ignore = create(_db, Lane, library_id=library.id, - display_name=display_name, - fiction=fiction, audiences=audiences, - sublanes=sublanes, priority=priority, - **extra_args) + lane, ignore = create( + _db, + Lane, + library_id=library.id, + display_name=display_name, + fiction=fiction, + audiences=audiences, + sublanes=sublanes, + priority=priority, + **extra_args + ) lane.visible = visible for genre in genres: lane.add_genre(genre) @@ -229,6 +249,7 @@ def lane_from_genres(_db, library, genres, display_name=None, lane.add_genre(genre, inclusive=False) return lane + def create_lanes_for_large_collection(_db, library, languages, priority=0): """Ensure that the lanes appropriate to a large collection are all present. @@ -262,17 +283,15 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): YA = [Classifier.AUDIENCE_YOUNG_ADULT] CHILDREN = [Classifier.AUDIENCE_CHILDREN] - common_args = dict( - languages=languages, - media=None - ) + common_args = dict(languages=languages, media=None) adult_common_args = dict(common_args) - adult_common_args['audiences'] = ADULT + adult_common_args["audiences"] = ADULT include_best_sellers = False nyt_data_source = DataSource.lookup(_db, DataSource.NYT) nyt_integration = get_one( - _db, ExternalIntegration, + _db, + ExternalIntegration, goal=ExternalIntegration.METADATA_GOAL, protocol=ExternalIntegration.NYT, ) @@ -282,7 +301,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): sublanes = [] if include_best_sellers: best_sellers, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Best Sellers", priority=priority, **common_args @@ -291,12 +312,13 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): best_sellers.list_datasource = nyt_data_source sublanes.append(best_sellers) - adult_fiction_sublanes = [] adult_fiction_priority = 0 if include_best_sellers: adult_fiction_best_sellers, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Best Sellers", fiction=True, priority=adult_fiction_priority, @@ -312,14 +334,15 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): else: genre_name = genre.get("name") genre_lane = lane_from_genres( - _db, library, [genre], - priority=adult_fiction_priority, - **adult_common_args) + _db, library, [genre], priority=adult_fiction_priority, **adult_common_args + ) adult_fiction_priority += 1 adult_fiction_sublanes.append(genre_lane) adult_fiction, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Fiction", genres=[], sublanes=adult_fiction_sublanes, @@ -334,7 +357,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): adult_nonfiction_priority = 0 if include_best_sellers: adult_nonfiction_best_sellers, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Best Sellers", fiction=False, priority=adult_nonfiction_priority, @@ -353,14 +378,19 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): else: genre_name = genre.get("name") genre_lane = lane_from_genres( - _db, library, [genre], + _db, + library, + [genre], priority=adult_nonfiction_priority, - **adult_common_args) + **adult_common_args + ) adult_nonfiction_priority += 1 adult_nonfiction_sublanes.append(genre_lane) adult_nonfiction, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Nonfiction", genres=[], sublanes=adult_nonfiction_sublanes, @@ -372,12 +402,15 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): sublanes.append(adult_nonfiction) ya_common_args = dict(common_args) - ya_common_args['audiences'] = YA + ya_common_args["audiences"] = YA ya_fiction, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Young Adult Fiction", - genres=[], fiction=True, + genres=[], + fiction=True, sublanes=[], priority=priority, **ya_common_args @@ -388,7 +421,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): ya_fiction_priority = 0 if include_best_sellers: ya_fiction_best_sellers, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Best Sellers", fiction=True, priority=ya_fiction_priority, @@ -399,49 +434,106 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): ya_fiction.sublanes.append(ya_fiction_best_sellers) ya_fiction.sublanes.append( - lane_from_genres(_db, library, [genres.Dystopian_SF], - priority=ya_fiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.Dystopian_SF], + priority=ya_fiction_priority, + **ya_common_args + ) + ) ya_fiction_priority += 1 ya_fiction.sublanes.append( - lane_from_genres(_db, library, [genres.Fantasy], - priority=ya_fiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.Fantasy], + priority=ya_fiction_priority, + **ya_common_args + ) + ) ya_fiction_priority += 1 ya_fiction.sublanes.append( - lane_from_genres(_db, library, [genres.Comics_Graphic_Novels], - priority=ya_fiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.Comics_Graphic_Novels], + priority=ya_fiction_priority, + **ya_common_args + ) + ) ya_fiction_priority += 1 ya_fiction.sublanes.append( - lane_from_genres(_db, library, [genres.Literary_Fiction], - display_name="Contemporary Fiction", - priority=ya_fiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.Literary_Fiction], + display_name="Contemporary Fiction", + priority=ya_fiction_priority, + **ya_common_args + ) + ) ya_fiction_priority += 1 ya_fiction.sublanes.append( - lane_from_genres(_db, library, [genres.LGBTQ_Fiction], - priority=ya_fiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.LGBTQ_Fiction], + priority=ya_fiction_priority, + **ya_common_args + ) + ) ya_fiction_priority += 1 ya_fiction.sublanes.append( - lane_from_genres(_db, library, [genres.Suspense_Thriller, genres.Mystery], - display_name="Mystery & Thriller", - priority=ya_fiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.Suspense_Thriller, genres.Mystery], + display_name="Mystery & Thriller", + priority=ya_fiction_priority, + **ya_common_args + ) + ) ya_fiction_priority += 1 ya_fiction.sublanes.append( - lane_from_genres(_db, library, [genres.Romance], - priority=ya_fiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.Romance], + priority=ya_fiction_priority, + **ya_common_args + ) + ) ya_fiction_priority += 1 ya_fiction.sublanes.append( - lane_from_genres(_db, library, [genres.Science_Fiction], - exclude_genres=[genres.Dystopian_SF, genres.Steampunk], - priority=ya_fiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.Science_Fiction], + exclude_genres=[genres.Dystopian_SF, genres.Steampunk], + priority=ya_fiction_priority, + **ya_common_args + ) + ) ya_fiction_priority += 1 ya_fiction.sublanes.append( - lane_from_genres(_db, library, [genres.Steampunk], - priority=ya_fiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.Steampunk], + priority=ya_fiction_priority, + **ya_common_args + ) + ) ya_fiction_priority += 1 ya_nonfiction, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Young Adult Nonfiction", - genres=[], fiction=False, + genres=[], + fiction=False, sublanes=[], priority=priority, **ya_common_args @@ -452,7 +544,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): ya_nonfiction_priority = 0 if include_best_sellers: ya_nonfiction_best_sellers, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Best Sellers", fiction=False, priority=ya_nonfiction_priority, @@ -463,32 +557,58 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): ya_nonfiction.sublanes.append(ya_nonfiction_best_sellers) ya_nonfiction.sublanes.append( - lane_from_genres(_db, library, [genres.Biography_Memoir], - display_name="Biography", - priority=ya_nonfiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.Biography_Memoir], + display_name="Biography", + priority=ya_nonfiction_priority, + **ya_common_args + ) + ) ya_nonfiction_priority += 1 ya_nonfiction.sublanes.append( - lane_from_genres(_db, library, [genres.History, genres.Social_Sciences], - display_name="History & Sociology", - priority=ya_nonfiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.History, genres.Social_Sciences], + display_name="History & Sociology", + priority=ya_nonfiction_priority, + **ya_common_args + ) + ) ya_nonfiction_priority += 1 ya_nonfiction.sublanes.append( - lane_from_genres(_db, library, [genres.Life_Strategies], - priority=ya_nonfiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.Life_Strategies], + priority=ya_nonfiction_priority, + **ya_common_args + ) + ) ya_nonfiction_priority += 1 ya_nonfiction.sublanes.append( - lane_from_genres(_db, library, [genres.Religion_Spirituality], - priority=ya_nonfiction_priority, **ya_common_args)) + lane_from_genres( + _db, + library, + [genres.Religion_Spirituality], + priority=ya_nonfiction_priority, + **ya_common_args + ) + ) ya_nonfiction_priority += 1 - children_common_args = dict(common_args) - children_common_args['audiences'] = CHILDREN + children_common_args["audiences"] = CHILDREN children, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Children and Middle Grade", - genres=[], fiction=None, + genres=[], + fiction=None, sublanes=[], priority=priority, **children_common_args @@ -499,7 +619,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children_priority = 0 if include_best_sellers: children_best_sellers, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Best Sellers", priority=children_priority, **children_common_args @@ -509,9 +631,13 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(children_best_sellers) picture_books, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Picture Books", - target_age=(0,4), genres=[], fiction=None, + target_age=(0, 4), + genres=[], + fiction=None, priority=children_priority, languages=languages, ) @@ -519,9 +645,13 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(picture_books) easy_readers, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Easy Readers", - target_age=(5,8), genres=[], fiction=None, + target_age=(5, 8), + genres=[], + fiction=None, priority=children_priority, languages=languages, ) @@ -529,9 +659,13 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(easy_readers) chapter_books, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Chapter Books", - target_age=(9,12), genres=[], fiction=None, + target_age=(9, 12), + genres=[], + fiction=None, priority=children_priority, languages=languages, ) @@ -539,7 +673,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(chapter_books) children_poetry, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Poetry Books", priority=children_priority, **children_common_args @@ -549,7 +685,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(children_poetry) children_folklore, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Folklore", priority=children_priority, **children_common_args @@ -559,7 +697,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(children_folklore) children_fantasy, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Fantasy", fiction=True, priority=children_priority, @@ -570,7 +710,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(children_fantasy) children_sf, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Science Fiction", fiction=True, priority=children_priority, @@ -581,7 +723,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(children_sf) realistic_fiction, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Realistic Fiction", fiction=True, priority=children_priority, @@ -592,7 +736,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(realistic_fiction) children_graphic_novels, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Comics & Graphic Novels", priority=children_priority, **children_common_args @@ -602,7 +748,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(children_graphic_novels) children_biography, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Biography", priority=children_priority, **children_common_args @@ -612,7 +760,9 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(children_biography) children_historical_fiction, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Historical Fiction", priority=children_priority, **children_common_args @@ -622,9 +772,12 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): children.sublanes.append(children_historical_fiction) informational, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Informational Books", - fiction=False, genres=[], + fiction=False, + genres=[], priority=children_priority, **children_common_args ) @@ -634,8 +787,13 @@ def create_lanes_for_large_collection(_db, library, languages, priority=0): return priority + def create_world_languages_lane( - _db, library, small_languages, tiny_languages, priority=0, + _db, + library, + small_languages, + tiny_languages, + priority=0, ): """Create a lane called 'World Languages' whose sublanes represent the non-large language collections available to this library. @@ -653,15 +811,16 @@ def create_world_languages_lane( else: complete_language_set.update(languageset) - world_languages, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="World Languages", fiction=None, priority=priority, languages=complete_language_set, media=[Edition.BOOK_MEDIUM], - genres=[] + genres=[], ) priority += 1 @@ -678,6 +837,7 @@ def create_world_languages_lane( ) return priority + def create_lane_for_small_collection(_db, library, parent, languages, priority=0): """Create a lane (with sublanes) for a small collection based on language, if the language exists in the lookup table. @@ -707,7 +867,9 @@ def create_lane_for_small_collection(_db, library, parent, languages, priority=0 sublane_priority = 0 adult_fiction, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Fiction", fiction=True, audiences=ADULT, @@ -717,7 +879,9 @@ def create_lane_for_small_collection(_db, library, parent, languages, priority=0 sublane_priority += 1 adult_nonfiction, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Nonfiction", fiction=False, audiences=ADULT, @@ -727,7 +891,9 @@ def create_lane_for_small_collection(_db, library, parent, languages, priority=0 sublane_priority += 1 ya_children, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name="Children & Young Adult", fiction=None, audiences=YA_CHILDREN, @@ -737,7 +903,9 @@ def create_lane_for_small_collection(_db, library, parent, languages, priority=0 sublane_priority += 1 lane, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name=language_identifier, parent=parent, sublanes=[adult_fiction, adult_nonfiction, ya_children], @@ -747,6 +915,7 @@ def create_lane_for_small_collection(_db, library, parent, languages, priority=0 priority += 1 return priority + def create_lane_for_tiny_collection(_db, library, parent, languages, priority=0): """Create a single lane for a tiny collection based on language, if the language exists in the lookup table. @@ -768,7 +937,9 @@ def create_lane_for_tiny_collection(_db, library, parent, languages, priority=0) return 0 language_lane, ignore = create( - _db, Lane, library=library, + _db, + Lane, + library=library, display_name=name, parent=parent, genres=[], @@ -784,19 +955,21 @@ class DynamicLane(WorkList): """A WorkList that's used to from an OPDS lane, but isn't a Lane in the database.""" + class DatabaseExclusiveWorkList(DatabaseBackedWorkList): """A DatabaseBackedWorkList that can _only_ get Works through the database.""" + def works(self, *args, **kwargs): return self.works_from_database(*args, **kwargs) + class WorkBasedLane(DynamicLane): """A lane that shows works related to one particular Work.""" DISPLAY_NAME = None ROUTE = None - def __init__(self, library, work, display_name=None, - children=None, **kwargs): + def __init__(self, library, work, display_name=None, children=None, **kwargs): self.work = work self.edition = work.presentation_edition @@ -805,22 +978,21 @@ def __init__(self, library, work, display_name=None, # language of the work. All children of this lane will be put # under a similar restriction. self.source_language = self.edition.language - kwargs['languages'] = [self.source_language] + kwargs["languages"] = [self.source_language] # To avoid showing inappropriate material, the value of this # lane's .audiences setting is always derived from the # audience of the work. All children of this lane will be # under a similar restriction. self.source_audience = self.work.audience - kwargs['audiences'] = self.audiences_list_from_source() + kwargs["audiences"] = self.audiences_list_from_source() display_name = display_name or self.DISPLAY_NAME children = children or list() super(WorkBasedLane, self).initialize( - library, display_name=display_name, children=children, - **kwargs + library, display_name=display_name, children=children, **kwargs ) @property @@ -828,15 +1000,14 @@ def url_arguments(self): if not self.ROUTE: raise NotImplementedError() identifier = self.edition.primary_identifier - kwargs = dict( - identifier_type=identifier.type, - identifier=identifier.identifier - ) + kwargs = dict(identifier_type=identifier.type, identifier=identifier.identifier) return self.ROUTE, kwargs def audiences_list_from_source(self): - if (not self.source_audience or - self.source_audience in Classifier.AUDIENCES_ADULT): + if ( + not self.source_audience + or self.source_audience in Classifier.AUDIENCES_ADULT + ): return Classifier.AUDIENCES if self.source_audience == Classifier.AUDIENCE_YOUNG_ADULT: return Classifier.AUDIENCES_JUVENILE @@ -873,18 +1044,21 @@ class RecommendationLane(WorkBasedLane): # Cache for 24 hours -- would ideally be much longer but availability # information goes stale. - MAX_CACHE_AGE = 24*60*60 + MAX_CACHE_AGE = 24 * 60 * 60 CACHED_FEED_TYPE = CachedFeed.RECOMMENDATIONS_TYPE - def __init__(self, library, work, display_name=None, - novelist_api=None, parent=None): + def __init__( + self, library, work, display_name=None, novelist_api=None, parent=None + ): """Constructor. :raises: CannotLoadConfiguration if `novelist_api` is not provided and no Novelist integration is configured for this library. """ super(RecommendationLane, self).__init__( - library, work, display_name=display_name, + library, + work, + display_name=display_name, ) self.novelist_api = novelist_api or NoveListAPI.from_config(library) if parent: @@ -922,8 +1096,10 @@ def overview_facets(self, _db, facets): # they come from the recommendation engine, since presumably # the best recommendations are in the front. return Facets.default( - self.get_library(_db), collection=facets.COLLECTION_FULL, - availability=facets.AVAILABLE_ALL, entrypoint=facets.entrypoint, + self.get_library(_db), + collection=facets.COLLECTION_FULL, + availability=facets.AVAILABLE_ALL, + entrypoint=facets.entrypoint, ) def modify_search_filter_hook(self, filter): @@ -945,24 +1121,23 @@ class SeriesFacets(DefaultSortOrderFacets): """A list with a series restriction is ordered by series position by default. """ + DEFAULT_SORT_ORDER = Facets.ORDER_SERIES_POSITION class SeriesLane(DynamicLane): """A lane of Works in a particular series.""" - ROUTE = 'series' + ROUTE = "series" # Cache for 24 hours -- would ideally be longer but availability # information goes stale. - MAX_CACHE_AGE = 24*60*60 + MAX_CACHE_AGE = 24 * 60 * 60 CACHED_FEED_TYPE = CachedFeed.SERIES_TYPE def __init__(self, library, series_name, parent=None, **kwargs): if not series_name: raise ValueError("SeriesLane can't be created without series") - super(SeriesLane, self).initialize( - library, display_name=series_name, **kwargs - ) + super(SeriesLane, self).initialize(library, display_name=series_name, **kwargs) self.series = series_name if parent: parent.append_child(self) @@ -981,9 +1156,9 @@ def __init__(self, library, series_name, parent=None, **kwargs): def url_arguments(self): kwargs = dict(series_name=self.series) if self.language_key: - kwargs['languages'] = self.language_key + kwargs["languages"] = self.language_key if self.audience_key: - kwargs['audiences'] = self.audience_key + kwargs["audiences"] = self.audience_key return self.ROUTE, kwargs def overview_facets(self, _db, facets): @@ -992,8 +1167,10 @@ def overview_facets(self, _db, facets): be ordered by series position. """ return SeriesFacets.default( - self.get_library(_db), collection=facets.COLLECTION_FULL, - availability=facets.AVAILABLE_ALL, entrypoint=facets.entrypoint, + self.get_library(_db), + collection=facets.COLLECTION_FULL, + availability=facets.AVAILABLE_ALL, + entrypoint=facets.entrypoint, ) def modify_search_filter_hook(self, filter): @@ -1005,20 +1182,22 @@ class ContributorFacets(DefaultSortOrderFacets): """A list with a contributor restriction is, by default, sorted by title. """ + DEFAULT_SORT_ORDER = Facets.ORDER_TITLE class ContributorLane(DynamicLane): """A lane of Works written by a particular contributor""" - ROUTE = 'contributor' + ROUTE = "contributor" # Cache for 24 hours -- would ideally be longer but availability # information goes stale. - MAX_CACHE_AGE = 24*60*60 + MAX_CACHE_AGE = 24 * 60 * 60 CACHED_FEED_TYPE = CachedFeed.CONTRIBUTOR_TYPE - def __init__(self, library, contributor, - parent=None, languages=None, audiences=None): + def __init__( + self, library, contributor, parent=None, languages=None, audiences=None + ): """Constructor. :param library: A Library. @@ -1028,17 +1207,17 @@ def __init__(self, library, contributor, :param audiences: An extra restriction on the audience for Works. """ if not contributor: - raise ValueError( - "ContributorLane can't be created without contributor" - ) + raise ValueError("ContributorLane can't be created without contributor") self.contributor = contributor self.contributor_key = ( self.contributor.display_name or self.contributor.sort_name ) super(ContributorLane, self).initialize( - library, display_name=self.contributor_key, - audiences=audiences, languages=languages, + library, + display_name=self.contributor_key, + audiences=audiences, + languages=languages, ) if parent: parent.append_child(self) @@ -1048,7 +1227,7 @@ def url_arguments(self): kwargs = dict( contributor_name=self.contributor_key, languages=self.language_key, - audiences=self.audience_key + audiences=self.audience_key, ) return self.ROUTE, kwargs @@ -1057,8 +1236,10 @@ def overview_facets(self, _db, facets): use in a grouped feed. """ return ContributorFacets.default( - self.get_library(_db), collection=facets.COLLECTION_FULL, - availability=facets.AVAILABLE_ALL, entrypoint=facets.entrypoint, + self.get_library(_db), + collection=facets.COLLECTION_FULL, + availability=facets.AVAILABLE_ALL, + entrypoint=facets.entrypoint, ) def modify_search_filter_hook(self, filter): @@ -1076,20 +1257,24 @@ class RelatedBooksLane(WorkBasedLane): * RecommendationLane: Works provided by a third-party recommendation service. """ + CACHED_FEED_TYPE = CachedFeed.RELATED_TYPE DISPLAY_NAME = "Related Books" - ROUTE = 'related_books' + ROUTE = "related_books" # Cache this lane for the shortest amount of time any of its # component lane should be cached. - MAX_CACHE_AGE = min(ContributorLane.MAX_CACHE_AGE, - SeriesLane.MAX_CACHE_AGE, - RecommendationLane.MAX_CACHE_AGE) + MAX_CACHE_AGE = min( + ContributorLane.MAX_CACHE_AGE, + SeriesLane.MAX_CACHE_AGE, + RecommendationLane.MAX_CACHE_AGE, + ) - def __init__(self, library, work, display_name=None, - novelist_api=None): + def __init__(self, library, work, display_name=None, novelist_api=None): super(RelatedBooksLane, self).__init__( - library, work, display_name=display_name, + library, + work, + display_name=display_name, ) _db = Session.object_session(library) sublanes = self._get_sublanes(_db, novelist_api) @@ -1118,7 +1303,14 @@ def _get_sublanes(self, _db, novelist_api): # Create a series sublane. series_name = self.edition.series if series_name: - sublanes.append(SeriesLane(self.get_library(_db), series_name, parent=self, languages=self.languages)) + sublanes.append( + SeriesLane( + self.get_library(_db), + series_name, + parent=self, + languages=self.languages, + ) + ) return sublanes @@ -1129,15 +1321,20 @@ def _contributor_sublanes(self, _db): while roles_by_priority and not viable_contributors: author_roles = roles_by_priority.pop(0) - viable_contributors = [c.contributor - for c in self.edition.contributions - if c.role in author_roles] + viable_contributors = [ + c.contributor + for c in self.edition.contributions + if c.role in author_roles + ] library = self.get_library(_db) for contributor in viable_contributors: contributor_lane = ContributorLane( - library, contributor, parent=self, - languages=self.languages, audiences=self.audiences, + library, + contributor, + parent=self, + languages=self.languages, + audiences=self.audiences, ) yield contributor_lane @@ -1146,8 +1343,10 @@ def _recommendation_sublane(self, _db, novelist_api): lane_name = "Similar titles recommended by NoveList" try: recommendation_lane = RecommendationLane( - library=self.get_library(_db), work=self.work, - display_name=lane_name, novelist_api=novelist_api, + library=self.get_library(_db), + work=self.work, + display_name=lane_name, + novelist_api=novelist_api, parent=self, ) if recommendation_lane.recommendations: @@ -1166,7 +1365,7 @@ class CrawlableFacets(Facets): # These facet settings are definitive of a crawlable feed. # Library configuration settings don't matter. SETTINGS = { - Facets.ORDER_FACET_GROUP_NAME : Facets.ORDER_LAST_UPDATE, + Facets.ORDER_FACET_GROUP_NAME: Facets.ORDER_LAST_UPDATE, Facets.AVAILABILITY_FACET_GROUP_NAME: Facets.AVAILABLE_ALL, Facets.COLLECTION_FACET_GROUP_NAME: Facets.COLLECTION_FULL, } @@ -1216,7 +1415,8 @@ def initialize(self, library_or_collections): self.collection_name = collections[0].name super(CrawlableCollectionBasedLane, self).initialize( - library, "Crawlable feed: %s" % identifier, + library, + "Crawlable feed: %s" % identifier, ) if collections is not None: # initialize() set the collection IDs to all collections @@ -1245,8 +1445,9 @@ class CrawlableCustomListBasedLane(CrawlableLane): def initialize(self, library, customlist): self.customlist_name = customlist.name super(CrawlableCustomListBasedLane, self).initialize( - library, "Crawlable feed: %s" % self.customlist_name, - customlists=[customlist] + library, + "Crawlable feed: %s" % self.customlist_name, + customlists=[customlist], ) @property @@ -1254,10 +1455,12 @@ def url_arguments(self): kwargs = dict(list_name=self.customlist_name) return self.ROUTE, kwargs + class KnownOverviewFacetsWorkList(WorkList): """A WorkList whose defining feature is that the Facets object to be used when generating a grouped feed is known in advance. """ + def __init__(self, facets, *args, **kwargs): """Constructor. @@ -1287,25 +1490,27 @@ class JackpotFacets(Facets): @classmethod def default_facet(cls, config, facet_group_name): if facet_group_name != cls.AVAILABILITY_FACET_GROUP_NAME: - return super(JackpotFacets, cls).default_facet( - config, facet_group_name - ) + return super(JackpotFacets, cls).default_facet(config, facet_group_name) return cls.AVAILABLE_NOW @classmethod def available_facets(cls, config, facet_group_name): if facet_group_name != cls.AVAILABILITY_FACET_GROUP_NAME: - return super(JackpotFacets, cls).available_facets( - config, facet_group_name - ) + return super(JackpotFacets, cls).available_facets(config, facet_group_name) + + return [ + cls.AVAILABLE_NOW, + cls.AVAILABLE_NOT_NOW, + cls.AVAILABLE_ALL, + cls.AVAILABLE_OPEN_ACCESS, + ] - return [cls.AVAILABLE_NOW, cls.AVAILABLE_NOT_NOW, - cls.AVAILABLE_ALL, cls.AVAILABLE_OPEN_ACCESS] class HasSeriesFacets(Facets): """A faceting object for a feed containg books guaranteed to belong to _some_ series. """ + def modify_search_filter(self, filter): filter.series = True @@ -1341,11 +1546,12 @@ def __init__(self, library, facets): data_source_name = collection.data_source.name else: data_source_name = "[Unknown]" - display_name = "License source {%s} - Medium {%s} - Collection name {%s}" % (data_source_name, medium, collection.name) - child = KnownOverviewFacetsWorkList(facets) - child.initialize( - library, media=[medium], display_name=display_name + display_name = ( + "License source {%s} - Medium {%s} - Collection name {%s}" + % (data_source_name, medium, collection.name) ) + child = KnownOverviewFacetsWorkList(facets) + child.initialize(library, media=[medium], display_name=display_name) child.collection_ids = [collection.id] self.children.append(child) diff --git a/api/lcp/collection.py b/api/lcp/collection.py index 59e6d38360..f60d44712e 100644 --- a/api/lcp/collection.py +++ b/api/lcp/collection.py @@ -5,32 +5,42 @@ from flask import send_file from sqlalchemy import or_ -from api.circulation import FulfillmentInfo, BaseCirculationAPI, LoanInfo +from api.circulation import BaseCirculationAPI, FulfillmentInfo, LoanInfo from api.lcp.encrypt import LCPEncryptionConfiguration from api.lcp.hash import HasherFactory -from api.lcp.server import LCPServerConfiguration, LCPServer +from api.lcp.server import LCPServer, LCPServerConfiguration from core.lcp.credential import LCPCredentialFactory -from core.model import ExternalIntegration, LicensePoolDeliveryMechanism, get_one, Loan, Collection, LicensePool, \ - DeliveryMechanism -from core.model.configuration import HasExternalIntegration, ConfigurationStorage, ConfigurationFactory -from core.util.datetime_helpers import ( - utc_now, +from core.model import ( + Collection, + DeliveryMechanism, + ExternalIntegration, + LicensePool, + LicensePoolDeliveryMechanism, + Loan, + get_one, ) +from core.model.configuration import ( + ConfigurationFactory, + ConfigurationStorage, + HasExternalIntegration, +) +from core.util.datetime_helpers import utc_now class LCPFulfilmentInfo(FulfillmentInfo): """Sends LCP licenses as fulfilment info""" def __init__( - self, - identifier, - collection, - data_source_name, - identifier_type, - content_link=None, - content_type=None, - content=None, - content_expires=None): + self, + identifier, + collection, + data_source_name, + identifier_type, + content_link=None, + content_type=None, + content=None, + content_expires=None, + ): """Initializes a new instance of LCPFulfilmentInfo class :param identifier: Identifier @@ -65,7 +75,7 @@ def __init__( content_link, content_type, content, - content_expires + content_expires, ) @property @@ -79,7 +89,7 @@ def as_response(self): BytesIO(json.dumps(self.content)), mimetype=DeliveryMechanism.LCP_DRM, as_attachment=True, - attachment_filename='{0}.lcpl'.format(self.identifier) + attachment_filename="{0}.lcpl".format(self.identifier), ) @@ -87,10 +97,12 @@ class LCPAPI(BaseCirculationAPI, HasExternalIntegration): """Implements LCP workflow""" NAME = ExternalIntegration.LCP - SERVICE_NAME = 'LCP' - DESCRIPTION = 'Manually imported collection protected using Readium LCP DRM' + SERVICE_NAME = "LCP" + DESCRIPTION = "Manually imported collection protected using Readium LCP DRM" - SETTINGS = LCPServerConfiguration.to_settings() + LCPEncryptionConfiguration.to_settings() + SETTINGS = ( + LCPServerConfiguration.to_settings() + LCPEncryptionConfiguration.to_settings() + ) def __init__(self, db, collection): """Initializes a new instance of LCPAPI class @@ -103,7 +115,9 @@ def __init__(self, db, collection): """ if collection.protocol != ExternalIntegration.LCP: raise ValueError( - 'Collection protocol is {0} but must be LCPAPI'.format(collection.protocol) + "Collection protocol is {0} but must be LCPAPI".format( + collection.protocol + ) ) self._db = db @@ -149,7 +163,12 @@ def _create_lcp_server(self): configuration_factory = ConfigurationFactory() hasher_factory = HasherFactory() credential_factory = LCPCredentialFactory() - lcp_server = LCPServer(configuration_storage, configuration_factory, hasher_factory, credential_factory) + lcp_server = LCPServer( + configuration_storage, + configuration_factory, + hasher_factory, + credential_factory, + ) return lcp_server @@ -186,17 +205,21 @@ def checkout(self, patron, pin, licensepool, internal_format): days = self.collection.default_loan_period(patron.library) today = utc_now() expires = today + datetime.timedelta(days=days) - loan = get_one(self._db, Loan, patron=patron, license_pool=licensepool, on_multiple='interchangeable') + loan = get_one( + self._db, + Loan, + patron=patron, + license_pool=licensepool, + on_multiple="interchangeable", + ) if loan: - license = self._lcp_server.get_license(self._db, loan.external_identifier, patron) + license = self._lcp_server.get_license( + self._db, loan.external_identifier, patron + ) else: license = self._lcp_server.generate_license( - self._db, - licensepool.identifier.identifier, - patron, - today, - expires + self._db, licensepool.identifier.identifier, patron, today, expires ) loan = LoanInfo( @@ -207,12 +230,20 @@ def checkout(self, patron, pin, licensepool, internal_format): start_date=today, end_date=expires, fulfillment_info=None, - external_identifier=license['id'] + external_identifier=license["id"], ) return loan - def fulfill(self, patron, pin, licensepool, internal_format=None, part=None, fulfill_part_url=None): + def fulfill( + self, + patron, + pin, + licensepool, + internal_format=None, + part=None, + fulfill_part_url=None, + ): """Get the actual resource file to the patron. :param patron: A Patron object for the patron who wants to check out the book @@ -240,8 +271,16 @@ def fulfill(self, patron, pin, licensepool, internal_format=None, part=None, ful :return: a FulfillmentInfo object :rtype: FulfillmentInfo """ - loan = get_one(self._db, Loan, patron=patron, license_pool=licensepool, on_multiple='interchangeable') - license = self._lcp_server.get_license(self._db, loan.external_identifier, patron) + loan = get_one( + self._db, + Loan, + patron=patron, + license_pool=licensepool, + on_multiple="interchangeable", + ) + license = self._lcp_server.get_license( + self._db, loan.external_identifier, patron + ) fulfillment_info = LCPFulfilmentInfo( licensepool.identifier.identifier, licensepool.collection, @@ -268,38 +307,35 @@ def patron_activity(self, patron, pin): :rtype: List[LoanInfo] """ now = utc_now() - loans = self._db\ - .query(Loan)\ - .join(LicensePool)\ - .join(Collection)\ + loans = ( + self._db.query(Loan) + .join(LicensePool) + .join(Collection) .filter( Collection.id == self._collection_id, Loan.patron == patron, - or_( - Loan.start is None, - Loan.start <= now - ), - or_( - Loan.end is None, - Loan.end > now - ) + or_(Loan.start is None, Loan.start <= now), + or_(Loan.end is None, Loan.end > now), ) + ) loan_info_objects = [] for loan in loans: licensepool = get_one(self._db, LicensePool, id=loan.license_pool_id) - loan_info_objects.append(LoanInfo( - collection=self.collection, - data_source_name=licensepool.data_source.name, - identifier_type=licensepool.identifier.type, - identifier=licensepool.identifier.identifier, - start_date=loan.start, - end_date=loan.end, - fulfillment_info=None, - external_identifier=loan.external_identifier - )) + loan_info_objects.append( + LoanInfo( + collection=self.collection, + data_source_name=licensepool.data_source.name, + identifier_type=licensepool.identifier.type, + identifier=licensepool.identifier.identifier, + start_date=loan.start, + end_date=loan.end, + fulfillment_info=None, + external_identifier=loan.external_identifier, + ) + ) return loan_info_objects diff --git a/api/lcp/controller.py b/api/lcp/controller.py index caa44b71fd..80db3a4233 100644 --- a/api/lcp/controller.py +++ b/api/lcp/controller.py @@ -7,7 +7,7 @@ from api.controller import CirculationManagerController from api.lcp.factory import LCPServerFactory from core.lcp.credential import LCPCredentialFactory -from core.model import Session, ExternalIntegration, Collection +from core.model import Collection, ExternalIntegration, Session from core.util.problem_detail import ProblemDetail @@ -32,11 +32,17 @@ def _get_patron(self): :return: Patron associated with the request (if any) :rtype: core.model.patron.Patron """ - self._logger.info('Started fetching an authenticated patron associated with the request') + self._logger.info( + "Started fetching an authenticated patron associated with the request" + ) patron = self.authenticated_patron_from_request() - self._logger.info('Finished fetching an authenticated patron associated with the request: {0}'.format(patron)) + self._logger.info( + "Finished fetching an authenticated patron associated with the request: {0}".format( + patron + ) + ) return patron @@ -48,11 +54,13 @@ def _get_lcp_passphrase(self, patron): """ db = Session.object_session(patron) - self._logger.info('Started fetching a patron\'s LCP passphrase') + self._logger.info("Started fetching a patron's LCP passphrase") lcp_passphrase = self._credential_factory.get_patron_passphrase(db, patron) - self._logger.info('Finished fetching a patron\'s LCP passphrase: {0}'.format(lcp_passphrase)) + self._logger.info( + "Finished fetching a patron's LCP passphrase: {0}".format(lcp_passphrase) + ) return lcp_passphrase @@ -70,7 +78,9 @@ def _get_lcp_collection(self, patron, collection_name): :rtype: core.model.collection.Collection """ db = Session.object_session(patron) - lcp_collection, _ = Collection.by_name_and_protocol(db, collection_name, ExternalIntegration.LCP) + lcp_collection, _ = Collection.by_name_and_protocol( + db, collection_name, ExternalIntegration.LCP + ) if not lcp_collection or lcp_collection not in patron.library.collections: return MISSING_COLLECTION @@ -83,16 +93,16 @@ def get_lcp_passphrase(self): :return: Flask response containing the LCP passphrase for the authenticated patron :rtype: Response """ - self._logger.info('Started fetching a patron\'s LCP passphrase') + self._logger.info("Started fetching a patron's LCP passphrase") patron = self._get_patron() lcp_passphrase = self._get_lcp_passphrase(patron) - self._logger.info('Finished fetching a patron\'s LCP passphrase: {0}'.format(lcp_passphrase)) + self._logger.info( + "Finished fetching a patron's LCP passphrase: {0}".format(lcp_passphrase) + ) - response = flask.jsonify({ - 'passphrase': lcp_passphrase - }) + response = flask.jsonify({"passphrase": lcp_passphrase}) return response @@ -108,7 +118,7 @@ def get_lcp_license(self, collection_name, license_id): :return: Flask response containing the LCP license with the specified ID :rtype: string """ - self._logger.info('Started fetching license # {0}'.format(license_id)) + self._logger.info("Started fetching license # {0}".format(license_id)) patron = self._get_patron() lcp_collection = self._get_lcp_collection(patron, collection_name) @@ -122,6 +132,8 @@ def get_lcp_license(self, collection_name, license_id): db = Session.object_session(patron) lcp_license = lcp_server.get_license(db, license_id, patron) - self._logger.info('Finished fetching license # {0}: {1}'.format(license_id, lcp_license)) + self._logger.info( + "Finished fetching license # {0}: {1}".format(license_id, lcp_license) + ) return flask.jsonify(lcp_license) diff --git a/api/lcp/encrypt.py b/api/lcp/encrypt.py index eacba7426a..7c2ab78150 100644 --- a/api/lcp/encrypt.py +++ b/api/lcp/encrypt.py @@ -9,7 +9,11 @@ from api.lcp import utils from core.exceptions import BaseError -from core.model.configuration import ConfigurationGrouping, ConfigurationMetadata, ConfigurationAttributeType +from core.model.configuration import ( + ConfigurationAttributeType, + ConfigurationGrouping, + ConfigurationMetadata, +) class LCPEncryptionException(BaseError): @@ -19,54 +23,54 @@ class LCPEncryptionException(BaseError): class LCPEncryptionConfiguration(ConfigurationGrouping): """Contains different settings required by LCPEncryptor""" - DEFAULT_LCPENCRYPT_LOCATION = '/go/bin/lcpencrypt' - DEFAULT_LCPENCRYPT_DOCKER_IMAGE = 'readium/lcpencrypt' + DEFAULT_LCPENCRYPT_LOCATION = "/go/bin/lcpencrypt" + DEFAULT_LCPENCRYPT_DOCKER_IMAGE = "readium/lcpencrypt" lcpencrypt_location = ConfigurationMetadata( - key='lcpencrypt_location', - label=_('lcpencrypt\'s location'), + key="lcpencrypt_location", + label=_("lcpencrypt's location"), description=_( - 'Full path to the local lcpencrypt binary. ' - 'The default value is {0}'.format( - DEFAULT_LCPENCRYPT_LOCATION - ) + "Full path to the local lcpencrypt binary. " + "The default value is {0}".format(DEFAULT_LCPENCRYPT_LOCATION) ), type=ConfigurationAttributeType.TEXT, required=False, - default=DEFAULT_LCPENCRYPT_LOCATION + default=DEFAULT_LCPENCRYPT_LOCATION, ) lcpencrypt_output_directory = ConfigurationMetadata( - key='lcpencrypt_output_directory', - label=_('lcpencrypt\'s output directory'), + key="lcpencrypt_output_directory", + label=_("lcpencrypt's output directory"), description=_( - 'Full path to the directory where lcpencrypt stores encrypted content. ' - 'If not set encrypted books will be stored in lcpencrypt\'s working directory'), + "Full path to the directory where lcpencrypt stores encrypted content. " + "If not set encrypted books will be stored in lcpencrypt's working directory" + ), type=ConfigurationAttributeType.TEXT, - required=False + required=False, ) class LCPEncryptionResult(object): """Represents an output sent by lcpencrypt""" - CONTENT_ID = 'content-id' - CONTENT_ENCRYPTION_KEY = 'content-encryption-key' - PROTECTED_CONTENT_LOCATION = 'protected-content-location' - PROTECTED_CONTENT_LENGTH = 'protected-content-length' - PROTECTED_CONTENT_SHA256 = 'protected-content-sha256' - PROTECTED_CONTENT_DISPOSITION = 'protected-content-disposition' - PROTECTED_CONTENT_TYPE = 'protected-content-type' + CONTENT_ID = "content-id" + CONTENT_ENCRYPTION_KEY = "content-encryption-key" + PROTECTED_CONTENT_LOCATION = "protected-content-location" + PROTECTED_CONTENT_LENGTH = "protected-content-length" + PROTECTED_CONTENT_SHA256 = "protected-content-sha256" + PROTECTED_CONTENT_DISPOSITION = "protected-content-disposition" + PROTECTED_CONTENT_TYPE = "protected-content-type" def __init__( - self, - content_id, - content_encryption_key, - protected_content_location, - protected_content_disposition, - protected_content_type, - protected_content_length, - protected_content_sha256): + self, + content_id, + content_encryption_key, + protected_content_location, + protected_content_disposition, + protected_content_type, + protected_content_length, + protected_content_sha256, + ): """Initializes a new instance of LCPEncryptorResult class :param: content_id: Content identifier @@ -176,7 +180,9 @@ def from_dict(cls, result_dict): protected_content_location = result_dict.get(cls.PROTECTED_CONTENT_LOCATION) protected_content_length = result_dict.get(cls.PROTECTED_CONTENT_LENGTH) protected_content_sha256 = result_dict.get(cls.PROTECTED_CONTENT_SHA256) - protected_content_disposition = result_dict.get(cls.PROTECTED_CONTENT_DISPOSITION) + protected_content_disposition = result_dict.get( + cls.PROTECTED_CONTENT_DISPOSITION + ) protected_content_type = result_dict.get(cls.PROTECTED_CONTENT_TYPE) return cls( @@ -186,7 +192,7 @@ def from_dict(cls, result_dict): protected_content_disposition=protected_content_disposition, protected_content_type=protected_content_type, protected_content_length=protected_content_length, - protected_content_sha256=protected_content_sha256 + protected_content_sha256=protected_content_sha256, ) def __eq__(self, other): @@ -201,14 +207,16 @@ def __eq__(self, other): if not isinstance(other, LCPEncryptionResult): return False - return \ - self.content_id == other.content_id and \ - self.content_encryption_key == other.content_encryption_key and \ - self.protected_content_location == other.protected_content_location and \ - self.protected_content_length == other.protected_content_length and \ - self.protected_content_sha256 == other.protected_content_sha256 and \ - self.protected_content_disposition == other.protected_content_disposition and \ - self.protected_content_type == other.protected_content_type + return ( + self.content_id == other.content_id + and self.content_encryption_key == other.content_encryption_key + and self.protected_content_location == other.protected_content_location + and self.protected_content_length == other.protected_content_length + and self.protected_content_sha256 == other.protected_content_sha256 + and self.protected_content_disposition + == other.protected_content_disposition + and self.protected_content_type == other.protected_content_type + ) def __repr__(self): """Returns a string representation of a LCPEncryptorResult object @@ -216,23 +224,24 @@ def __repr__(self): :return: string representation of a LCPEncryptorResult object :rtype: string """ - return \ - ''.format( + return ( + "".format( self.content_id, self.content_encryption_key, self.protected_content_location, self.protected_content_length, self.protected_content_sha256, self.protected_content_disposition, - self.protected_content_type + self.protected_content_type, ) + ) class LCPEncryptorResultJSONEncoder(JSONEncoder): @@ -248,16 +257,16 @@ def default(self, result): :rtype: string """ if not isinstance(result, LCPEncryptionResult): - raise ValueError('result must have type LCPEncryptorResult') + raise ValueError("result must have type LCPEncryptorResult") result = { - 'content-id': result.content_id, - 'content-encryption-key': result.content_encryption_key, - 'protected-content-location': result.protected_content_location, - 'protected-content-length': result.protected_content_length, - 'protected-content-sha256': result.protected_content_sha256, - 'protected-content-disposition': result.protected_content_disposition, - 'protected-content-type': result.protected_content_type + "content-id": result.content_id, + "content-encryption-key": result.content_encryption_key, + "protected-content-location": result.protected_content_location, + "protected-content-length": result.protected_content_length, + "protected-content-sha256": result.protected_content_sha256, + "protected-content-disposition": result.protected_content_disposition, + "protected-content-type": result.protected_content_type, } return result @@ -296,7 +305,7 @@ def __init__(self, file_path, identifier, configuration): output_directory, identifier + target_extension if target_extension not in identifier - else identifier + else identifier, ) self._output_file_path = output_file_path @@ -345,21 +354,18 @@ def to_array(self): """ parameters = [ self._lcpencrypt_location, - '-input', + "-input", self._input_file_path, - '-contentid', - self._content_id + "-contentid", + self._content_id, ] if self._output_file_path: - parameters.extend([ - '-output', - self._output_file_path - ]) + parameters.extend(["-output", self._output_file_path]) return parameters - OUTPUT_REGEX = re.compile(r'(\{.+\})?(.+)', re.DOTALL) + OUTPUT_REGEX = re.compile(r"(\{.+\})?(.+)", re.DOTALL) def __init__(self, configuration_storage, configuration_factory): """Initializes a new instance of LCPEncryptor class @@ -394,7 +400,7 @@ def _parse_output(self, output): :return: Encryption result :rtype: LCPEncryptionResult """ - bracket_index = output.find('{') + bracket_index = output.find("{") if bracket_index > 0: output = output[bracket_index:] @@ -402,12 +408,12 @@ def _parse_output(self, output): match = self.OUTPUT_REGEX.match(output) if not match: - raise LCPEncryptionException('Output has a wrong format') + raise LCPEncryptionException("Output has a wrong format") match_groups = match.groups() if not match_groups: - raise LCPEncryptionException('Output has a wrong format') + raise LCPEncryptionException("Output has a wrong format") if not match_groups[0]: raise LCPEncryptionException(match_groups[1].strip()) @@ -416,10 +422,12 @@ def _parse_output(self, output): json_result = json.loads(json_output) result = LCPEncryptionResult.from_dict(json_result) - if not result.protected_content_length or \ - not result.protected_content_sha256 or \ - not result.content_encryption_key: - raise LCPEncryptionException('Encryption failed') + if ( + not result.protected_content_length + or not result.protected_content_sha256 + or not result.content_encryption_key + ): + raise LCPEncryptionException("Encryption failed") return result @@ -439,7 +447,7 @@ def _run_lcpencrypt_locally(self, file_path, identifier, configuration): :rtype: LCPEncryptionResult """ self._logger.info( - 'Started running a local lcpencrypt binary. File path: {0}. Identifier: {1}'.format( + "Started running a local lcpencrypt binary. File path: {0}. Identifier: {1}".format( file_path, identifier ) ) @@ -448,26 +456,40 @@ def _run_lcpencrypt_locally(self, file_path, identifier, configuration): try: if parameters.output_file_path: - self._logger.info('Creating a directory tree for {0}'.format(parameters.output_file_path)) + self._logger.info( + "Creating a directory tree for {0}".format( + parameters.output_file_path + ) + ) output_directory = os.path.dirname(parameters.output_file_path) if not os.path.exists(output_directory): os.makedirs(output_directory) - self._logger.info('Directory tree {0} has been successfully created'.format(output_directory)) + self._logger.info( + "Directory tree {0} has been successfully created".format( + output_directory + ) + ) - self._logger.info('Running lcpencrypt using the following parameters: {0}'.format(parameters.to_array())) + self._logger.info( + "Running lcpencrypt using the following parameters: {0}".format( + parameters.to_array() + ) + ) output = subprocess.check_output(parameters.to_array()) result = self._parse_output(output) except Exception as exception: - self._logger.exception('An unhandled exception occurred during running a local lcpencrypt binary') + self._logger.exception( + "An unhandled exception occurred during running a local lcpencrypt binary" + ) raise LCPEncryptionException(str(exception), inner_exception=exception) self._logger.info( - 'Finished running a local lcpencrypt binary. File path: {0}. Identifier: {1}. Result: {2}'.format( + "Finished running a local lcpencrypt binary. File path: {0}. Identifier: {1}. Result: {2}".format( file_path, identifier, result ) ) @@ -490,9 +512,12 @@ def encrypt(self, db, file_path, identifier): :rtype: LCPEncryptionResult """ with self._configuration_factory.create( - self._configuration_storage, db, LCPEncryptionConfiguration) as configuration: + self._configuration_storage, db, LCPEncryptionConfiguration + ) as configuration: if self._lcpencrypt_exists_locally(configuration): - result = self._run_lcpencrypt_locally(file_path, identifier, configuration) + result = self._run_lcpencrypt_locally( + file_path, identifier, configuration + ) return result else: diff --git a/api/lcp/factory.py b/api/lcp/factory.py index b8784b7ad3..fcdbccadf6 100644 --- a/api/lcp/factory.py +++ b/api/lcp/factory.py @@ -1,7 +1,7 @@ from api.lcp.hash import HasherFactory from api.lcp.server import LCPServer from core.lcp.credential import LCPCredentialFactory -from core.model.configuration import ConfigurationStorage, ConfigurationFactory +from core.model.configuration import ConfigurationFactory, ConfigurationStorage class LCPServerFactory(object): @@ -20,6 +20,11 @@ def create(self, integration_association): configuration_factory = ConfigurationFactory() hasher_factory = HasherFactory() credential_factory = LCPCredentialFactory() - lcp_server = LCPServer(configuration_storage, configuration_factory, hasher_factory, credential_factory) + lcp_server = LCPServer( + configuration_storage, + configuration_factory, + hasher_factory, + credential_factory, + ) return lcp_server diff --git a/api/lcp/hash.py b/api/lcp/hash.py index 0d538ceb31..f90eb6957f 100644 --- a/api/lcp/hash.py +++ b/api/lcp/hash.py @@ -1,14 +1,13 @@ import hashlib from abc import ABCMeta, abstractmethod - from enum import Enum from core.exceptions import BaseError class HashingAlgorithm(Enum): - SHA256 = 'http://www.w3.org/2001/04/xmlenc#sha256' - SHA512 = 'http://www.w3.org/2001/04/xmlenc#sha512' + SHA256 = "http://www.w3.org/2001/04/xmlenc#sha256" + SHA512 = "http://www.w3.org/2001/04/xmlenc#sha512" class HashingError(BaseError): @@ -33,12 +32,20 @@ def hash(self, value): class UniversalHasher(Hasher): def hash(self, value): - if self._hashing_algorithm in [HashingAlgorithm.SHA256, HashingAlgorithm.SHA256.value]: + if self._hashing_algorithm in [ + HashingAlgorithm.SHA256, + HashingAlgorithm.SHA256.value, + ]: return hashlib.sha256(value.encode("utf-8")).hexdigest() - elif self._hashing_algorithm in [HashingAlgorithm.SHA512, HashingAlgorithm.SHA512.value]: + elif self._hashing_algorithm in [ + HashingAlgorithm.SHA512, + HashingAlgorithm.SHA512.value, + ]: return hashlib.sha512(value.encode("utf-8")).hexdigest() else: - raise HashingError('Unknown hashing algorithm {0}'.format(self._hashing_algorithm)) + raise HashingError( + "Unknown hashing algorithm {0}".format(self._hashing_algorithm) + ) class HasherFactory(object): diff --git a/api/lcp/mirror.py b/api/lcp/mirror.py index acd18e7ee6..9c1745663c 100644 --- a/api/lcp/mirror.py +++ b/api/lcp/mirror.py @@ -9,22 +9,26 @@ from api.lcp.server import LCPServer from core.lcp.credential import LCPCredentialFactory from core.mirror import MirrorUploader -from core.model import ExternalIntegration, Collection -from core.model.collection import HasExternalIntegrationPerCollection, CollectionConfigurationStorage -from core.model.configuration import ConfigurationAttributeType, \ - ConfigurationMetadata, ConfigurationFactory -from core.s3 import MinIOUploader, S3UploaderConfiguration, MinIOUploaderConfiguration +from core.model import Collection, ExternalIntegration +from core.model.collection import ( + CollectionConfigurationStorage, + HasExternalIntegrationPerCollection, +) +from core.model.configuration import ( + ConfigurationAttributeType, + ConfigurationFactory, + ConfigurationMetadata, +) +from core.s3 import MinIOUploader, MinIOUploaderConfiguration, S3UploaderConfiguration class LCPMirrorConfiguration(S3UploaderConfiguration): endpoint_url = ConfigurationMetadata( key=MinIOUploaderConfiguration.endpoint_url.key, - label=_('Endpoint URL'), - description=_( - 'S3 endpoint URL' - ), + label=_("Endpoint URL"), + description=_("S3 endpoint URL"), type=ConfigurationAttributeType.TEXT, - required=False + required=False, ) @@ -44,7 +48,7 @@ class LCPMirror(MinIOUploader, HasExternalIntegrationPerCollection): S3UploaderConfiguration.s3_addressing_style.to_settings(), S3UploaderConfiguration.s3_presigned_url_expiration.to_settings(), S3UploaderConfiguration.url_template.to_settings(), - LCPMirrorConfiguration.endpoint_url.to_settings() + LCPMirrorConfiguration.endpoint_url.to_settings(), ] def __init__(self, integration): @@ -71,7 +75,12 @@ def _create_lcp_importer(self, collection): hasher_factory = HasherFactory() credential_factory = LCPCredentialFactory() lcp_encryptor = LCPEncryptor(configuration_storage, configuration_factory) - lcp_server = LCPServer(configuration_storage, configuration_factory, hasher_factory, credential_factory) + lcp_server = LCPServer( + configuration_storage, + configuration_factory, + hasher_factory, + credential_factory, + ) lcp_importer = LCPImporter(lcp_encryptor, lcp_server) return lcp_importer @@ -86,13 +95,12 @@ def collection_external_integration(self, collection): :rtype: core.model.configuration.ExternalIntegration """ db = Session.object_session(collection) - external_integration = db \ - .query(ExternalIntegration) \ - .join(Collection) \ - .filter( - Collection.id == collection.id - ) \ + external_integration = ( + db.query(ExternalIntegration) + .join(Collection) + .filter(Collection.id == collection.id) .one() + ) return external_integration @@ -102,11 +110,20 @@ def cover_image_root(self, bucket, data_source, scaled_size=None): def marc_file_root(self, bucket, library): raise NotImplementedError() - def book_url(self, identifier, extension='.epub', open_access=False, data_source=None, title=None): + def book_url( + self, + identifier, + extension=".epub", + open_access=False, + data_source=None, + title=None, + ): """Returns the path to the hosted EPUB file for the given identifier.""" bucket = self.get_bucket( - S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY if open_access - else S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY) + S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY + if open_access + else S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY + ) root = self.content_root(bucket) book_url = root + self.key_join([identifier.identifier]) @@ -133,11 +150,13 @@ def mirror_one(self, representation, mirror_to, collection=None): db = Session.object_session(representation) bucket = self.get_bucket(S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY) content_root = self.content_root(bucket) - identifier = mirror_to.replace(content_root, '') + identifier = mirror_to.replace(content_root, "") lcp_importer = self._create_lcp_importer(collection) # First, we need to copy unencrypted book's content to a temporary file - with tempfile.NamedTemporaryFile(suffix=representation.extension(representation.media_type)) as temporary_file: + with tempfile.NamedTemporaryFile( + suffix=representation.extension(representation.media_type) + ) as temporary_file: temporary_file.write(representation.content_fh().read()) temporary_file.flush() diff --git a/api/lcp/server.py b/api/lcp/server.py index 556620cce5..c418f47fd5 100644 --- a/api/lcp/server.py +++ b/api/lcp/server.py @@ -9,111 +9,125 @@ from api.lcp import utils from api.lcp.encrypt import LCPEncryptionResult, LCPEncryptorResultJSONEncoder from api.lcp.hash import HashingAlgorithm -from core.model.configuration import ConfigurationGrouping, ConfigurationMetadata, ConfigurationAttributeType, \ - ConfigurationOption +from core.model.configuration import ( + ConfigurationAttributeType, + ConfigurationGrouping, + ConfigurationMetadata, + ConfigurationOption, +) class LCPServerConfiguration(ConfigurationGrouping): """Contains LCP License Server's settings""" DEFAULT_PAGE_SIZE = 100 - DEFAULT_PASSPHRASE_HINT = 'If you do not remember your passphrase, please contact your administrator' + DEFAULT_PASSPHRASE_HINT = ( + "If you do not remember your passphrase, please contact your administrator" + ) DEFAULT_ENCRYPTION_ALGORITHM = HashingAlgorithm.SHA256.value lcpserver_url = ConfigurationMetadata( - key='lcpserver_url', - label=_('LCP License Server\'s URL'), - description=_('URL of the LCP License Server'), + key="lcpserver_url", + label=_("LCP License Server's URL"), + description=_("URL of the LCP License Server"), type=ConfigurationAttributeType.TEXT, - required=True + required=True, ) lcpserver_user = ConfigurationMetadata( - key='lcpserver_user', - label=_('LCP License Server\'s user'), - description=_('Name of the user used to connect to the LCP License Server'), + key="lcpserver_user", + label=_("LCP License Server's user"), + description=_("Name of the user used to connect to the LCP License Server"), type=ConfigurationAttributeType.TEXT, - required=True + required=True, ) lcpserver_password = ConfigurationMetadata( - key='lcpserver_password', - label=_('LCP License Server\'s password'), - description=_('Password of the user used to connect to the LCP License Server'), + key="lcpserver_password", + label=_("LCP License Server's password"), + description=_("Password of the user used to connect to the LCP License Server"), type=ConfigurationAttributeType.TEXT, - required=True + required=True, ) lcpserver_input_directory = ConfigurationMetadata( - key='lcpserver_input_directory', - label=_('LCP License Server\'s input directory'), + key="lcpserver_input_directory", + label=_("LCP License Server's input directory"), description=_( - 'Full path to the directory containing encrypted books. ' - 'This directory should be the same as lcpencrypt\'s output directory' + "Full path to the directory containing encrypted books. " + "This directory should be the same as lcpencrypt's output directory" ), type=ConfigurationAttributeType.TEXT, - required=True + required=True, ) lcpserver_page_size = ConfigurationMetadata( - key='lcpserver_page_size', - label=_('LCP License Server\'s page size'), - description=_('Number of licences returned by the server'), + key="lcpserver_page_size", + label=_("LCP License Server's page size"), + description=_("Number of licences returned by the server"), type=ConfigurationAttributeType.NUMBER, required=False, - default=DEFAULT_PAGE_SIZE + default=DEFAULT_PAGE_SIZE, ) provider_name = ConfigurationMetadata( - key='provider_name', - label=_('LCP service provider\'s identifier'), - description=_( - 'URI that identifies the provider in an unambiguous way' - ), + key="provider_name", + label=_("LCP service provider's identifier"), + description=_("URI that identifies the provider in an unambiguous way"), type=ConfigurationAttributeType.TEXT, - required=True + required=True, ) passphrase_hint = ConfigurationMetadata( - key='passphrase_hint', - label=_('Passphrase hint'), - description=_('Hint proposed to the user for selecting their passphrase'), + key="passphrase_hint", + label=_("Passphrase hint"), + description=_("Hint proposed to the user for selecting their passphrase"), type=ConfigurationAttributeType.TEXT, required=False, - default=DEFAULT_PASSPHRASE_HINT + default=DEFAULT_PASSPHRASE_HINT, ) encryption_algorithm = ConfigurationMetadata( - key='encryption_algorithm', - label=_('Passphrase encryption algorithm'), - description=_('Algorithm used for encrypting the passphrase'), + key="encryption_algorithm", + label=_("Passphrase encryption algorithm"), + description=_("Algorithm used for encrypting the passphrase"), type=ConfigurationAttributeType.SELECT, required=False, default=DEFAULT_ENCRYPTION_ALGORITHM, - options=ConfigurationOption.from_enum(HashingAlgorithm) + options=ConfigurationOption.from_enum(HashingAlgorithm), ) max_printable_pages = ConfigurationMetadata( - key='max_printable_pages', - label=_('Maximum number or printable pages'), - description=_('Maximum number of pages that can be printed over the lifetime of the license'), + key="max_printable_pages", + label=_("Maximum number or printable pages"), + description=_( + "Maximum number of pages that can be printed over the lifetime of the license" + ), type=ConfigurationAttributeType.NUMBER, - required=False + required=False, ) max_copiable_pages = ConfigurationMetadata( - key='max_copiable_pages', - label=_('Maximum number or copiable characters'), - description=_('Maximum number of characters that can be copied to the clipboard'), + key="max_copiable_pages", + label=_("Maximum number or copiable characters"), + description=_( + "Maximum number of characters that can be copied to the clipboard" + ), type=ConfigurationAttributeType.NUMBER, - required=False + required=False, ) class LCPServer(object): """Wrapper around LCP License Server's API""" - def __init__(self, configuration_storage, configuration_factory, hasher_factory, credential_factory): + def __init__( + self, + configuration_storage, + configuration_factory, + hasher_factory, + credential_factory, + ): """Initializes a new instance of LCPServer class :param configuration_storage: ConfigurationStorage object @@ -145,11 +159,14 @@ def _get_hasher(self, configuration): """ if self._hasher_instance is None: self._hasher_instance = self._hasher_factory.create( - configuration.encryption_algorithm) + configuration.encryption_algorithm + ) return self._hasher_instance - def _create_partial_license(self, db, configuration, patron, license_start=None, license_end=None): + def _create_partial_license( + self, db, configuration, patron, license_start=None, license_end=None + ): """Creates a partial LCP license used an input by the LCP License Server for generation of LCP licenses :param configuration: Configuration object @@ -168,39 +185,56 @@ def _create_partial_license(self, db, configuration, patron, license_start=None, :rtype: Dict """ hasher = self._get_hasher(configuration) - hashed_passphrase = hasher.hash(self._credential_factory.get_patron_passphrase(db, patron)) + hashed_passphrase = hasher.hash( + self._credential_factory.get_patron_passphrase(db, patron) + ) self._credential_factory.set_hashed_passphrase(db, patron, hashed_passphrase) partial_license = { - 'provider': configuration.provider_name, - 'encryption': { - 'user_key': { - 'text_hint': configuration.passphrase_hint, - 'hex_value': hashed_passphrase, + "provider": configuration.provider_name, + "encryption": { + "user_key": { + "text_hint": configuration.passphrase_hint, + "hex_value": hashed_passphrase, } - } + }, } if patron: - partial_license['user'] = { - 'id': self._credential_factory.get_patron_id(db, patron) + partial_license["user"] = { + "id": self._credential_factory.get_patron_id(db, patron) } rights_fields = [ - license_start, license_end, configuration.max_printable_pages, configuration.max_copiable_pages] - - if any([rights_field is not None and rights_field != '' for rights_field in rights_fields]): - partial_license['rights'] = {} + license_start, + license_end, + configuration.max_printable_pages, + configuration.max_copiable_pages, + ] + + if any( + [ + rights_field is not None and rights_field != "" + for rights_field in rights_fields + ] + ): + partial_license["rights"] = {} if license_start: - partial_license['rights']['start'] = utils.format_datetime(license_start) + partial_license["rights"]["start"] = utils.format_datetime(license_start) if license_end: - partial_license['rights']['end'] = utils.format_datetime(license_end) - if configuration.max_printable_pages is not None and configuration.max_printable_pages != '': - partial_license['rights']['print'] = int(configuration.max_printable_pages) - if configuration.max_copiable_pages is not None and configuration.max_copiable_pages != '': - partial_license['rights']['copy'] = int(configuration.max_copiable_pages) + partial_license["rights"]["end"] = utils.format_datetime(license_end) + if ( + configuration.max_printable_pages is not None + and configuration.max_printable_pages != "" + ): + partial_license["rights"]["print"] = int(configuration.max_printable_pages) + if ( + configuration.max_copiable_pages is not None + and configuration.max_copiable_pages != "" + ): + partial_license["rights"]["copy"] = int(configuration.max_copiable_pages) return partial_license @@ -229,11 +263,10 @@ def _send_request(configuration, method, path, payload, json_encoder=None): method, url, data=json_payload, - headers={'Content-Type': 'application/json'}, + headers={"Content-Type": "application/json"}, auth=HTTPBasicAuth( - configuration.lcpserver_user, - configuration.lcpserver_password - ) + configuration.lcpserver_user, configuration.lcpserver_password + ), ) response.raise_for_status() @@ -250,9 +283,12 @@ def add_content(self, db, encrypted_content): :type encrypted_content: LCPEncryptionResult """ with self._configuration_factory.create( - self._configuration_storage, db, LCPServerConfiguration) as configuration: + self._configuration_storage, db, LCPServerConfiguration + ) as configuration: content_location = os.path.join( - configuration.lcpserver_input_directory, encrypted_content.protected_content_disposition) + configuration.lcpserver_input_directory, + encrypted_content.protected_content_disposition, + ) payload = LCPEncryptionResult( content_id=encrypted_content.content_id, content_encryption_key=encrypted_content.content_encryption_key, @@ -260,11 +296,13 @@ def add_content(self, db, encrypted_content): protected_content_disposition=encrypted_content.protected_content_disposition, protected_content_type=encrypted_content.protected_content_type, protected_content_length=encrypted_content.protected_content_length, - protected_content_sha256=encrypted_content.protected_content_sha256 + protected_content_sha256=encrypted_content.protected_content_sha256, ) - path = '/contents/{0}'.format(encrypted_content.content_id) + path = "/contents/{0}".format(encrypted_content.content_id) - self._send_request(configuration, 'put', path, payload, LCPEncryptorResultJSONEncoder) + self._send_request( + configuration, "put", path, payload, LCPEncryptorResultJSONEncoder + ) def generate_license(self, db, content_id, patron, license_start, license_end): """Generates a new LCP license @@ -291,11 +329,15 @@ def generate_license(self, db, content_id, patron, license_start, license_end): :rtype: Dict """ with self._configuration_factory.create( - self._configuration_storage, db, LCPServerConfiguration) as configuration: + self._configuration_storage, db, LCPServerConfiguration + ) as configuration: partial_license_payload = self._create_partial_license( - db, configuration, patron, license_start, license_end) - path = 'contents/{0}/license'.format(content_id) - response = self._send_request(configuration, 'post', path, partial_license_payload) + db, configuration, patron, license_start, license_end + ) + path = "contents/{0}/license".format(content_id) + response = self._send_request( + configuration, "post", path, partial_license_payload + ) return response.json() @@ -315,10 +357,15 @@ def get_license(self, db, license_id, patron): :rtype: string """ with self._configuration_factory.create( - self._configuration_storage, db, LCPServerConfiguration) as configuration: - partial_license_payload = self._create_partial_license(db, configuration, patron) - path = 'licenses/{0}'.format(license_id) + self._configuration_storage, db, LCPServerConfiguration + ) as configuration: + partial_license_payload = self._create_partial_license( + db, configuration, patron + ) + path = "licenses/{0}".format(license_id) - response = self._send_request(configuration, 'post', path, partial_license_payload) + response = self._send_request( + configuration, "post", path, partial_license_payload + ) return response.json() diff --git a/api/lcp/utils.py b/api/lcp/utils.py index 557bce1c0e..290b4f87fe 100644 --- a/api/lcp/utils.py +++ b/api/lcp/utils.py @@ -10,24 +10,24 @@ def format_datetime(datetime_value): :return: String representation of the datetime value :rtype: string """ - datetime_string_value = datetime_value.strftime('%Y-%m-%dT%H:%M:%S') + datetime_string_value = datetime_value.strftime("%Y-%m-%dT%H:%M:%S") # NOTE: Go can parse only strings where the timezone contains a colon (e.g., -07:00) # Unfortunately, Python doesn't support such format and we have to do it manually # We assume that all the dates are in UTC - datetime_string_value += '+00:00' + datetime_string_value += "+00:00" return datetime_string_value def get_target_extension(input_extension): - if input_extension == '.epub': - target_extension = '.epub' - elif input_extension == '.pdf': - target_extension = '.lcpdf' - elif input_extension == '.lpf': + if input_extension == ".epub": + target_extension = ".epub" + elif input_extension == ".pdf": + target_extension = ".lcpdf" + elif input_extension == ".lpf": target_extension = ".audiobook" - elif input_extension == '.audiobook': + elif input_extension == ".audiobook": target_extension = ".audiobook" else: raise LCPError('Unknown extension "{0}"'.format(input_extension)) diff --git a/api/local_analytics_exporter.py b/api/local_analytics_exporter.py index 8637b179ca..47427702af 100644 --- a/api/local_analytics_exporter.py +++ b/api/local_analytics_exporter.py @@ -1,18 +1,9 @@ import logging -import unicodecsv as csv from io import BytesIO -from sqlalchemy.sql import ( - func, - select, -) -from sqlalchemy.sql.expression import ( - and_, - case, - literal_column, - join, - or_, -) +import unicodecsv as csv +from sqlalchemy.sql import func, select +from sqlalchemy.sql.expression import and_, case, join, literal_column, or_ from core.model import ( CirculationEvent, @@ -24,6 +15,7 @@ WorkGenre, ) + class LocalAnalyticsExporter(object): """Export large numbers of analytics events in CSV format.""" @@ -35,9 +27,20 @@ def export(self, _db, start, end, locations=None, library=None): # Write the CSV file to a BytesIO. header = [ - "time", "event", "identifier", "identifier_type", "title", "author", - "fiction", "audience", "publisher", "imprint", "language", - "target_age", "genres", "location" + "time", + "event", + "identifier", + "identifier_type", + "title", + "author", + "fiction", + "audience", + "publisher", + "imprint", + "language", + "target_age", + "genres", + "location", ] output = BytesIO() writer = csv.writer(output, encoding="utf-8") @@ -45,7 +48,7 @@ def export(self, _db, start, end, locations=None, library=None): writer.writerows(results) return output.getvalue().decode("utf-8") - def analytics_query(self, start, end, locations=None, library=None): + def analytics_query(self, start, end, locations=None, library=None): """Build a database query that fetches rows of analytics data. This method uses low-level SQLAlchemy code to do all @@ -66,7 +69,7 @@ def analytics_query(self, start, end, locations=None, library=None): event_types = [ CirculationEvent.CM_CHECKOUT, CirculationEvent.CM_FULFILL, - CirculationEvent.OPEN_BOOK + CirculationEvent.OPEN_BOOK, ] locations = locations.strip().split(",") @@ -76,52 +79,48 @@ def analytics_query(self, start, end, locations=None, library=None): ] if library: - clauses += [ - CirculationEvent.library == library - ] + clauses += [CirculationEvent.library == library] # Build the primary query. This is a query against the # CirculationEvent table and a few other tables joined against # it. This makes up the bulk of the data. - events_alias = select( - [ - func.to_char( - CirculationEvent.start, "YYYY-MM-DD HH24:MI:SS" - ).label("start"), - CirculationEvent.type.label("event_type"), - Identifier.identifier, - Identifier.type.label("identifier_type"), - Edition.sort_title, - Edition.sort_author, - case( - [(Work.fiction==True, literal_column("'fiction'"))], - else_=literal_column("'nonfiction'") - ).label("fiction"), - Work.id.label("work_id"), - Work.audience, - Edition.publisher, - Edition.imprint, - Edition.language, - CirculationEvent.location, - ], - ).select_from( - join( - CirculationEvent, LicensePool, - CirculationEvent.license_pool_id==LicensePool.id - ).join( - Identifier, - LicensePool.identifier_id==Identifier.id - ).join( - Work, - Work.id==LicensePool.work_id - ).join( - Edition, Work.presentation_edition_id==Edition.id + events_alias = ( + select( + [ + func.to_char(CirculationEvent.start, "YYYY-MM-DD HH24:MI:SS").label( + "start" + ), + CirculationEvent.type.label("event_type"), + Identifier.identifier, + Identifier.type.label("identifier_type"), + Edition.sort_title, + Edition.sort_author, + case( + [(Work.fiction == True, literal_column("'fiction'"))], + else_=literal_column("'nonfiction'"), + ).label("fiction"), + Work.id.label("work_id"), + Work.audience, + Edition.publisher, + Edition.imprint, + Edition.language, + CirculationEvent.location, + ], ) - ).where( - and_(*clauses) - ).order_by( - CirculationEvent.start.asc() - ).alias("events_alias") + .select_from( + join( + CirculationEvent, + LicensePool, + CirculationEvent.license_pool_id == LicensePool.id, + ) + .join(Identifier, LicensePool.identifier_id == Identifier.id) + .join(Work, Work.id == LicensePool.work_id) + .join(Edition, Work.presentation_edition_id == Edition.id) + ) + .where(and_(*clauses)) + .order_by(CirculationEvent.start.asc()) + .alias("events_alias") + ) # A subquery can hook into the main query by referencing its # 'work_id' field in its WHERE clause. @@ -136,29 +135,20 @@ def analytics_query(self, start, end, locations=None, library=None): # This Alias selects some number of rows, each containing one # string column (Genre.name). Genres with higher affinities with # this work go first. - genres_alias = select( - [Genre.name.label("genre_name")] - ).select_from( - join( - WorkGenre, Genre, - WorkGenre.genre_id==Genre.id - ) - ).where( - WorkGenre.work_id==work_id_column - ).order_by( - WorkGenre.affinity.desc(), Genre.name - ).alias("genres_subquery") + genres_alias = ( + select([Genre.name.label("genre_name")]) + .select_from(join(WorkGenre, Genre, WorkGenre.genre_id == Genre.id)) + .where(WorkGenre.work_id == work_id_column) + .order_by(WorkGenre.affinity.desc(), Genre.name) + .alias("genres_subquery") + ) # Use array_agg() to consolidate the rows into one row -- this # gives us a single value, an array of strings, for each # Work. Then use array_to_string to convert the array into a # single comma-separated string. genres = select( - [ - func.array_to_string( - func.array_agg(genres_alias.c.genre_name), "," - ) - ] + [func.array_to_string(func.array_agg(genres_alias.c.genre_name), ",")] ).select_from(genres_alias) # This subquery gets the a Work's target age as a single string. @@ -167,25 +157,25 @@ def analytics_query(self, start, end, locations=None, library=None): # This Alias selects two fields: the lower and upper bounds of # the Work's target age. This reuses code originally written # for Work.to_search_documents(). - target_age = Work.target_age_query(work_id_column).alias( - "target_age_subquery" - ) + target_age = Work.target_age_query(work_id_column).alias("target_age_subquery") # Concatenate the lower and upper bounds with a dash in the # middle. If both lower and upper bound are empty, just give # the empty string. This simulates the behavior of # Work.target_age_string. - target_age_string = select([ - case( - [ - (or_(target_age.c.lower != None, - target_age.c.upper != None), - func.concat(target_age.c.lower, "-", target_age.c.upper)) - ], - else_=literal_column("''") - ) - ]).select_from(target_age) - + target_age_string = select( + [ + case( + [ + ( + or_(target_age.c.lower != None, target_age.c.upper != None), + func.concat(target_age.c.lower, "-", target_age.c.upper), + ) + ], + else_=literal_column("''"), + ) + ] + ).select_from(target_age) # Build the main query out of the subqueries. events = events_alias.c @@ -202,11 +192,9 @@ def analytics_query(self, start, end, locations=None, library=None): events.publisher, events.imprint, events.language, - target_age_string.label('target_age'), - genres.label('genres'), + target_age_string.label("target_age"), + genres.label("genres"), events.location, ] - ).select_from( - events_alias - ) + ).select_from(events_alias) return query diff --git a/api/marc.py b/api/marc.py index 73941cd7e0..9c581e3f79 100644 --- a/api/marc.py +++ b/api/marc.py @@ -1,38 +1,52 @@ +import urllib.error +import urllib.parse +import urllib.request + from pymarc import Field -import urllib.request, urllib.parse, urllib.error from core.config import Configuration -from core.marc import ( - Annotator, - MARCExporter, -) -from core.model import ( - ConfigurationSetting, - Session, -) +from core.marc import Annotator, MARCExporter +from core.model import ConfigurationSetting, Session + class LibraryAnnotator(Annotator): def __init__(self, library): super(LibraryAnnotator, self).__init__() self.library = library _db = Session.object_session(library) - self.base_url = ConfigurationSetting.sitewide(_db, Configuration.BASE_URL_KEY).value + self.base_url = ConfigurationSetting.sitewide( + _db, Configuration.BASE_URL_KEY + ).value def value(self, key, integration): _db = Session.object_session(integration) return ConfigurationSetting.for_library_and_externalintegration( - _db, key, self.library, integration).value - + _db, key, self.library, integration + ).value - def annotate_work_record(self, work, active_license_pool, edition, - identifier, record, integration=None, updated=None): + def annotate_work_record( + self, + work, + active_license_pool, + edition, + identifier, + record, + integration=None, + updated=None, + ): super(LibraryAnnotator, self).annotate_work_record( - work, active_license_pool, edition, identifier, record, integration, updated) + work, active_license_pool, edition, identifier, record, integration, updated + ) if integration: marc_org = self.value(MARCExporter.MARC_ORGANIZATION_CODE, integration) - include_summary = (self.value(MARCExporter.INCLUDE_SUMMARY, integration) == "true") - include_genres = (self.value(MARCExporter.INCLUDE_SIMPLIFIED_GENRES, integration) == "true") + include_summary = ( + self.value(MARCExporter.INCLUDE_SUMMARY, integration) == "true" + ) + include_genres = ( + self.value(MARCExporter.INCLUDE_SIMPLIFIED_GENRES, integration) + == "true" + ) if marc_org: self.add_marc_organization_code(record, marc_org) @@ -55,14 +69,20 @@ def add_web_client_urls(self, record, library, identifier, integration=None): settings.append(marc_setting) from api.registry import Registration - settings += [s.value for s in _db.query( - ConfigurationSetting - ).filter( - ConfigurationSetting.key==Registration.LIBRARY_REGISTRATION_WEB_CLIENT, - ConfigurationSetting.library_id==library.id - ) if s.value] - qualified_identifier = urllib.parse.quote(identifier.type + "/" + identifier.identifier, safe='') + settings += [ + s.value + for s in _db.query(ConfigurationSetting).filter( + ConfigurationSetting.key + == Registration.LIBRARY_REGISTRATION_WEB_CLIENT, + ConfigurationSetting.library_id == library.id, + ) + if s.value + ] + + qualified_identifier = urllib.parse.quote( + identifier.type + "/" + identifier.identifier, safe="" + ) for web_client_base_url in settings: link = "{}/{}/works/{}".format( @@ -70,14 +90,12 @@ def add_web_client_urls(self, record, library, identifier, integration=None): library.short_name, qualified_identifier, ) - encoded_link = urllib.parse.quote(link, safe='') - url = "{}/book/{}".format( - web_client_base_url, - encoded_link - ) + encoded_link = urllib.parse.quote(link, safe="") + url = "{}/book/{}".format(web_client_base_url, encoded_link) record.add_field( Field( tag="856", indicators=["4", "0"], subfields=["u", url], - )) + ) + ) diff --git a/api/metadata_wrangler.py b/api/metadata_wrangler.py index 69baee7672..86aa47b854 100644 --- a/api/metadata_wrangler.py +++ b/api/metadata_wrangler.py @@ -1,17 +1,14 @@ # Code relating to the interaction between the circulation manager # and the metadata wrangler. import datetime -import feedparser from io import StringIO -from lxml import etree +import feedparser +from lxml import etree from sqlalchemy import and_, func, or_ -from sqlalchemy.orm import ( - aliased, - contains_eager, -) +from sqlalchemy.orm import aliased, contains_eager -from .config import CannotLoadConfiguration +from api.coverage import OPDSImportCoverageProvider, ReaperImporter, RegistrarImporter from core.coverage import CoverageFailure from core.metadata_layer import TimestampData from core.model import ( @@ -22,23 +19,13 @@ Session, Timestamp, ) -from core.monitor import ( - CollectionMonitor, -) +from core.monitor import CollectionMonitor from core.opds import AcquisitionFeed -from core.opds_import import ( - MetadataWranglerOPDSLookup, - OPDSImporter, - OPDSXMLParser, -) - +from core.opds_import import MetadataWranglerOPDSLookup, OPDSImporter, OPDSXMLParser from core.util.http import RemoteIntegrationException -from api.coverage import ( - OPDSImportCoverageProvider, - RegistrarImporter, - ReaperImporter, -) +from .config import CannotLoadConfiguration + class MetadataWranglerCollectionMonitor(CollectionMonitor): @@ -47,16 +34,16 @@ class MetadataWranglerCollectionMonitor(CollectionMonitor): """ def __init__(self, _db, collection, lookup=None): - super(MetadataWranglerCollectionMonitor, self).__init__( - _db, collection - ) + super(MetadataWranglerCollectionMonitor, self).__init__(_db, collection) self.lookup = lookup or MetadataWranglerOPDSLookup.from_config( self._db, collection=collection ) self.importer = OPDSImporter( - self._db, self.collection, + self._db, + self.collection, data_source_name=DataSource.METADATA_WRANGLER, - metadata_client=self.lookup, map_from_collection=True, + metadata_client=self.lookup, + map_from_collection=True, ) def get_response(self, url=None, **kwargs): @@ -69,8 +56,7 @@ def get_response(self, url=None, **kwargs): return response except RemoteIntegrationException as e: self.log.error( - "Error getting feed for %r: %s", - self.collection, e.debug_message + "Error getting feed for %r: %s", self.collection, e.debug_message ) raise e @@ -125,8 +111,7 @@ def run_once(self, progress): total_editions += len(editions) achievements = "Editions processed: %s" % total_editions if not new_timestamp or ( - possible_new_timestamp - and possible_new_timestamp > new_timestamp + possible_new_timestamp and possible_new_timestamp > new_timestamp ): # We imported an OPDS feed that included an entry with # a certain 'last updated' timestamp (or was empty but @@ -179,8 +164,9 @@ def import_one_feed(self, timestamp, url): # Import the metadata raw_feed = response.text - (editions, licensepools, - works, errors) = self.importer.import_from_feed(raw_feed) + (editions, licensepools, works, errors) = self.importer.import_from_feed( + raw_feed + ) # TODO: this oughtn't be necessary, because import_from_feed # already parsed the feed, but there's no way to access the @@ -198,9 +184,7 @@ def import_one_feed(self, timestamp, url): timestamp = min(update_dates) else: # Look for a timestamp on the feed level. - feed_timestamp = self.importer._datetime( - parsed['feed'], 'updated_parsed' - ) + feed_timestamp = self.importer._datetime(parsed["feed"], "updated_parsed") # Subtract one day from the time to reduce the chance of # race conditions. Otherwise, work done but not committed @@ -229,9 +213,7 @@ class MWAuxiliaryMetadataMonitor(MetadataWranglerCollectionMonitor): DEFAULT_START_TIME = CollectionMonitor.NEVER def __init__(self, _db, collection, lookup=None, provider=None): - super(MWAuxiliaryMetadataMonitor, self).__init__( - _db, collection, lookup=lookup - ) + super(MWAuxiliaryMetadataMonitor, self).__init__(_db, collection, lookup=lookup) self.parser = OPDSXMLParser() self.provider = provider or MetadataUploadCoverageProvider( collection, lookup_client=lookup @@ -258,8 +240,9 @@ def run_once(self, progress): # they have a presentation-ready work. (This prevents creating # CoverageRecords for identifiers that don't actually have metadata # to send.) - identifiers = [i for i in identifiers - if i.work and i.work.simple_opds_entry] + identifiers = [ + i for i in identifiers if i.work and i.work.simple_opds_entry + ] total_identifiers_processed += len(identifiers) self.provider.bulk_register(identifiers) self.provider.run_on_specific_identifiers(identifiers) @@ -384,8 +367,9 @@ def items_that_need_coverage(self, identifiers=None, **kwargs): # Start with all items in this Collection that have not been # registered. - uncovered = super(MetadataWranglerCollectionRegistrar, self)\ - .items_that_need_coverage(identifiers, **kwargs) + uncovered = super( + MetadataWranglerCollectionRegistrar, self + ).items_that_need_coverage(identifiers, **kwargs) # Make sure they're actually available through this # collection. uncovered = uncovered.filter( @@ -394,30 +378,38 @@ def items_that_need_coverage(self, identifiers=None, **kwargs): # Exclude items that have been reaped because we stopped # having a license. - reaper_covered = self._db.query(Identifier)\ - .join(Identifier.coverage_records)\ + reaper_covered = ( + self._db.query(Identifier) + .join(Identifier.coverage_records) .filter( - CoverageRecord.data_source_id==self.data_source.id, - CoverageRecord.collection_id==self.collection_id, - CoverageRecord.operation==CoverageRecord.REAP_OPERATION + CoverageRecord.data_source_id == self.data_source.id, + CoverageRecord.collection_id == self.collection_id, + CoverageRecord.operation == CoverageRecord.REAP_OPERATION, ) + ) # If any items were reaped earlier but have since been # relicensed or otherwise added back to the collection, remove # their reaper CoverageRecords. This ensures we get Metadata # Wrangler coverage for books that have had their licenses # repurchased or extended. - relicensed = reaper_covered.join(Identifier.licensed_through).filter( - LicensePool.collection_id==self.collection_id, - or_(LicensePool.licenses_owned > 0, LicensePool.open_access) - ).options(contains_eager(Identifier.coverage_records)) + relicensed = ( + reaper_covered.join(Identifier.licensed_through) + .filter( + LicensePool.collection_id == self.collection_id, + or_(LicensePool.licenses_owned > 0, LicensePool.open_access), + ) + .options(contains_eager(Identifier.coverage_records)) + ) needs_commit = False for identifier in relicensed.all(): for record in identifier.coverage_records: - if (record.data_source_id==self.data_source.id and - record.collection_id==self.collection_id and - record.operation==CoverageRecord.REAP_OPERATION): + if ( + record.data_source_id == self.data_source.id + and record.collection_id == self.collection_id + and record.operation == CoverageRecord.REAP_OPERATION + ): # Delete any reaper CoverageRecord for this Identifier # in this Collection. self._db.delete(record) @@ -446,16 +438,19 @@ def api_method(self): return self.lookup_client.remove def items_that_need_coverage(self, identifiers=None, **kwargs): - """Retrieves Identifiers that were imported but are no longer licensed. - """ - qu = self._db.query(Identifier).select_from(LicensePool).\ - join(LicensePool.identifier).join(CoverageRecord).\ - filter(LicensePool.collection_id==self.collection_id).\ - filter(LicensePool.licenses_owned==0, LicensePool.open_access!=True).\ - filter(CoverageRecord.data_source_id==self.data_source.id).\ - filter(CoverageRecord.operation==CoverageRecord.IMPORT_OPERATION).\ - filter(CoverageRecord.status==CoverageRecord.SUCCESS).\ - filter(CoverageRecord.collection==self.collection) + """Retrieves Identifiers that were imported but are no longer licensed.""" + qu = ( + self._db.query(Identifier) + .select_from(LicensePool) + .join(LicensePool.identifier) + .join(CoverageRecord) + .filter(LicensePool.collection_id == self.collection_id) + .filter(LicensePool.licenses_owned == 0, LicensePool.open_access != True) + .filter(CoverageRecord.data_source_id == self.data_source.id) + .filter(CoverageRecord.operation == CoverageRecord.IMPORT_OPERATION) + .filter(CoverageRecord.status == CoverageRecord.SUCCESS) + .filter(CoverageRecord.collection == self.collection) + ) if identifiers: qu = qu.filter(Identifier.id.in_([x.id for x in identifiers])) @@ -471,22 +466,22 @@ def finalize_batch(self): # 'import' CoverageRecords that have been obviated by a # 'reaper' coverage record for the same Identifier. reaper_coverage = aliased(CoverageRecord) - qu = self._db.query(CoverageRecord).join( - reaper_coverage, - CoverageRecord.identifier_id==reaper_coverage.identifier_id - - # The CoverageRecords were selecting are 'import' records. - ).filter( - CoverageRecord.data_source_id==self.data_source.id - ).filter( - CoverageRecord.operation==CoverageRecord.IMPORT_OPERATION - - # And we're only selecting them if there's also a 'reaper' - # coverage record. - ).filter( - reaper_coverage.data_source_id==self.data_source.id - ).filter( - reaper_coverage.operation==CoverageRecord.REAP_OPERATION + qu = ( + self._db.query(CoverageRecord) + .join( + reaper_coverage, + CoverageRecord.identifier_id == reaper_coverage.identifier_id + # The CoverageRecords were selecting are 'import' records. + ) + .filter(CoverageRecord.data_source_id == self.data_source.id) + .filter( + CoverageRecord.operation + == CoverageRecord.IMPORT_OPERATION + # And we're only selecting them if there's also a 'reaper' + # coverage record. + ) + .filter(reaper_coverage.data_source_id == self.data_source.id) + .filter(reaper_coverage.operation == CoverageRecord.REAP_OPERATION) ) # Delete all 'import' CoverageRecords that have been reaped. @@ -499,13 +494,14 @@ class MetadataUploadCoverageProvider(BaseMetadataWranglerCoverageProvider): """Provide coverage for identifiers by uploading OPDS metadata to the metadata wrangler. """ + DEFAULT_BATCH_SIZE = 25 SERVICE_NAME = "Metadata Upload Coverage Provider" OPERATION = CoverageRecord.METADATA_UPLOAD_OPERATION DATA_SOURCE_NAME = DataSource.INTERNAL_PROCESSING def __init__(self, *args, **kwargs): - kwargs['registered_only'] = kwargs.get('registered_only', True) + kwargs["registered_only"] = kwargs.get("registered_only", True) super(MetadataUploadCoverageProvider, self).__init__(*args, **kwargs) def process_batch(self, batch): diff --git a/api/millenium_patron.py b/api/millenium_patron.py index 297b024ec4..d2c7e78c9c 100644 --- a/api/millenium_patron.py +++ b/api/millenium_patron.py @@ -1,125 +1,150 @@ -import dateutil +import datetime import logging -from lxml import etree +import os +import re from urllib import parse -import datetime + +import dateutil import requests -from money import Money from flask_babel import lazy_gettext as _ +from lxml import etree +from money import Money -from core.util.datetime_helpers import ( - datetime_utc, - utc_now, -) -from core.util.xmlparser import XMLParser -from .authenticator import ( - BasicAuthenticationProvider, - PatronData, -) -from .config import ( - Configuration, - CannotLoadConfiguration, -) -import os -import re -from core.model import ( - get_one, - get_one_or_create, - ExternalIntegration, - Patron, -) -from core.util.http import HTTP +from core.model import ExternalIntegration, Patron, get_one, get_one_or_create from core.util import MoneyUtility +from core.util.datetime_helpers import datetime_utc, utc_now +from core.util.http import HTTP +from core.util.xmlparser import XMLParser + +from .authenticator import BasicAuthenticationProvider, PatronData +from .config import CannotLoadConfiguration, Configuration + class MilleniumPatronAPI(BasicAuthenticationProvider, XMLParser): NAME = "Millenium" - RECORD_NUMBER_FIELD = 'RECORD #[p81]' - PATRON_TYPE_FIELD = 'P TYPE[p47]' - EXPIRATION_FIELD = 'EXP DATE[p43]' - HOME_BRANCH_FIELD = 'HOME LIBR[p53]' - ADDRESS_FIELD = 'ADDRESS[pa]' - BARCODE_FIELD = 'P BARCODE[pb]' - USERNAME_FIELD = 'ALT ID[pu]' - FINES_FIELD = 'MONEY OWED[p96]' - BLOCK_FIELD = 'MBLOCK[p56]' - ERROR_MESSAGE_FIELD = 'ERRMSG' - PERSONAL_NAME_FIELD = 'PATRN NAME[pn]' - EMAIL_ADDRESS_FIELD = 'EMAIL ADDR[pz]' - EXPIRATION_DATE_FORMAT = '%m-%d-%y' - - MULTIVALUE_FIELDS = set(['NOTE[px]', BARCODE_FIELD]) + RECORD_NUMBER_FIELD = "RECORD #[p81]" + PATRON_TYPE_FIELD = "P TYPE[p47]" + EXPIRATION_FIELD = "EXP DATE[p43]" + HOME_BRANCH_FIELD = "HOME LIBR[p53]" + ADDRESS_FIELD = "ADDRESS[pa]" + BARCODE_FIELD = "P BARCODE[pb]" + USERNAME_FIELD = "ALT ID[pu]" + FINES_FIELD = "MONEY OWED[p96]" + BLOCK_FIELD = "MBLOCK[p56]" + ERROR_MESSAGE_FIELD = "ERRMSG" + PERSONAL_NAME_FIELD = "PATRN NAME[pn]" + EMAIL_ADDRESS_FIELD = "EMAIL ADDR[pz]" + EXPIRATION_DATE_FORMAT = "%m-%d-%y" + + MULTIVALUE_FIELDS = set(["NOTE[px]", BARCODE_FIELD]) DEFAULT_CURRENCY = "USD" # Identifiers that contain any of these strings are ignored when # finding the "correct" identifier in a patron's record, even if # it means they end up with no identifier at all. - IDENTIFIER_BLACKLIST = 'identifier_blacklist' + IDENTIFIER_BLACKLIST = "identifier_blacklist" # A configuration value for whether or not to validate the SSL certificate # of the Millenium Patron API server. VERIFY_CERTIFICATE = "verify_certificate" # The field to use when validating a patron's credential. - AUTHENTICATION_MODE = 'auth_mode' - PIN_AUTHENTICATION_MODE = 'pin' - FAMILY_NAME_AUTHENTICATION_MODE = 'family_name' - - NEIGHBORHOOD_MODE = 'neighborhood_mode' - NO_NEIGHBORHOOD_MODE = 'disabled' - HOME_BRANCH_NEIGHBORHOOD_MODE = 'home_branch' - POSTAL_CODE_NEIGHBORHOOD_MODE = 'postal_code' + AUTHENTICATION_MODE = "auth_mode" + PIN_AUTHENTICATION_MODE = "pin" + FAMILY_NAME_AUTHENTICATION_MODE = "family_name" + + NEIGHBORHOOD_MODE = "neighborhood_mode" + NO_NEIGHBORHOOD_MODE = "disabled" + HOME_BRANCH_NEIGHBORHOOD_MODE = "home_branch" + POSTAL_CODE_NEIGHBORHOOD_MODE = "postal_code" NEIGHBORHOOD_MODES = set( - [NO_NEIGHBORHOOD_MODE, HOME_BRANCH_NEIGHBORHOOD_MODE, POSTAL_CODE_NEIGHBORHOOD_MODE] + [ + NO_NEIGHBORHOOD_MODE, + HOME_BRANCH_NEIGHBORHOOD_MODE, + POSTAL_CODE_NEIGHBORHOOD_MODE, + ] ) # The field to use when seeing which values of MBLOCK[p56] mean a patron # is blocked. By default, any value other than '-' indicates a block. - BLOCK_TYPES = 'block_types' + BLOCK_TYPES = "block_types" - AUTHENTICATION_MODES = [ - PIN_AUTHENTICATION_MODE, FAMILY_NAME_AUTHENTICATION_MODE - ] + AUTHENTICATION_MODES = [PIN_AUTHENTICATION_MODE, FAMILY_NAME_AUTHENTICATION_MODE] SETTINGS = [ - { "key": ExternalIntegration.URL, "format": "url", "label": _("URL"), "required": True }, - { "key": VERIFY_CERTIFICATE, "label": _("Certificate Verification"), - "type": "select", "options": [ - { "key": "true", "label": _("Verify Certificate Normally (Required for production)") }, - { "key": "false", "label": _("Ignore Certificate Problems (For temporary testing only)") }, - ], - "default": "true" + { + "key": ExternalIntegration.URL, + "format": "url", + "label": _("URL"), + "required": True, }, - { "key": BLOCK_TYPES, "label": _("Block types"), - "description": _("Values of MBLOCK[p56] which mean a patron is blocked. By default, any value other than '-' indicates a block."), + { + "key": VERIFY_CERTIFICATE, + "label": _("Certificate Verification"), + "type": "select", + "options": [ + { + "key": "true", + "label": _("Verify Certificate Normally (Required for production)"), + }, + { + "key": "false", + "label": _( + "Ignore Certificate Problems (For temporary testing only)" + ), + }, + ], + "default": "true", }, - { "key": IDENTIFIER_BLACKLIST, "label": _("Identifier Blacklist"), - "type": "list", - "description": _("Identifiers containing any of these strings are ignored when finding the 'correct' " + - "identifier for a patron's record, even if it means they end up with no identifier at all. " + - "If librarians invalidate library cards by adding strings like \"EXPIRED\" or \"INVALID\" " + - "on to the beginning of the card number, put those strings here so the Circulation Manager " + - "knows they do not represent real card numbers."), + { + "key": BLOCK_TYPES, + "label": _("Block types"), + "description": _( + "Values of MBLOCK[p56] which mean a patron is blocked. By default, any value other than '-' indicates a block." + ), }, - { "key": AUTHENTICATION_MODE, "label": _("Authentication Mode"), - "type": "select", - "options": [ - { "key": PIN_AUTHENTICATION_MODE, "label": _("PIN") }, - { "key": FAMILY_NAME_AUTHENTICATION_MODE, "label": _("Family Name") }, - ], - "default": PIN_AUTHENTICATION_MODE + { + "key": IDENTIFIER_BLACKLIST, + "label": _("Identifier Blacklist"), + "type": "list", + "description": _( + "Identifiers containing any of these strings are ignored when finding the 'correct' " + + "identifier for a patron's record, even if it means they end up with no identifier at all. " + + 'If librarians invalidate library cards by adding strings like "EXPIRED" or "INVALID" ' + + "on to the beginning of the card number, put those strings here so the Circulation Manager " + + "knows they do not represent real card numbers." + ), + }, + { + "key": AUTHENTICATION_MODE, + "label": _("Authentication Mode"), + "type": "select", + "options": [ + {"key": PIN_AUTHENTICATION_MODE, "label": _("PIN")}, + {"key": FAMILY_NAME_AUTHENTICATION_MODE, "label": _("Family Name")}, + ], + "default": PIN_AUTHENTICATION_MODE, }, { "key": NEIGHBORHOOD_MODE, "label": _("Patron neighborhood field"), - "description": _("It's sometimes possible to guess a patron's neighborhood from their ILS record. You can use this when analyzing circulation activity by neighborhood. If you don't need to do this, it's better for patron privacy to disable this feature."), + "description": _( + "It's sometimes possible to guess a patron's neighborhood from their ILS record. You can use this when analyzing circulation activity by neighborhood. If you don't need to do this, it's better for patron privacy to disable this feature." + ), "type": "select", "options": [ - { "key": NO_NEIGHBORHOOD_MODE, "label": _("Disable this feature") }, - { "key": HOME_BRANCH_NEIGHBORHOOD_MODE, "label": _("Patron's home library branch is their neighborhood.") }, - { "key": POSTAL_CODE_NEIGHBORHOOD_MODE, "label": _("Patron's postal code is their neighborhood.") }, + {"key": NO_NEIGHBORHOOD_MODE, "label": _("Disable this feature")}, + { + "key": HOME_BRANCH_NEIGHBORHOOD_MODE, + "label": _("Patron's home library branch is their neighborhood."), + }, + { + "key": POSTAL_CODE_NEIGHBORHOOD_MODE, + "label": _("Patron's postal code is their neighborhood."), + }, ], "default": NO_NEIGHBORHOOD_MODE, }, @@ -128,16 +153,20 @@ class MilleniumPatronAPI(BasicAuthenticationProvider, XMLParser): # Replace library settings to allow text in identifier field. LIBRARY_SETTINGS = [] for setting in BasicAuthenticationProvider.LIBRARY_SETTINGS: - if setting['key'] == BasicAuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: - LIBRARY_SETTINGS.append({ - "key": BasicAuthenticationProvider.LIBRARY_IDENTIFIER_FIELD, - "label": _("Library Identifier Field"), - "description": _("This is the field on the patron record that the Library Identifier Restriction " + - "Type is applied to. The option 'barcode' matches the users barcode, other " + - "values are pulled directly from the patron record for example: 'P TYPE[p47]'. " + - "This value is not used if Library Identifier Restriction Type " + - "is set to 'No restriction'."), - }) + if setting["key"] == BasicAuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: + LIBRARY_SETTINGS.append( + { + "key": BasicAuthenticationProvider.LIBRARY_IDENTIFIER_FIELD, + "label": _("Library Identifier Field"), + "description": _( + "This is the field on the patron record that the Library Identifier Restriction " + + "Type is applied to. The option 'barcode' matches the users barcode, other " + + "values are pulled directly from the patron record for example: 'P TYPE[p47]'. " + + "This value is not used if Library Identifier Restriction Type " + + "is set to 'No restriction'." + ), + } + ) else: LIBRARY_SETTINGS.append(setting) @@ -145,15 +174,14 @@ def __init__(self, library, integration, analytics=None): super(MilleniumPatronAPI, self).__init__(library, integration, analytics) url = integration.url if not url: - raise CannotLoadConfiguration( - "Millenium Patron API server not configured." - ) + raise CannotLoadConfiguration("Millenium Patron API server not configured.") - if not url.endswith('/'): + if not url.endswith("/"): url = url + "/" self.root = url self.verify_certificate = integration.setting( - self.VERIFY_CERTIFICATE).json_value + self.VERIFY_CERTIFICATE + ).json_value if self.verify_certificate is None: self.verify_certificate = True self.parser = etree.HTMLParser() @@ -162,13 +190,17 @@ def __init__(self, library, integration, analytics=None): # identifiers, some of which are not real library cards. A # blacklist allows us to exclude certain types of identifiers # from being considered as library cards. - authorization_identifier_blacklist = integration.setting( - self.IDENTIFIER_BLACKLIST).json_value or [] - self.blacklist = [re.compile(x, re.I) - for x in authorization_identifier_blacklist] + authorization_identifier_blacklist = ( + integration.setting(self.IDENTIFIER_BLACKLIST).json_value or [] + ) + self.blacklist = [ + re.compile(x, re.I) for x in authorization_identifier_blacklist + ] - auth_mode = integration.setting( - self.AUTHENTICATION_MODE).value or self.PIN_AUTHENTICATION_MODE + auth_mode = ( + integration.setting(self.AUTHENTICATION_MODE).value + or self.PIN_AUTHENTICATION_MODE + ) if auth_mode not in self.AUTHENTICATION_MODES: raise CannotLoadConfiguration( @@ -178,16 +210,17 @@ def __init__(self, library, integration, analytics=None): self.block_types = integration.setting(self.BLOCK_TYPES).value or None - neighborhood_mode = integration.setting( - self.NEIGHBORHOOD_MODE - ).value or self.NO_NEIGHBORHOOD_MODE + neighborhood_mode = ( + integration.setting(self.NEIGHBORHOOD_MODE).value + or self.NO_NEIGHBORHOOD_MODE + ) if neighborhood_mode not in self.NEIGHBORHOOD_MODES: raise CannotLoadConfiguration( - "Unrecognized Millenium Patron API neighborhood mode: %s." % neighborhood_mode + "Unrecognized Millenium Patron API neighborhood mode: %s." + % neighborhood_mode ) self.neighborhood_mode = neighborhood_mode - # Begin implementation of BasicAuthenticationProvider abstract # methods. @@ -215,14 +248,14 @@ def remote_authenticate(self, username, password): # The PIN is URL-encoded. The username is not: as far as # we can tell Millenium Patron doesn't even try to decode # it. - quoted_password = parse.quote(password, safe='') if password else password + quoted_password = parse.quote(password, safe="") if password else password path = "%(barcode)s/%(pin)s/pintest" % dict( barcode=username, pin=quoted_password ) url = self.root + path response = self.request(url) data = dict(self._extract_text_nodes(response.content)) - if data.get('RETCOD') == '0': + if data.get("RETCOD") == "0": return PatronData(authorization_identifier=username, complete=False) return False elif self.auth_mode == self.FAMILY_NAME_AUTHENTICATION_MODE: @@ -245,10 +278,10 @@ def family_name_match(self, actual_name, supposed_family_name): """Does `supposed_family_name` match `actual_name`?""" if actual_name is None or supposed_family_name is None: return False - if actual_name.find(',') != -1: - actual_family_name = actual_name.split(',')[0] + if actual_name.find(",") != -1: + actual_family_name = actual_name.split(",")[0] else: - actual_name_split = actual_name.split(' ') + actual_name_split = actual_name.split(" ") actual_family_name = actual_name_split[-1] if actual_family_name.upper() == supposed_family_name.upper(): return True @@ -265,7 +298,6 @@ def _remote_patron_lookup(self, patron_or_patrondata_or_identifier): response = self.request(url) return self.patron_dump_to_patrondata(identifier, response.content) - # End implementation of BasicAuthenticationProvider abstract # methods. @@ -280,7 +312,7 @@ def _update_request_kwargs(self, kwargs): """Modify the kwargs to HTTP.request_with_timeout to reflect the API configuration, in a testable way. """ - kwargs['verify'] = self.verify_certificate + kwargs["verify"] = self.verify_certificate @classmethod def _patron_block_reason(cls, block_types, mblock_value): @@ -292,7 +324,7 @@ def _patron_block_reason(cls, block_types, mblock_value): if not block_types: # Apply the default rules. - if not mblock_value or mblock_value.strip() in ('', '-'): + if not mblock_value or mblock_value.strip() in ("", "-"): # This patron is not blocked at all. return PatronData.NO_VALUE else: @@ -367,23 +399,26 @@ def patron_dump_to_patrondata(self, current_identifier, content): # Parse the expiration date according to server local # time, not UTC. expires_local = datetime.datetime.strptime( - v, self.EXPIRATION_DATE_FORMAT).replace( - tzinfo=dateutil.tz.tzlocal() - ) + v, self.EXPIRATION_DATE_FORMAT + ).replace(tzinfo=dateutil.tz.tzlocal()) expires_local = expires_local.date() authorization_expires = expires_local except ValueError: self.log.warn( 'Malformed expiration date for patron: "%s". Treating as unexpirable.', - v + v, ) elif k == self.PATRON_TYPE_FIELD: external_type = v - elif (k == self.HOME_BRANCH_FIELD - and self.neighborhood_mode == self.HOME_BRANCH_NEIGHBORHOOD_MODE): + elif ( + k == self.HOME_BRANCH_FIELD + and self.neighborhood_mode == self.HOME_BRANCH_NEIGHBORHOOD_MODE + ): neighborhood = v.strip() - elif (k == self.ADDRESS_FIELD - and self.neighborhood_mode == self.POSTAL_CODE_NEIGHBORHOOD_MODE): + elif ( + k == self.ADDRESS_FIELD + and self.neighborhood_mode == self.POSTAL_CODE_NEIGHBORHOOD_MODE + ): neighborhood = self.extract_postal_code(v) elif k == self.ERROR_MESSAGE_FIELD: # An error has occured. Most likely the patron lookup @@ -431,7 +466,7 @@ def patron_dump_to_patrondata(self, current_identifier, content): # database record because syncing with the ILS is so # expensive. cached_neighborhood=neighborhood, - complete=True + complete=True, ) return data @@ -440,25 +475,26 @@ def _extract_text_nodes(self, content): if isinstance(content, bytes): content = content.decode("utf8") for line in content.split("\n"): - if line.startswith(''): + if line.startswith(""): line = line[12:] - if not line.endswith('
    '): + if not line.endswith("
    "): continue kv = line[:-4] - if not '=' in kv: + if not "=" in kv: # This shouldn't happen, but there's no need to crash. self.log.warn("Unexpected line in patron dump: %s", line) continue - yield kv.split('=', 1) + yield kv.split("=", 1) # A number of regular expressions for finding postal codes in # freeform addresses, with more reliable techniques at the front. POSTAL_CODE_RES = [ - re.compile(x) for x in [ - "[^0-9]([0-9]{5})-[0-9]{4}$", # ZIP+4 at end - "[^0-9]([0-9]{5})$", # ZIP at end - ".*[^0-9]([0-9]{5})-[0-9]{4}[^0-9]", # ZIP+4 as close to end as possible without being at the end - ".*[^0-9]([0-9]{5})[^0-9]", # ZIP as close to end as possible without being at the end + re.compile(x) + for x in [ + "[^0-9]([0-9]{5})-[0-9]{4}$", # ZIP+4 at end + "[^0-9]([0-9]{5})$", # ZIP at end + ".*[^0-9]([0-9]{5})-[0-9]{4}[^0-9]", # ZIP+4 as close to end as possible without being at the end + ".*[^0-9]([0-9]{5})[^0-9]", # ZIP as close to end as possible without being at the end ] ] @@ -489,7 +525,7 @@ class MockMilleniumPatronAPI(MilleniumPatronAPI): permanent_id="12345", authorization_identifier="0", username="alice", - authorization_expires = datetime_utc(2015, 4, 1) + authorization_expires=datetime_utc(2015, 4, 1), ) # This user's card still has ten days on it. @@ -498,7 +534,7 @@ class MockMilleniumPatronAPI(MilleniumPatronAPI): permanent_id="67890", authorization_identifier="5", username="bob", - authorization_expires = the_future, + authorization_expires=the_future, ) users = [user1, user2] @@ -518,7 +554,7 @@ def remote_authenticate(self, barcode, pin): """ u = self.dump(barcode) - if 'ERRNUM' in u: + if "ERRNUM" in u: return False return len(barcode) == 14 or pin == barcode[0] * 4 @@ -530,4 +566,5 @@ def remote_patron_lookup(self, patron_or_patrondata): return u return None + AuthenticationProvider = MilleniumPatronAPI diff --git a/api/monitor.py b/api/monitor.py index 978cd6d527..6638c32120 100644 --- a/api/monitor.py +++ b/api/monitor.py @@ -2,15 +2,8 @@ import os import sys -from sqlalchemy import ( - and_, - or_, -) +from sqlalchemy import and_, or_ -from core.monitor import ( - EditionSweepMonitor, - ReaperMonitor, -) from core.model import ( Annotation, Collection, @@ -22,11 +15,10 @@ LicensePool, Loan, ) +from core.monitor import EditionSweepMonitor, ReaperMonitor from core.util.datetime_helpers import utc_now -from .odl import ( - ODLAPI, - SharedODLAPI, -) + +from .odl import ODLAPI, SharedODLAPI class LoanlikeReaperMonitor(ReaperMonitor): @@ -51,25 +43,26 @@ def where_clause(self): Subclasses will append extra clauses to this filter. """ source_of_truth = or_( - LicensePool.open_access==True, - ExternalIntegration.protocol.in_( - self.SOURCE_OF_TRUTH_PROTOCOLS - ) + LicensePool.open_access == True, + ExternalIntegration.protocol.in_(self.SOURCE_OF_TRUTH_PROTOCOLS), ) - source_of_truth_subquery = self._db.query(self.MODEL_CLASS.id).join( - self.MODEL_CLASS.license_pool).join( - LicensePool.collection).join( - ExternalIntegration, - Collection.external_integration_id==ExternalIntegration.id - ).filter( - source_of_truth - ) + source_of_truth_subquery = ( + self._db.query(self.MODEL_CLASS.id) + .join(self.MODEL_CLASS.license_pool) + .join(LicensePool.collection) + .join( + ExternalIntegration, + Collection.external_integration_id == ExternalIntegration.id, + ) + .filter(source_of_truth) + ) return ~self.MODEL_CLASS.id.in_(source_of_truth_subquery) class LoanReaper(LoanlikeReaperMonitor): """Remove expired and abandoned loans from the database.""" + MODEL_CLASS = Loan MAX_AGE = 90 @@ -84,15 +77,17 @@ def where_clause(self): now = utc_now() expired = end_field < now very_old_with_no_clear_end_date = and_( - start_field < self.cutoff, - end_field == None + start_field < self.cutoff, end_field == None ) return and_(superclause, or_(expired, very_old_with_no_clear_end_date)) + + ReaperMonitor.REGISTRY.append(LoanReaper) class HoldReaper(LoanlikeReaperMonitor): """Remove seemingly abandoned holds from the database.""" + MODEL_CLASS = Hold MAX_AGE = 365 @@ -109,10 +104,11 @@ def where_clause(self): superclause = super(HoldReaper, self).where_clause end_date_in_past = end_field < utc_now() probably_abandoned = and_( - start_field < self.cutoff, - or_(end_field == None, end_date_in_past) + start_field < self.cutoff, or_(end_field == None, end_date_in_past) ) return and_(superclause, probably_abandoned) + + ReaperMonitor.REGISTRY.append(HoldReaper) @@ -120,7 +116,7 @@ class IdlingAnnotationReaper(ReaperMonitor): """Remove idling annotations for inactive loans.""" MODEL_CLASS = Annotation - TIMESTAMP_FIELD = 'timestamp' + TIMESTAMP_FIELD = "timestamp" MAX_AGE = 60 @property @@ -134,22 +130,21 @@ def where_clause(self): restrictions = [] for t in Loan, Hold: - active_subquery = self._db.query( - Annotation.id - ).join( - t, - t.patron_id==Annotation.patron_id - ).join( - LicensePool, - and_(LicensePool.id==t.license_pool_id, - LicensePool.identifier_id==Annotation.identifier_id) - ) - restrictions.append( - ~Annotation.id.in_(active_subquery) + active_subquery = ( + self._db.query(Annotation.id) + .join(t, t.patron_id == Annotation.patron_id) + .join( + LicensePool, + and_( + LicensePool.id == t.license_pool_id, + LicensePool.identifier_id == Annotation.identifier_id, + ), + ) ) + restrictions.append(~Annotation.id.in_(active_subquery)) return and_( - superclause, - Annotation.motivation==Annotation.IDLING, - *restrictions + superclause, Annotation.motivation == Annotation.IDLING, *restrictions ) + + ReaperMonitor.REGISTRY.append(IdlingAnnotationReaper) diff --git a/api/novelist.py b/api/novelist.py index 063edbf3bf..c20bcc16b6 100644 --- a/api/novelist.py +++ b/api/novelist.py @@ -1,17 +1,16 @@ import json import logging -import urllib.request, urllib.parse, urllib.error +import urllib.error +import urllib.parse +import urllib.request from collections import Counter + from flask_babel import lazy_gettext as _ +from sqlalchemy.orm import aliased +from sqlalchemy.sql import and_, join, or_, select -from core.config import ( - CannotLoadConfiguration, - Configuration, -) -from core.coverage import ( - CoverageFailure, - IdentifierCoverageProvider, -) +from core.config import CannotLoadConfiguration, Configuration +from core.coverage import CoverageFailure, IdentifierCoverageProvider from core.metadata_layer import ( ContributorData, IdentifierData, @@ -21,32 +20,26 @@ SubjectData, ) from core.model import ( + Collection, + Contribution, + Contributor, DataSource, + Edition, + Equivalency, ExternalIntegration, Hyperlink, Identifier, + LicensePool, Measurement, Representation, Session, Subject, get_one, - Equivalency, - LicensePool, - Collection, - Edition, - Contributor, - Contribution, ) from core.util import TitleProcessor -from sqlalchemy.sql import ( - select, - join, - and_, - or_, -) -from sqlalchemy.orm import aliased from core.util.http import HTTP + class NoveListAPI(object): PROTOCOL = ExternalIntegration.NOVELIST @@ -58,8 +51,8 @@ class NoveListAPI(object): AUTHORIZED_IDENTIFIER = "62521fa1-bdbb-4939-84aa-aee2a52c8d59" SETTINGS = [ - { "key": ExternalIntegration.USERNAME, "label": _("Profile"), "required": True }, - { "key": ExternalIntegration.PASSWORD, "label": _("Password"), "required": True }, + {"key": ExternalIntegration.USERNAME, "label": _("Profile"), "required": True}, + {"key": ExternalIntegration.PASSWORD, "label": _("Password"), "required": True}, ] # Different libraries may have different NoveList integrations @@ -83,13 +76,13 @@ class NoveListAPI(object): ) COLLECTION_DATA_API = "http://www.noveListcollectiondata.com/api/collections" AUTH_PARAMS = "&profile=%(profile)s&password=%(password)s" - MAX_REPRESENTATION_AGE = 7*24*60*60 # one week + MAX_REPRESENTATION_AGE = 7 * 24 * 60 * 60 # one week currentQueryIdentifier = None medium_to_book_format_type_values = { - Edition.BOOK_MEDIUM : "EBook", - Edition.AUDIO_MEDIUM : "Audiobook", + Edition.BOOK_MEDIUM: "EBook", + Edition.AUDIO_MEDIUM: "Audiobook", } @classmethod @@ -97,7 +90,8 @@ def from_config(cls, library): profile, password = cls.values(library) if not (profile and password): raise CannotLoadConfiguration( - "No NoveList integration configured for library (%s)." % library.short_name + "No NoveList integration configured for library (%s)." + % library.short_name ) _db = Session.object_session(library) @@ -108,8 +102,10 @@ def values(cls, library): _db = Session.object_session(library) integration = ExternalIntegration.lookup( - _db, ExternalIntegration.NOVELIST, - ExternalIntegration.METADATA_GOAL, library=library + _db, + ExternalIntegration.NOVELIST, + ExternalIntegration.METADATA_GOAL, + library=library, ) if not integration: @@ -121,9 +117,7 @@ def values(cls, library): @classmethod def is_configured(cls, library): - if (cls.IS_CONFIGURED is None or - library.id != cls._configuration_library_id - ): + if cls.IS_CONFIGURED is None or library.id != cls._configuration_library_id: profile, password = cls.values(library) cls.IS_CONFIGURED = bool(profile and password) cls._configuration_library_id = library.id @@ -149,16 +143,23 @@ def lookup_equivalent_isbns(self, identifier): # Find strong ISBN equivalents. isbns = list() for license_source in license_sources: - isbns += [eq.output for eq in identifier.equivalencies if ( - eq.data_source==license_source and - eq.strength==1 and - eq.output.type==Identifier.ISBN - )] + isbns += [ + eq.output + for eq in identifier.equivalencies + if ( + eq.data_source == license_source + and eq.strength == 1 + and eq.output.type == Identifier.ISBN + ) + ] if not isbns: self.log.warn( - ("Identifiers without an ISBN equivalent can't" - "be looked up with NoveList: %r"), identifier + ( + "Identifiers without an ISBN equivalent can't" + "be looked up with NoveList: %r" + ), + identifier, ) return None @@ -171,8 +172,11 @@ def lookup_equivalent_isbns(self, identifier): if not lookup_metadata: self.log.warn( - ("No NoveList metadata found for Identifiers without an ISBN" - "equivalent can't be looked up with NoveList: %r"), identifier + ( + "No NoveList metadata found for Identifiers without an ISBN" + "equivalent can't be looked up with NoveList: %r" + ), + identifier, ) return None @@ -189,11 +193,10 @@ def lookup_equivalent_isbns(self, identifier): def _confirm_same_identifier(self, metadata_objects): """Ensures that all metadata objects have the same NoveList ID""" - novelist_ids = set([ - metadata.primary_identifier.identifier - for metadata in metadata_objects - ]) - return len(novelist_ids)==1 + novelist_ids = set( + [metadata.primary_identifier.identifier for metadata in metadata_objects] + ) + return len(novelist_ids) == 1 def choose_best_metadata(self, metadata_objects, identifier): """Chooses the most likely book metadata from a list of Metadata objects @@ -213,14 +216,17 @@ def choose_best_metadata(self, metadata_objects, identifier): for metadata in metadata_objects: counter[metadata.primary_identifier] += 1 - [(target_identifier, most_amount), - (ignore, secondmost)] = counter.most_common(2) - if most_amount==secondmost: + [(target_identifier, most_amount), (ignore, secondmost)] = counter.most_common( + 2 + ) + if most_amount == secondmost: # The counts are the same, and neither can be trusted. self.log.warn(self.NO_ISBN_EQUIVALENCY, identifier) return None, None confidence = most_amount / float(len(metadata_objects)) - target_metadata = [m for m in metadata_objects if m.primary_identifier==target_identifier] + target_metadata = [ + m for m in metadata_objects if m.primary_identifier == target_identifier + ] return target_metadata[0], confidence def lookup(self, identifier, **kwargs): @@ -235,13 +241,16 @@ def lookup(self, identifier, **kwargs): return self.lookup_equivalent_isbns(identifier) params = dict( - ClientIdentifier=client_identifier, ISBN=identifier.identifier, - version=self.version, profile=self.profile, password=self.password + ClientIdentifier=client_identifier, + ISBN=identifier.identifier, + version=self.version, + profile=self.profile, + password=self.password, ) scrubbed_url = str(self.scrubbed_url(params)) url = self.build_query_url(params) - self.log.debug("NoveList lookup: %s", url) + self.log.debug("NoveList lookup: %s", url) # We want to make an HTTP request for `url` but cache the # result under `scrubbed_url`. Define a 'URL normalization' @@ -250,10 +259,13 @@ def normalized_url(original): return scrubbed_url representation, from_cache = Representation.post( - _db=self._db, url=str(url), data='', + _db=self._db, + url=str(url), + data="", max_age=self.MAX_REPRESENTATION_AGE, response_reviewer=self.review_response, - url_normalizer=normalized_url, **kwargs + url_normalizer=normalized_url, + **kwargs ) # Commit to the database immediately to reduce the chance @@ -282,9 +294,9 @@ def scrubbed_url(cls, params): def _scrub_subtitle(cls, subtitle): """Removes common NoveList subtitle annoyances""" if subtitle: - subtitle = subtitle.replace('[electronic resource]', '') + subtitle = subtitle.replace("[electronic resource]", "") # Then get rid of any leading whitespace or punctuation. - subtitle = TitleProcessor.extract_subtitle('', subtitle) + subtitle = TitleProcessor.extract_subtitle("", subtitle) return subtitle @classmethod @@ -306,9 +318,9 @@ def lookup_info_to_metadata(self, lookup_representation): return None lookup_info = json.loads(lookup_representation.content) - book_info = lookup_info['TitleInfo'] + book_info = lookup_info["TitleInfo"] if book_info: - novelist_identifier = book_info.get('ui') + novelist_identifier = book_info.get("ui") if not book_info or not novelist_identifier: # NoveList didn't know the ISBN. return None @@ -321,28 +333,31 @@ def lookup_info_to_metadata(self, lookup_representation): # Get the equivalent ISBN identifiers. metadata.identifiers += self._extract_isbns(book_info) - author = book_info.get('author') + author = book_info.get("author") if author: metadata.contributors.append(ContributorData(sort_name=author)) - description = book_info.get('description') + description = book_info.get("description") if description: - metadata.links.append(LinkData( - rel=Hyperlink.DESCRIPTION, content=description, - media_type=Representation.TEXT_PLAIN - )) + metadata.links.append( + LinkData( + rel=Hyperlink.DESCRIPTION, + content=description, + media_type=Representation.TEXT_PLAIN, + ) + ) - audience_level = book_info.get('audience_level') + audience_level = book_info.get("audience_level") if audience_level: - metadata.subjects.append(SubjectData( - Subject.FREEFORM_AUDIENCE, audience_level - )) + metadata.subjects.append( + SubjectData(Subject.FREEFORM_AUDIENCE, audience_level) + ) - novelist_rating = book_info.get('rating') + novelist_rating = book_info.get("rating") if novelist_rating: - metadata.measurements.append(MeasurementData( - Measurement.RATING, novelist_rating - )) + metadata.measurements.append( + MeasurementData(Measurement.RATING, novelist_rating) + ) # Extract feature content if it is available. series_info = None @@ -350,20 +365,20 @@ def lookup_info_to_metadata(self, lookup_representation): lexile_info = None goodreads_info = None recommendations_info = None - feature_content = lookup_info.get('FeatureContent') + feature_content = lookup_info.get("FeatureContent") if feature_content: - series_info = feature_content.get('SeriesInfo') - appeals_info = feature_content.get('Appeals') - lexile_info = feature_content.get('LexileInfo') - goodreads_info = feature_content.get('GoodReads') - recommendations_info = feature_content.get('SimilarTitles') + series_info = feature_content.get("SeriesInfo") + appeals_info = feature_content.get("Appeals") + lexile_info = feature_content.get("LexileInfo") + goodreads_info = feature_content.get("GoodReads") + recommendations_info = feature_content.get("SimilarTitles") metadata, title_key = self.get_series_information( metadata, series_info, book_info ) metadata.title = book_info.get(title_key) subtitle = TitleProcessor.extract_subtitle( - metadata.title, book_info.get('full_title') + metadata.title, book_info.get("full_title") ) metadata.subtitle = self._scrub_subtitle(subtitle) @@ -372,32 +387,37 @@ def lookup_info_to_metadata(self, lookup_representation): if appeals_info: extracted_genres = False for appeal in appeals_info: - genres = appeal.get('genres') + genres = appeal.get("genres") if genres: for genre in genres: - metadata.subjects.append(SubjectData( - Subject.TAG, genre['Name'] - )) + metadata.subjects.append( + SubjectData(Subject.TAG, genre["Name"]) + ) extracted_genres = True if extracted_genres: break if lexile_info: - metadata.subjects.append(SubjectData( - Subject.LEXILE_SCORE, lexile_info['Lexile'] - )) + metadata.subjects.append( + SubjectData(Subject.LEXILE_SCORE, lexile_info["Lexile"]) + ) if goodreads_info: - metadata.measurements.append(MeasurementData( - Measurement.RATING, goodreads_info['average_rating'] - )) + metadata.measurements.append( + MeasurementData(Measurement.RATING, goodreads_info["average_rating"]) + ) metadata = self.get_recommendations(metadata, recommendations_info) # If nothing interesting comes from the API, ignore it. - if not (metadata.measurements or metadata.series_position or - metadata.series or metadata.subjects or metadata.links or - metadata.subtitle or metadata.recommendations + if not ( + metadata.measurements + or metadata.series_position + or metadata.series + or metadata.subjects + or metadata.links + or metadata.subtitle + or metadata.recommendations ): metadata = None return metadata @@ -405,41 +425,47 @@ def lookup_info_to_metadata(self, lookup_representation): def get_series_information(self, metadata, series_info, book_info): """Returns metadata object with series info and optimal title key""" - title_key = 'main_title' + title_key = "main_title" if series_info: - metadata.series = series_info['full_title'] - series_titles = series_info.get('series_titles') + metadata.series = series_info["full_title"] + series_titles = series_info.get("series_titles") if series_titles: - matching_series_volume = [volume for volume in series_titles - if volume.get('full_title')==book_info.get('full_title')] + matching_series_volume = [ + volume + for volume in series_titles + if volume.get("full_title") == book_info.get("full_title") + ] if not matching_series_volume: # If there's no full_title match, try the main_title. - matching_series_volume = [volume for volume in series_titles - if volume.get('main_title')==book_info.get('main_title')] + matching_series_volume = [ + volume + for volume in series_titles + if volume.get("main_title") == book_info.get("main_title") + ] if len(matching_series_volume) > 1: # This probably won't happen, but if it does, it will be # difficult to debug without an error. raise ValueError("Multiple matching volumes found.") - series_position = matching_series_volume[0].get('volume') + series_position = matching_series_volume[0].get("volume") if series_position: - if series_position.endswith('.'): + if series_position.endswith("."): series_position = series_position[:-1] metadata.series_position = int(series_position) # Sometimes all of the volumes in a series have the same # main_title so using the full_title is preferred. main_titles = [volume.get(title_key) for volume in series_titles] - if len(main_titles) > 1 and len(set(main_titles))==1: - title_key = 'full_title' + if len(main_titles) > 1 and len(set(main_titles)) == 1: + title_key = "full_title" return metadata, title_key def _extract_isbns(self, book_info): isbns = [] - synonymous_ids = book_info.get('manifestations') + synonymous_ids = book_info.get("manifestations") for synonymous_id in synonymous_ids: - isbn = synonymous_id.get('ISBN') + isbn = synonymous_id.get("ISBN") if isbn: isbn_data = IdentifierData(Identifier.ISBN, isbn) isbns.append(isbn_data) @@ -450,8 +476,8 @@ def get_recommendations(self, metadata, recommendations_info): if not recommendations_info: return metadata - related_books = recommendations_info.get('titles') - related_books = [b for b in related_books if b.get('is_held_locally')] + related_books = recommendations_info.get("titles") + related_books = [b for b in related_books if b.get("is_held_locally")] if related_books: for book_info in related_books: metadata.recommendations += self._extract_isbns(book_info) @@ -477,29 +503,44 @@ def get_items_from_query(self, library): roles = list(Contributor.AUTHOR_ROLES) roles.append(Contributor.NARRATOR_ROLE) - isbnQuery = select( - [i1.identifier, i1.type, i2.identifier, - Edition.title, Edition.medium, Edition.published, - Contribution.role, Contributor.sort_name, - DataSource.name], - ).select_from( - join(LicensePool, i1, i1.id==LicensePool.identifier_id) - .join(Equivalency, i1.id==Equivalency.input_id, LEFT_OUTER_JOIN) - .join(i2, Equivalency.output_id==i2.id, LEFT_OUTER_JOIN) - .join( - Edition, - or_(Edition.primary_identifier_id==i1.id, Edition.primary_identifier_id==i2.id) + isbnQuery = ( + select( + [ + i1.identifier, + i1.type, + i2.identifier, + Edition.title, + Edition.medium, + Edition.published, + Contribution.role, + Contributor.sort_name, + DataSource.name, + ], ) - .join(Contribution, Edition.id==Contribution.edition_id) - .join(Contributor, Contribution.contributor_id==Contributor.id) - .join(DataSource, DataSource.id==LicensePool.data_source_id) - ).where( - and_( - LicensePool.collection_id.in_(collectionList), - or_(i1.type=="ISBN", i2.type=="ISBN"), - or_(Contribution.role.in_(roles)) + .select_from( + join(LicensePool, i1, i1.id == LicensePool.identifier_id) + .join(Equivalency, i1.id == Equivalency.input_id, LEFT_OUTER_JOIN) + .join(i2, Equivalency.output_id == i2.id, LEFT_OUTER_JOIN) + .join( + Edition, + or_( + Edition.primary_identifier_id == i1.id, + Edition.primary_identifier_id == i2.id, + ), + ) + .join(Contribution, Edition.id == Contribution.edition_id) + .join(Contributor, Contribution.contributor_id == Contributor.id) + .join(DataSource, DataSource.id == LicensePool.data_source_id) ) - ).order_by(i1.identifier, i2.identifier) + .where( + and_( + LicensePool.collection_id.in_(collectionList), + or_(i1.type == "ISBN", i2.type == "ISBN"), + or_(Contribution.role.in_(roles)), + ) + ) + .order_by(i1.identifier, i2.identifier) + ) result = self._db.execute(isbnQuery) @@ -515,18 +556,21 @@ def get_items_from_query(self, library): for item in result: if newItem: existingItem = newItem - (currentIdentifier, existingItem, newItem, addItem) = ( - self.create_item_object(item, currentIdentifier, existingItem) - ) + ( + currentIdentifier, + existingItem, + newItem, + addItem, + ) = self.create_item_object(item, currentIdentifier, existingItem) if addItem and existingItem: # The Role property isn't needed in the actual request. - del existingItem['role'] + del existingItem["role"] items.append(existingItem) # For the case when there's only one item in `result` if newItem: - del newItem['role'] + del newItem["role"] items.append(newItem) return items @@ -551,7 +595,7 @@ def create_item_object(self, object, currentIdentifier, existingItem): if not object: return (None, None, None, False) - if (object[1] == Identifier.ISBN): + if object[1] == Identifier.ISBN: isbn = object[0] elif object[2] is not None: isbn = object[2] @@ -573,13 +617,13 @@ def create_item_object(self, object, currentIdentifier, existingItem): # If we encounter an existing ISBN and its role is "Primary Author", # then that value overrides the existing Author property. if isbn == currentIdentifier and existingItem: - if not existingItem.get('author') and role in Contributor.AUTHOR_ROLES: - existingItem['author'] = author_or_narrator - if not existingItem.get('narrator') and role == Contributor.NARRATOR_ROLE: - existingItem['narrator'] = author_or_narrator + if not existingItem.get("author") and role in Contributor.AUTHOR_ROLES: + existingItem["author"] = author_or_narrator + if not existingItem.get("narrator") and role == Contributor.NARRATOR_ROLE: + existingItem["narrator"] = author_or_narrator if role == Contributor.PRIMARY_AUTHOR_ROLE: - existingItem['author'] = author_or_narrator - existingItem['role'] = role + existingItem["author"] = author_or_narrator + existingItem["role"] = role # Always return False to keep processing the currentIdentifier until # we get a new ISBN to process. In that case, return and add all @@ -595,7 +639,7 @@ def create_item_object(self, object, currentIdentifier, existingItem): title=title, mediaType=mediaType, role=role, - distributor=distributor + distributor=distributor, ) publicationDate = object[5] @@ -608,9 +652,9 @@ def create_item_object(self, object, currentIdentifier, existingItem): # the current new item for further data aggregation. addItem = True if existingItem else False if role in Contributor.AUTHOR_ROLES: - newItem['author'] = author_or_narrator + newItem["author"] = author_or_narrator if role == Contributor.NARRATOR_ROLE: - newItem['narrator'] = author_or_narrator + newItem["narrator"] = author_or_narrator return (isbn, existingItem, newItem, addItem) @@ -619,25 +663,22 @@ def put_items_novelist(self, library): content = None if items: - data=json.dumps(self.make_novelist_data_object(items)) + data = json.dumps(self.make_novelist_data_object(items)) response = self.put( self.COLLECTION_DATA_API, { "AuthorizedIdentifier": self.AUTHORIZED_IDENTIFIER, - "Content-Type": "application/json; charset=utf-8" + "Content-Type": "application/json; charset=utf-8", }, - data=data + data=data, ) - if (response.status_code == 200): + if response.status_code == 200: content = json.loads(response.content) - logging.info( - "Success from NoveList: %r", response.content - ) + logging.info("Success from NoveList: %r", response.content) else: logging.error("Data sent was: %r", data) logging.error( - "Error %s from NoveList: %r", response.status_code, - response.content + "Error %s from NoveList: %r", response.status_code, response.content ) return content @@ -649,20 +690,17 @@ def make_novelist_data_object(self, items): } def put(self, url, headers, **kwargs): - data = kwargs.get('data') - if 'data' in kwargs: - del kwargs['data'] + data = kwargs.get("data") + if "data" in kwargs: + del kwargs["data"] # This might take a very long time -- disable the normal # timeout. - kwargs['timeout'] = None - response = HTTP.put_with_timeout( - url, data, headers=headers, **kwargs - ) + kwargs["timeout"] = None + response = HTTP.put_with_timeout(url, data, headers=headers, **kwargs) return response class MockNoveListAPI(NoveListAPI): - def __init__(self, _db, *args, **kwargs): self._db = _db self.responses = [] diff --git a/api/nyt.py b/api/nyt.py index d6fd6b635a..9c534fca68 100644 --- a/api/nyt.py +++ b/api/nyt.py @@ -1,41 +1,32 @@ """Interface to the New York Times APIs.""" +import json +import logging +import os from collections import Counter from datetime import datetime, timedelta + import dateutil import isbnlib -import os -import json -import logging -from sqlalchemy.orm.session import Session -from sqlalchemy.orm.exc import ( - NoResultFound, -) from flask_babel import lazy_gettext as _ +from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm.session import Session -from .config import ( - CannotLoadConfiguration, - IntegrationException, -) - -from core.selftest import ( - HasSelfTests, -) -from core.opds_import import MetadataWranglerOPDSLookup -from core.metadata_layer import ( - Metadata, - IdentifierData, - ContributorData, -) +from core.external_list import TitleFromExternalList +from core.metadata_layer import ContributorData, IdentifierData, Metadata from core.model import ( - get_one_or_create, CustomList, DataSource, Edition, ExternalIntegration, Identifier, Representation, + get_one_or_create, ) -from core.external_list import TitleFromExternalList +from core.opds_import import MetadataWranglerOPDSLookup +from core.selftest import HasSelfTests + +from .config import CannotLoadConfiguration, IntegrationException + class NYTAPI(object): @@ -57,9 +48,7 @@ def parse_datetime(cls, d): We take midnight Eastern time to be the publication time. """ - return datetime.strptime(d, cls.DATE_FORMAT).replace( - tzinfo=cls.TIME_ZONE - ) + return datetime.strptime(d, cls.DATE_FORMAT).replace(tzinfo=cls.TIME_ZONE) @classmethod def parse_date(cls, d): @@ -83,7 +72,7 @@ class NYTBestSellerAPI(NYTAPI, HasSelfTests): CARDINALITY = 1 SETTINGS = [ - { "key": ExternalIntegration.PASSWORD, "label": _("API key"), "required": True }, + {"key": ExternalIntegration.PASSWORD, "label": _("API key"), "required": True}, ] # An NYT integration is shared by all libraries in a circulation manager. @@ -117,9 +106,7 @@ def __init__(self, _db, api_key=None, do_get=None, metadata_client=None): self.do_get = do_get or Representation.simple_http_get if not metadata_client: try: - metadata_client = MetadataWranglerOPDSLookup.from_config( - self._db - ) + metadata_client = MetadataWranglerOPDSLookup.from_config(self._db) except CannotLoadConfiguration as e: self.log.error( "Metadata wrangler integration is not configured, proceeding without one." @@ -129,14 +116,11 @@ def __init__(self, _db, api_key=None, do_get=None, metadata_client=None): @classmethod def external_integration(cls, _db): return ExternalIntegration.lookup( - _db, ExternalIntegration.NYT, - ExternalIntegration.METADATA_GOAL + _db, ExternalIntegration.NYT, ExternalIntegration.METADATA_GOAL ) def _run_self_tests(self, _db): - yield self.run_test( - "Getting list of best-seller lists", self.list_of_lists - ) + yield self.run_test("Getting list of best-seller lists", self.list_of_lists) @property def source(self): @@ -149,13 +133,18 @@ def request(self, path, identifier=None, max_age=LIST_MAX_AGE): url = self.BASE_URL + path else: url = path - joiner = '?' - if '?' in url: - joiner = '&' + joiner = "?" + if "?" in url: + joiner = "&" url += joiner + "api-key=" + self.api_key representation, cached = Representation.get( - self._db, url, do_get=self.do_get, max_age=max_age, debug=True, - pause_before=0.1) + self._db, + url, + do_get=self.do_get, + max_age=max_age, + debug=True, + pause_before=0.1, + ) status = representation.status_code if status == 200: # Everything's fine. @@ -163,13 +152,14 @@ def request(self, path, identifier=None, max_age=LIST_MAX_AGE): return content diagnostic = "Response from %s was: %r" % ( - url, representation.content.decode("utf-8") if representation.content else "" + url, + representation.content.decode("utf-8") if representation.content else "", ) if status == 403: raise IntegrationException( "API authentication failed", - "API key is most likely wrong. %s" % diagnostic + "API key is most likely wrong. %s" % diagnostic, ) else: raise IntegrationException( @@ -181,8 +171,9 @@ def list_of_lists(self, max_age=LIST_OF_LISTS_MAX_AGE): def list_info(self, list_name): list_of_lists = self.list_of_lists() - list_info = [x for x in list_of_lists['results'] - if x['list_name_encoded'] == list_name] + list_info = [ + x for x in list_of_lists["results"] if x["list_name_encoded"] == list_name + ] if not list_info: raise ValueError("No such list: %s" % list_name) return list_info[0] @@ -211,15 +202,14 @@ def fill_in_history(self, list): class NYTBestSellerList(list): - def __init__(self, list_info, metadata_client): - self.name = list_info['display_name'] - self.created = NYTAPI.parse_datetime(list_info['oldest_published_date']) - self.updated = NYTAPI.parse_datetime(list_info['newest_published_date']) - self.foreign_identifier = list_info['list_name_encoded'] - if list_info['updated'] == 'WEEKLY': + self.name = list_info["display_name"] + self.created = NYTAPI.parse_datetime(list_info["oldest_published_date"]) + self.updated = NYTAPI.parse_datetime(list_info["newest_published_date"]) + self.foreign_identifier = list_info["list_name_encoded"] + if list_info["updated"] == "WEEKLY": frequency = 7 - elif list_info['updated'] == 'MONTHLY': + elif list_info["updated"] == "MONTHLY": frequency = 30 self.frequency = timedelta(frequency) self.items_by_isbn = dict() @@ -259,11 +249,10 @@ def all_dates(self): def update(self, json_data): """Update the list with information from the given JSON structure.""" - for li_data in json_data.get('results', []): + for li_data in json_data.get("results", []): try: - book = li_data['book_details'][0] - key = ( - book.get('primary_isbn13') or book.get('primary_isbn10')) + book = li_data["book_details"][0] + key = book.get("primary_isbn13") or book.get("primary_isbn10") if key in self.items_by_isbn: item = self.items_by_isbn[key] self.log.debug("Previously seen ISBN: %r", key) @@ -281,11 +270,13 @@ def update(self, json_data): # This is the date the *best-seller list* was published, # not the date the book was published. - list_date = NYTAPI.parse_datetime(li_data['published_date']) + list_date = NYTAPI.parse_datetime(li_data["published_date"]) if not item.first_appearance or list_date < item.first_appearance: item.first_appearance = list_date - if (not item.most_recent_appearance - or list_date > item.most_recent_appearance): + if ( + not item.most_recent_appearance + or list_date > item.most_recent_appearance + ): item.most_recent_appearance = list_date def to_customlist(self, _db): @@ -296,9 +287,9 @@ def to_customlist(self, _db): CustomList, data_source=data_source, foreign_identifier=self.foreign_identifier, - create_method_kwargs = dict( + create_method_kwargs=dict( created=self.created, - ) + ), ) l.name = self.name l.updated = self.updated @@ -314,19 +305,17 @@ def update_custom_list(self, custom_list): # Add new items to the list. for i in self: list_item, was_new = i.to_custom_list_entry( - custom_list, self.metadata_client) + custom_list, self.metadata_client + ) # If possible, associate the item with a Work. list_item.set_work() class NYTBestSellerListTitle(TitleFromExternalList): - def __init__(self, data, medium): data = data try: - bestsellers_date = NYTAPI.parse_datetime( - data.get('bestsellers_date') - ) + bestsellers_date = NYTAPI.parse_datetime(data.get("bestsellers_date")) first_appearance = bestsellers_date most_recent_appearance = bestsellers_date except ValueError as e: @@ -336,35 +325,32 @@ def __init__(self, data, medium): try: # This is the date the _book_ was published, not the date # the _bestseller list_ was published. - published_date = NYTAPI.parse_date(data.get('published_date')) + published_date = NYTAPI.parse_date(data.get("published_date")) except ValueError as e: published_date = None - details = data['book_details'] + details = data["book_details"] other_isbns = [] if len(details) == 0: publisher = annotation = primary_isbn10 = primary_isbn13 = title = None display_author = None else: d = details[0] - title = d.get('title', None) - display_author = d.get('author', None) - publisher = d.get('publisher', None) - annotation = d.get('description', None) - primary_isbn10 = d.get('primary_isbn10', None) - primary_isbn13 = d.get('primary_isbn13', None) + title = d.get("title", None) + display_author = d.get("author", None) + publisher = d.get("publisher", None) + annotation = d.get("description", None) + primary_isbn10 = d.get("primary_isbn10", None) + primary_isbn13 = d.get("primary_isbn13", None) # The list of other ISBNs frequently contains ISBNs for # other books in the same series, as well as ISBNs that # are just wrong. Assign these equivalencies at a low # level of confidence. - for isbn in d.get('isbns', []): - isbn13 = isbn.get('isbn13', None) + for isbn in d.get("isbns", []): + isbn13 = isbn.get("isbn13", None) if isbn13: - other_isbns.append( - IdentifierData(Identifier.ISBN, isbn13, 0.50) - ) - + other_isbns.append(IdentifierData(Identifier.ISBN, isbn13, 0.50)) primary_isbn = primary_isbn13 or primary_isbn10 if primary_isbn: @@ -372,15 +358,13 @@ def __init__(self, data, medium): contributors = [] if display_author: - contributors.append( - ContributorData(display_name=display_author) - ) + contributors.append(ContributorData(display_name=display_author)) metadata = Metadata( data_source=DataSource.NYT, title=title, medium=medium, - language='eng', + language="eng", published=published_date, publisher=publisher, contributors=contributors, @@ -389,6 +373,5 @@ def __init__(self, data, medium): ) super(NYTBestSellerListTitle, self).__init__( - metadata, first_appearance, most_recent_appearance, - annotation + metadata, first_appearance, most_recent_appearance, annotation ) diff --git a/api/odilo.py b/api/odilo.py index 8347f8e606..a547efa0a0 100644 --- a/api/odilo.py +++ b/api/odilo.py @@ -2,40 +2,26 @@ import base64 import datetime import json -import isbnlib import logging -from sqlalchemy.orm.session import Session +import isbnlib from flask_babel import lazy_gettext as _ +from sqlalchemy.orm.session import Session -from .circulation import ( - LoanInfo, - HoldInfo, - FulfillmentInfo, - BaseCirculationAPI, -) - -from core.model import ( - Credential, - DataSource, - ExternalIntegration, - Identifier -) - -from .selftest import ( - HasSelfTests, - SelfTestResult, -) -from core.monitor import ( - CollectionMonitor, - TimelineMonitor, +from core.analytics import Analytics +from core.config import CannotLoadConfiguration +from core.coverage import BibliographicCoverageProvider +from core.metadata_layer import ( + CirculationData, + ContributorData, + FormatData, + IdentifierData, + LinkData, + Metadata, + ReplacementPolicy, + SubjectData, ) -from core.util.http import HTTP - -from .circulation_exceptions import * - from core.model import ( - get_one_or_create, Classification, Collection, Contributor, @@ -48,79 +34,49 @@ Identifier, Representation, Subject, + get_one_or_create, ) - -from core.analytics import Analytics - -from core.metadata_layer import ( - CirculationData, - ContributorData, - FormatData, - IdentifierData, - Metadata, - LinkData, - ReplacementPolicy, - SubjectData, -) - -from core.coverage import ( - BibliographicCoverageProvider, -) - -from core.config import ( - CannotLoadConfiguration, -) - -from core.testing import DatabaseTest - -from core.util.datetime_helpers import ( - from_timestamp, - strptime_utc, - utc_now, -) -from core.util.http import ( - HTTP, - BadResponseException, -) - +from core.monitor import CollectionMonitor, TimelineMonitor +from core.testing import DatabaseTest, MockRequestsResponse +from core.util.datetime_helpers import from_timestamp, strptime_utc, utc_now +from core.util.http import HTTP, BadResponseException from core.util.personal_names import sort_name_to_display_name -from core.testing import MockRequestsResponse +from .circulation import BaseCirculationAPI, FulfillmentInfo, HoldInfo, LoanInfo +from .circulation_exceptions import * +from .selftest import HasSelfTests, SelfTestResult + class OdiloRepresentationExtractor(object): """Extract useful information from Odilo's JSON representations.""" log = logging.getLogger("OdiloRepresentationExtractor") - ACSM = 'ACSM' - ACSM_EPUB = 'ACSM_EPUB' - ACSM_PDF = 'ACSM_PDF' - EBOOK_STREAMING = 'EBOOK_STREAMING' + ACSM = "ACSM" + ACSM_EPUB = "ACSM_EPUB" + ACSM_PDF = "ACSM_PDF" + EBOOK_STREAMING = "EBOOK_STREAMING" format_data_for_odilo_format = { - ACSM_PDF: ( - Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM - ), - ACSM_EPUB: ( - Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM - ), + ACSM_PDF: (Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), + ACSM_EPUB: (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), EBOOK_STREAMING: ( - Representation.TEXT_HTML_MEDIA_TYPE, DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE + Representation.TEXT_HTML_MEDIA_TYPE, + DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, ), "MP3": ( - Representation.MP3_MEDIA_TYPE, DeliveryMechanism.STREAMING_AUDIO_CONTENT_TYPE + Representation.MP3_MEDIA_TYPE, + DeliveryMechanism.STREAMING_AUDIO_CONTENT_TYPE, ), "MP4": ( - Representation.MP4_MEDIA_TYPE, DeliveryMechanism.STREAMING_VIDEO_CONTENT_TYPE + Representation.MP4_MEDIA_TYPE, + DeliveryMechanism.STREAMING_VIDEO_CONTENT_TYPE, ), "WMV": ( - Representation.WMV_MEDIA_TYPE, DeliveryMechanism.STREAMING_VIDEO_CONTENT_TYPE - ), - "JPG": ( - Representation.JPEG_MEDIA_TYPE, DeliveryMechanism.NO_DRM + Representation.WMV_MEDIA_TYPE, + DeliveryMechanism.STREAMING_VIDEO_CONTENT_TYPE, ), - "SCORM": ( - Representation.SCORM_MEDIA_TYPE, DeliveryMechanism.NO_DRM - ) + "JPG": (Representation.JPEG_MEDIA_TYPE, DeliveryMechanism.NO_DRM), + "SCORM": (Representation.SCORM_MEDIA_TYPE, DeliveryMechanism.NO_DRM), } odilo_medium_to_simplified_medium = { @@ -131,34 +87,36 @@ class OdiloRepresentationExtractor(object): "MP4": Edition.VIDEO_MEDIUM, "WMV": Edition.VIDEO_MEDIUM, "JPG": Edition.IMAGE_MEDIUM, - "SCORM": Edition.COURSEWARE_MEDIUM + "SCORM": Edition.COURSEWARE_MEDIUM, } @classmethod def record_info_to_circulation(cls, availability): - """ Note: The json data passed into this method is from a different file/stream + """Note: The json data passed into this method is from a different file/stream from the json data that goes into the record_info_to_metadata() method. """ - if 'recordId' not in availability: + if "recordId" not in availability: return None - record_id = availability['recordId'] - primary_identifier = IdentifierData(Identifier.ODILO_ID, record_id) # We own this availability. + record_id = availability["recordId"] + primary_identifier = IdentifierData( + Identifier.ODILO_ID, record_id + ) # We own this availability. - licenses_owned = int(availability['totalCopies']) - licenses_available = int(availability['availableCopies']) + licenses_owned = int(availability["totalCopies"]) + licenses_available = int(availability["availableCopies"]) # 'licenses_reserved' is the number of patrons who put the book on hold earlier, # but who are now at the front of the queue and who could get the book right now if they wanted to. - if 'notifiedHolds' in availability: - licenses_reserved = int(availability['notifiedHolds']) + if "notifiedHolds" in availability: + licenses_reserved = int(availability["notifiedHolds"]) else: licenses_reserved = 0 # 'patrons_in_hold_queue' contains the number of patrons who are currently waiting for a copy of the book. - if 'holdsQueueSize' in availability: - patrons_in_hold_queue = int(availability['holdsQueueSize']) + if "holdsQueueSize" in availability: + patrons_in_hold_queue = int(availability["holdsQueueSize"]) else: patrons_in_hold_queue = 0 @@ -186,82 +144,106 @@ def record_info_to_metadata(cls, book, availability): Note: The json data passed into this method is from a different file/stream from the json data that goes into the book_info_to_circulation() method. """ - if 'id' not in book: + if "id" not in book: return None - odilo_id = book['id'] + odilo_id = book["id"] primary_identifier = IdentifierData(Identifier.ODILO_ID, odilo_id) - active = book.get('active') + active = book.get("active") - title = book.get('title') - subtitle = book.get('subtitle') - series = book.get('series').strip() or None - series_position = book.get('seriesPosition').strip() or None + title = book.get("title") + subtitle = book.get("subtitle") + series = book.get("series").strip() or None + series_position = book.get("seriesPosition").strip() or None contributors = [] - sort_author = book.get('author') + sort_author = book.get("author") if sort_author: roles = [Contributor.AUTHOR_ROLE] display_author = sort_name_to_display_name(sort_author) contributor = ContributorData( - sort_name=sort_author, display_name=display_author, - roles=roles, biography=None + sort_name=sort_author, + display_name=display_author, + roles=roles, + biography=None, ) contributors.append(contributor) - publisher = book.get('publisher') + publisher = book.get("publisher") # Metadata --> Marc21 260$c - published = book.get('publicationDate') + published = book.get("publicationDate") if not published: # yyyyMMdd --> record creation date - published = book.get('releaseDate') + published = book.get("releaseDate") if published: try: published = strptime_utc(published, "%Y%m%d") except ValueError as e: - cls.log.warn('Cannot parse publication date from: ' + published + ', message: ' + str(e)) + cls.log.warn( + "Cannot parse publication date from: " + + published + + ", message: " + + str(e) + ) # yyyyMMdd --> record last modification date - last_update = book.get('modificationDate') + last_update = book.get("modificationDate") if last_update: try: last_update = strptime_utc(last_update, "%Y%m%d") except ValueError as e: - cls.log.warn('Cannot parse last update date from: ' + last_update + ', message: ' + str(e)) + cls.log.warn( + "Cannot parse last update date from: " + + last_update + + ", message: " + + str(e) + ) - language = book.get('language', 'spa') + language = book.get("language", "spa") subjects = [] trusted_weight = Classification.TRUSTED_DISTRIBUTOR_WEIGHT - for subject in book.get('subjects', []): - subjects.append(SubjectData(type=Subject.TAG, identifier=subject, weight=trusted_weight)) + for subject in book.get("subjects", []): + subjects.append( + SubjectData(type=Subject.TAG, identifier=subject, weight=trusted_weight) + ) - for subjectBisacCode in book.get('subjectsBisacCodes', []): - subjects.append(SubjectData(type=Subject.BISAC, identifier=subjectBisacCode, weight=trusted_weight)) + for subjectBisacCode in book.get("subjectsBisacCodes", []): + subjects.append( + SubjectData( + type=Subject.BISAC, + identifier=subjectBisacCode, + weight=trusted_weight, + ) + ) - grade_level = book.get('gradeLevel') + grade_level = book.get("gradeLevel") if grade_level: - subject = SubjectData(type=Subject.GRADE_LEVEL, identifier=grade_level, weight=trusted_weight) + subject = SubjectData( + type=Subject.GRADE_LEVEL, identifier=grade_level, weight=trusted_weight + ) subjects.append(subject) medium = None - file_format = book.get('fileFormat') + file_format = book.get("fileFormat") formats = [] - for format_received in book.get('formats', []): + for format_received in book.get("formats", []): if format_received in cls.format_data_for_odilo_format: medium = cls.set_format(format_received, formats) elif format_received == cls.ACSM and file_format: - medium = cls.set_format(format_received + '_' + file_format.upper(), formats) + medium = cls.set_format( + format_received + "_" + file_format.upper(), formats + ) else: - cls.log.warn('Unrecognized format received: ' + format_received) + cls.log.warn("Unrecognized format received: " + format_received) if not medium: medium = Edition.BOOK_MEDIUM identifiers = [] - isbn = book.get('isbn') + isbn = book.get("isbn") if isbn: if isbnlib.is_isbn10(isbn): isbn = isbnlib.to_isbn13(isbn) @@ -269,22 +251,30 @@ def record_info_to_metadata(cls, book, availability): # A cover links = [] - cover_image_url = book.get('coverImageUrl') + cover_image_url = book.get("coverImageUrl") if cover_image_url: - image_data = cls.image_link_to_linkdata(cover_image_url, Hyperlink.THUMBNAIL_IMAGE) + image_data = cls.image_link_to_linkdata( + cover_image_url, Hyperlink.THUMBNAIL_IMAGE + ) if image_data: links.append(image_data) - original_image_url = book.get('originalImageUrl') + original_image_url = book.get("originalImageUrl") if original_image_url: image_data = cls.image_link_to_linkdata(original_image_url, Hyperlink.IMAGE) if image_data: links.append(image_data) # Descriptions become links. - description = book.get('description') + description = book.get("description") if description: - links.append(LinkData(rel=Hyperlink.DESCRIPTION, content=description, media_type="text/html")) + links.append( + LinkData( + rel=Hyperlink.DESCRIPTION, + content=description, + media_type="text/html", + ) + ) metadata = Metadata( data_source=DataSource.ODILO, @@ -301,10 +291,12 @@ def record_info_to_metadata(cls, book, availability): subjects=subjects, contributors=contributors, links=links, - data_source_last_updated=last_update + data_source_last_updated=last_update, ) - metadata.circulation = OdiloRepresentationExtractor.record_info_to_circulation(availability) + metadata.circulation = OdiloRepresentationExtractor.record_info_to_circulation( + availability + ) # 'active' --> means that the book exists but it's no longer in the collection # (it could be available again in the future) if metadata.circulation: @@ -328,17 +320,26 @@ class OdiloAPI(BaseCirculationAPI, HasSelfTests): NAME = ExternalIntegration.ODILO DESCRIPTION = _("Integrate an Odilo library collection.") SETTINGS = [ - { - "key": LIBRARY_API_BASE_URL, - "label": _("Library API base URL"), - "description": _("This might look like https://[library].odilo.us/api/v2."), - "required": True, - "format": "url", - }, - { "key": ExternalIntegration.USERNAME, "label": _("Client Key"), "required": True }, - { "key": ExternalIntegration.PASSWORD, "label": _("Client Secret"), "required": True }, - ] + BaseCirculationAPI.SETTINGS - + { + "key": LIBRARY_API_BASE_URL, + "label": _("Library API base URL"), + "description": _( + "This might look like https://[library].odilo.us/api/v2." + ), + "required": True, + "format": "url", + }, + { + "key": ExternalIntegration.USERNAME, + "label": _("Client Key"), + "required": True, + }, + { + "key": ExternalIntegration.PASSWORD, + "label": _("Client Secret"), + "required": True, + }, + ] + BaseCirculationAPI.SETTINGS # --- OAuth --- TOKEN_ENDPOINT = "/token" @@ -366,8 +367,11 @@ class OdiloAPI(BaseCirculationAPI, HasSelfTests): # maps a 2-tuple (media_type, drm_mechanism) to the internal string used in Odilo API to describe that setup. delivery_mechanism_to_internal_format = { - v: k for k, v in list(OdiloRepresentationExtractor.format_data_for_odilo_format.items()) - } + v: k + for k, v in list( + OdiloRepresentationExtractor.format_data_for_odilo_format.items() + ) + } error_to_exception = { "TitleNotCheckedOut": NoActiveLoan, @@ -378,13 +382,14 @@ class OdiloAPI(BaseCirculationAPI, HasSelfTests): } def __init__(self, _db, collection): - self.odilo_bibliographic_coverage_provider = ( - OdiloBibliographicCoverageProvider( - collection, api_class=self - ) + self.odilo_bibliographic_coverage_provider = OdiloBibliographicCoverageProvider( + collection, api_class=self ) if collection.protocol != ExternalIntegration.ODILO: - raise ValueError("Collection protocol is %s, but passed into OdiloAPI!" % collection.protocol) + raise ValueError( + "Collection protocol is %s, but passed into OdiloAPI!" + % collection.protocol + ) self._db = _db self.analytics = Analytics(self._db) @@ -393,22 +398,31 @@ def __init__(self, _db, collection): self.token = None self.client_key = collection.external_integration.username self.client_secret = collection.external_integration.password - self.library_api_base_url = collection.external_integration.setting(self.LIBRARY_API_BASE_URL).value - - if not self.client_key or not self.client_secret or not self.library_api_base_url: + self.library_api_base_url = collection.external_integration.setting( + self.LIBRARY_API_BASE_URL + ).value + + if ( + not self.client_key + or not self.client_secret + or not self.library_api_base_url + ): raise CannotLoadConfiguration("Odilo configuration is incomplete.") # Use utf8 instead of unicode encoding settings = [self.client_key, self.client_secret, self.library_api_base_url] self.client_key, self.client_secret, self.library_api_base_url = ( - setting.encode('utf8') for setting in settings + setting.encode("utf8") for setting in settings ) # Get set up with up-to-date credentials from the API. self.check_creds() if not self.token: - raise CannotLoadConfiguration("Invalid credentials for %s, cannot intialize API %s" - % (self.client_key, self.library_api_base_url)) + raise CannotLoadConfiguration( + "Invalid credentials for %s, cannot intialize API %s" + % (self.client_key, self.library_api_base_url) + ) + @property def collection(self): return Collection.by_id(self._db, id=self.collection_id) @@ -422,8 +436,7 @@ def external_integration(self, _db): def _run_self_tests(self, _db): result = self.run_test( - "Obtaining a sitewide access token", self.check_creds, - force_refresh=True + "Obtaining a sitewide access token", self.check_creds, force_refresh=True ) yield result if not result.success: @@ -437,10 +450,11 @@ def _run_self_tests(self, _db): continue library, patron, pin = result - task = "Viewing the active loans for the test patron for library %s" % library.name - yield self.run_test( - task, self.get_patron_checkouts, patron, pin + task = ( + "Viewing the active loans for the test patron for library %s" + % library.name ) + yield self.run_test(task, self.get_patron_checkouts, patron, pin) def check_creds(self, force_refresh=False): """If the Bearer Token has expired, update it.""" @@ -466,18 +480,17 @@ def refresh_creds(self, credential): response = self.token_post( self.TOKEN_ENDPOINT, dict(grant_type="client_credentials"), - allowed_response_codes=[200, 400] + allowed_response_codes=[200, 400], ) # If you put in the wrong URL, this is where you'll run into # problems, so it's useful to give a helpful error message if # Odilo doesn't provide anything more specific. - generic_error = "%s may not be the right base URL. Response document was: %r" % ( - self.library_api_base_url, response.content.decode("utf-8") - ) - generic_exception = BadResponseException( - self.TOKEN_ENDPOINT, generic_error + generic_error = ( + "%s may not be the right base URL. Response document was: %r" + % (self.library_api_base_url, response.content.decode("utf-8")) ) + generic_exception = BadResponseException(self.TOKEN_ENDPOINT, generic_error) try: data = response.json() @@ -488,36 +501,44 @@ def refresh_creds(self, credential): self.token = credential.credential return elif response.status_code == 400: - if data and 'errors' in data and len(data['errors']) > 0: - error = data['errors'][0] - if 'description' in error: - message = error['description'] + if data and "errors" in data and len(data["errors"]) > 0: + error = data["errors"][0] + if "description" in error: + message = error["description"] else: message = generic_error raise BadResponseException(self.TOKEN_ENDPOINT, message) raise generic_exception - def patron_request(self, patron, pin, url, extra_headers={}, data=None, exception_on_401=False, method=None): + def patron_request( + self, + patron, + pin, + url, + extra_headers={}, + data=None, + exception_on_401=False, + method=None, + ): """Make an HTTP request on behalf of a patron. The results are never cached. """ headers = dict(Authorization="Bearer %s" % self.token) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" headers.update(extra_headers) - if method and method.lower() in ('get', 'post', 'put', 'delete'): + if method and method.lower() in ("get", "post", "put", "delete"): method = method.lower() else: if data: - method = 'post' + method = "post" else: - method = 'get' + method = "get" url = self._make_absolute_url(url) response = HTTP.request_with_timeout( - method, url, headers=headers, data=data, - timeout=60 + method, url, headers=headers, data=data, timeout=60 ) # TODO: If Odilo doesn't recognize the patron it will send @@ -537,8 +558,7 @@ def _make_absolute_url(self, url): """Prepend the API base URL onto `url` unless it is already an absolute HTTP URL. """ - if not any(url.startswith(protocol) - for protocol in ('http://', 'https://')): + if not any(url.startswith(protocol) for protocol in ("http://", "https://")): url = self.library_api_base_url.decode("utf-8") + url return url @@ -548,14 +568,16 @@ def get(self, url, extra_headers={}, exception_on_401=False): extra_headers = {} headers = dict(Authorization="Bearer %s" % self.token) headers.update(extra_headers) - status_code, headers, content = self._do_get(self.library_api_base_url.decode("utf-8") + url, headers) + status_code, headers, content = self._do_get( + self.library_api_base_url.decode("utf-8") + url, headers + ) if status_code == 401: if exception_on_401: # This is our second try. Give up. raise BadResponseException.from_response( url, "Something's wrong with the Odilo OAuth Bearer Token!", - (status_code, headers, content) + (status_code, headers, content), ) else: # Refresh the token and try again. @@ -569,10 +591,11 @@ def token_post(self, url, payload, headers={}, **kwargs): s = "%s:%s" % (self.client_key, self.client_secret) auth = base64.standard_b64encode(s).strip() headers = dict(headers) - headers['Authorization'] = "Basic %s" % auth - headers['Content-Type'] = "application/x-www-form-urlencoded" - return self._do_post(self.library_api_base_url + url, payload, headers, **kwargs) - + headers["Authorization"] = "Basic %s" % auth + headers["Content-Type"] = "application/x-www-form-urlencoded" + return self._do_post( + self.library_api_base_url + url, payload, headers, **kwargs + ) def checkout(self, patron, pin, licensepool, internal_format): """Check out a book on behalf of a patron. @@ -599,46 +622,62 @@ def checkout(self, patron, pin, licensepool, internal_format): ) response = self.patron_request( - patron, pin, self.CHECKOUT_ENDPOINT.format(recordId=record_id), - extra_headers={'Content-Type': 'application/x-www-form-urlencoded'}, - data=payload) + patron, + pin, + self.CHECKOUT_ENDPOINT.format(recordId=record_id), + extra_headers={"Content-Type": "application/x-www-form-urlencoded"}, + data=payload, + ) if response.content: response_json = response.json() if response.status_code == 404: - self.raise_exception_on_error(response_json, default_exception_class=CannotLoan) + self.raise_exception_on_error( + response_json, default_exception_class=CannotLoan + ) else: - return self.loan_info_from_odilo_checkout(licensepool.collection, response_json) + return self.loan_info_from_odilo_checkout( + licensepool.collection, response_json + ) # TODO: we need to improve this at the API and use an error code elif response.status_code == 400: - raise NoAcceptableFormat('record_id: %s, format: %s' % (record_id, internal_format)) + raise NoAcceptableFormat( + "record_id: %s, format: %s" % (record_id, internal_format) + ) - raise CannotLoan('patron: %s, record_id: %s, format: %s' % (patron, record_id, internal_format)) + raise CannotLoan( + "patron: %s, record_id: %s, format: %s" + % (patron, record_id, internal_format) + ) def loan_info_from_odilo_checkout(self, collection, checkout): - start_date = self.extract_date(checkout, 'startTime') - end_date = self.extract_date(checkout, 'endTime') + start_date = self.extract_date(checkout, "startTime") + end_date = self.extract_date(checkout, "endTime") return LoanInfo( collection, DataSource.ODILO, Identifier.ODILO_ID, - checkout['id'], + checkout["id"], start_date, end_date, - checkout['downloadUrl'] + checkout["downloadUrl"], ) def checkin(self, patron, pin, licensepool): record_id = licensepool.identifier.identifier loan = self.get_checkout(patron, pin, record_id) - url = self.CHECKIN_ENDPOINT.format(checkoutId=loan['id'], patronId=patron.authorization_identifier) + url = self.CHECKIN_ENDPOINT.format( + checkoutId=loan["id"], patronId=patron.authorization_identifier + ) - response = self.patron_request(patron, pin, url, method='POST') + response = self.patron_request(patron, pin, url, method="POST") if response.status_code == 200: return response - self.raise_exception_on_error(response.json(), default_exception_class=CannotReturn) + self.raise_exception_on_error( + response.json(), default_exception_class=CannotReturn + ) @classmethod def extract_date(cls, data, field_name): @@ -650,13 +689,15 @@ def extract_date(cls, data, field_name): return d @classmethod - def raise_exception_on_error(cls, data, default_exception_class=None, ignore_exception_codes=None): - if not data or 'errors' not in data or len(data['errors']) <= 0: - return '', '' + def raise_exception_on_error( + cls, data, default_exception_class=None, ignore_exception_codes=None + ): + if not data or "errors" not in data or len(data["errors"]) <= 0: + return "", "" - error = data['errors'][0] - error_code = error['id'] - message = ('description' in error and error['description']) or '' + error = data["errors"][0] + error_code = error["id"] + message = ("description" in error and error["description"]) or "" if not ignore_exception_codes or error_code not in ignore_exception_codes: if error_code in cls.error_to_exception: @@ -667,18 +708,25 @@ def raise_exception_on_error(cls, data, default_exception_class=None, ignore_exc def get_checkout(self, patron, pin, record_id): patron_checkouts = self.get_patron_checkouts(patron, pin) for checkout in patron_checkouts: - if checkout['recordId'] == record_id: + if checkout["recordId"] == record_id: return checkout - raise NotFoundOnRemote("Could not find active loan for patron %s, record %s" % (patron, record_id)) + raise NotFoundOnRemote( + "Could not find active loan for patron %s, record %s" % (patron, record_id) + ) def get_hold(self, patron, pin, record_id): patron_holds = self.get_patron_holds(patron, pin) for hold in patron_holds: - if hold['recordId'] == record_id and hold['status'] in ('informed', 'waiting'): + if hold["recordId"] == record_id and hold["status"] in ( + "informed", + "waiting", + ): return hold - raise NotFoundOnRemote("Could not find active hold for patron %s, record %s" % (patron, record_id)) + raise NotFoundOnRemote( + "Could not find active hold for patron %s, record %s" % (patron, record_id) + ) def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): """Get the actual resource file to the patron. @@ -689,10 +737,15 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): :return: a FulfillmentInfo object. """ record_id = licensepool.identifier.identifier - content_link, content, content_type = self.get_fulfillment_link(patron, pin, record_id, internal_format) + content_link, content, content_type = self.get_fulfillment_link( + patron, pin, record_id, internal_format + ) if not content_link and not content: - self.log.info("Odilo record_id %s was not available as %s" % (record_id, internal_format)) + self.log.info( + "Odilo record_id %s was not available as %s" + % (record_id, internal_format) + ) else: return FulfillmentInfo( licensepool.collection, @@ -702,26 +755,43 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): content_link=content_link, content=content, content_type=content_type, - content_expires=None + content_expires=None, ) def get_fulfillment_link(self, patron, pin, record_id, format_type): - """Get the link corresponding to an existing checkout. - """ + """Get the link corresponding to an existing checkout.""" # Retrieve checkout with its download_ulr. It is necessary to generate a download token in our API checkout = self.get_checkout(patron, pin, record_id) - loan_format = checkout['format'] - if format_type and loan_format and ( - format_type == loan_format or - (loan_format == OdiloRepresentationExtractor.ACSM and format_type in (OdiloRepresentationExtractor.ACSM_EPUB, OdiloRepresentationExtractor.ACSM_PDF)) + loan_format = checkout["format"] + if ( + format_type + and loan_format + and ( + format_type == loan_format + or ( + loan_format == OdiloRepresentationExtractor.ACSM + and format_type + in ( + OdiloRepresentationExtractor.ACSM_EPUB, + OdiloRepresentationExtractor.ACSM_PDF, + ) + ) + ) ): - if 'downloadUrl' in checkout and checkout['downloadUrl']: - content_link = checkout['downloadUrl'] + if "downloadUrl" in checkout and checkout["downloadUrl"]: + content_link = checkout["downloadUrl"] content = None - content_type = OdiloRepresentationExtractor.format_data_for_odilo_format[format_type] + content_type = ( + OdiloRepresentationExtractor.format_data_for_odilo_format[ + format_type + ] + ) # Get also .acsm file - if format_type in (OdiloRepresentationExtractor.ACSM_EPUB, OdiloRepresentationExtractor.ACSM_PDF): + if format_type in ( + OdiloRepresentationExtractor.ACSM_EPUB, + OdiloRepresentationExtractor.ACSM_PDF, + ): response = self.patron_request(patron, pin, content_link) if response.status_code == 200: content = response.content @@ -730,16 +800,30 @@ def get_fulfillment_link(self, patron, pin, record_id, format_type): return content_link, content, content_type - raise CannotFulfill("Cannot obtain a download link for patron[%r], record_id[%s], format_type[%s].", patron, - record_id, format_type) + raise CannotFulfill( + "Cannot obtain a download link for patron[%r], record_id[%s], format_type[%s].", + patron, + record_id, + format_type, + ) def get_patron_checkouts(self, patron, pin): - data = self.patron_request(patron, pin, self.PATRON_CHECKOUTS_ENDPOINT.format(patronId=patron.authorization_identifier)).json() + data = self.patron_request( + patron, + pin, + self.PATRON_CHECKOUTS_ENDPOINT.format( + patronId=patron.authorization_identifier + ), + ).json() self.raise_exception_on_error(data) return data def get_patron_holds(self, patron, pin): - data = self.patron_request(patron, pin, self.PATRON_HOLDS_ENDPOINT.format(patronId=patron.authorization_identifier)).json() + data = self.patron_request( + patron, + pin, + self.PATRON_HOLDS_ENDPOINT.format(patronId=patron.authorization_identifier), + ).json() self.raise_exception_on_error(data) return data @@ -763,26 +847,26 @@ def patron_activity(self, patron, pin): return loans_info + holds_info def hold_from_odilo_hold(self, collection, hold): - start = self.extract_date(hold, 'startTime') + start = self.extract_date(hold, "startTime") # end_date: The estimated date the title will be available for the patron to borrow. - end = self.extract_date(hold, 'notifiedTime') - position = hold.get('holdQueuePosition') + end = self.extract_date(hold, "notifiedTime") + position = hold.get("holdQueuePosition") if position is not None: position = int(position) # Patron already notified to borrow the title - if 'informed' == hold['status']: + if "informed" == hold["status"]: position = 0 return HoldInfo( collection, DataSource.ODILO, Identifier.ODILO_ID, - hold['id'], + hold["id"], start_date=start, end_date=end, - hold_position=position + hold_position=position, ) def place_hold(self, patron, pin, licensepool, notification_email_address): @@ -797,9 +881,12 @@ def place_hold(self, patron, pin, licensepool, notification_email_address): payload = dict(patronId=patron.authorization_identifier) response = self.patron_request( - patron, pin, self.PLACE_HOLD_ENDPOINT.format(recordId=record_id), - extra_headers={'Content-Type': 'application/x-www-form-urlencoded'}, - data=payload) + patron, + pin, + self.PLACE_HOLD_ENDPOINT.format(recordId=record_id), + extra_headers={"Content-Type": "application/x-www-form-urlencoded"}, + data=payload, + ) data = response.json() if response.status_code == 200: @@ -808,31 +895,35 @@ def place_hold(self, patron, pin, licensepool, notification_email_address): self.raise_exception_on_error(data, CannotHold) def release_hold(self, patron, pin, licensepool): - """Release a patron's hold on a book. - """ + """Release a patron's hold on a book.""" record_id = licensepool.identifier.identifier hold = self.get_hold(patron, pin, record_id) - url = self.RELEASE_HOLD_ENDPOINT.format(holdId=hold['id']) + url = self.RELEASE_HOLD_ENDPOINT.format(holdId=hold["id"]) payload = json.dumps(dict(patronId=patron.authorization_identifier)) - response = self.patron_request(patron, pin, url, extra_headers={}, data=payload, method='POST') + response = self.patron_request( + patron, pin, url, extra_headers={}, data=payload, method="POST" + ) if response.status_code == 200: return True - self.raise_exception_on_error(response.json(), default_exception_class=CannotReleaseHold, - ignore_exception_codes=['HOLD_NOT_FOUND']) + self.raise_exception_on_error( + response.json(), + default_exception_class=CannotReleaseHold, + ignore_exception_codes=["HOLD_NOT_FOUND"], + ) return True @staticmethod def _update_credential(credential, odilo_data): """Copy Odilo OAuth data into a Credential object.""" - credential.credential = odilo_data['token'] - if odilo_data['expiresIn'] == -1: + credential.credential = odilo_data["token"] + if odilo_data["expiresIn"] == -1: # This token never expires. credential.expires = None else: - expires_in = (odilo_data['expiresIn'] * 0.9) + expires_in = odilo_data["expiresIn"] * 0.9 credential.expires = utc_now() + datetime.timedelta(seconds=expires_in) def get_metadata(self, record_id): @@ -846,9 +937,14 @@ def get_metadata(self, record_id): if status_code == 200 and content: return content else: - msg = 'Cannot retrieve metadata for record: ' + record_id + ' response http ' + status_code + msg = ( + "Cannot retrieve metadata for record: " + + record_id + + " response http " + + status_code + ) if content: - msg += ' content: ' + content + msg += " content: " + content self.log.warn(msg) return None @@ -860,20 +956,25 @@ def get_availability(self, record_id): if status_code == 200 and len(content) > 0: return content else: - msg = 'Cannot retrieve availability for record: ' + record_id + ' response http ' + status_code + msg = ( + "Cannot retrieve availability for record: " + + record_id + + " response http " + + status_code + ) if content: - msg += ' content: ' + content + msg += " content: " + content self.log.warn(msg) return None @staticmethod def _do_get(url, headers, **kwargs): # More time please - if 'timeout' not in kwargs: - kwargs['timeout'] = 60 + if "timeout" not in kwargs: + kwargs["timeout"] = 60 - if 'allow_redirects' not in kwargs: - kwargs['allow_redirects'] = True + if "allow_redirects" not in kwargs: + kwargs["allow_redirects"] = True response = HTTP.get_with_timeout(url, headers=headers, **kwargs) return response.status_code, response.headers, response.content @@ -881,16 +982,15 @@ def _do_get(url, headers, **kwargs): @staticmethod def _do_post(url, payload, headers, **kwargs): # More time please - if 'timeout' not in kwargs: - kwargs['timeout'] = 60 + if "timeout" not in kwargs: + kwargs["timeout"] = 60 return HTTP.post_with_timeout(url, payload, headers=headers, **kwargs) - class OdiloCirculationMonitor(CollectionMonitor, TimelineMonitor): - """Maintain LicensePools for recently changed Odilo titles - """ + """Maintain LicensePools for recently changed Odilo titles""" + SERVICE_NAME = "Odilo Circulation Monitor" INTERVAL_SECONDS = 500 PROTOCOL = ExternalIntegration.ODILO @@ -908,7 +1008,12 @@ def catch_up_from(self, start, cutoff, progress): covered by this Monitor. """ - self.log.info("Starting recently_changed_ids, start: " + str(start) + ", cutoff: " + str(cutoff)) + self.log.info( + "Starting recently_changed_ids, start: " + + str(start) + + ", cutoff: " + + str(cutoff) + ) start_time = utc_now() updated, new = self.all_ids(start) @@ -916,13 +1021,10 @@ def catch_up_from(self, start, cutoff, progress): time_elapsed = finish_time - start_time self.log.info("recently_changed_ids finished in: " + str(time_elapsed)) - progress.achievements = ( - "Updated records: %d. New records: %d." % (updated, new) - ) + progress.achievements = "Updated records: %d. New records: %d." % (updated, new) def all_ids(self, modification_date=None): - """Get IDs for every book in the system, from modification date if any - """ + """Get IDs for every book in the system, from modification date if any""" retrieved = 0 parsed = 0 @@ -931,7 +1033,9 @@ def all_ids(self, modification_date=None): limit = self.api.PAGE_SIZE_LIMIT if modification_date and isinstance(modification_date, datetime.date): - modification_date = modification_date.strftime('%Y-%m-%d') # Format YYYY-MM-DD + modification_date = modification_date.strftime( + "%Y-%m-%d" + ) # Format YYYY-MM-DD # Retrieve first group of records url = self.get_url(limit, modification_date, offset) @@ -942,13 +1046,18 @@ def all_ids(self, modification_date=None): while status_code == 200 and len(content) > 0: offset += limit retrieved += len(content) - self.log.info('Retrieved %i records' % retrieved) + self.log.info("Retrieved %i records" % retrieved) # Process a bunch of records retrieved for record in content: - record_id = record['id'] - self.log.info('Processing record %i/%i: %s' % (parsed, retrieved, record_id)) - identifier, is_new = self.api.odilo_bibliographic_coverage_provider.process_item( + record_id = record["id"] + self.log.info( + "Processing record %i/%i: %s" % (parsed, retrieved, record_id) + ) + ( + identifier, + is_new, + ) = self.api.odilo_bibliographic_coverage_provider.process_item( record_id, record ) @@ -966,11 +1075,17 @@ def all_ids(self, modification_date=None): content = json.loads(content) if status_code >= 400: - self.log.error('ERROR: Fail while retrieving data from remote source: HTTP ' + status_code) + self.log.error( + "ERROR: Fail while retrieving data from remote source: HTTP " + + status_code + ) if content: - self.log.error('ERROR response content: ' + str(content)) + self.log.error("ERROR response content: " + str(content)) else: - self.log.info('Retrieving all ids finished ok. Retrieved %i records. New records: %i!!' % (retrieved, new)) + self.log.info( + "Retrieving all ids finished ok. Retrieved %i records. New records: %i!!" + % (retrieved, new) + ) return retrieved, new def get_url(self, limit, modification_date, offset): @@ -980,6 +1095,7 @@ def get_url(self, limit, modification_date, offset): return url + class MockOdiloAPI(OdiloAPI): def patron_request(self, patron, pin, *args, **kwargs): response = self._make_request(*args, **kwargs) @@ -989,25 +1105,29 @@ def patron_request(self, patron, pin, *args, **kwargs): # The last item in the record of the request is keyword arguments. # Stick this information in there to minimize confusion. - original_data[-1]['_patron'] = patron - original_data[-1]['_pin'] = pin + original_data[-1]["_patron"] = patron + original_data[-1]["_pin"] = pin return response + @classmethod def mock_collection(cls, _db): library = DatabaseTest.make_default_library(_db) collection, ignore = get_one_or_create( - _db, Collection, + _db, + Collection, name="Test Odilo Collection", create_method_kwargs=dict( - external_account_id='library_id_123', - ) + external_account_id="library_id_123", + ), ) integration = collection.create_external_integration( protocol=ExternalIntegration.ODILO ) - integration.username = 'username' - integration.password = 'password' - integration.setting(OdiloAPI.LIBRARY_API_BASE_URL).value = 'http://library_api_base_url/api/v2' + integration.username = "username" + integration.password = "password" + integration.setting( + OdiloAPI.LIBRARY_API_BASE_URL + ).value = "http://library_api_base_url/api/v2" library.collections.append(collection) return collection @@ -1017,12 +1137,11 @@ def __init__(self, _db, collection, *args, **kwargs): self.requests = [] self.responses = [] - self.access_token_response = self.mock_access_token_response('bearer token') + self.access_token_response = self.mock_access_token_response("bearer token") super(MockOdiloAPI, self).__init__(_db, collection, *args, **kwargs) def token_post(self, url, payload, headers={}, **kwargs): - """Mock the request for an OAuth token. - """ + """Mock the request for an OAuth token.""" self.access_token_requests.append((url, payload, headers, kwargs)) response = self.access_token_response @@ -1033,9 +1152,7 @@ def mock_access_token_response(self, credential, expires_in=-1): return MockRequestsResponse(200, {}, json.dumps(token)) def queue_response(self, status_code, headers={}, content=None): - self.responses.insert( - 0, MockRequestsResponse(status_code, headers, content) - ) + self.responses.insert(0, MockRequestsResponse(status_code, headers, content)) def _do_get(self, url, *args, **kwargs): """Simulate Representation.simple_http_get.""" @@ -1049,8 +1166,10 @@ def _make_request(self, url, *args, **kwargs): response = self.responses.pop() self.requests.append((url, args, kwargs)) return HTTP._process_response( - url, response, kwargs.get('allowed_response_codes'), - kwargs.get('disallowed_response_codes') + url, + response, + kwargs.get("allowed_response_codes"), + kwargs.get("disallowed_response_codes"), ) @@ -1075,9 +1194,7 @@ def __init__(self, collection, api_class=OdiloAPI, **kwargs): :param api_class: Instantiate this class with the given Collection, rather than instantiating OdiloAPI. """ - super(OdiloBibliographicCoverageProvider, self).__init__( - collection, **kwargs - ) + super(OdiloBibliographicCoverageProvider, self).__init__(collection, **kwargs) if isinstance(api_class, OdiloAPI): # Use a previously instantiated OdiloAPI instance # rather than creating a new one. @@ -1097,7 +1214,7 @@ def __init__(self, collection, api_class=OdiloAPI, **kwargs): rights=True, link_content=True, # even_if_not_apparently_updated=False, - analytics=Analytics(self._db) + analytics=Analytics(self._db), ) def process_item(self, record_id, record=None): @@ -1105,12 +1222,14 @@ def process_item(self, record_id, record=None): record = self.api.get_metadata(record_id) if not record: - return self.failure(record_id, 'Record not found', transient=False) + return self.failure(record_id, "Record not found", transient=False) # Retrieve availability availability = self.api.get_availability(record_id) - metadata, is_active = OdiloRepresentationExtractor.record_info_to_metadata(record, availability) + metadata, is_active = OdiloRepresentationExtractor.record_info_to_metadata( + record, availability + ) if not metadata: e = "Could not extract metadata from Odilo data: %s" % record_id return self.failure(record_id, e) diff --git a/api/odl.py b/api/odl.py index b9b7ed5dbf..b4c983082f 100644 --- a/api/odl.py +++ b/api/odl.py @@ -15,18 +15,12 @@ from flask_babel import lazy_gettext as _ from lxml import etree from sqlalchemy.sql.expression import or_ -from typing import Optional from uritemplate import URITemplate from core import util from core.analytics import Analytics from core.lcp.credential import LCPCredentialFactory -from core.metadata_layer import ( - CirculationData, - FormatData, - LicenseData, - TimestampData, -) +from core.metadata_layer import CirculationData, FormatData, LicenseData, TimestampData from core.model import ( Collection, ConfigurationSetting, @@ -37,6 +31,7 @@ Hold, Hyperlink, LicensePool, + LicensePoolDeliveryMechanism, Loan, MediaTypes, Representation, @@ -44,20 +39,15 @@ Session, get_one, get_one_or_create, - Representation, LicensePoolDeliveryMechanism) -from core.model.configuration import ConfigurationFactory, ConfigurationStorage, HasExternalIntegration, \ - ConfigurationGrouping, ConfigurationMetadata, ConfigurationAttributeType, ConfigurationOption -from core.monitor import ( - CollectionMonitor, - IdentifierSweepMonitor) -from core.opds_import import ( - OPDSXMLParser, - OPDSImporter, - OPDSImportMonitor, ) -from core.testing import ( - DatabaseTest, - MockRequestsResponse, +from core.model.configuration import ( + ConfigurationAttributeType, + ConfigurationFactory, + ConfigurationGrouping, + ConfigurationMetadata, + ConfigurationOption, + ConfigurationStorage, + HasExternalIntegration, ) from core.monitor import CollectionMonitor, IdentifierSweepMonitor from core.opds_import import OPDSImporter, OPDSImportMonitor, OPDSXMLParser @@ -68,14 +58,14 @@ from .circulation import BaseCirculationAPI, FulfillmentInfo, HoldInfo, LoanInfo from .circulation_exceptions import * -from .lcp.hash import HasherFactory, Hasher, HashingAlgorithm +from .lcp.hash import Hasher, HasherFactory, HashingAlgorithm from .shared_collection import BaseSharedCollectionAPI class ODLAPIConfiguration(ConfigurationGrouping): """Contains LCP License Server's settings""" - DEFAULT_PASSPHRASE_HINT = 'Look in the Palace app.' + DEFAULT_PASSPHRASE_HINT = "Look in the Palace app." DEFAULT_ENCRYPTION_ALGORITHM = HashingAlgorithm.SHA256.value feed_url = ConfigurationMetadata( @@ -84,7 +74,7 @@ class ODLAPIConfiguration(ConfigurationGrouping): description="", type=ConfigurationAttributeType.TEXT, required=True, - format="url" + format="url", ) username = ConfigurationMetadata( @@ -114,28 +104,34 @@ class ODLAPIConfiguration(ConfigurationGrouping): default_reservation_period = ConfigurationMetadata( key=Collection.DEFAULT_RESERVATION_PERIOD_KEY, label=_("Default Reservation Period (in Days)"), - description=_("The number of days a patron has to check out a book after a hold becomes available."), + description=_( + "The number of days a patron has to check out a book after a hold becomes available." + ), type=ConfigurationAttributeType.NUMBER, required=False, - default=Collection.STANDARD_DEFAULT_RESERVATION_PERIOD + default=Collection.STANDARD_DEFAULT_RESERVATION_PERIOD, ) passphrase_hint = ConfigurationMetadata( key="passphrase_hint", label=_("Passphrase hint"), - description=_("Hint displayed to the user when opening an LCP protected publication."), + description=_( + "Hint displayed to the user when opening an LCP protected publication." + ), type=ConfigurationAttributeType.TEXT, required=False, - default=DEFAULT_PASSPHRASE_HINT + default=DEFAULT_PASSPHRASE_HINT, ) passphrase_hint_url = ConfigurationMetadata( key="passphrase_hint_url", label=_("Passphrase hint URL"), - description=_("Hint URL available to the user when opening an LCP protected publication."), + description=_( + "Hint URL available to the user when opening an LCP protected publication." + ), type=ConfigurationAttributeType.TEXT, required=False, - format="url" + format="url", ) encryption_algorithm = ConfigurationMetadata( @@ -145,7 +141,7 @@ class ODLAPIConfiguration(ConfigurationGrouping): type=ConfigurationAttributeType.SELECT, required=False, default=DEFAULT_ENCRYPTION_ALGORITHM, - options=ConfigurationOption.from_enum(HashingAlgorithm) + options=ConfigurationOption.from_enum(HashingAlgorithm), ) @@ -164,7 +160,9 @@ class ODLAPI(BaseCirculationAPI, BaseSharedCollectionAPI, HasExternalIntegration """ NAME = ExternalIntegration.ODL - DESCRIPTION = _("Import books from a distributor that uses ODL (Open Distribution to Libraries).") + DESCRIPTION = _( + "Import books from a distributor that uses ODL (Open Distribution to Libraries)." + ) SETTINGS = BaseSharedCollectionAPI.SETTINGS + ODLAPIConfiguration.to_settings() @@ -206,11 +204,13 @@ class ODLAPI(BaseCirculationAPI, BaseSharedCollectionAPI, HasExternalIntegration def __init__(self, _db, collection): if collection.protocol != self.NAME: raise ValueError( - "Collection protocol is %s, but passed into ODLAPI!" % - collection.protocol + "Collection protocol is %s, but passed into ODLAPI!" + % collection.protocol ) self.collection_id = collection.id - self.data_source_name = collection.external_integration.setting(Collection.DATA_SOURCE_NAME_SETTING).value + self.data_source_name = collection.external_integration.setting( + Collection.DATA_SOURCE_NAME_SETTING + ).value # Create the data source if it doesn't exist yet. DataSource.lookup(_db, self.data_source_name, autocreate=True) @@ -224,7 +224,9 @@ def __init__(self, _db, collection): self._credential_factory = LCPCredentialFactory() self._hasher_instance: Optional[Hasher] = None - def external_integration(self, db: sqlalchemy.orm.session.Session) -> ExternalIntegration: + def external_integration( + self, db: sqlalchemy.orm.session.Session + ) -> ExternalIntegration: """Return an external integration associated with this object. :param db: Database session @@ -259,7 +261,8 @@ def _get_hasher(self, configuration): self._hasher_instance = self._hasher_factory.create( configuration.encryption_algorithm if configuration.encryption_algorithm - else ODLAPIConfiguration.DEFAULT_ENCRYPTION_ALGORITHM) + else ODLAPIConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + ) return self._hasher_instance @@ -272,13 +275,12 @@ def _get(self, url, headers=None): password = self.password headers = dict(headers or {}) auth_header = "Basic %s" % base64.b64encode("%s:%s" % (username, password)) - headers['Authorization'] = auth_header + headers["Authorization"] = auth_header return HTTP.get_with_timeout(url, headers=headers) def _url_for(self, *args, **kwargs): - """Wrapper around flask's url_for to be overridden for tests. - """ + """Wrapper around flask's url_for to be overridden for tests.""" return url_for(*args, **kwargs) def get_license_status_document(self, loan): @@ -306,10 +308,8 @@ def get_license_status_document(self, loan): # TODO: should integration clients be able to specify their own loan period? default_loan_period = self.collection(_db).default_loan_period( loan.integration_client - ) - expires = utc_now() + datetime.timedelta( - days=default_loan_period - ) + ) + expires = utc_now() + datetime.timedelta(days=default_loan_period) # The patron UUID is generated randomly on each loan, so the distributor # doesn't know when multiple loans come from the same patron. patron_id = str(uuid.uuid1()) @@ -324,12 +324,19 @@ def get_license_status_document(self, loan): patron = loan.patron with self._configuration_factory.create( - self._configuration_storage, db, ODLAPIConfiguration) as configuration: + self._configuration_storage, db, ODLAPIConfiguration + ) as configuration: hasher = self._get_hasher(configuration) - hashed_passphrase = hasher.hash(self._credential_factory.get_patron_passphrase(db, patron)) - encoded_passphrase = base64.b64encode(binascii.unhexlify(hashed_passphrase)) + hashed_passphrase = hasher.hash( + self._credential_factory.get_patron_passphrase(db, patron) + ) + encoded_passphrase = base64.b64encode( + binascii.unhexlify(hashed_passphrase) + ) - self._credential_factory.set_hashed_passphrase(db, patron, hashed_passphrase) + self._credential_factory.set_hashed_passphrase( + db, patron, hashed_passphrase + ) notification_url = self._url_for( "odl_notify", @@ -347,7 +354,7 @@ def get_license_status_document(self, loan): notification_url=notification_url, passphrase=encoded_passphrase, hint=configuration.passphrase_hint, - hint_url=configuration.passphrase_hint_url + hint_url=configuration.passphrase_hint_url, ) response = self._get(url) @@ -355,19 +362,23 @@ def get_license_status_document(self, loan): try: status_doc = json.loads(response.content) except ValueError as e: - raise BadResponseException(url, "License Status Document was not valid JSON.") + raise BadResponseException( + url, "License Status Document was not valid JSON." + ) if status_doc.get("status") not in self.STATUS_VALUES: - raise BadResponseException(url, "License Status Document had an unknown status value.") + raise BadResponseException( + url, "License Status Document had an unknown status value." + ) return status_doc def checkin(self, patron, pin, licensepool): """Return a loan early.""" _db = Session.object_session(patron) - loan = _db.query(Loan).filter( - Loan.patron==patron - ).filter( - Loan.license_pool_id==licensepool.id + loan = ( + _db.query(Loan) + .filter(Loan.patron == patron) + .filter(Loan.license_pool_id == licensepool.id) ) if loan.count() < 1: raise NotCheckedOut() @@ -378,7 +389,12 @@ def _checkin(self, loan): _db = Session.object_session(loan) doc = self.get_license_status_document(loan) status = doc.get("status") - if status in [self.REVOKED_STATUS, self.RETURNED_STATUS, self.CANCELLED_STATUS, self.EXPIRED_STATUS]: + if status in [ + self.REVOKED_STATUS, + self.RETURNED_STATUS, + self.CANCELLED_STATUS, + self.EXPIRED_STATUS, + ]: # This loan was already returned early or revoked by the distributor, or it expired. self.update_loan(loan, doc) raise NotCheckedOut() @@ -417,10 +433,10 @@ def checkout(self, patron, pin, licensepool, internal_format): """Create a new loan.""" _db = Session.object_session(patron) - loan = _db.query(Loan).filter( - Loan.patron==patron - ).filter( - Loan.license_pool_id==licensepool.id + loan = ( + _db.query(Loan) + .filter(Loan.patron == patron) + .filter(Loan.license_pool_id == licensepool.id) ) if loan.count() > 0: raise AlreadyCheckedOut() @@ -452,11 +468,9 @@ def _checkout(self, patron_or_client, licensepool, hold=None): # If there's a holds queue, the patron or client must have a non-expired hold # with position 0 to check out the book. - if ((not hold or - hold.position > 0 or - (hold.end and hold.end < utc_now())) and - licensepool.licenses_available < 1 - ): + if ( + not hold or hold.position > 0 or (hold.end and hold.end < utc_now()) + ) and licensepool.licenses_available < 1: raise NoAvailableCopies() # Create a local loan so its database id can be used to @@ -517,18 +531,18 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): """ _db = Session.object_session(patron) - loan = _db.query(Loan).filter( - Loan.patron==patron - ).filter( - Loan.license_pool_id==licensepool.id + loan = ( + _db.query(Loan) + .filter(Loan.patron == patron) + .filter(Loan.license_pool_id == licensepool.id) ) loan = loan.one() return self._fulfill(loan, internal_format) def _fulfill( - self, - loan: Loan, - delivery_mechanism: Optional[LicensePoolDeliveryMechanism] = None + self, + loan: Loan, + delivery_mechanism: Optional[LicensePoolDeliveryMechanism] = None, ) -> FulfillmentInfo: licensepool = loan.license_pool doc = self.get_license_status_document(loan) @@ -573,8 +587,12 @@ def _fulfill( elif isinstance(delivery_mechanism, LicensePoolDeliveryMechanism): # If we have a LicensePoolDeliveryMechanism instance, # then we use it to find a link with the correct content and DRM types. - if candidate_content_type == delivery_mechanism.delivery_mechanism.drm_scheme or \ - candidate_content_type == delivery_mechanism.delivery_mechanism.content_type: + if ( + candidate_content_type + == delivery_mechanism.delivery_mechanism.drm_scheme + or candidate_content_type + == delivery_mechanism.delivery_mechanism.content_type + ): content_link = candidate_content_link content_type = candidate_content_type break @@ -594,17 +612,19 @@ def _count_holds_before(self, hold): # Count holds on the license pool that started before this hold and # aren't expired. _db = Session.object_session(hold) - return _db.query(Hold).filter( - Hold.license_pool_id==hold.license_pool_id - ).filter( - Hold.startutc_now(), - Hold.position>0, + return ( + _db.query(Hold) + .filter(Hold.license_pool_id == hold.license_pool_id) + .filter(Hold.start < hold.start) + .filter( + or_( + Hold.end == None, + Hold.end > utc_now(), + Hold.position > 0, + ) ) - ).count() + .count() + ) def _update_hold_end_date(self, hold): _db = Session.object_session(hold) @@ -631,24 +651,29 @@ def _update_hold_end_date(self, hold): # but we're still calculating the worst case. elif hold.position > 0: # Find the current loans and reserved holds for the licenses. - current_loans = _db.query(Loan).filter( - Loan.license_pool_id==pool.id - ).filter( - or_( - Loan.end==None, - Loan.end>utc_now() - ) - ).order_by(Loan.start).all() - current_holds = _db.query(Hold).filter( - Hold.license_pool_id==pool.id - ).filter( - or_( - Hold.end==None, - Hold.end>utc_now(), - Hold.position>0, + current_loans = ( + _db.query(Loan) + .filter(Loan.license_pool_id == pool.id) + .filter(or_(Loan.end == None, Loan.end > utc_now())) + .order_by(Loan.start) + .all() + ) + current_holds = ( + _db.query(Hold) + .filter(Hold.license_pool_id == pool.id) + .filter( + or_( + Hold.end == None, + Hold.end > utc_now(), + Hold.position > 0, + ) ) - ).order_by(Hold.start).all() - licenses_reserved = min(pool.licenses_owned - len(current_loans), len(current_holds)) + .order_by(Hold.start) + .all() + ) + licenses_reserved = min( + pool.licenses_owned - len(current_loans), len(current_holds) + ) current_reservations = current_holds[:licenses_reserved] # The licenses will have to go through some number of cycles @@ -659,7 +684,7 @@ def _update_hold_end_date(self, hold): # Each of the owned licenses is currently either on loan or reserved. # Figure out which license this hold will eventually get if every # patron keeps their loans and holds for the maximum time. - copy_index = (hold.position - licenses_reserved - 1) % pool.licenses_owned + copy_index = (hold.position - licenses_reserved - 1) % pool.licenses_owned # In the worse case, the first cycle ends when a current loan expires, or # after a current reservation is checked out and then expires. @@ -667,11 +692,15 @@ def _update_hold_end_date(self, hold): next_cycle_start = current_loans[copy_index].end else: reservation = current_reservations[copy_index - len(current_loans)] - next_cycle_start = reservation.end + datetime.timedelta(days=default_loan_period) + next_cycle_start = reservation.end + datetime.timedelta( + days=default_loan_period + ) # Assume all cycles after the first cycle take the maximum time. cycle_period = default_loan_period + default_reservation_period - hold.end = next_cycle_start + datetime.timedelta(days=(cycle_period * cycles)) + hold.end = next_cycle_start + datetime.timedelta( + days=(cycle_period * cycles) + ) # If the end date isn't set yet or the position just became 0, the # hold just became available. The patron's reservation period starts now. @@ -681,14 +710,14 @@ def _update_hold_end_date(self, hold): def _update_hold_position(self, hold): _db = Session.object_session(hold) pool = hold.license_pool - loans_count = _db.query(Loan).filter( - Loan.license_pool_id==pool.id, - ).filter( - or_( - Loan.end==None, - Loan.end > utc_now() + loans_count = ( + _db.query(Loan) + .filter( + Loan.license_pool_id == pool.id, ) - ).count() + .filter(or_(Loan.end == None, Loan.end > utc_now())) + .count() + ) holds_count = self._count_holds_before(hold) remaining_licenses = pool.licenses_owned - loans_count @@ -705,27 +734,27 @@ def update_hold_queue(self, licensepool): # Update the pool and the next holds in the queue when a license is reserved. _db = Session.object_session(licensepool) - loans_count = _db.query(Loan).filter( - Loan.license_pool_id==licensepool.id - ).filter( - or_( - Loan.end==None, - Loan.end>utc_now() - ) - ).count() + loans_count = ( + _db.query(Loan) + .filter(Loan.license_pool_id == licensepool.id) + .filter(or_(Loan.end == None, Loan.end > utc_now())) + .count() + ) remaining_licenses = max(licensepool.licenses_owned - loans_count, 0) - holds = _db.query(Hold).filter( - Hold.license_pool_id==licensepool.id - ).filter( - or_( - Hold.end==None, - Hold.end>utc_now(), - Hold.position>0, + holds = ( + _db.query(Hold) + .filter(Hold.license_pool_id == licensepool.id) + .filter( + or_( + Hold.end == None, + Hold.end > utc_now(), + Hold.position > 0, + ) ) - ).order_by( - Hold.start - ).all() + .order_by(Hold.start) + .all() + ) if len(holds) > remaining_licenses: new_licenses_available = 0 @@ -744,7 +773,7 @@ def update_hold_queue(self, licensepool): as_of=utc_now(), ) - for hold in holds[:licensepool.licenses_reserved]: + for hold in holds[: licensepool.licenses_reserved]: if hold.position != 0: # This hold just got a reserved license. self._update_hold_end_date(hold) @@ -786,7 +815,8 @@ def release_hold(self, patron, pin, licensepool): _db = Session.object_session(patron) hold = get_one( - _db, Hold, + _db, + Hold, license_pool_id=licensepool.id, patron=patron, ) @@ -810,20 +840,21 @@ def _release_hold(self, hold): def patron_activity(self, patron, pin): """Look up non-expired loans for this collection in the database.""" _db = Session.object_session(patron) - loans = _db.query(Loan).join(Loan.license_pool).filter( - LicensePool.collection_id==self.collection_id - ).filter( - Loan.patron==patron - ).filter( - Loan.end>=utc_now() + loans = ( + _db.query(Loan) + .join(Loan.license_pool) + .filter(LicensePool.collection_id == self.collection_id) + .filter(Loan.patron == patron) + .filter(Loan.end >= utc_now()) ) # Get the patron's holds. If there are any expired holds, delete them. # Update the end date and position for the remaining holds. - holds = _db.query(Hold).join(Hold.license_pool).filter( - LicensePool.collection_id==self.collection_id - ).filter( - Hold.patron==patron + holds = ( + _db.query(Hold) + .join(Hold.license_pool) + .filter(LicensePool.collection_id == self.collection_id) + .filter(Hold.patron == patron) ) remaining_holds = [] for hold in holds: @@ -843,7 +874,8 @@ def patron_activity(self, patron, pin): loan.start, loan.end, external_identifier=loan.external_identifier, - ) for loan in loans + ) + for loan in loans ] + [ HoldInfo( hold.license_pool.collection, @@ -853,7 +885,8 @@ def patron_activity(self, patron, pin): start_date=hold.start, end_date=hold.end, hold_position=hold.position, - ) for hold in remaining_holds + ) + for hold in remaining_holds ] def update_loan(self, loan, status_doc=None): @@ -869,9 +902,16 @@ def update_loan(self, loan, status_doc=None): # We already check that the status is valid in get_license_status_document, # but if the document came from a notification it hasn't been checked yet. if status not in self.STATUS_VALUES: - raise BadResponseException("The License Status Document had an unknown status value.") + raise BadResponseException( + "The License Status Document had an unknown status value." + ) - if status in [self.REVOKED_STATUS, self.RETURNED_STATUS, self.CANCELLED_STATUS, self.EXPIRED_STATUS]: + if status in [ + self.REVOKED_STATUS, + self.RETURNED_STATUS, + self.CANCELLED_STATUS, + self.EXPIRED_STATUS, + ]: # This loan is no longer active. Update the pool's availability # and delete the loan. @@ -896,8 +936,7 @@ def release_hold_from_external_library(self, client, hold): class ODLXMLParser(OPDSXMLParser): - NAMESPACES = dict(OPDSXMLParser.NAMESPACES, - odl="http://opds-spec.org/odl") + NAMESPACES = dict(OPDSXMLParser.NAMESPACES, odl="http://opds-spec.org/odl") class ODLImporter(OPDSImporter): @@ -906,23 +945,24 @@ class ODLImporter(OPDSImporter): The only change from OPDSImporter is that this importer extracts format information from 'odl:license' tags. """ + NAME = ODLAPI.NAME PARSER_CLASS = ODLXMLParser # The media type for a License Info Docuemnt, used to get information # about the license. - LICENSE_INFO_DOCUMENT_MEDIA_TYPE = 'application/vnd.odl.info+json' + LICENSE_INFO_DOCUMENT_MEDIA_TYPE = "application/vnd.odl.info+json" @classmethod def parse_license( - cls, - identifier: str, - total_checkouts: Optional[int], - concurrent_checkouts: Optional[int], - expires: Optional[datetime.datetime], - checkout_link: Optional[str], - odl_status_link: Optional[str], - do_get: Callable + cls, + identifier: str, + total_checkouts: Optional[int], + concurrent_checkouts: Optional[int], + expires: Optional[datetime.datetime], + checkout_link: Optional[str], + odl_status_link: Optional[str], + do_get: Callable, ) -> Optional[LicenseData]: """Check the license's attributes passed as parameters: - if they're correct, turn them into a LicenseData object @@ -969,9 +1009,7 @@ def parse_license( break if odl_status_link: - status_code, _, response = do_get( - odl_status_link, headers={} - ) + status_code, _, response = do_get(odl_status_link, headers={}) if status_code in (200, 201): status = json.loads(response) @@ -1013,7 +1051,9 @@ def parse_license( return None @classmethod - def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get=None): + def _detail_for_elementtree_entry( + cls, parser, entry_tag, feed_url=None, do_get=None + ): do_get = do_get or Representation.cautious_http_get # TODO: Review for consistency when updated ODL spec is ready. @@ -1024,11 +1064,11 @@ def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get= licenses_owned = 0 licenses_available = 0 - odl_license_tags = parser._xpath(entry_tag, 'odl:license') or [] + odl_license_tags = parser._xpath(entry_tag, "odl:license") or [] medium = None for odl_license_tag in odl_license_tags: - identifier = subtag(odl_license_tag, 'dcterms:identifier') - full_content_type = subtag(odl_license_tag, 'dcterms:format') + identifier = subtag(odl_license_tag, "dcterms:identifier") + full_content_type = subtag(odl_license_tag, "dcterms:format") if not medium: medium = Edition.medium_from_media_type(full_content_type) @@ -1040,9 +1080,9 @@ def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get= # But it may instead describe an audiobook protected with # the Feedbooks access-control scheme. - feedbooks_audio = "%s; protection=%s" % ( + feedbooks_audio = "%s; protection=%s" % ( MediaTypes.AUDIOBOOK_MANIFEST_MEDIA_TYPE, - DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_DRM + DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_DRM, ) if full_content_type == feedbooks_audio: content_type = MediaTypes.AUDIOBOOK_MANIFEST_MEDIA_TYPE @@ -1050,15 +1090,13 @@ def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get= # Additional DRM schemes may be described in # tags. - protection_tags = parser._xpath( - odl_license_tag, 'odl:protection' - ) or [] + protection_tags = parser._xpath(odl_license_tag, "odl:protection") or [] for protection_tag in protection_tags: - drm_scheme = subtag(protection_tag, 'dcterms:format') + drm_scheme = subtag(protection_tag, "dcterms:format") if drm_scheme: drm_schemes.append(drm_scheme) - for drm_scheme in (drm_schemes or [None]): + for drm_scheme in drm_schemes or [None]: formats.append( FormatData( content_type=content_type, @@ -1067,10 +1105,10 @@ def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get= ) ) - data['medium'] = medium + data["medium"] = medium checkout_link = None - for link_tag in parser._xpath(odl_license_tag, 'odl:tlink') or []: + for link_tag in parser._xpath(odl_license_tag, "odl:tlink") or []: rel = link_tag.attrib.get("rel") if rel == Hyperlink.BORROW: checkout_link = link_tag.attrib.get("href") @@ -1078,12 +1116,13 @@ def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get= # Look for a link to the License Info Document for this license. odl_status_link = None - for link_tag in parser._xpath(odl_license_tag, 'atom:link') or []: + for link_tag in parser._xpath(odl_license_tag, "atom:link") or []: attrib = link_tag.attrib rel = attrib.get("rel") type = attrib.get("type", "") - if (rel == 'self' - and type.startswith(cls.LICENSE_INFO_DOCUMENT_MEDIA_TYPE)): + if rel == "self" and type.startswith( + cls.LICENSE_INFO_DOCUMENT_MEDIA_TYPE + ): odl_status_link = attrib.get("href") break @@ -1104,7 +1143,7 @@ def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get= expires, checkout_link, odl_status_link, - do_get + do_get, ) if not license: @@ -1115,24 +1154,26 @@ def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get= licenses.append(license) - if not data.get('circulation'): - data['circulation'] = dict() - if not data['circulation'].get('formats'): - data['circulation']['formats'] = [] - data['circulation']['formats'].extend(formats) - if not data['circulation'].get('licenses'): - data['circulation']['licenses'] = [] - data['circulation']['licenses'].extend(licenses) - data['circulation']['licenses_owned'] = licenses_owned - data['circulation']['licenses_available'] = licenses_available + if not data.get("circulation"): + data["circulation"] = dict() + if not data["circulation"].get("formats"): + data["circulation"]["formats"] = [] + data["circulation"]["formats"].extend(formats) + if not data["circulation"].get("licenses"): + data["circulation"]["licenses"] = [] + data["circulation"]["licenses"].extend(licenses) + data["circulation"]["licenses_owned"] = licenses_owned + data["circulation"]["licenses_available"] = licenses_available return data class ODLImportMonitor(OPDSImportMonitor): """Import information from an ODL feed.""" + PROTOCOL = ODLImporter.NAME SERVICE_NAME = "ODL Import Monitor" + class ODLHoldReaper(CollectionMonitor): """Check for holds that have expired and delete them, and update the holds queues for their pools.""" @@ -1146,14 +1187,12 @@ def __init__(self, _db, collection=None, api=None, **kwargs): def run_once(self, progress): # Find holds that have expired. - expired_holds = self._db.query(Hold).join( - Hold.license_pool - ).filter( - LicensePool.collection_id==self.api.collection_id - ).filter( - Hold.end 0: raise AlreadyCheckedOut() - holds = _db.query(Hold).filter( - Hold.patron==patron - ).filter( - Hold.license_pool_id==licensepool.id + holds = ( + _db.query(Hold) + .filter(Hold.patron == patron) + .filter(Hold.license_pool_id == licensepool.id) ) if holds.count() > 0: hold = holds.one() @@ -1334,23 +1397,36 @@ def checkout(self, patron, pin, licensepool, internal_format): availability = entry.get("opds_availability", {}) if availability.get("status") != "ready": raise NoAvailableCopies() - checkout_links = [link for link in entry.get("links") if link.get("rel") == Hyperlink.BORROW] + checkout_links = [ + link + for link in entry.get("links") + if link.get("rel") == Hyperlink.BORROW + ] if len(checkout_links) < 1: raise NoAvailableCopies() checkout_url = checkout_links[0].get("href") else: - borrow_links = [link for link in licensepool.identifier.links if link.rel == Hyperlink.BORROW] + borrow_links = [ + link + for link in licensepool.identifier.links + if link.rel == Hyperlink.BORROW + ] if not borrow_links: raise CannotLoan() checkout_url = borrow_links[0].resource.url try: - response = self._get(checkout_url, allowed_response_codes=["2xx", "3xx", "403", "404"]) + response = self._get( + checkout_url, allowed_response_codes=["2xx", "3xx", "403", "404"] + ) except RemoteIntegrationException as e: raise CannotLoan() if response.status_code == 403: raise NoAvailableCopies() elif response.status_code == 404: - if hasattr(response, 'json') and response.json().get('type', '') == NO_LICENSES.uri: + if ( + hasattr(response, "json") + and response.json().get("type", "") == NO_LICENSES.uri + ): raise NoLicenses() feed = self._parse_feed_from_response(response) @@ -1375,7 +1451,7 @@ def checkout(self, patron, pin, licensepool, internal_format): licensepool.identifier.identifier, start, end, - external_identifier=external_identifier + external_identifier=external_identifier, ) elif availability.get("status") in ["ready", "reserved"]: # We tried to borrow this book but it wasn't available, @@ -1391,7 +1467,7 @@ def checkout(self, patron, pin, licensepool, internal_format): start, end, hold_position=position, - external_identifier=external_identifier + external_identifier=external_identifier, ) else: # We didn't get an error, but something went wrong and we don't have a @@ -1401,10 +1477,10 @@ def checkout(self, patron, pin, licensepool, internal_format): def checkin(self, patron, pin, licensepool): _db = Session.object_session(patron) - loan = _db.query(Loan).filter( - Loan.patron==patron - ).filter( - Loan.license_pool_id==licensepool.id + loan = ( + _db.query(Loan) + .filter(Loan.patron == patron) + .filter(Loan.license_pool_id == licensepool.id) ) if loan.count() < 1: raise NotCheckedOut() @@ -1423,7 +1499,11 @@ def checkin(self, patron, pin, licensepool): if len(entries) < 1: raise CannotReturn() entry = entries[0] - revoke_links = [link for link in entry.get("links") if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke"] + revoke_links = [ + link + for link in entry.get("links") + if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke" + ] if len(revoke_links) < 1: raise CannotReturn() revoke_url = revoke_links[0].get("href") @@ -1443,10 +1523,10 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): """ _db = Session.object_session(patron) - loan = _db.query(Loan).filter( - Loan.patron==patron - ).filter( - Loan.license_pool_id==licensepool.id + loan = ( + _db.query(Loan) + .filter(Loan.patron == patron) + .filter(Loan.license_pool_id == licensepool.id) ) if loan.count() < 1: raise NotCheckedOut() @@ -1480,7 +1560,9 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): parser = SharedODLImporter.PARSER_CLASS() root = etree.parse(StringIO(response_content)) - fulfill_url = SharedODLImporter.get_fulfill_url(response_content, requested_content_type, requested_drm_scheme) + fulfill_url = SharedODLImporter.get_fulfill_url( + response_content, requested_content_type, requested_drm_scheme + ) if not fulfill_url: raise FormatNotAvailable() @@ -1509,7 +1591,8 @@ def release_hold(self, patron, pin, licensepool): _db = Session.object_session(patron) hold = get_one( - _db, Hold, + _db, + Hold, license_pool_id=licensepool.id, patron=patron, ) @@ -1532,7 +1615,11 @@ def release_hold(self, patron, pin, licensepool): availability = entry.get("opds_availability", {}) if availability.get("status") not in ["reserved", "ready"]: raise CannotReleaseHold() - revoke_links = [link for link in entry.get("links") if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke"] + revoke_links = [ + link + for link in entry.get("links") + if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke" + ] if len(revoke_links) < 1: raise CannotReleaseHold() revoke_url = revoke_links[0].get("href") @@ -1544,22 +1631,26 @@ def release_hold(self, patron, pin, licensepool): def patron_activity(self, patron, pin): _db = Session.object_session(patron) - loans = _db.query(Loan).join(Loan.license_pool).filter( - LicensePool.collection_id==self.collection_id - ).filter( - Loan.patron==patron + loans = ( + _db.query(Loan) + .join(Loan.license_pool) + .filter(LicensePool.collection_id == self.collection_id) + .filter(Loan.patron == patron) ) - holds = _db.query(Hold).join(Hold.license_pool).filter( - LicensePool.collection_id==self.collection_id - ).filter( - Hold.patron==patron + holds = ( + _db.query(Hold) + .join(Hold.license_pool) + .filter(LicensePool.collection_id == self.collection_id) + .filter(Hold.patron == patron) ) activity = [] for loan in loans: info_url = loan.external_identifier - response = self._get(info_url, patron=patron, allowed_response_codes=["2xx", "3xx", "404"]) + response = self._get( + info_url, patron=patron, allowed_response_codes=["2xx", "3xx", "404"] + ) if response.status_code == 404: # 404 is returned when the loan has been deleted. Leave this loan out of the result. continue @@ -1588,7 +1679,9 @@ def patron_activity(self, patron, pin): ) for hold in holds: info_url = hold.external_identifier - response = self._get(info_url, patron=patron, allowed_response_codes=["2xx", "3xx", "404"]) + response = self._get( + info_url, patron=patron, allowed_response_codes=["2xx", "3xx", "404"] + ) if response.status_code == 404: # 404 is returned when the hold has been deleted. Leave this hold out of the result. continue @@ -1632,52 +1725,63 @@ def get_fulfill_url(cls, entry, requested_content_type, requested_drm_scheme): root = etree.parse(StringIO(entry)) fulfill_url = None - for link_tag in parser._xpath(root, 'atom:link'): + for link_tag in parser._xpath(root, "atom:link"): if link_tag.attrib.get("rel") == Hyperlink.GENERIC_OPDS_ACQUISITION: content_type = None drm_scheme = link_tag.attrib.get("type") - indirect_acquisition = parser._xpath(link_tag, "opds:indirectAcquisition") + indirect_acquisition = parser._xpath( + link_tag, "opds:indirectAcquisition" + ) if indirect_acquisition: content_type = indirect_acquisition[0].get("type") else: content_type = drm_scheme drm_scheme = None - if content_type == requested_content_type and drm_scheme == requested_drm_scheme: + if ( + content_type == requested_content_type + and drm_scheme == requested_drm_scheme + ): fulfill_url = link_tag.attrib.get("href") break return fulfill_url - @classmethod - def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get=None): + def _detail_for_elementtree_entry( + cls, parser, entry_tag, feed_url=None, do_get=None + ): data = OPDSImporter._detail_for_elementtree_entry(parser, entry_tag, feed_url) - borrow_links = [link for link in data.get("links") if link.rel == Hyperlink.BORROW] + borrow_links = [ + link for link in data.get("links") if link.rel == Hyperlink.BORROW + ] licenses_available = 0 licenses_owned = 0 patrons_in_hold_queue = 0 formats = [] - for link_tag in parser._xpath(entry_tag, 'atom:link'): + for link_tag in parser._xpath(entry_tag, "atom:link"): if link_tag.attrib.get("rel") == Hyperlink.BORROW: content_type = None drm_scheme = None - indirect_acquisition = parser._xpath(link_tag, "opds:indirectAcquisition") + indirect_acquisition = parser._xpath( + link_tag, "opds:indirectAcquisition" + ) if indirect_acquisition: drm_scheme = indirect_acquisition[0].attrib.get("type") - second_indirect_acquisition = parser._xpath(indirect_acquisition[0], "opds:indirectAcquisition") + second_indirect_acquisition = parser._xpath( + indirect_acquisition[0], "opds:indirectAcquisition" + ) if second_indirect_acquisition: content_type = second_indirect_acquisition[0].attrib.get("type") else: content_type = drm_scheme drm_scheme = None - - copies_tags = parser._xpath(link_tag, 'opds:copies') + copies_tags = parser._xpath(link_tag, "opds:copies") if copies_tags: copies_tag = copies_tags[0] licenses_available = copies_tag.attrib.get("available") @@ -1686,7 +1790,7 @@ def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get= licenses_owned = copies_tag.attrib.get("total") if licenses_owned != None: licenses_owned = int(licenses_owned) - holds_tags = parser._xpath(link_tag, 'opds:holds') + holds_tags = parser._xpath(link_tag, "opds:holds") if holds_tags: holds_tag = holds_tags[0] patrons_in_hold_queue = holds_tag.attrib.get("total") @@ -1707,9 +1811,10 @@ def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get= formats=formats, ) - data['circulation'] = circulation + data["circulation"] = circulation return data + class SharedODLImportMonitor(OPDSImportMonitor): PROTOCOL = SharedODLImporter.NAME SERVICE_NAME = "Shared ODL Import Monitor" @@ -1718,6 +1823,7 @@ def opds_url(self, collection): base_url = collection.external_account_id return base_url + "/crawlable" + class MockSharedODLAPI(SharedODLAPI): """Mock API for tests that overrides _get and tracks requests.""" @@ -1726,14 +1832,14 @@ def mock_collection(cls, _db): """Create a mock ODL collection to use in tests.""" library = DatabaseTest.make_default_library(_db) collection, ignore = get_one_or_create( - _db, Collection, - name="Test Shared ODL Collection", create_method_kwargs=dict( + _db, + Collection, + name="Test Shared ODL Collection", + create_method_kwargs=dict( external_account_id="http://shared-odl", - ) - ) - integration = collection.create_external_integration( - protocol=SharedODLAPI.NAME + ), ) + integration = collection.create_external_integration(protocol=SharedODLAPI.NAME) library.collections.append(collection) return collection @@ -1741,21 +1847,19 @@ def __init__(self, _db, collection, *args, **kwargs): self.responses = [] self.requests = [] self.request_args = [] - super(MockSharedODLAPI, self).__init__( - _db, collection, *args, **kwargs - ) + super(MockSharedODLAPI, self).__init__(_db, collection, *args, **kwargs) def queue_response(self, status_code, headers={}, content=None): - self.responses.insert( - 0, MockRequestsResponse(status_code, headers, content) - ) + self.responses.insert(0, MockRequestsResponse(status_code, headers, content)) def _get(self, url, patron=None, headers=None, allowed_response_codes=None): allowed_response_codes = allowed_response_codes or ["2xx", "3xx"] self.requests.append(url) self.request_args.append((patron, headers, allowed_response_codes)) response = self.responses.pop() - return HTTP._process_response(url, response, allowed_response_codes=allowed_response_codes) + return HTTP._process_response( + url, response, allowed_response_codes=allowed_response_codes + ) class ODLExpiredItemsReaper(IdentifierSweepMonitor): @@ -1769,7 +1873,9 @@ def __init__(self, _db, collection): def process_item(self, identifier): for licensepool in identifier.licensed_through: - remaining_checkouts = 0 # total number of checkouts across all the licenses in the pool + remaining_checkouts = ( + 0 # total number of checkouts across all the licenses in the pool + ) concurrent_checkouts = 0 # number of concurrent checkouts allowed across all the licenses in the pool # 0 is a starting point, @@ -1779,8 +1885,10 @@ def process_item(self, identifier): remaining_checkouts += license_pool_license.remaining_checkouts concurrent_checkouts += license_pool_license.concurrent_checkouts - if remaining_checkouts != licensepool.licenses_owned or \ - concurrent_checkouts != licensepool.licenses_available: + if ( + remaining_checkouts != licensepool.licenses_owned + or concurrent_checkouts != licensepool.licenses_available + ): licenses_owned = max(remaining_checkouts, 0) licenses_available = max(concurrent_checkouts, 0) diff --git a/api/onix.py b/api/onix.py index 1d44de46dc..ed27a816ec 100644 --- a/api/onix.py +++ b/api/onix.py @@ -1,46 +1,49 @@ import logging +from enum import Enum import dateutil.parser -from enum import Enum from lxml import etree from core.classifier import Classifier from core.metadata_layer import ( - Metadata, - IdentifierData, - SubjectData, + CirculationData, ContributorData, + IdentifierData, LinkData, - CirculationData) + Metadata, + SubjectData, +) from core.model import ( Classification, - Identifier, Contributor, + EditionConstants, Hyperlink, + Identifier, + LicensePool, Representation, Subject, - LicensePool, EditionConstants) +) from core.util.datetime_helpers import to_utc from core.util.xmlparser import XMLParser class UsageStatus(Enum): - UNLIMITED = '01' - LIMITED = '02' - PROHIBITED = '03' + UNLIMITED = "01" + LIMITED = "02" + PROHIBITED = "03" class UsageUnit(Enum): - COPIES = '01' - CHARACTERS = '02' - WORDS = '03' - PAGES = '04' - PERCENTAGE = '05' - DEVICES = '06' - CONCURRENT_USERS = '07' - PERCENTAGE_PER_TIME_PERIOD = '08' - DAYS = '09' - TIMES = '10' + COPIES = "01" + CHARACTERS = "02" + WORDS = "03" + PAGES = "04" + PERCENTAGE = "05" + DEVICES = "06" + CONCURRENT_USERS = "07" + PERCENTAGE_PER_TIME_PERIOD = "08" + DAYS = "09" + TIMES = "10" class ONIXExtractor(object): @@ -49,139 +52,139 @@ class ONIXExtractor(object): # TODO: '20' indicates a semicolon-separated list of freeform tags, # which could also be useful. SUBJECT_TYPES = { - '01': Classifier.DDC, - '03': Classifier.LCC, - '04': Classifier.LCSH, - '10': Classifier.BISAC, - '12': Classifier.BIC, + "01": Classifier.DDC, + "03": Classifier.LCC, + "04": Classifier.LCSH, + "10": Classifier.BISAC, + "12": Classifier.BIC, } AUDIENCE_TYPES = { - '01': Classifier.AUDIENCE_ADULT, # General/trade for adult audience - '02': Classifier.AUDIENCE_CHILDREN, # (not for educational purpose) - '03': Classifier.AUDIENCE_YOUNG_ADULT, # (not for educational purpose) - '04': Classifier.AUDIENCE_CHILDREN, # Primary and secondary/elementary and high school - '05': Classifier.AUDIENCE_ADULT, # College/higher education - '06': Classifier.AUDIENCE_ADULT, # Professional and scholarly - '07': Classifier.AUDIENCE_ADULT, # ESL - '08': Classifier.AUDIENCE_ADULT, # Adult education - '09': Classifier.AUDIENCE_ADULT, # Second language teaching other than English + "01": Classifier.AUDIENCE_ADULT, # General/trade for adult audience + "02": Classifier.AUDIENCE_CHILDREN, # (not for educational purpose) + "03": Classifier.AUDIENCE_YOUNG_ADULT, # (not for educational purpose) + "04": Classifier.AUDIENCE_CHILDREN, # Primary and secondary/elementary and high school + "05": Classifier.AUDIENCE_ADULT, # College/higher education + "06": Classifier.AUDIENCE_ADULT, # Professional and scholarly + "07": Classifier.AUDIENCE_ADULT, # ESL + "08": Classifier.AUDIENCE_ADULT, # Adult education + "09": Classifier.AUDIENCE_ADULT, # Second language teaching other than English } CONTRIBUTOR_TYPES = { - 'A01': Contributor.AUTHOR_ROLE, - 'A02': Contributor.AUTHOR_ROLE, # 'With or as told to' - 'A03': Contributor.AUTHOR_ROLE, # Screenplay author - 'A04': Contributor.LYRICIST_ROLE, # Libretto author for an opera - 'A05': Contributor.LYRICIST_ROLE, - 'A06': Contributor.COMPOSER_ROLE, - 'A07': Contributor.ILLUSTRATOR_ROLE, # Visual artist who is the primary creator of the work - 'A08': Contributor.PHOTOGRAPHER_ROLE, - 'A09': Contributor.AUTHOR_ROLE, # 'Created by' - 'A10': Contributor.UNKNOWN_ROLE, # 'From an idea by' - 'A11': Contributor.DESIGNER_ROLE, - 'A12': Contributor.ILLUSTRATOR_ROLE, - 'A13': Contributor.PHOTOGRAPHER_ROLE, - 'A14': Contributor.AUTHOR_ROLE, # Author of the text for a work that is primarily photos or illustrations - 'A15': Contributor.INTRODUCTION_ROLE, # Preface author - 'A16': Contributor.UNKNOWN_ROLE, # Prologue author - 'A17': Contributor.UNKNOWN_ROLE, # Summary author - 'A18': Contributor.UNKNOWN_ROLE, # Supplement author - 'A19': Contributor.AFTERWORD_ROLE, # Afterword author - 'A20': Contributor.UNKNOWN_ROLE, # Author of notes or annotations - 'A21': Contributor.UNKNOWN_ROLE, # Author of commentary on main text - 'A22': Contributor.UNKNOWN_ROLE, # Epilogue author - 'A23': Contributor.FOREWORD_ROLE, - 'A24': Contributor.INTRODUCTION_ROLE, - 'A25': Contributor.UNKNOWN_ROLE, # Author/compiler of footnotes - 'A26': Contributor.UNKNOWN_ROLE, # Author of memoir accompanying main text - 'A27': Contributor.UNKNOWN_ROLE, # Person who carried out experiments reported in the text - 'A29': Contributor.INTRODUCTION_ROLE, # Author of introduction and notes - 'A30': Contributor.UNKNOWN_ROLE, # Writer of computer programs ancillary to the text - 'A31': Contributor.LYRICIST_ROLE, # 'Book and lyrics by' - 'A32': Contributor.CONTRIBUTOR_ROLE, # 'Contributions by' - 'A33': Contributor.UNKNOWN_ROLE, # Appendix author - 'A34': Contributor.UNKNOWN_ROLE, # Compiler of index - 'A35': Contributor.ARTIST_ROLE, # 'Drawings by' - 'A36': Contributor.ARTIST_ROLE, # Cover artist - 'A37': Contributor.UNKNOWN_ROLE, # Responsible for preliminary work on which the work is based - 'A38': Contributor.UNKNOWN_ROLE, # Author of the first edition who is not an author of the current edition - 'A39': Contributor.UNKNOWN_ROLE, # 'Maps by' - 'A40': Contributor.ARTIST_ROLE, # 'Inked or colored by' - 'A41': Contributor.UNKNOWN_ROLE, # 'Paper engineering by' - 'A42': Contributor.UNKNOWN_ROLE, # 'Continued by' - 'A43': Contributor.UNKNOWN_ROLE, # Interviewer - 'A44': Contributor.UNKNOWN_ROLE, # Interviewee - 'A45': Contributor.AUTHOR_ROLE, # Writer of dialogue, captions in a comic book - 'A46': Contributor.ARTIST_ROLE, # Inker - 'A47': Contributor.ARTIST_ROLE, # Colorist - 'A48': Contributor.ARTIST_ROLE, # Letterer - 'A51': Contributor.UNKNOWN_ROLE, # 'Research by' - 'A99': Contributor.UNKNOWN_ROLE, # 'Other primary creator' - 'B01': Contributor.EDITOR_ROLE, - 'B02': Contributor.EDITOR_ROLE, # 'Revised by' - 'B03': Contributor.UNKNOWN_ROLE, # 'Retold by' - 'B04': Contributor.UNKNOWN_ROLE, # 'Abridged by' - 'B05': Contributor.ADAPTER_ROLE, - 'B06': Contributor.TRANSLATOR_ROLE, - 'B07': Contributor.UNKNOWN_ROLE, # 'As told by' - 'B08': Contributor.TRANSLATOR_ROLE, # With commentary on the translation - 'B09': Contributor.EDITOR_ROLE, # Series editor - 'B10': Contributor.TRANSLATOR_ROLE, # 'Edited and translated by' - 'B11': Contributor.EDITOR_ROLE, # Editor-in-chief - 'B12': Contributor.EDITOR_ROLE, # Guest editor - 'B13': Contributor.EDITOR_ROLE, # Volume editor - 'B14': Contributor.EDITOR_ROLE, # Editorial board member - 'B15': Contributor.EDITOR_ROLE, # 'Editorial coordination by' - 'B16': Contributor.EDITOR_ROLE, # Managing editor - 'B17': Contributor.EDITOR_ROLE, # Founding editor of a serial publication - 'B18': Contributor.EDITOR_ROLE, # 'Prepared for publication by' - 'B19': Contributor.EDITOR_ROLE, # Associate editor - 'B20': Contributor.EDITOR_ROLE, # Consultant editor - 'B21': Contributor.EDITOR_ROLE, # General editor - 'B22': Contributor.UNKNOWN_ROLE, # 'Dramatized by' - 'B23': Contributor.EDITOR_ROLE, # 'General rapporteur' - 'B24': Contributor.EDITOR_ROLE, # Literary editor - 'B25': Contributor.COMPOSER_ROLE, # 'Arranged by (music)' - 'B26': Contributor.EDITOR_ROLE, # Technical editor - 'B27': Contributor.UNKNOWN_ROLE, # Thesis advisor - 'B28': Contributor.UNKNOWN_ROLE, # Thesis examiner - 'B29': Contributor.EDITOR_ROLE, # Scientific editor - 'B30': Contributor.UNKNOWN_ROLE, # Historical advisor - 'B31': Contributor.UNKNOWN_ROLE, # Editor of the first edition who is not an editor of the current edition - 'B99': Contributor.EDITOR_ROLE, # Other type of adaptation or editing - 'C01': Contributor.UNKNOWN_ROLE, # 'Compiled by' - 'C02': Contributor.UNKNOWN_ROLE, # 'Selected by' - 'C03': Contributor.UNKNOWN_ROLE, # 'Non-text material selected by' - 'C04': Contributor.UNKNOWN_ROLE, # 'Curated by' - 'C99': Contributor.UNKNOWN_ROLE, # Other type of compilation - 'D01': Contributor.PRODUCER_ROLE, - 'D02': Contributor.DIRECTOR_ROLE, - 'D03': Contributor.MUSICIAN_ROLE, # Conductor - 'D04': Contributor.UNKNOWN_ROLE, # Choreographer - 'D05': Contributor.DIRECTOR_ROLE, # Other type of direction - 'E01': Contributor.ACTOR_ROLE, - 'E02': Contributor.PERFORMER_ROLE, # Dancer - 'E03': Contributor.NARRATOR_ROLE, # 'Narrator' - 'E04': Contributor.UNKNOWN_ROLE, # Commentator - 'E05': Contributor.PERFORMER_ROLE, # Vocal soloist - 'E06': Contributor.PERFORMER_ROLE, # Instrumental soloist - 'E07': Contributor.NARRATOR_ROLE, # Reader of recorded text, as in an audiobook - 'E08': Contributor.PERFORMER_ROLE, # Name of a musical group in a performing role - 'E09': Contributor.PERFORMER_ROLE, # Speaker - 'E10': Contributor.UNKNOWN_ROLE, # Presenter - 'E99': Contributor.PERFORMER_ROLE, # Other type of performer - 'F01': Contributor.PHOTOGRAPHER_ROLE, # 'Filmed/photographed by' - 'F02': Contributor.EDITOR_ROLE, # 'Editor (film or video)' - 'F99': Contributor.UNKNOWN_ROLE, # Other type of recording - 'Z01': Contributor.UNKNOWN_ROLE, # 'Assisted by' - 'Z02': Contributor.UNKNOWN_ROLE, # 'Honored/dedicated to' - 'Z99': Contributor.UNKNOWN_ROLE, # Other creative responsibility + "A01": Contributor.AUTHOR_ROLE, + "A02": Contributor.AUTHOR_ROLE, # 'With or as told to' + "A03": Contributor.AUTHOR_ROLE, # Screenplay author + "A04": Contributor.LYRICIST_ROLE, # Libretto author for an opera + "A05": Contributor.LYRICIST_ROLE, + "A06": Contributor.COMPOSER_ROLE, + "A07": Contributor.ILLUSTRATOR_ROLE, # Visual artist who is the primary creator of the work + "A08": Contributor.PHOTOGRAPHER_ROLE, + "A09": Contributor.AUTHOR_ROLE, # 'Created by' + "A10": Contributor.UNKNOWN_ROLE, # 'From an idea by' + "A11": Contributor.DESIGNER_ROLE, + "A12": Contributor.ILLUSTRATOR_ROLE, + "A13": Contributor.PHOTOGRAPHER_ROLE, + "A14": Contributor.AUTHOR_ROLE, # Author of the text for a work that is primarily photos or illustrations + "A15": Contributor.INTRODUCTION_ROLE, # Preface author + "A16": Contributor.UNKNOWN_ROLE, # Prologue author + "A17": Contributor.UNKNOWN_ROLE, # Summary author + "A18": Contributor.UNKNOWN_ROLE, # Supplement author + "A19": Contributor.AFTERWORD_ROLE, # Afterword author + "A20": Contributor.UNKNOWN_ROLE, # Author of notes or annotations + "A21": Contributor.UNKNOWN_ROLE, # Author of commentary on main text + "A22": Contributor.UNKNOWN_ROLE, # Epilogue author + "A23": Contributor.FOREWORD_ROLE, + "A24": Contributor.INTRODUCTION_ROLE, + "A25": Contributor.UNKNOWN_ROLE, # Author/compiler of footnotes + "A26": Contributor.UNKNOWN_ROLE, # Author of memoir accompanying main text + "A27": Contributor.UNKNOWN_ROLE, # Person who carried out experiments reported in the text + "A29": Contributor.INTRODUCTION_ROLE, # Author of introduction and notes + "A30": Contributor.UNKNOWN_ROLE, # Writer of computer programs ancillary to the text + "A31": Contributor.LYRICIST_ROLE, # 'Book and lyrics by' + "A32": Contributor.CONTRIBUTOR_ROLE, # 'Contributions by' + "A33": Contributor.UNKNOWN_ROLE, # Appendix author + "A34": Contributor.UNKNOWN_ROLE, # Compiler of index + "A35": Contributor.ARTIST_ROLE, # 'Drawings by' + "A36": Contributor.ARTIST_ROLE, # Cover artist + "A37": Contributor.UNKNOWN_ROLE, # Responsible for preliminary work on which the work is based + "A38": Contributor.UNKNOWN_ROLE, # Author of the first edition who is not an author of the current edition + "A39": Contributor.UNKNOWN_ROLE, # 'Maps by' + "A40": Contributor.ARTIST_ROLE, # 'Inked or colored by' + "A41": Contributor.UNKNOWN_ROLE, # 'Paper engineering by' + "A42": Contributor.UNKNOWN_ROLE, # 'Continued by' + "A43": Contributor.UNKNOWN_ROLE, # Interviewer + "A44": Contributor.UNKNOWN_ROLE, # Interviewee + "A45": Contributor.AUTHOR_ROLE, # Writer of dialogue, captions in a comic book + "A46": Contributor.ARTIST_ROLE, # Inker + "A47": Contributor.ARTIST_ROLE, # Colorist + "A48": Contributor.ARTIST_ROLE, # Letterer + "A51": Contributor.UNKNOWN_ROLE, # 'Research by' + "A99": Contributor.UNKNOWN_ROLE, # 'Other primary creator' + "B01": Contributor.EDITOR_ROLE, + "B02": Contributor.EDITOR_ROLE, # 'Revised by' + "B03": Contributor.UNKNOWN_ROLE, # 'Retold by' + "B04": Contributor.UNKNOWN_ROLE, # 'Abridged by' + "B05": Contributor.ADAPTER_ROLE, + "B06": Contributor.TRANSLATOR_ROLE, + "B07": Contributor.UNKNOWN_ROLE, # 'As told by' + "B08": Contributor.TRANSLATOR_ROLE, # With commentary on the translation + "B09": Contributor.EDITOR_ROLE, # Series editor + "B10": Contributor.TRANSLATOR_ROLE, # 'Edited and translated by' + "B11": Contributor.EDITOR_ROLE, # Editor-in-chief + "B12": Contributor.EDITOR_ROLE, # Guest editor + "B13": Contributor.EDITOR_ROLE, # Volume editor + "B14": Contributor.EDITOR_ROLE, # Editorial board member + "B15": Contributor.EDITOR_ROLE, # 'Editorial coordination by' + "B16": Contributor.EDITOR_ROLE, # Managing editor + "B17": Contributor.EDITOR_ROLE, # Founding editor of a serial publication + "B18": Contributor.EDITOR_ROLE, # 'Prepared for publication by' + "B19": Contributor.EDITOR_ROLE, # Associate editor + "B20": Contributor.EDITOR_ROLE, # Consultant editor + "B21": Contributor.EDITOR_ROLE, # General editor + "B22": Contributor.UNKNOWN_ROLE, # 'Dramatized by' + "B23": Contributor.EDITOR_ROLE, # 'General rapporteur' + "B24": Contributor.EDITOR_ROLE, # Literary editor + "B25": Contributor.COMPOSER_ROLE, # 'Arranged by (music)' + "B26": Contributor.EDITOR_ROLE, # Technical editor + "B27": Contributor.UNKNOWN_ROLE, # Thesis advisor + "B28": Contributor.UNKNOWN_ROLE, # Thesis examiner + "B29": Contributor.EDITOR_ROLE, # Scientific editor + "B30": Contributor.UNKNOWN_ROLE, # Historical advisor + "B31": Contributor.UNKNOWN_ROLE, # Editor of the first edition who is not an editor of the current edition + "B99": Contributor.EDITOR_ROLE, # Other type of adaptation or editing + "C01": Contributor.UNKNOWN_ROLE, # 'Compiled by' + "C02": Contributor.UNKNOWN_ROLE, # 'Selected by' + "C03": Contributor.UNKNOWN_ROLE, # 'Non-text material selected by' + "C04": Contributor.UNKNOWN_ROLE, # 'Curated by' + "C99": Contributor.UNKNOWN_ROLE, # Other type of compilation + "D01": Contributor.PRODUCER_ROLE, + "D02": Contributor.DIRECTOR_ROLE, + "D03": Contributor.MUSICIAN_ROLE, # Conductor + "D04": Contributor.UNKNOWN_ROLE, # Choreographer + "D05": Contributor.DIRECTOR_ROLE, # Other type of direction + "E01": Contributor.ACTOR_ROLE, + "E02": Contributor.PERFORMER_ROLE, # Dancer + "E03": Contributor.NARRATOR_ROLE, # 'Narrator' + "E04": Contributor.UNKNOWN_ROLE, # Commentator + "E05": Contributor.PERFORMER_ROLE, # Vocal soloist + "E06": Contributor.PERFORMER_ROLE, # Instrumental soloist + "E07": Contributor.NARRATOR_ROLE, # Reader of recorded text, as in an audiobook + "E08": Contributor.PERFORMER_ROLE, # Name of a musical group in a performing role + "E09": Contributor.PERFORMER_ROLE, # Speaker + "E10": Contributor.UNKNOWN_ROLE, # Presenter + "E99": Contributor.PERFORMER_ROLE, # Other type of performer + "F01": Contributor.PHOTOGRAPHER_ROLE, # 'Filmed/photographed by' + "F02": Contributor.EDITOR_ROLE, # 'Editor (film or video)' + "F99": Contributor.UNKNOWN_ROLE, # Other type of recording + "Z01": Contributor.UNKNOWN_ROLE, # 'Assisted by' + "Z02": Contributor.UNKNOWN_ROLE, # 'Honored/dedicated to' + "Z99": Contributor.UNKNOWN_ROLE, # Other creative responsibility } PRODUCT_CONTENT_TYPES = { - '10': EditionConstants.BOOK_MEDIUM, # Text (eye-readable) - '01': EditionConstants.AUDIO_MEDIUM # Audiobook + "10": EditionConstants.BOOK_MEDIUM, # Text (eye-readable) + "01": EditionConstants.AUDIO_MEDIUM, # Audiobook } _logger = logging.getLogger(__name__) @@ -198,67 +201,87 @@ def parse(cls, file, data_source_name, default_medium=None): tree = etree.parse(file) root = tree.getroot() - for record in root.findall('product'): - title = parser.text_of_optional_subtag(record, 'descriptivedetail/titledetail/titleelement/b203') + for record in root.findall("product"): + title = parser.text_of_optional_subtag( + record, "descriptivedetail/titledetail/titleelement/b203" + ) if not title: - title_prefix = parser.text_of_optional_subtag(record, 'descriptivedetail/titledetail/titleelement/b030') - title_without_prefix = parser.text_of_optional_subtag(record, 'descriptivedetail/titledetail/titleelement/b031') + title_prefix = parser.text_of_optional_subtag( + record, "descriptivedetail/titledetail/titleelement/b030" + ) + title_without_prefix = parser.text_of_optional_subtag( + record, "descriptivedetail/titledetail/titleelement/b031" + ) if title_prefix and title_without_prefix: title = title_prefix + " " + title_without_prefix - medium = parser.text_of_optional_subtag(record, 'b385') + medium = parser.text_of_optional_subtag(record, "b385") if not medium and default_medium: medium = default_medium else: - medium = cls.PRODUCT_CONTENT_TYPES.get(medium, EditionConstants.BOOK_MEDIUM) + medium = cls.PRODUCT_CONTENT_TYPES.get( + medium, EditionConstants.BOOK_MEDIUM + ) - subtitle = parser.text_of_optional_subtag(record, 'descriptivedetail/titledetail/titleelement/b029') - language = parser.text_of_optional_subtag(record, 'descriptivedetail/language/b252') or "eng" - publisher = parser.text_of_optional_subtag(record, 'publishingdetail/publisher/b081') - imprint = parser.text_of_optional_subtag(record, 'publishingdetail/imprint/b079') + subtitle = parser.text_of_optional_subtag( + record, "descriptivedetail/titledetail/titleelement/b029" + ) + language = ( + parser.text_of_optional_subtag( + record, "descriptivedetail/language/b252" + ) + or "eng" + ) + publisher = parser.text_of_optional_subtag( + record, "publishingdetail/publisher/b081" + ) + imprint = parser.text_of_optional_subtag( + record, "publishingdetail/imprint/b079" + ) if imprint == publisher: imprint = None - publishing_date = parser.text_of_optional_subtag(record, 'publishingdetail/publishingdate/b306') + publishing_date = parser.text_of_optional_subtag( + record, "publishingdetail/publishingdate/b306" + ) issued = None if publishing_date: issued = dateutil.parser.isoparse(publishing_date) if issued.tzinfo is None: cls._logger.warning( - "Publishing date {} does not contain timezone information. Assuming UTC." - .format(publishing_date) + "Publishing date {} does not contain timezone information. Assuming UTC.".format( + publishing_date + ) ) issued = to_utc(issued) - identifier_tags = parser._xpath(record, 'productidentifier') + identifier_tags = parser._xpath(record, "productidentifier") identifiers = [] primary_identifier = None for tag in identifier_tags: type = parser.text_of_subtag(tag, "b221") - if type == '02' or type == '15': - primary_identifier = IdentifierData(Identifier.ISBN, parser.text_of_subtag(tag, 'b244')) + if type == "02" or type == "15": + primary_identifier = IdentifierData( + Identifier.ISBN, parser.text_of_subtag(tag, "b244") + ) identifiers.append(primary_identifier) - subject_tags = parser._xpath(record, 'descriptivedetail/subject') + subject_tags = parser._xpath(record, "descriptivedetail/subject") subjects = [] weight = Classification.TRUSTED_DISTRIBUTOR_WEIGHT for tag in subject_tags: - type = parser.text_of_subtag(tag, 'b067') + type = parser.text_of_subtag(tag, "b067") if type in cls.SUBJECT_TYPES: - b069 = parser.text_of_optional_subtag(tag, 'b069') + b069 = parser.text_of_optional_subtag(tag, "b069") if b069: subjects.append( - SubjectData( - cls.SUBJECT_TYPES[type], - b069, - weight=weight - ) + SubjectData(cls.SUBJECT_TYPES[type], b069, weight=weight) ) - audience_tags = parser._xpath(record, 'descriptivedetail/audience/b204') + audience_tags = parser._xpath(record, "descriptivedetail/audience/b204") audiences = [] for tag in audience_tags: if tag.text in cls.AUDIENCE_TYPES: @@ -266,109 +289,132 @@ def parse(cls, file, data_source_name, default_medium=None): SubjectData( Subject.FREEFORM_AUDIENCE, cls.AUDIENCE_TYPES[tag.text], - weight=weight + weight=weight, ) ) # TODO: We don't handle ONIX unnamed and alternatively named contributors. - contributor_tags = parser._xpath(record, 'descriptivedetail/contributor') + contributor_tags = parser._xpath(record, "descriptivedetail/contributor") contributors = [] for tag in contributor_tags: - type = parser.text_of_subtag(tag, 'b035') + type = parser.text_of_subtag(tag, "b035") if type in cls.CONTRIBUTOR_TYPES: - person_name_display = parser.text_of_optional_subtag(tag, 'b036') - person_name_inverted = parser.text_of_optional_subtag(tag, 'b037') - corp_name_display = parser.text_of_optional_subtag(tag, 'b047') - corp_name_inverted = parser.text_of_optional_subtag(tag, 'x443') - bio = parser.text_of_optional_subtag(tag, 'b044') + person_name_display = parser.text_of_optional_subtag(tag, "b036") + person_name_inverted = parser.text_of_optional_subtag(tag, "b037") + corp_name_display = parser.text_of_optional_subtag(tag, "b047") + corp_name_inverted = parser.text_of_optional_subtag(tag, "x443") + bio = parser.text_of_optional_subtag(tag, "b044") family_name = None if person_name_display or person_name_inverted: display_name = person_name_display sort_name = person_name_inverted - family_name = parser.text_of_optional_subtag(tag, 'b040') + family_name = parser.text_of_optional_subtag(tag, "b040") elif corp_name_display or corp_name_inverted: display_name = corp_name_display # Sort form for corporate name might just be the display name sort_name = corp_name_inverted or corp_name_display else: sort_name = display_name = None - contributors.append(ContributorData(sort_name=sort_name, - display_name=display_name, - family_name=family_name, - roles=[cls.CONTRIBUTOR_TYPES[type]], - biography=bio)) + contributors.append( + ContributorData( + sort_name=sort_name, + display_name=display_name, + family_name=family_name, + roles=[cls.CONTRIBUTOR_TYPES[type]], + biography=bio, + ) + ) - collateral_tags = parser._xpath(record, 'collateraldetail/textcontent') + collateral_tags = parser._xpath(record, "collateraldetail/textcontent") links = [] for tag in collateral_tags: - type = parser.text_of_subtag(tag, 'x426') + type = parser.text_of_subtag(tag, "x426") # TODO: '03' is the summary in the example I'm testing, but that # might not be generally true. - if type == '03': - text = parser.text_of_subtag(tag, 'd104') - links.append(LinkData(rel=Hyperlink.DESCRIPTION, - media_type=Representation.TEXT_HTML_MEDIA_TYPE, - content=text)) + if type == "03": + text = parser.text_of_subtag(tag, "d104") + links.append( + LinkData( + rel=Hyperlink.DESCRIPTION, + media_type=Representation.TEXT_HTML_MEDIA_TYPE, + content=text, + ) + ) - usage_constraint_tags = parser._xpath(record, 'descriptivedetail/epubusageconstraint') + usage_constraint_tags = parser._xpath( + record, "descriptivedetail/epubusageconstraint" + ) licenses_owned = LicensePool.UNLIMITED_ACCESS if usage_constraint_tags: - cls._logger.debug('Found {0} EpubUsageConstraint tags'.format(len(usage_constraint_tags))) + cls._logger.debug( + "Found {0} EpubUsageConstraint tags".format( + len(usage_constraint_tags) + ) + ) for usage_constraint_tag in usage_constraint_tags: - usage_status = parser.text_of_subtag(usage_constraint_tag, 'x319') + usage_status = parser.text_of_subtag(usage_constraint_tag, "x319") - cls._logger.debug('EpubUsageStatus: {0}'.format(usage_status)) + cls._logger.debug("EpubUsageStatus: {0}".format(usage_status)) if usage_status == UsageStatus.PROHIBITED.value: - raise Exception('The content is prohibited') + raise Exception("The content is prohibited") elif usage_status == UsageStatus.LIMITED.value: - usage_limit_tags = parser._xpath(record, 'descriptivedetail/epubusageconstraint/epubusagelimit') + usage_limit_tags = parser._xpath( + record, "descriptivedetail/epubusageconstraint/epubusagelimit" + ) - cls._logger.debug('Found {0} EpubUsageLimit tags'.format(len(usage_limit_tags))) + cls._logger.debug( + "Found {0} EpubUsageLimit tags".format(len(usage_limit_tags)) + ) if not usage_limit_tags: continue [usage_limit_tag] = usage_limit_tags - usage_unit = parser.text_of_subtag(usage_limit_tag, 'x321') + usage_unit = parser.text_of_subtag(usage_limit_tag, "x321") - cls._logger.debug('EpubUsageUnit: {0}'.format(usage_unit)) + cls._logger.debug("EpubUsageUnit: {0}".format(usage_unit)) - if usage_unit == UsageUnit.COPIES.value or usage_status == UsageUnit.CONCURRENT_USERS.value: - quantity_limit = parser.text_of_subtag(usage_limit_tag, 'x320') + if ( + usage_unit == UsageUnit.COPIES.value + or usage_status == UsageUnit.CONCURRENT_USERS.value + ): + quantity_limit = parser.text_of_subtag(usage_limit_tag, "x320") - cls._logger.debug('Quantity: {0}'.format(quantity_limit)) + cls._logger.debug("Quantity: {0}".format(quantity_limit)) if licenses_owned == LicensePool.UNLIMITED_ACCESS: licenses_owned = 0 licenses_owned += int(quantity_limit) - metadata_records.append(Metadata( - data_source=data_source_name, - title=title, - subtitle=subtitle, - language=language, - medium=medium, - publisher=publisher, - imprint=imprint, - issued=issued, - primary_identifier=primary_identifier, - identifiers=identifiers, - subjects=subjects, - contributors=contributors, - links=links, - circulation=CirculationData( - data_source_name, - primary_identifier, - licenses_owned=licenses_owned, - licenses_available=licenses_owned, - licenses_reserved=0, - patrons_in_hold_queue=0 + metadata_records.append( + Metadata( + data_source=data_source_name, + title=title, + subtitle=subtitle, + language=language, + medium=medium, + publisher=publisher, + imprint=imprint, + issued=issued, + primary_identifier=primary_identifier, + identifiers=identifiers, + subjects=subjects, + contributors=contributors, + links=links, + circulation=CirculationData( + data_source_name, + primary_identifier, + licenses_owned=licenses_owned, + licenses_available=licenses_owned, + licenses_reserved=0, + patrons_in_hold_queue=0, + ), ) - )) + ) return metadata_records diff --git a/api/opds.py b/api/opds.py index b51eda3840..8c8416b4c7 100644 --- a/api/opds.py +++ b/api/opds.py @@ -1,27 +1,27 @@ -import urllib.request, urllib.parse, urllib.error import copy import logging +import urllib.error +import urllib.parse +import urllib.request +import uuid +from collections import defaultdict + from flask import url_for from lxml import etree -from collections import defaultdict -import uuid from sqlalchemy.orm import lazyload +from api.lanes import ( + CrawlableCollectionBasedLane, + CrawlableCustomListBasedLane, + DynamicLane, +) +from core.analytics import Analytics +from core.app_server import cdn_url_for from core.cdn import cdnify from core.classifier import Classifier -from core.entrypoint import ( - EverythingEntryPoint, -) +from core.entrypoint import EverythingEntryPoint from core.external_search import WorkSearchResult -from core.opds import ( - Annotator, - AcquisitionFeed, - UnfulfillableWork, -) -from core.util.flask_util import OPDSFeedResponse -from core.util.opds_writer import ( - OPDSFeed, -) +from core.lane import Lane, WorkList from core.model import ( CirculationEvent, ConfigurationSetting, @@ -29,6 +29,7 @@ CustomList, DataSource, DeliveryMechanism, + Edition, Hold, Identifier, LicensePool, @@ -37,36 +38,29 @@ Patron, Session, Work, - Edition, -) -from core.lane import ( - Lane, - WorkList, ) +from core.opds import AcquisitionFeed, Annotator, UnfulfillableWork from core.util.datetime_helpers import from_timestamp -from api.lanes import ( - DynamicLane, - CrawlableCustomListBasedLane, - CrawlableCollectionBasedLane, -) -from core.app_server import cdn_url_for +from core.util.flask_util import OPDSFeedResponse +from core.util.opds_writer import OPDSFeed from .adobe_vendor_id import AuthdataUtility from .annotations import AnnotationWriter from .circulation import BaseCirculationAPI -from .config import ( - CannotLoadConfiguration, - Configuration, -) +from .config import CannotLoadConfiguration, Configuration from .novelist import NoveListAPI -from core.analytics import Analytics -class CirculationManagerAnnotator(Annotator): - def __init__(self, lane, - active_loans_by_work={}, active_holds_by_work={}, - active_fulfillments_by_work={}, hidden_content_types=[], - test_mode=False): +class CirculationManagerAnnotator(Annotator): + def __init__( + self, + lane, + active_loans_by_work={}, + active_holds_by_work={}, + active_fulfillments_by_work={}, + hidden_content_types=[], + test_mode=False, + ): if lane: logger_name = "Circulation Manager Annotator for %s" % lane.display_name else: @@ -91,7 +85,12 @@ def is_work_entry_solo(self, work): :rtype: bool """ return any( - work in x for x in (self.active_loans_by_work, self.active_holds_by_work, self.active_fulfillments_by_work) + work in x + for x in ( + self.active_loans_by_work, + self.active_holds_by_work, + self.active_fulfillments_by_work, + ) ) def _lane_identifier(self, lane): @@ -112,7 +111,7 @@ def url_for(self, *args, **kwargs): if self.test_mode: new_kwargs = {} for k, v in list(kwargs.items()): - if not k.startswith('_'): + if not k.startswith("_"): new_kwargs[k] = v return self.test_url_for(False, *args, **new_kwargs) else: @@ -128,26 +127,32 @@ def test_url_for(self, cdn=False, *args, **kwargs): # Generate a plausible-looking URL that doesn't depend on Flask # being set up. if cdn: - host = 'cdn' + host = "cdn" else: - host = 'host' + host = "host" url = ("http://%s/" % host) + "/".join(args) - connector = '?' + connector = "?" for k, v in sorted(kwargs.items()): if v is None: - v = '' + v = "" v = urllib.parse.quote(str(v)) k = urllib.parse.quote(str(k)) url += connector + "%s=%s" % (k, v) - connector = '&' + connector = "&" return url def facet_url(self, facets): return self.feed_url(self.lane, facets=facets, default_route=self.facet_view) - def feed_url(self, lane, facets=None, pagination=None, default_route='feed', extra_kwargs=None): - if (isinstance(lane, WorkList) and - hasattr(lane, 'url_arguments')): + def feed_url( + self, + lane, + facets=None, + pagination=None, + default_route="feed", + extra_kwargs=None, + ): + if isinstance(lane, WorkList) and hasattr(lane, "url_arguments"): route, kwargs = lane.url_arguments else: route = default_route @@ -163,12 +168,16 @@ def feed_url(self, lane, facets=None, pagination=None, default_route='feed', ext def navigation_url(self, lane): return self.cdn_url_for( - "navigation_feed", lane_identifier=self._lane_identifier(lane), - library_short_name=lane.library.short_name, _external=True) + "navigation_feed", + lane_identifier=self._lane_identifier(lane), + library_short_name=lane.library.short_name, + _external=True, + ) def active_licensepool_for(self, work): - loan = (self.active_loans_by_work.get(work) or - self.active_holds_by_work.get(work)) + loan = self.active_loans_by_work.get(work) or self.active_holds_by_work.get( + work + ) if loan: # The active license pool is the one associated with # the loan/hold. @@ -176,8 +185,7 @@ def active_licensepool_for(self, work): else: # There is no active loan. Use the default logic for # determining the active license pool. - return super( - CirculationManagerAnnotator, self).active_licensepool_for(work) + return super(CirculationManagerAnnotator, self).active_licensepool_for(work) def visible_delivery_mechanisms(self, licensepool): """Filter the given `licensepool`'s LicensePoolDeliveryMechanisms @@ -193,14 +201,16 @@ def visible_delivery_mechanisms(self, licensepool): continue yield lpdm - def annotate_work_entry(self, work, active_license_pool, edition, identifier, feed, entry, updated=None): + def annotate_work_entry( + self, work, active_license_pool, edition, identifier, feed, entry, updated=None + ): # If ElasticSearch included a more accurate last_update_time, # use it instead of Work.last_update_time updated = work.last_update_time if isinstance(work, WorkSearchResult): # Elasticsearch puts this field in a list, but we've set it up # so there will be at most one value. - last_updates = getattr(work._hit, 'last_update', []) + last_updates = getattr(work._hit, "last_update", []) if last_updates: # last_update is seconds-since epoch; convert to UTC datetime. updated = from_timestamp(last_updates[0]) @@ -222,17 +232,28 @@ def annotate_work_entry(self, work, active_license_pool, edition, identifier, fe # Now we need to generate a tag for every delivery mechanism # that has well-defined media types. link_tags = self.acquisition_links( - active_license_pool, active_loan, active_hold, active_fulfillment, - feed, identifier + active_license_pool, + active_loan, + active_hold, + active_fulfillment, + feed, + identifier, ) for tag in link_tags: entry.append(tag) def acquisition_links( - self, active_license_pool, active_loan, active_hold, - active_fulfillment, feed, identifier, can_hold=True, - can_revoke_hold=True, set_mechanism_at_borrow=False, - direct_fulfillment_delivery_mechanisms=[] + self, + active_license_pool, + active_loan, + active_hold, + active_fulfillment, + feed, + identifier, + can_hold=True, + can_revoke_hold=True, + set_mechanism_at_borrow=False, + direct_fulfillment_delivery_mechanisms=[], ): """Generate a number of tags that enumerate all acquisition methods. @@ -270,7 +291,9 @@ def acquisition_links( # add a link to revoke it. revoke_links = [] if can_revoke: - revoke_links.append(self.revoke_link(active_license_pool, active_loan, active_hold)) + revoke_links.append( + self.revoke_link(active_license_pool, active_loan, active_hold) + ) # Add next-step information for every useful delivery # mechanism. @@ -278,9 +301,7 @@ def acquisition_links( if can_borrow: # Borrowing a book gives you an OPDS entry that gives you # fulfillment links for every visible delivery mechanism. - visible_mechanisms = self.visible_delivery_mechanisms( - active_license_pool - ) + visible_mechanisms = self.visible_delivery_mechanisms(active_license_pool) if set_mechanism_at_borrow and active_license_pool: # The ebook distributor requires that the delivery # mechanism be set at the point of checkout. This means @@ -288,9 +309,7 @@ def acquisition_links( for mechanism in visible_mechanisms: borrow_links.append( self.borrow_link( - active_license_pool, - mechanism, [mechanism], - active_hold + active_license_pool, mechanism, [mechanism], active_hold ) ) elif active_license_pool: @@ -302,9 +321,7 @@ def acquisition_links( # will be set at the point of fulfillment. borrow_links.append( self.borrow_link( - active_license_pool, - None, visible_mechanisms, - active_hold + active_license_pool, None, visible_mechanisms, active_hold ) ) @@ -326,7 +343,8 @@ def acquisition_links( url = active_fulfillment.content_link rel = OPDSFeed.ACQUISITION_REL link_tag = AcquisitionFeed.acquisition_link( - rel=rel, href=url, types=[type], active_loan=active_loan) + rel=rel, href=url, types=[type], active_loan=active_loan + ) fulfill_links.append(link_tag) elif active_loan and active_loan.fulfillment: @@ -340,26 +358,25 @@ def acquisition_links( # they already chose it and they're stuck with it. for lpdm in active_license_pool.delivery_mechanisms: - if lpdm is active_loan.fulfillment or lpdm.delivery_mechanism.is_streaming: + if ( + lpdm is active_loan.fulfillment + or lpdm.delivery_mechanism.is_streaming + ): fulfill_links.append( self.fulfill_link( active_license_pool, active_loan, - lpdm.delivery_mechanism + lpdm.delivery_mechanism, ) ) else: # The delivery mechanism for this loan has not been # set. There is one fulfill link for every visible # delivery mechanism. - for lpdm in self.visible_delivery_mechanisms( - active_license_pool - ): + for lpdm in self.visible_delivery_mechanisms(active_license_pool): fulfill_links.append( self.fulfill_link( - active_license_pool, - active_loan, - lpdm.delivery_mechanism + active_license_pool, active_loan, lpdm.delivery_mechanism ) ) @@ -377,7 +394,7 @@ def acquisition_links( active_license_pool, active_loan, lpdm.delivery_mechanism, - rel=OPDSFeed.OPEN_ACCESS_REL + rel=OPDSFeed.OPEN_ACCESS_REL, ) direct_fulfill.attrib.update(self.rights_attributes(lpdm)) open_access_links.append(direct_fulfill) @@ -387,26 +404,41 @@ def acquisition_links( if active_license_pool and active_license_pool.open_access: for lpdm in active_license_pool.delivery_mechanisms: if lpdm.resource: - open_access_links.append(self.open_access_link(active_license_pool, lpdm)) + open_access_links.append( + self.open_access_link(active_license_pool, lpdm) + ) - return [x for x in borrow_links + fulfill_links + open_access_links + revoke_links - if x is not None] + return [ + x + for x in borrow_links + fulfill_links + open_access_links + revoke_links + if x is not None + ] def revoke_link(self, active_license_pool, active_loan, active_hold): return None - def borrow_link(self, active_license_pool, - borrow_mechanism, fulfillment_mechanisms, active_hold=None): + def borrow_link( + self, + active_license_pool, + borrow_mechanism, + fulfillment_mechanisms, + active_hold=None, + ): return None - def fulfill_link(self, license_pool, active_loan, delivery_mechanism, - rel=OPDSFeed.ACQUISITION_REL): + def fulfill_link( + self, + license_pool, + active_loan, + delivery_mechanism, + rel=OPDSFeed.ACQUISITION_REL, + ): return None def open_access_link(self, pool, lpdm): _db = Session.object_session(lpdm) url = cdnify(lpdm.resource.url) - kw = dict(rel=OPDSFeed.OPEN_ACCESS_REL, type='') + kw = dict(rel=OPDSFeed.OPEN_ACCESS_REL, type="") # Start off assuming that the URL associated with the # LicensePoolDeliveryMechanism's Resource is the URL we should @@ -417,9 +449,9 @@ def open_access_link(self, pool, lpdm): rep = lpdm.resource.representation if rep: if rep.media_type: - kw['type'] = rep.media_type + kw["type"] = rep.media_type href = rep.public_url - kw['href'] = cdnify(href) + kw["href"] = cdnify(href) link_tag = AcquisitionFeed.link(**kw) link_tag.attrib.update(self.rights_attributes(lpdm)) always_available = OPDSFeed.makeelement( @@ -437,7 +469,7 @@ def rights_attributes(self, lpdm): if not lpdm or not lpdm.rights_status or not lpdm.rights_status.uri: return {} rights_attr = "{%s}rights" % OPDSFeed.DCTERMS_NS - return {rights_attr : lpdm.rights_status.uri } + return {rights_attr: lpdm.rights_status.uri} @classmethod def _single_entry_response( @@ -459,8 +491,7 @@ def _single_entry_response( """ if not work: return feed_class( - _db, title="Unknown work", url=url, works=[], - annotator=annotator + _db, title="Unknown work", url=url, works=[], annotator=annotator ).as_error_response() # This method is generally used for reporting the results of @@ -471,8 +502,8 @@ def _single_entry_response( # specific to the authenticated client. The client should # cache this document for a while, but no one else should # cache it. - response_kwargs.setdefault('max_age', 30*60) - response_kwargs.setdefault('private', True) + response_kwargs.setdefault("max_age", 30 * 60) + response_kwargs.setdefault("private", True) return feed_class.single_entry(_db, work, annotator, **response_kwargs) @@ -499,14 +530,20 @@ class LibraryAnnotator(CirculationManagerAnnotator): Configuration.HELP_URI, ] - def __init__(self, circulation, lane, library, patron=None, - active_loans_by_work={}, active_holds_by_work={}, - active_fulfillments_by_work={}, - facet_view='feed', - test_mode=False, - top_level_title="All Books", - library_identifies_patrons = True, - facets=None + def __init__( + self, + circulation, + lane, + library, + patron=None, + active_loans_by_work={}, + active_holds_by_work={}, + active_fulfillments_by_work={}, + facet_view="feed", + test_mode=False, + top_level_title="All Books", + library_identifies_patrons=True, + facets=None, ): """Constructor. @@ -523,11 +560,12 @@ def __init__(self, circulation, lane, library, patron=None, require a loan. """ super(LibraryAnnotator, self).__init__( - lane, active_loans_by_work=active_loans_by_work, + lane, + active_loans_by_work=active_loans_by_work, active_holds_by_work=active_holds_by_work, active_fulfillments_by_work=active_fulfillments_by_work, hidden_content_types=self._hidden_content_types(library), - test_mode=test_mode + test_mode=test_mode, ) self.circulation = circulation self.library = library @@ -568,11 +606,11 @@ def top_level_title(self): def permalink_for(self, work, license_pool, identifier): url = self.url_for( - 'permalink', + "permalink", identifier_type=identifier.type, identifier=identifier.identifier, library_short_name=self.library.short_name, - _external=True + _external=True, ) return url, OPDSFeed.ENTRY_TYPE @@ -594,11 +632,13 @@ def groups_url(self, lane, facets=None): def default_lane_url(self, facets=None): return self.groups_url(None, facets=facets) - def feed_url(self, lane, facets=None, pagination=None, default_route='feed'): + def feed_url(self, lane, facets=None, pagination=None, default_route="feed"): extra_kwargs = dict() if self.library: - extra_kwargs['library_short_name']=self.library.short_name - return super(LibraryAnnotator, self).feed_url(lane, facets, pagination, default_route, extra_kwargs) + extra_kwargs["library_short_name"] = self.library.short_name + return super(LibraryAnnotator, self).feed_url( + lane, facets, pagination, default_route, extra_kwargs + ) def search_url(self, lane, query, pagination, facets=None): lane_identifier = self._lane_identifier(lane) @@ -608,9 +648,12 @@ def search_url(self, lane, query, pagination, facets=None): if pagination: kwargs.update(dict(list(pagination.items()))) return self.url_for( - "lane_search", lane_identifier=lane_identifier, + "lane_search", + lane_identifier=lane_identifier, library_short_name=self.library.short_name, - _external=True, **kwargs) + _external=True, + **kwargs + ) def group_uri(self, work, license_pool, identifier): if not work in self.lanes_by_work: @@ -620,25 +663,29 @@ def group_uri(self, work, license_pool, identifier): if not lanes: # I don't think this should ever happen? lane_name = None - url = self.cdn_url_for('acquisition_groups', lane_identifier=None, - library_short_name=self.library.short_name, _external=True) + url = self.cdn_url_for( + "acquisition_groups", + lane_identifier=None, + library_short_name=self.library.short_name, + _external=True, + ) title = "All Books" return url, title lane = lanes[0] self.lanes_by_work[work] = lanes[1:] - lane_name = '' + lane_name = "" show_feed = False if isinstance(lane, dict): - show_feed = lane.get('link_to_list_feed', show_feed) - title = lane.get('label', lane_name) - lane = lane['lane'] + show_feed = lane.get("link_to_list_feed", show_feed) + title = lane.get("label", lane_name) + lane = lane["lane"] if isinstance(lane, str): return lane, lane_name - if hasattr(lane, 'display_name') and not title: + if hasattr(lane, "display_name") and not title: title = lane.display_name if show_feed: @@ -654,10 +701,7 @@ def lane_url(self, lane, facets=None): if lane and isinstance(lane, Lane) and lane.sublanes: url = self.groups_url(lane, facets=facets) - elif lane and ( - isinstance(lane, Lane) - or isinstance(lane, DynamicLane) - ): + elif lane and (isinstance(lane, Lane) or isinstance(lane, DynamicLane)): url = self.feed_url(lane, facets) else: # This lane isn't part of our lane hierarchy. It's probably @@ -666,18 +710,20 @@ def lane_url(self, lane, facets=None): url = self.default_lane_url(facets=facets) return url - def annotate_work_entry(self, work, active_license_pool, edition, identifier, feed, entry): + def annotate_work_entry( + self, work, active_license_pool, edition, identifier, feed, entry + ): # Add a link for reporting problems. feed.add_link_to_entry( entry, - rel='issues', + rel="issues", href=self.url_for( - 'report', + "report", identifier_type=identifier.type, identifier=identifier.identifier, library_short_name=self.library.short_name, - _external=True - ) + _external=True, + ), ) super(LibraryAnnotator, self).annotate_work_entry( @@ -696,32 +742,32 @@ def annotate_work_entry(self, work, active_license_pool, edition, identifier, fe # recommendations, too. feed.add_link_to_entry( entry, - rel='recommendations', + rel="recommendations", type=OPDSFeed.ACQUISITION_FEED_TYPE, - title='Recommended Works', + title="Recommended Works", href=self.url_for( - 'recommendations', + "recommendations", identifier_type=identifier.type, identifier=identifier.identifier, library_short_name=self.library.short_name, - _external=True - ) + _external=True, + ), ) # Add a link for related books if available. if self.related_books_available(work, self.library): feed.add_link_to_entry( entry, - rel='related', + rel="related", type=OPDSFeed.ACQUISITION_FEED_TYPE, - title='Recommended Works', + title="Recommended Works", href=self.url_for( - 'related_books', + "related_books", identifier_type=identifier.type, identifier=identifier.identifier, library_short_name=self.library.short_name, - _external=True - ) + _external=True, + ), ) # Add a link to get a patron's annotations for this book. @@ -731,12 +777,12 @@ def annotate_work_entry(self, work, active_license_pool, edition, identifier, fe rel="http://www.w3.org/ns/oa#annotationService", type=AnnotationWriter.CONTENT_TYPE, href=self.url_for( - 'annotations_for_work', + "annotations_for_work", identifier_type=identifier.type, identifier=identifier.identifier, library_short_name=self.library.short_name, - _external=True - ) + _external=True, + ), ) if Analytics.is_configured(self.library): @@ -744,24 +790,21 @@ def annotate_work_entry(self, work, active_license_pool, edition, identifier, fe entry, rel="http://librarysimplified.org/terms/rel/analytics/open-book", href=self.url_for( - 'track_analytics_event', + "track_analytics_event", identifier_type=identifier.type, identifier=identifier.identifier, event_type=CirculationEvent.OPEN_BOOK, library_short_name=self.library.short_name, - _external=True - ) + _external=True, + ), ) @classmethod def related_books_available(cls, work, library): - """:return: bool asserting whether related books might exist for a particular Work - """ + """:return: bool asserting whether related books might exist for a particular Work""" contributions = work.sort_author and work.sort_author != Edition.UNKNOWN_AUTHOR - return (contributions - or work.series - or NoveListAPI.is_configured(library)) + return contributions or work.series or NoveListAPI.is_configured(library) def language_and_audience_key_from_work(self, work): language_key = work.language @@ -772,8 +815,7 @@ def language_and_audience_key_from_work(self, work): elif work.audience == Classifier.AUDIENCE_YOUNG_ADULT: audiences = Classifier.AUDIENCES_JUVENILE elif work.audience == Classifier.AUDIENCE_ALL_AGES: - audiences = [Classifier.AUDIENCE_CHILDREN, - Classifier.AUDIENCE_ALL_AGES] + audiences = [Classifier.AUDIENCE_CHILDREN, Classifier.AUDIENCE_ALL_AGES] elif work.audience in Classifier.AUDIENCES_ADULT: audiences = list(Classifier.AUDIENCES_NO_RESEARCH) elif work.audience == Classifier.AUDIENCE_RESEARCH: @@ -781,10 +823,10 @@ def language_and_audience_key_from_work(self, work): else: audiences = [] - audience_key=None + audience_key = None if audiences: audience_strings = [urllib.parse.quote_plus(a) for a in sorted(audiences)] - audience_key = ','.join(audience_strings) + audience_key = ",".join(audience_strings) return language_key, audience_key @@ -792,12 +834,12 @@ def add_author_links(self, work, feed, entry): """Find all the tags and add a link to each one that points to the author's other works. """ - author_tag = '{%s}author' % OPDSFeed.ATOM_NS + author_tag = "{%s}author" % OPDSFeed.ATOM_NS author_entries = entry.findall(author_tag) languages, audiences = self.language_and_audience_key_from_work(work) for author_entry in author_entries: - name_tag = '{%s}name' % OPDSFeed.ATOM_NS + name_tag = "{%s}name" % OPDSFeed.ATOM_NS # A database ID would be better than a name, but the # tag was created as part of the work's cached @@ -818,21 +860,21 @@ def add_author_links(self, work, feed, entry): feed.add_link_to_entry( author_entry, - rel='contributor', + rel="contributor", type=OPDSFeed.ACQUISITION_FEED_TYPE, title=contributor_name, href=self.url_for( - 'contributor', + "contributor", contributor_name=contributor_name, languages=languages, audiences=audiences, library_short_name=self.library.short_name, - _external=True - ) + _external=True, + ), ) def add_series_link(self, work, feed, entry): - series_tag = OPDSFeed.schema_('Series') + series_tag = OPDSFeed.schema_("Series") series_entry = entry.find(series_tag) if series_entry is None: @@ -842,14 +884,15 @@ def add_series_link(self, work, feed, entry): work_title = work.title self.log.error( 'add_series_link() called on work %s ("%s"), which has no tag in its OPDS entry.', - work_id, work_title + work_id, + work_title, ) return series_name = work.series languages, audiences = self.language_and_audience_key_from_work(work) href = self.url_for( - 'series', + "series", series_name=series_name, languages=languages, audiences=audiences, @@ -858,10 +901,10 @@ def add_series_link(self, work, feed, entry): ) feed.add_link_to_entry( series_entry, - rel='series', + rel="series", type=OPDSFeed.ACQUISITION_FEED_TYPE, title=series_name, - href=href + href=href, ) def annotate_feed(self, feed, lane): @@ -888,17 +931,18 @@ def annotate_feed(self, feed, lane): search_facets = self.facets search_facet_kwargs.update(dict(list(search_facets.items()))) - lane_identifier = self._lane_identifier(lane) search_url = self.url_for( - 'lane_search', lane_identifier=lane_identifier, + "lane_search", + lane_identifier=lane_identifier, library_short_name=self.library.short_name, - _external=True, **search_facet_kwargs + _external=True, + **search_facet_kwargs ) search_link = dict( rel="search", type="application/opensearchdescription+xml", - href=search_url + href=search_url, ) feed.add_link_to_feed(feed.feed, **search_link) @@ -908,13 +952,23 @@ def annotate_feed(self, feed, lane): shelf_link = dict( rel="http://opds-spec.org/shelf", type=OPDSFeed.ACQUISITION_FEED_TYPE, - href=self.url_for('active_loans', library_short_name=self.library.short_name, _external=True)) + href=self.url_for( + "active_loans", + library_short_name=self.library.short_name, + _external=True, + ), + ) feed.add_link_to_feed(feed.feed, **shelf_link) annotations_link = dict( rel="http://www.w3.org/ns/oa#annotationService", type=AnnotationWriter.CONTENT_TYPE, - href=self.url_for('annotations', library_short_name=self.library.short_name, _external=True)) + href=self.url_for( + "annotations", + library_short_name=self.library.short_name, + _external=True, + ), + ) feed.add_link_to_feed(feed.feed, **annotations_link) if lane and lane.uses_customlists: @@ -929,9 +983,10 @@ def annotate_feed(self, feed, lane): if name: crawlable_url = self.url_for( - "crawlable_list_feed", list_name=name, + "crawlable_list_feed", + list_name=name, library_short_name=self.library.short_name, - _external=True + _external=True, ) crawlable_link = dict( rel="http://opds-spec.org/crawlable", @@ -960,24 +1015,38 @@ def _add_link(l): _add_link(d) navigation_urls = ConfigurationSetting.for_library( - Configuration.WEB_HEADER_LINKS, self.library).json_value + Configuration.WEB_HEADER_LINKS, self.library + ).json_value if navigation_urls: navigation_labels = ConfigurationSetting.for_library( - Configuration.WEB_HEADER_LABELS, self.library).json_value + Configuration.WEB_HEADER_LABELS, self.library + ).json_value for (url, label) in zip(navigation_urls, navigation_labels): - d = dict(href=url, title=label, type="text/html", rel="related", role="navigation") + d = dict( + href=url, + title=label, + type="text/html", + rel="related", + role="navigation", + ) _add_link(d) for type, value in Configuration.help_uris(self.library): d = dict(href=value, rel="help") if type: - d['type'] = type + d["type"] = type _add_link(d) def acquisition_links( - self, active_license_pool, active_loan, active_hold, - active_fulfillment, feed, identifier, - direct_fulfillment_delivery_mechanisms=None, mock_api=None + self, + active_license_pool, + active_loan, + active_hold, + active_fulfillment, + feed, + identifier, + direct_fulfillment_delivery_mechanisms=None, + mock_api=None, ): """Generate one or more tags that can be used to borrow, reserve, or fulfill a book, depending on the state of the book @@ -1005,19 +1074,19 @@ def acquisition_links( vendor who provided this LicensePool. If this is not provided, a live API for that vendor will be used. """ - direct_fulfillment_delivery_mechanisms = direct_fulfillment_delivery_mechanisms or [] + direct_fulfillment_delivery_mechanisms = ( + direct_fulfillment_delivery_mechanisms or [] + ) api = mock_api if not api and self.circulation and active_license_pool: api = self.circulation.api_for_license_pool(active_license_pool) if api: set_mechanism_at_borrow = ( - api.SET_DELIVERY_MECHANISM_AT == BaseCirculationAPI.BORROW_STEP) - if (active_license_pool and not self.identifies_patrons - and not active_loan): + api.SET_DELIVERY_MECHANISM_AT == BaseCirculationAPI.BORROW_STEP + ) + if active_license_pool and not self.identifies_patrons and not active_loan: for lpdm in active_license_pool.delivery_mechanisms: - if api.can_fulfill_without_loan( - None, active_license_pool, lpdm - ): + if api.can_fulfill_without_loan(None, active_license_pool, lpdm): # This title can be fulfilled without an # active loan, so we're going to add an acquisition # link that goes directly to the fulfillment step @@ -1029,27 +1098,46 @@ def acquisition_links( set_mechanism_at_borrow = False return super(LibraryAnnotator, self).acquisition_links( - active_license_pool, active_loan, active_hold, active_fulfillment, - feed, identifier, can_hold=self.library.allow_holds, - can_revoke_hold=(active_hold and (not self.circulation or self.circulation.can_revoke_hold(active_license_pool, active_hold))), + active_license_pool, + active_loan, + active_hold, + active_fulfillment, + feed, + identifier, + can_hold=self.library.allow_holds, + can_revoke_hold=( + active_hold + and ( + not self.circulation + or self.circulation.can_revoke_hold( + active_license_pool, active_hold + ) + ) + ), set_mechanism_at_borrow=set_mechanism_at_borrow, - direct_fulfillment_delivery_mechanisms=direct_fulfillment_delivery_mechanisms + direct_fulfillment_delivery_mechanisms=direct_fulfillment_delivery_mechanisms, ) def revoke_link(self, active_license_pool, active_loan, active_hold): if not self.identifies_patrons: return url = self.url_for( - 'revoke_loan_or_hold', + "revoke_loan_or_hold", license_pool_id=active_license_pool.id, library_short_name=self.library.short_name, - _external=True) + _external=True, + ) kw = dict(href=url, rel=OPDSFeed.REVOKE_LOAN_REL) revoke_link_tag = OPDSFeed.makeelement("link", **kw) return revoke_link_tag - def borrow_link(self, active_license_pool, - borrow_mechanism, fulfillment_mechanisms, active_hold=None): + def borrow_link( + self, + active_license_pool, + borrow_mechanism, + fulfillment_mechanisms, + active_hold=None, + ): if not self.identifies_patrons: return identifier = active_license_pool.identifier @@ -1067,7 +1155,8 @@ def borrow_link(self, active_license_pool, identifier=identifier.identifier, mechanism_id=mechanism_id, library_short_name=self.library.short_name, - _external=True) + _external=True, + ) rel = OPDSFeed.BORROW_REL borrow_link = AcquisitionFeed.link( rel=rel, href=borrow_url, type=OPDSFeed.ENTRY_TYPE @@ -1101,8 +1190,13 @@ def borrow_link(self, active_license_pool, borrow_link.extend(indirect_acquisitions) return borrow_link - def fulfill_link(self, license_pool, active_loan, delivery_mechanism, - rel=OPDSFeed.ACQUISITION_REL): + def fulfill_link( + self, + license_pool, + active_loan, + delivery_mechanism, + rel=OPDSFeed.ACQUISITION_REL, + ): """Create a new fulfillment link. This link may include tags from the OPDS Extensions for DRM. @@ -1110,7 +1204,9 @@ def fulfill_link(self, license_pool, active_loan, delivery_mechanism, if not self.identifies_patrons and rel != OPDSFeed.OPEN_ACCESS_REL: return if isinstance(delivery_mechanism, LicensePoolDeliveryMechanism): - logging.warn("LicensePoolDeliveryMechanism passed into fulfill_link instead of DeliveryMechanism!") + logging.warn( + "LicensePoolDeliveryMechanism passed into fulfill_link instead of DeliveryMechanism!" + ) delivery_mechanism = delivery_mechanism.delivery_mechanism format_types = AcquisitionFeed.format_types(delivery_mechanism) if not format_types: @@ -1121,13 +1217,11 @@ def fulfill_link(self, license_pool, active_loan, delivery_mechanism, license_pool_id=license_pool.id, mechanism_id=delivery_mechanism.id, library_short_name=self.library.short_name, - _external=True + _external=True, ) link_tag = AcquisitionFeed.acquisition_link( - rel=rel, href=fulfill_url, - types=format_types, - active_loan=active_loan + rel=rel, href=fulfill_url, types=format_types, active_loan=active_loan ) children = AcquisitionFeed.license_tags(license_pool, active_loan, None) @@ -1146,13 +1240,14 @@ def open_access_link(self, pool, lpdm): license_pool_id=pool.id, mechanism_id=lpdm.delivery_mechanism.id, library_short_name=self.library.short_name, - _external=True + _external=True, ) link_tag.attrib.update(dict(href=fulfill_url)) return link_tag - def drm_device_registration_tags(self, license_pool, active_loan, - delivery_mechanism): + def drm_device_registration_tags( + self, license_pool, active_loan, delivery_mechanism + ): """Construct OPDS Extensions for DRM tags that explain how to register a device with the DRM server that manages this loan. :param delivery_mechanism: A DeliveryMechanism @@ -1196,10 +1291,15 @@ def adobe_id_tags(self, patron_identifier): try: authdata = AuthdataUtility.from_config(self.library) except CannotLoadConfiguration as e: - logging.error("Cannot load Short Client Token configuration; outgoing OPDS entries will not have DRM autodiscovery support", exc_info=e) + logging.error( + "Cannot load Short Client Token configuration; outgoing OPDS entries will not have DRM autodiscovery support", + exc_info=e, + ) return [] if authdata: - vendor_id, token = authdata.short_client_token_for_patron(patron_identifier) + vendor_id, token = authdata.short_client_token_for_patron( + patron_identifier + ) drm_licensor = OPDSFeed.makeelement("{%s}licensor" % OPDSFeed.DRM_NS) vendor_attr = "{%s}vendor" % OPDSFeed.DRM_NS drm_licensor.attrib[vendor_attr] = vendor_id @@ -1211,9 +1311,13 @@ def adobe_id_tags(self, patron_identifier): # endpoint. See: # https://github.com/NYPL-Simplified/Simplified/wiki/DRM-Device-Management device_list_link = OPDSFeed.makeelement("link") - device_list_link.attrib['rel'] = 'http://librarysimplified.org/terms/drm/rel/devices' - device_list_link.attrib['href'] = self.url_for( - "adobe_drm_devices", library_short_name=self.library.short_name, _external=True + device_list_link.attrib[ + "rel" + ] = "http://librarysimplified.org/terms/drm/rel/devices" + device_list_link.attrib["href"] = self.url_for( + "adobe_drm_devices", + library_short_name=self.library.short_name, + _external=True, ) drm_licensor.append(device_list_link) cached = [drm_licensor] @@ -1228,11 +1332,17 @@ def add_patron(self, feed_obj): return patron_details = {} if self.patron.username: - patron_details["{%s}username" % OPDSFeed.SIMPLIFIED_NS] = self.patron.username + patron_details[ + "{%s}username" % OPDSFeed.SIMPLIFIED_NS + ] = self.patron.username if self.patron.authorization_identifier: - patron_details["{%s}authorizationIdentifier" % OPDSFeed.SIMPLIFIED_NS] = self.patron.authorization_identifier + patron_details[ + "{%s}authorizationIdentifier" % OPDSFeed.SIMPLIFIED_NS + ] = self.patron.authorization_identifier - patron_tag = OPDSFeed.makeelement("{%s}patron" % OPDSFeed.SIMPLIFIED_NS, patron_details) + patron_tag = OPDSFeed.makeelement( + "{%s}patron" % OPDSFeed.SIMPLIFIED_NS, patron_details + ) feed_obj.feed.append(patron_tag) def add_authentication_document_link(self, feed_obj): @@ -1245,64 +1355,88 @@ def add_authentication_document_link(self, feed_obj): # patron authentication at this library. feed_obj.add_link_to_feed( feed_obj.feed, - rel='http://opds-spec.org/auth/document', + rel="http://opds-spec.org/auth/document", href=self.url_for( - 'authentication_document', - library_short_name=self.library.short_name, _external=True - ) + "authentication_document", + library_short_name=self.library.short_name, + _external=True, + ), ) class SharedCollectionAnnotator(CirculationManagerAnnotator): - - def __init__(self, collection, lane, - active_loans_by_work={}, active_holds_by_work={}, - active_fulfillments_by_work={}, - test_mode=False, + def __init__( + self, + collection, + lane, + active_loans_by_work={}, + active_holds_by_work={}, + active_fulfillments_by_work={}, + test_mode=False, ): - super(SharedCollectionAnnotator, self).__init__(lane, active_loans_by_work=active_loans_by_work, - active_holds_by_work=active_holds_by_work, - active_fulfillments_by_work=active_fulfillments_by_work, - test_mode=test_mode) + super(SharedCollectionAnnotator, self).__init__( + lane, + active_loans_by_work=active_loans_by_work, + active_holds_by_work=active_holds_by_work, + active_fulfillments_by_work=active_fulfillments_by_work, + test_mode=test_mode, + ) self.collection = collection def top_level_title(self): return self.collection.name def default_lane_url(self): - return self.feed_url(None, default_route='crawlable_collection_feed') + return self.feed_url(None, default_route="crawlable_collection_feed") def lane_url(self, lane): - return self.feed_url(lane, default_route='crawlable_collection_feed') + return self.feed_url(lane, default_route="crawlable_collection_feed") - def feed_url(self, lane, facets=None, pagination=None, default_route='feed'): + def feed_url(self, lane, facets=None, pagination=None, default_route="feed"): extra_kwargs = dict(collection_name=self.collection.name) - return super(SharedCollectionAnnotator, self).feed_url(lane, facets, pagination, default_route, extra_kwargs) + return super(SharedCollectionAnnotator, self).feed_url( + lane, facets, pagination, default_route, extra_kwargs + ) - def acquisition_links(self, active_license_pool, active_loan, active_hold, active_fulfillment, - feed, identifier): + def acquisition_links( + self, + active_license_pool, + active_loan, + active_hold, + active_fulfillment, + feed, + identifier, + ): """Generate a number of tags that enumerate all acquisition methods.""" links = super(SharedCollectionAnnotator, self).acquisition_links( - active_license_pool, active_loan, active_hold, active_fulfillment, feed, identifier) + active_license_pool, + active_loan, + active_hold, + active_fulfillment, + feed, + identifier, + ) info_links = [] if active_loan: url = self.url_for( - 'shared_collection_loan_info', + "shared_collection_loan_info", collection_name=self.collection.name, loan_id=active_loan.id, - _external=True) - kw = dict(href=url, rel='self') + _external=True, + ) + kw = dict(href=url, rel="self") info_link_tag = OPDSFeed.makeelement("link", **kw) info_links.append(info_link_tag) if active_hold and active_hold: url = self.url_for( - 'shared_collection_hold_info', + "shared_collection_hold_info", collection_name=self.collection.name, hold_id=active_hold.id, - _external=True) - kw = dict(href=url, rel='self') + _external=True, + ) + kw = dict(href=url, rel="self") info_link_tag = OPDSFeed.makeelement("link", **kw) info_links.append(info_link_tag) return links + info_links @@ -1311,24 +1445,31 @@ def revoke_link(self, active_license_pool, active_loan, active_hold): url = None if active_loan: url = self.url_for( - 'shared_collection_revoke_loan', + "shared_collection_revoke_loan", collection_name=self.collection.name, loan_id=active_loan.id, - _external=True) + _external=True, + ) elif active_hold: url = self.url_for( - 'shared_collection_revoke_hold', + "shared_collection_revoke_hold", collection_name=self.collection.name, hold_id=active_hold.id, - _external=True) + _external=True, + ) if url: kw = dict(href=url, rel=OPDSFeed.REVOKE_LOAN_REL) revoke_link_tag = OPDSFeed.makeelement("link", **kw) return revoke_link_tag - def borrow_link(self, active_license_pool, - borrow_mechanism, fulfillment_mechanisms, active_hold=None): + def borrow_link( + self, + active_license_pool, + borrow_mechanism, + fulfillment_mechanisms, + active_hold=None, + ): if active_license_pool.open_access: # No need to borrow from a shared collection when the book # already has an open access link. @@ -1383,11 +1524,18 @@ def borrow_link(self, active_license_pool, borrow_link.extend(indirect_acquisitions) return borrow_link - def fulfill_link(self, license_pool, active_loan, delivery_mechanism, - rel=OPDSFeed.ACQUISITION_REL): + def fulfill_link( + self, + license_pool, + active_loan, + delivery_mechanism, + rel=OPDSFeed.ACQUISITION_REL, + ): """Create a new fulfillment link.""" if isinstance(delivery_mechanism, LicensePoolDeliveryMechanism): - logging.warn("LicensePoolDeliveryMechanism passed into fulfill_link instead of DeliveryMechanism!") + logging.warn( + "LicensePoolDeliveryMechanism passed into fulfill_link instead of DeliveryMechanism!" + ) delivery_mechanism = delivery_mechanism.delivery_mechanism format_types = AcquisitionFeed.format_types(delivery_mechanism) if not format_types: @@ -1398,24 +1546,20 @@ def fulfill_link(self, license_pool, active_loan, delivery_mechanism, collection_name=license_pool.collection.name, loan_id=active_loan.id, mechanism_id=delivery_mechanism.id, - _external=True + _external=True, ) link_tag = AcquisitionFeed.acquisition_link( - rel=rel, href=fulfill_url, - types=format_types, - active_loan=active_loan + rel=rel, href=fulfill_url, types=format_types, active_loan=active_loan ) children = AcquisitionFeed.license_tags(license_pool, active_loan, None) link_tag.extend(children) return link_tag -class LibraryLoanAndHoldAnnotator(LibraryAnnotator): +class LibraryLoanAndHoldAnnotator(LibraryAnnotator): @classmethod - def active_loans_for( - cls, circulation, patron, test_mode=False, **response_kwargs - ): + def active_loans_for(cls, circulation, patron, test_mode=False, **response_kwargs): db = Session.object_session(patron) active_loans_by_work = {} for loan in patron.loans: @@ -1429,10 +1573,17 @@ def active_loans_for( active_holds_by_work[work] = hold annotator = cls( - circulation, None, patron.library, patron, active_loans_by_work, active_holds_by_work, - test_mode=test_mode + circulation, + None, + patron.library, + patron, + active_loans_by_work, + active_holds_by_work, + test_mode=test_mode, + ) + url = annotator.url_for( + "active_loans", library_short_name=patron.library.short_name, _external=True ) - url = annotator.url_for('active_loans', library_short_name=patron.library.short_name, _external=True) works = patron.works_on_loan_or_on_hold() feed_obj = AcquisitionFeed(db, "Active loans and holds", url, works, annotator) @@ -1444,8 +1595,15 @@ def active_loans_for( return response @classmethod - def single_item_feed(cls, circulation, item, fulfillment=None, test_mode=False, - feed_class=AcquisitionFeed, **response_kwargs): + def single_item_feed( + cls, + circulation, + item, + fulfillment=None, + test_mode=False, + feed_class=AcquisitionFeed, + **response_kwargs + ): """Construct a response containing a single OPDS entry representing an active loan or hold. @@ -1495,19 +1653,21 @@ def single_item_feed(cls, circulation, item, fulfillment=None, test_mode=False, active_fulfillments_by_work[work] = fulfillment annotator = cls( - circulation, None, library, + circulation, + None, + library, active_loans_by_work=active_loans_by_work, active_holds_by_work=active_holds_by_work, active_fulfillments_by_work=active_fulfillments_by_work, - test_mode=test_mode + test_mode=test_mode, ) identifier = license_pool.identifier url = annotator.url_for( - 'loan_or_hold_detail', + "loan_or_hold_detail", identifier_type=identifier.type, identifier=identifier.identifier, library_short_name=library.short_name, - _external=True + _external=True, ) return annotator._single_entry_response( _db, work, annotator, url, feed_class, **response_kwargs @@ -1522,7 +1682,7 @@ def drm_device_registration_feed_tags(self, patron): logout, even if there is no active loan that requires one. """ tags = copy.deepcopy(self.adobe_id_tags(patron)) - attr = '{%s}scheme' % OPDSFeed.DRM_NS + attr = "{%s}scheme" % OPDSFeed.DRM_NS for tag in tags: tag.attrib[attr] = "http://librarysimplified.org/terms/drm/scheme/ACS" return tags @@ -1534,9 +1694,9 @@ def user_profile_management_protocol_link(self): for the current patron. """ link = OPDSFeed.makeelement("link") - link.attrib['rel'] = 'http://librarysimplified.org/terms/rel/user-profile' - link.attrib['href'] = self.url_for( - 'patron_profile', library_short_name=self.library.short_name, _external=True + link.attrib["rel"] = "http://librarysimplified.org/terms/rel/user-profile" + link.attrib["href"] = self.url_for( + "patron_profile", library_short_name=self.library.short_name, _external=True ) return link @@ -1544,20 +1704,25 @@ def annotate_feed(self, feed, lane): """Annotate the feed with top-level DRM device registration tags and a link to the User Profile Management Protocol endpoint. """ - super(LibraryLoanAndHoldAnnotator, self).annotate_feed( - feed, lane - ) + super(LibraryLoanAndHoldAnnotator, self).annotate_feed(feed, lane) if self.patron: tags = self.drm_device_registration_feed_tags(self.patron) tags.append(self.user_profile_management_protocol_link) for tag in tags: feed.feed.append(tag) -class SharedCollectionLoanAndHoldAnnotator(SharedCollectionAnnotator): +class SharedCollectionLoanAndHoldAnnotator(SharedCollectionAnnotator): @classmethod - def single_item_feed(cls, collection, item, fulfillment=None, test_mode=False, - feed_class=AcquisitionFeed, **response_kwargs): + def single_item_feed( + cls, + collection, + item, + fulfillment=None, + test_mode=False, + feed_class=AcquisitionFeed, + **response_kwargs + ): """Create an OPDS entry representing a single loan or hold. TODO: This and LibraryLoanAndHoldAnnotator.single_item_feed @@ -1579,25 +1744,23 @@ def single_item_feed(cls, collection, item, fulfillment=None, test_mode=False, active_fulfillments_by_work[work] = fulfillment if isinstance(item, Loan): d = active_loans_by_work - route = 'shared_collection_loan_info' + route = "shared_collection_loan_info" route_kwargs = dict(loan_id=item.id) elif isinstance(item, Hold): d = active_holds_by_work - route = 'shared_collection_hold_info' + route = "shared_collection_hold_info" route_kwargs = dict(hold_id=item.id) d[work] = item annotator = cls( - collection, None, + collection, + None, active_loans_by_work=active_loans_by_work, active_holds_by_work=active_holds_by_work, active_fulfillments_by_work=active_fulfillments_by_work, - test_mode=test_mode + test_mode=test_mode, ) url = annotator.url_for( - route, - collection_name=collection.name, - _external=True, - **route_kwargs + route, collection_name=collection.name, _external=True, **route_kwargs ) return annotator._single_entry_response( _db, work, annotator, url, feed_class, **response_kwargs diff --git a/api/opds_for_distributors.py b/api/opds_for_distributors.py index 3f6eec0b20..e52f763ac4 100644 --- a/api/opds_for_distributors.py +++ b/api/opds_for_distributors.py @@ -1,12 +1,10 @@ import datetime -import feedparser import json + +import feedparser from flask_babel import lazy_gettext as _ -from core.opds_import import ( - OPDSImporter, - OPDSImportMonitor, -) +from core.metadata_layer import FormatData, TimestampData from core.model import ( Collection, Credential, @@ -22,29 +20,23 @@ get_one, get_one_or_create, ) -from core.metadata_layer import ( - FormatData, - TimestampData, -) +from core.opds_import import OPDSImporter, OPDSImportMonitor from core.selftest import HasSelfTests -from .circulation import ( - BaseCirculationAPI, - LoanInfo, - FulfillmentInfo, -) +from core.testing import DatabaseTest, MockRequestsResponse from core.util.datetime_helpers import utc_now from core.util.http import HTTP from core.util.string_helpers import base64 -from core.testing import ( - DatabaseTest, - MockRequestsResponse, -) -from .config import IntegrationException + +from .circulation import BaseCirculationAPI, FulfillmentInfo, LoanInfo from .circulation_exceptions import * +from .config import IntegrationException + class OPDSForDistributorsAPI(BaseCirculationAPI, HasSelfTests): NAME = "OPDS for Distributors" - DESCRIPTION = _("Import books from a distributor that requires authentication to get the OPDS feed and download books.") + DESCRIPTION = _( + "Import books from a distributor that requires authentication to get the OPDS feed and download books." + ) SETTINGS = OPDSImporter.BASE_SETTINGS + [ { @@ -56,7 +48,7 @@ class OPDSForDistributorsAPI(BaseCirculationAPI, HasSelfTests): "key": ExternalIntegration.PASSWORD, "label": _("Library's password or secret key"), "required": True, - } + }, ] # In OPDS For Distributors, all items are gated through the @@ -66,36 +58,34 @@ class OPDSForDistributorsAPI(BaseCirculationAPI, HasSelfTests): # combined with the BEARER_TOKEN scheme, then we should import # titles with that media type... SUPPORTED_MEDIA_TYPES = [ - format for (format, drm) in - DeliveryMechanism.default_client_can_fulfill_lookup + format + for (format, drm) in DeliveryMechanism.default_client_can_fulfill_lookup if drm == (DeliveryMechanism.BEARER_TOKEN) and format is not None ] # ...and we should map requests for delivery of that media type to # the (type, BEARER_TOKEN) DeliveryMechanism. delivery_mechanism_to_internal_format = { - (type, DeliveryMechanism.BEARER_TOKEN): type - for type in SUPPORTED_MEDIA_TYPES + (type, DeliveryMechanism.BEARER_TOKEN): type for type in SUPPORTED_MEDIA_TYPES } def __init__(self, _db, collection): self.collection_id = collection.id self.external_integration_id = collection.external_integration.id - self.data_source_name = collection.external_integration.setting(Collection.DATA_SOURCE_NAME_SETTING).value + self.data_source_name = collection.external_integration.setting( + Collection.DATA_SOURCE_NAME_SETTING + ).value self.username = collection.external_integration.username self.password = collection.external_integration.password self.feed_url = collection.external_account_id self.auth_url = None def external_integration(self, _db): - return get_one(_db, ExternalIntegration, - id=self.external_integration_id) + return get_one(_db, ExternalIntegration, id=self.external_integration_id) def _run_self_tests(self, _db): """Try to get a token.""" - yield self.run_test( - "Negotiate a fulfillment token", self._get_token, _db - ) + yield self.run_test("Negotiate a fulfillment token", self._get_token, _db) def _request_with_timeout(self, method, url, *args, **kwargs): """Wrapper around HTTP.request_with_timeout to be overridden for tests.""" @@ -109,62 +99,88 @@ def _get_token(self, _db): # Keep track of the most recent URL we retrieved for error # reporting purposes. current_url = self.feed_url - response = self._request_with_timeout('GET', current_url) + response = self._request_with_timeout("GET", current_url) if response.status_code != 401: # This feed doesn't require authentication, so # we need to find a link to the authentication document. feed = feedparser.parse(response.content) - links = feed.get('feed', {}).get('links', []) - auth_doc_links = [l for l in links if l['rel'] == "http://opds-spec.org/auth/document"] + links = feed.get("feed", {}).get("links", []) + auth_doc_links = [ + l for l in links if l["rel"] == "http://opds-spec.org/auth/document" + ] if not auth_doc_links: - raise LibraryAuthorizationFailedException("No authentication document link found in %s" % current_url) + raise LibraryAuthorizationFailedException( + "No authentication document link found in %s" % current_url + ) current_url = auth_doc_links[0].get("href") - response = self._request_with_timeout('GET', current_url) + response = self._request_with_timeout("GET", current_url) try: auth_doc = json.loads(response.content) except Exception as e: - raise LibraryAuthorizationFailedException("Could not load authentication document from %s" % current_url) - auth_types = auth_doc.get('authentication', []) - credentials_types = [t for t in auth_types if t['type'] == "http://opds-spec.org/auth/oauth/client_credentials"] + raise LibraryAuthorizationFailedException( + "Could not load authentication document from %s" % current_url + ) + auth_types = auth_doc.get("authentication", []) + credentials_types = [ + t + for t in auth_types + if t["type"] == "http://opds-spec.org/auth/oauth/client_credentials" + ] if not credentials_types: - raise LibraryAuthorizationFailedException("Could not find any credential-based authentication mechanisms in %s" % current_url) + raise LibraryAuthorizationFailedException( + "Could not find any credential-based authentication mechanisms in %s" + % current_url + ) - links = credentials_types[0].get('links', []) + links = credentials_types[0].get("links", []) auth_links = [l for l in links if l.get("rel") == "authenticate"] if not auth_links: - raise LibraryAuthorizationFailedException("Could not find any authentication links in %s" % current_url) + raise LibraryAuthorizationFailedException( + "Could not find any authentication links in %s" % current_url + ) self.auth_url = auth_links[0].get("href") def refresh(credential): headers = dict() - auth_header = "Basic %s" % base64.b64encode("%s:%s" % (self.username, self.password)) - headers['Authorization'] = auth_header - headers['Content-Type'] = "application/x-www-form-urlencoded" - body = dict(grant_type='client_credentials') - token_response = self._request_with_timeout('POST', self.auth_url, data=body, headers=headers) + auth_header = "Basic %s" % base64.b64encode( + "%s:%s" % (self.username, self.password) + ) + headers["Authorization"] = auth_header + headers["Content-Type"] = "application/x-www-form-urlencoded" + body = dict(grant_type="client_credentials") + token_response = self._request_with_timeout( + "POST", self.auth_url, data=body, headers=headers + ) token = json.loads(token_response.content) access_token = token.get("access_token") expires_in = token.get("expires_in") if not access_token or not expires_in: raise LibraryAuthorizationFailedException( - "Document retrieved from %s is not a bearer token: %s" % ( + "Document retrieved from %s is not a bearer token: %s" + % ( # Response comes in as a byte string. - self.auth_url, token_response.content.decode("utf-8") + self.auth_url, + token_response.content.decode("utf-8"), ) ) credential.credential = access_token expires_in = expires_in # We'll avoid edge cases by assuming the token expires 75% # into its useful lifetime. - credential.expires = utc_now() + datetime.timedelta(seconds=expires_in*0.75) - return Credential.lookup(_db, self.data_source_name, - "OPDS For Distributors Bearer Token", - patron=None, - refresher_method=refresh, - ) + credential.expires = utc_now() + datetime.timedelta( + seconds=expires_in * 0.75 + ) + + return Credential.lookup( + _db, + self.data_source_name, + "OPDS For Distributors Bearer Token", + patron=None, + refresher_method=refresh, + ) def can_fulfill_without_loan(self, patron, licensepool, lpdm): """Since OPDS For Distributors delivers books to the library rather @@ -179,9 +195,7 @@ def can_fulfill_without_loan(self, patron, licensepool, lpdm): if not lpdm or not lpdm.delivery_mechanism: return False drm_scheme = lpdm.delivery_mechanism.drm_scheme - if drm_scheme in ( - DeliveryMechanism.NO_DRM, DeliveryMechanism.BEARER_TOKEN - ): + if drm_scheme in (DeliveryMechanism.NO_DRM, DeliveryMechanism.BEARER_TOKEN): return True return False @@ -190,7 +204,8 @@ def checkin(self, patron, pin, licensepool): _db = Session.object_session(patron) try: loan = get_one( - _db, Loan, + _db, + Loan, patron_id=patron.id, license_pool_id=licensepool.id, ) @@ -223,7 +238,10 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): # Find the acquisition link with the right media type. for link in links: media_type = link.resource.representation.media_type - if link.rel == Hyperlink.GENERIC_OPDS_ACQUISITION and media_type == internal_format: + if ( + link.rel == Hyperlink.GENERIC_OPDS_ACQUISITION + and media_type == internal_format + ): url = link.resource.representation.url # Obtain a Credential with the information from our @@ -259,10 +277,11 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): def patron_activity(self, patron, pin): # Look up loans for this collection in the database. _db = Session.object_session(patron) - loans = _db.query(Loan).join(Loan.license_pool).filter( - LicensePool.collection_id==self.collection_id - ).filter( - Loan.patron==patron + loans = ( + _db.query(Loan) + .join(Loan.license_pool) + .filter(LicensePool.collection_id == self.collection_id) + .filter(Loan.patron == patron) ) return [ LoanInfo( @@ -271,8 +290,11 @@ def patron_activity(self, patron, pin): loan.license_pool.identifier.type, loan.license_pool.identifier.identifier, loan.start, - loan.end - ) for loan in loans] + loan.end, + ) + for loan in loans + ] + class OPDSForDistributorsImporter(OPDSImporter): NAME = OPDSForDistributorsAPI.NAME @@ -283,22 +305,24 @@ def update_work_for_edition(self, *args, **kwargs): not open-access, but a library that can perform this import has a license for the title and can distribute unlimited copies. """ - pool, work = super( - OPDSForDistributorsImporter, self).update_work_for_edition( - *args, **kwargs + pool, work = super(OPDSForDistributorsImporter, self).update_work_for_edition( + *args, **kwargs ) pool.update_availability( - new_licenses_owned=1, new_licenses_available=1, - new_licenses_reserved=0, new_patrons_in_hold_queue=0 + new_licenses_owned=1, + new_licenses_available=1, + new_licenses_reserved=0, + new_patrons_in_hold_queue=0, ) return pool, work @classmethod def _add_format_data(cls, circulation): for link in circulation.links: - if (link.rel == Hyperlink.GENERIC_OPDS_ACQUISITION - and link.media_type in - OPDSForDistributorsAPI.SUPPORTED_MEDIA_TYPES): + if ( + link.rel == Hyperlink.GENERIC_OPDS_ACQUISITION + and link.media_type in OPDSForDistributorsAPI.SUPPORTED_MEDIA_TYPES + ): circulation.formats.append( FormatData( content_type=link.media_type, @@ -313,10 +337,13 @@ class OPDSForDistributorsImportMonitor(OPDSImportMonitor): """Monitor an OPDS feed that requires or allows authentication, such as Biblioboard or Plympton. """ + PROTOCOL = OPDSForDistributorsImporter.NAME def __init__(self, _db, collection, import_class, **kwargs): - super(OPDSForDistributorsImportMonitor, self).__init__(_db, collection, import_class, **kwargs) + super(OPDSForDistributorsImportMonitor, self).__init__( + _db, collection, import_class, **kwargs + ) self.api = OPDSForDistributorsAPI(_db, collection) @@ -328,10 +355,11 @@ def _get(self, url, headers): token = self.api._get_token(self._db).credential headers = dict(headers or {}) auth_header = "Bearer %s" % token - headers['Authorization'] = auth_header + headers["Authorization"] = auth_header return super(OPDSForDistributorsImportMonitor, self)._get(url, headers) + class OPDSForDistributorsReaperMonitor(OPDSForDistributorsImportMonitor): """This is an unusual import monitor that crawls the entire OPDS feed and keeps track of every identifier it sees, to find out if anything @@ -339,7 +367,9 @@ class OPDSForDistributorsReaperMonitor(OPDSForDistributorsImportMonitor): """ def __init__(self, _db, collection, import_class, **kwargs): - super(OPDSForDistributorsReaperMonitor, self).__init__(_db, collection, import_class, **kwargs) + super(OPDSForDistributorsReaperMonitor, self).__init__( + _db, collection, import_class, **kwargs + ) self.seen_identifiers = set() def feed_contains_new_data(self, feed): @@ -364,27 +394,22 @@ def run_once(self, progress): # self.seen_identifiers is full of URNs. We need the values # that go in Identifier.identifier. - identifiers, failures = Identifier.parse_urns( - self._db, self.seen_identifiers - ) + identifiers, failures = Identifier.parse_urns(self._db, self.seen_identifiers) identifier_ids = [x.id for x in list(identifiers.values())] # At this point we've gone through the feed and collected all the identifiers. # If there's anything we didn't see, we know it's no longer available. - qu = self._db.query( - LicensePool - ).join( - Identifier - ).filter( - LicensePool.collection_id==self.collection.id - ).filter( - ~Identifier.id.in_(identifier_ids) - ).filter( - LicensePool.licenses_available > 0 + qu = ( + self._db.query(LicensePool) + .join(Identifier) + .filter(LicensePool.collection_id == self.collection.id) + .filter(~Identifier.id.in_(identifier_ids)) + .filter(LicensePool.licenses_available > 0) ) pools_reaped = qu.count() self.log.info( - "Reaping %s license pools for collection %s." % (pools_reaped, self.collection.name) + "Reaping %s license pools for collection %s." + % (pools_reaped, self.collection.name) ) for pool in qu: @@ -394,23 +419,25 @@ def run_once(self, progress): achievements = "License pools removed: %d." % pools_reaped return TimestampData(achievements=achievements) -class MockOPDSForDistributorsAPI(OPDSForDistributorsAPI): +class MockOPDSForDistributorsAPI(OPDSForDistributorsAPI): @classmethod def mock_collection(self, _db): """Create a mock OPDS For Distributors collection to use in tests.""" library = DatabaseTest.make_default_library(_db) collection, ignore = get_one_or_create( - _db, Collection, - name="Test OPDS For Distributors Collection", create_method_kwargs=dict( + _db, + Collection, + name="Test OPDS For Distributors Collection", + create_method_kwargs=dict( external_account_id="http://opds", - ) + ), ) integration = collection.create_external_integration( protocol=OPDSForDistributorsAPI.NAME ) - integration.username = 'a' - integration.password = 'b' + integration.username = "a" + integration.password = "b" library.collections.append(collection) return collection @@ -422,14 +449,14 @@ def __init__(self, _db, collection, *args, **kwargs): ) def queue_response(self, status_code, headers={}, content=None): - self.responses.insert( - 0, MockRequestsResponse(status_code, headers, content) - ) + self.responses.insert(0, MockRequestsResponse(status_code, headers, content)) def _request_with_timeout(self, method, url, *args, **kwargs): self.requests.append([method, url, args, kwargs]) response = self.responses.pop() return HTTP._process_response( - url, response, kwargs.get('allowed_response_codes'), - kwargs.get('disallowed_response_codes') + url, + response, + kwargs.get("allowed_response_codes"), + kwargs.get("disallowed_response_codes"), ) diff --git a/api/overdrive.py b/api/overdrive.py index fc648dda66..53afa11278 100644 --- a/api/overdrive.py +++ b/api/overdrive.py @@ -1,33 +1,17 @@ import datetime -import dateutil import json -import pytz import re -import requests -import flask import urllib.parse -from flask_babel import lazy_gettext as _ +import dateutil +import flask +import pytz +import requests +from flask_babel import lazy_gettext as _ from sqlalchemy.orm import contains_eager -from .circulation import ( - DeliveryMechanismInfo, - LoanInfo, - HoldInfo, - FulfillmentInfo, - BaseCirculationAPI, -) -from .selftest import ( - HasSelfTests, - SelfTestResult, -) -from core.overdrive import ( - OverdriveAPI as BaseOverdriveAPI, - OverdriveRepresentationExtractor, - OverdriveBibliographicCoverageProvider, - MockOverdriveAPI as BaseMockOverdriveAPI, -) - +from core.analytics import Analytics +from core.metadata_layer import ReplacementPolicy from core.model import ( CirculationEvent, Collection, @@ -44,19 +28,27 @@ Representation, Session, ) - -from core.monitor import ( - CollectionMonitor, - IdentifierSweepMonitor, - TimelineMonitor, +from core.monitor import CollectionMonitor, IdentifierSweepMonitor, TimelineMonitor +from core.overdrive import MockOverdriveAPI as BaseMockOverdriveAPI +from core.overdrive import OverdriveAPI as BaseOverdriveAPI +from core.overdrive import ( + OverdriveBibliographicCoverageProvider, + OverdriveRepresentationExtractor, ) +from core.scripts import Script from core.util.datetime_helpers import strptime_utc from core.util.http import HTTP -from core.metadata_layer import ReplacementPolicy -from core.scripts import Script +from .circulation import ( + BaseCirculationAPI, + DeliveryMechanismInfo, + FulfillmentInfo, + HoldInfo, + LoanInfo, +) from .circulation_exceptions import * -from core.analytics import Analytics +from .selftest import HasSelfTests, SelfTestResult + class OverdriveAPIConstants(object): # These are not real Overdrive formats; we use them internally so @@ -64,56 +56,84 @@ class OverdriveAPIConstants(object): # to get into Overdrive Read, and using it to get a link to a # manifest file. MANIFEST_INTERNAL_FORMATS = set( - ['audiobook-overdrive-manifest', 'ebook-overdrive-manifest'] + ["audiobook-overdrive-manifest", "ebook-overdrive-manifest"] ) - + # These formats can be delivered either as manifest files or as # links to websites that stream the content. STREAMING_FORMATS = [ - 'ebook-overdrive', - 'audiobook-overdrive', + "ebook-overdrive", + "audiobook-overdrive", ] -class OverdriveAPI(BaseOverdriveAPI, BaseCirculationAPI, HasSelfTests, OverdriveAPIConstants): + +class OverdriveAPI( + BaseOverdriveAPI, BaseCirculationAPI, HasSelfTests, OverdriveAPIConstants +): NAME = ExternalIntegration.OVERDRIVE - DESCRIPTION = _("Integrate an Overdrive collection. For an Overdrive Advantage collection, select the consortium's Overdrive collection as the parent.") + DESCRIPTION = _( + "Integrate an Overdrive collection. For an Overdrive Advantage collection, select the consortium's Overdrive collection as the parent." + ) SETTINGS = [ - { "key": Collection.EXTERNAL_ACCOUNT_ID_KEY, "label": _("Library ID"), "required": True }, - { "key": BaseOverdriveAPI.WEBSITE_ID, "label": _("Website ID"), "required": True }, - { "key": ExternalIntegration.USERNAME, "label": _("Client Key"), "required": True }, - { "key": ExternalIntegration.PASSWORD, "label": _("Client Secret"), "required": True }, + { + "key": Collection.EXTERNAL_ACCOUNT_ID_KEY, + "label": _("Library ID"), + "required": True, + }, + { + "key": BaseOverdriveAPI.WEBSITE_ID, + "label": _("Website ID"), + "required": True, + }, + { + "key": ExternalIntegration.USERNAME, + "label": _("Client Key"), + "required": True, + }, + { + "key": ExternalIntegration.PASSWORD, + "label": _("Client Secret"), + "required": True, + }, { "key": BaseOverdriveAPI.SERVER_NICKNAME, "label": _("Server family"), - "description": _("Unless you hear otherwise from Overdrive, your integration should use their production servers."), + "description": _( + "Unless you hear otherwise from Overdrive, your integration should use their production servers." + ), "type": "select", "options": [ - dict( - label=_("Production"), - key=BaseOverdriveAPI.PRODUCTION_SERVERS - ), + dict(label=_("Production"), key=BaseOverdriveAPI.PRODUCTION_SERVERS), dict( label=_("Testing"), key=BaseOverdriveAPI.TESTING_SERVERS, - ) + ), ], "default": BaseOverdriveAPI.PRODUCTION_SERVERS, }, ] + BaseCirculationAPI.SETTINGS LIBRARY_SETTINGS = BaseCirculationAPI.LIBRARY_SETTINGS + [ - { "key": BaseOverdriveAPI.ILS_NAME_KEY, "label": _("ILS Name"), - "default": BaseOverdriveAPI.ILS_NAME_DEFAULT, - "description": _("When multiple libraries share an Overdrive account, Overdrive uses a setting called 'ILS Name' to determine which ILS to check when validating a given patron."), + { + "key": BaseOverdriveAPI.ILS_NAME_KEY, + "label": _("ILS Name"), + "default": BaseOverdriveAPI.ILS_NAME_DEFAULT, + "description": _( + "When multiple libraries share an Overdrive account, Overdrive uses a setting called 'ILS Name' to determine which ILS to check when validating a given patron." + ), }, - BaseCirculationAPI.DEFAULT_LOAN_DURATION_SETTING + BaseCirculationAPI.DEFAULT_LOAN_DURATION_SETTING, ] # An Overdrive Advantage collection inherits everything except the library id # from its parent. CHILD_SETTINGS = [ - { "key": Collection.EXTERNAL_ACCOUNT_ID_KEY, "label": _("Library ID"), "required": True }, + { + "key": Collection.EXTERNAL_ACCOUNT_ID_KEY, + "label": _("Library ID"), + "required": True, + }, ] SET_DELIVERY_MECHANISM_AT = BaseCirculationAPI.FULFILL_STEP @@ -133,19 +153,20 @@ class OverdriveAPI(BaseOverdriveAPI, BaseCirculationAPI, HasSelfTests, Overdrive # When a request comes in for a given DeliveryMechanism, what # do we tell Overdrive? delivery_mechanism_to_internal_format = { - (epub, no_drm): 'ebook-epub-open', - (epub, adobe_drm): 'ebook-epub-adobe', - (pdf, no_drm): 'ebook-pdf-open', - (pdf, adobe_drm): 'ebook-pdf-adobe', - (streaming_text, streaming_drm): 'ebook-overdrive', - (streaming_audio, streaming_drm): 'audiobook-overdrive', - (overdrive_audiobook_manifest, libby_drm): 'audiobook-overdrive-manifest' + (epub, no_drm): "ebook-epub-open", + (epub, adobe_drm): "ebook-epub-adobe", + (pdf, no_drm): "ebook-pdf-open", + (pdf, adobe_drm): "ebook-pdf-adobe", + (streaming_text, streaming_drm): "ebook-overdrive", + (streaming_audio, streaming_drm): "audiobook-overdrive", + (overdrive_audiobook_manifest, libby_drm): "audiobook-overdrive-manifest", } # Once you choose a non-streaming format you're locked into it and can't # use other formats. LOCK_IN_FORMATS = [ - x for x in BaseOverdriveAPI.FORMATS + x + for x in BaseOverdriveAPI.FORMATS if x not in OverdriveAPIConstants.STREAMING_FORMATS and x not in OverdriveAPIConstants.MANIFEST_INTERNAL_FORMATS ] @@ -164,9 +185,7 @@ class OverdriveAPI(BaseOverdriveAPI, BaseCirculationAPI, HasSelfTests, Overdrive def __init__(self, _db, collection): super(OverdriveAPI, self).__init__(_db, collection) self.overdrive_bibliographic_coverage_provider = ( - OverdriveBibliographicCoverageProvider( - collection, api_class=self - ) + OverdriveBibliographicCoverageProvider(collection, api_class=self) ) def external_integration(self, _db): @@ -175,7 +194,8 @@ def external_integration(self, _db): def _run_self_tests(self, _db): result = self.run_test( "Checking global Client Authentication privileges", - self.check_creds, force_refresh=True + self.check_creds, + force_refresh=True, ) yield result if not result.success: @@ -187,20 +207,17 @@ def _count_advantage(): """Count the Overdrive Advantage accounts""" accounts = list(self.get_advantage_accounts()) return "Found %d Overdrive Advantage account(s)." % len(accounts) - yield self.run_test( - "Looking up Overdrive Advantage accounts", - _count_advantage - ) + + yield self.run_test("Looking up Overdrive Advantage accounts", _count_advantage) def _count_books(): """Count the titles in the collection.""" url = self._all_products_link status, headers, body = self.get(url, {}) body = json.loads(body) - return "%d item(s) in collection" % body['totalItems'] - yield self.run_test( - "Counting size of collection", _count_books - ) + return "%d item(s) in collection" % body["totalItems"] + + yield self.run_test("Counting size of collection", _count_books) default_patrons = [] for result in self.default_patrons(self.collection): @@ -208,13 +225,22 @@ def _count_books(): yield result continue library, patron, pin = result - task = "Checking Patron Authentication privileges, using test patron for library %s" % library.name - yield self.run_test( - task, self.get_patron_credential, patron, pin + task = ( + "Checking Patron Authentication privileges, using test patron for library %s" + % library.name ) - - def patron_request(self, patron, pin, url, extra_headers={}, data=None, - exception_on_401=False, method=None): + yield self.run_test(task, self.get_patron_credential, patron, pin) + + def patron_request( + self, + patron, + pin, + url, + extra_headers={}, + data=None, + exception_on_401=False, + method=None, + ): """Make an HTTP request on behalf of a patron. The results are never cached. @@ -222,17 +248,15 @@ def patron_request(self, patron, pin, url, extra_headers={}, data=None, patron_credential = self.get_patron_credential(patron, pin) headers = dict(Authorization="Bearer %s" % patron_credential.credential) headers.update(extra_headers) - if method and method.lower() in ('get', 'post', 'put', 'delete'): + if method and method.lower() in ("get", "post", "put", "delete"): method = method.lower() else: if data: - method = 'post' + method = "post" else: - method = 'get' + method = "get" url = self.endpoint(url) - response = HTTP.request_with_timeout( - method, url, headers=headers, data=data - ) + response = HTTP.request_with_timeout(method, url, headers=headers, data=data) if response.status_code == 401: if exception_on_401: # This is our second try. Give up. @@ -241,10 +265,8 @@ def patron_request(self, patron, pin, url, extra_headers={}, data=None, ) else: # Refresh the token and try again. - self.refresh_patron_access_token( - patron_credential, patron, pin) - return self.patron_request( - patron, pin, url, extra_headers, data, True) + self.refresh_patron_access_token(patron_credential, patron, pin) + return self.patron_request(patron, pin, url, extra_headers, data, True) else: # This is commented out because it may expose patron # information. @@ -254,12 +276,17 @@ def patron_request(self, patron, pin, url, extra_headers={}, data=None, def get_patron_credential(self, patron, pin): """Create an OAuth token for the given patron.""" + def refresh(credential): - return self.refresh_patron_access_token( - credential, patron, pin) + return self.refresh_patron_access_token(credential, patron, pin) + return Credential.lookup( - self._db, DataSource.OVERDRIVE, "OAuth Token", patron, refresh, - collection=self.collection + self._db, + DataSource.OVERDRIVE, + "OAuth Token", + patron, + refresh, + collection=self.collection, ) def scope_string(self, library): @@ -270,7 +297,8 @@ def scope_string(self, library): its own Patron Authentication. """ return "websiteid:%s authorizationname:%s" % ( - self.website_id.decode("utf-8"), self.ils_name(library) + self.website_id.decode("utf-8"), + self.ils_name(library), ) def refresh_patron_access_token(self, credential, patron, pin): @@ -282,29 +310,29 @@ def refresh_patron_access_token(self, credential, patron, pin): payload = dict( grant_type="password", username=patron.authorization_identifier, - scope=self.scope_string(patron.library) + scope=self.scope_string(patron.library), ) if pin: # A PIN was provided. - payload['password'] = pin + payload["password"] = pin else: # No PIN was provided. Depending on the library, # this might be fine. If it's not fine, Overdrive will # refuse to issue a token. - payload['password_required'] = 'false' - payload['password'] = '[ignore]' + payload["password_required"] = "false" + payload["password"] = "[ignore]" response = self.token_post(self.PATRON_TOKEN_ENDPOINT, payload) if response.status_code == 200: self._update_credential(credential, response.json()) elif response.status_code == 400: response = response.json() - message = response['error'] - error = response.get('error_description') + message = response["error"] + error = response.get("error_description") if error: - message += '/' + error + message += "/" + error diagnostic = None debug = message - if error == 'Requested record not found': + if error == "Requested record not found": debug = "The patron failed Overdrive's cross-check against the library's ILS." raise PatronAuthorizationFailedException(message, debug) return credential @@ -326,14 +354,13 @@ def checkout(self, patron, pin, licensepool, internal_format): """ identifier = licensepool.identifier - overdrive_id=identifier.identifier + overdrive_id = identifier.identifier headers = {"Content-Type": "application/json"} payload = dict(fields=[dict(name="reserveId", value=overdrive_id)]) payload = json.dumps(payload) response = self.patron_request( - patron, pin, self.CHECKOUTS_ENDPOINT, extra_headers=headers, - data=payload + patron, pin, self.CHECKOUTS_ENDPOINT, extra_headers=headers, data=payload ) data = response.json() if response.status_code == 400: @@ -366,13 +393,13 @@ def _process_checkout_error(self, patron, pin, licensepool, error): code = "Unknown Error" identifier = licensepool.identifier if isinstance(error, dict): - code = error.get('errorCode', code) - if code == 'NoCopiesAvailable': + code = error.get("errorCode", code) + if code == "NoCopiesAvailable": # Clearly our info is out of date. self.update_licensepool(identifier.identifier) raise NoAvailableCopies() - if code == 'TitleAlreadyCheckedOut': + if code == "TitleAlreadyCheckedOut": # Client should have used a fulfill link instead, but # we can handle it. # @@ -388,7 +415,7 @@ def _process_checkout_error(self, patron, pin, licensepool, error): identifier.identifier, None, expires, - None + None, ) if code in self.ERROR_MESSAGE_TO_EXCEPTION: @@ -406,10 +433,13 @@ def checkin(self, patron, pin, licensepool): loans = [l for l in patron.loans if l.license_pool == licensepool] if loans: loan = loans[0] - if (loan and loan.fulfillment + if ( + loan + and loan.fulfillment and loan.fulfillment.delivery_mechanism and loan.fulfillment.delivery_mechanism.drm_scheme - == DeliveryMechanism.NO_DRM): + == DeliveryMechanism.NO_DRM + ): # This patron fulfilled this loan without DRM. That means we # should be able to find a loanEarlyReturnURL and hit it. if self.perform_early_return(patron, pin, loan): @@ -426,7 +456,7 @@ def checkin(self, patron, pin, licensepool): # loan exists but has not been locked to a delivery mechanism. overdrive_id = licensepool.identifier.identifier url = self.endpoint(self.CHECKOUT_ENDPOINT, overdrive_id=overdrive_id) - return self.patron_request(patron, pin, url, method='DELETE') + return self.patron_request(patron, pin, url, method="DELETE") def perform_early_return(self, patron, pin, loan, http_get=None): """Ask Overdrive for a loanEarlyReturnURL for the given loan @@ -450,8 +480,7 @@ def perform_early_return(self, patron, pin, loan, http_get=None): # Ask Overdrive for a link that can be used to fulfill the book # (but which may also contain an early return URL). url, media_type = self.get_fulfillment_link( - patron, pin, loan.license_pool.identifier.identifier, - internal_format + patron, pin, loan.license_pool.identifier.identifier, internal_format ) # The URL comes from Overdrive, so it probably doesn't need # interpolation, but just in case. @@ -460,7 +489,7 @@ def perform_early_return(self, patron, pin, loan, http_get=None): # Make a regular, non-authenticated request to the fulfillment link. http_get = http_get or HTTP.get_with_timeout response = http_get(url, allow_redirects=False) - location = response.headers.get('location') + location = response.headers.get("location") # Try to find an early return URL in the Location header # sent from the fulfillment request. @@ -482,7 +511,7 @@ def _extract_early_return_url(cls, location): return None parsed = urllib.parse.urlparse(location) query = urllib.parse.parse_qs(parsed.query) - urls = query.get('loanEarlyReturnUrl') + urls = query.get("loanEarlyReturnUrl") if urls: return urls[0] @@ -494,14 +523,14 @@ def fill_out_form(self, **values): return headers, json.dumps(dict(fields=fields)) error_to_exception = { - "TitleNotCheckedOut" : NoActiveLoan, + "TitleNotCheckedOut": NoActiveLoan, } def raise_exception_on_error(self, data, custom_error_to_exception={}): - if not 'errorCode' in data: + if not "errorCode" in data: return - error = data['errorCode'] - message = data.get('message') or '' + error = data["errorCode"] + message = data.get("message") or "" for d in custom_error_to_exception, self.error_to_exception: if error in d: raise d[error](message) @@ -513,9 +542,7 @@ def get_loan(self, patron, pin, overdrive_id): return data def get_hold(self, patron, pin, overdrive_id): - url = self.endpoint( - self.HOLD_ENDPOINT, product_id=overdrive_id.upper() - ) + url = self.endpoint(self.HOLD_ENDPOINT, product_id=overdrive_id.upper()) data = self.patron_request(patron, pin, url).json() self.raise_exception_on_error(data) return data @@ -552,12 +579,18 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): # It's possible the available formats for this book have changed and we # have an inaccurate delivery mechanism. Try to update the formats, but # reraise the error regardless. - self.log.info("Overdrive id %s was not available as %s, getting updated formats" % (licensepool.identifier.identifier, internal_format)) + self.log.info( + "Overdrive id %s was not available as %s, getting updated formats" + % (licensepool.identifier.identifier, internal_format) + ) try: self.update_formats(licensepool) except Exception as e2: - self.log.error("Could not update formats for Overdrive ID %s" % licensepool.identifier.identifier) + self.log.error( + "Could not update formats for Overdrive ID %s" + % licensepool.identifier.identifier + ) raise e @@ -569,35 +602,37 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): content_link=url, content_type=media_type, content=None, - content_expires=None + content_expires=None, ) - def get_fulfillment_link(self, patron, pin, overdrive_id, format_type): - """Get the link to the ACSM or manifest for an existing loan. - """ + """Get the link to the ACSM or manifest for an existing loan.""" loan = self.get_loan(patron, pin, overdrive_id) if not loan: raise NoActiveLoan("Could not find active loan for %s" % overdrive_id) download_link = None - if (not loan.get('isFormatLockedIn') - and format_type in self.LOCK_IN_FORMATS): + if not loan.get("isFormatLockedIn") and format_type in self.LOCK_IN_FORMATS: # The format is not locked in. Lock it in. # This will happen the first time someone tries to fulfill # a loan with a lock-in format (basically Adobe-gated formats) - response = self.lock_in_format( - patron, pin, overdrive_id, format_type) + response = self.lock_in_format(patron, pin, overdrive_id, format_type) if response.status_code not in (201, 200): if response.status_code == 400: message = response.json().get("message") - if message == "The selected format may not be available for this title.": - raise FormatNotAvailable("This book is not available in the format you requested.") + if ( + message + == "The selected format may not be available for this title." + ): + raise FormatNotAvailable( + "This book is not available in the format you requested." + ) else: raise CannotFulfill("Could not lock in format %s" % format_type) response = response.json() try: download_link = self.extract_download_link( - response, self.DEFAULT_ERROR_URL) + response, self.DEFAULT_ERROR_URL + ) except IOError as e: # Get the loan fresh and see if that solves the problem. loan = self.get_loan(patron, pin, overdrive_id) @@ -611,8 +646,8 @@ def get_fulfillment_link(self, patron, pin, overdrive_id, format_type): ) if not download_link: raise CannotFulfill( - "No download link for %s, format %s" % ( - overdrive_id, format_type)) + "No download link for %s, format %s" % (overdrive_id, format_type) + ) if download_link: if format_type in self.MANIFEST_INTERNAL_FORMATS: @@ -620,58 +655,66 @@ def get_fulfillment_link(self, patron, pin, overdrive_id, format_type): # credentials to fulfill this URL; we can't do it. scope_string = self.scope_string(patron.library) return OverdriveManifestFulfillmentInfo( - self.collection, download_link, - overdrive_id, scope_string + self.collection, download_link, overdrive_id, scope_string ) return self.get_fulfillment_link_from_download_link( - patron, pin, download_link) + patron, pin, download_link + ) - raise CannotFulfill("Cannot obtain a download link for patron[%r], overdrive_id[%s], format_type[%s].", patron, overdrive_id, format_type) + raise CannotFulfill( + "Cannot obtain a download link for patron[%r], overdrive_id[%s], format_type[%s].", + patron, + overdrive_id, + format_type, + ) - def get_fulfillment_link_from_download_link(self, patron, pin, download_link, fulfill_url=None): + def get_fulfillment_link_from_download_link( + self, patron, pin, download_link, fulfill_url=None + ): # If this for Overdrive's streaming reader, and the link expires, # the patron can go back to the circulation manager fulfill url # again to get a new one. if not fulfill_url and flask.request: fulfill_url = flask.request.url else: - fulfill_url="" + fulfill_url = "" download_link = download_link.replace("{odreadauthurl}", fulfill_url) download_response = self.patron_request(patron, pin, download_link) return self.extract_content_link(download_response.json()) def extract_content_link(self, content_link_gateway_json): - link = content_link_gateway_json['links']['contentlink'] - return link['href'], link['type'] + link = content_link_gateway_json["links"]["contentlink"] + return link["href"], link["type"] def lock_in_format(self, patron, pin, overdrive_id, format_type): overdrive_id = overdrive_id.upper() headers, document = self.fill_out_form( - reserveId=overdrive_id, formatType=format_type) - url = self.endpoint( - self.FORMATS_ENDPOINT, overdrive_id=overdrive_id + reserveId=overdrive_id, formatType=format_type ) + url = self.endpoint(self.FORMATS_ENDPOINT, overdrive_id=overdrive_id) return self.patron_request(patron, pin, url, headers, document) @classmethod - def extract_data_from_checkout_response(cls, checkout_response_json, - format_type, error_url): + def extract_data_from_checkout_response( + cls, checkout_response_json, format_type, error_url + ): expires = cls.extract_expiration_date(checkout_response_json) return expires, cls.get_download_link( - checkout_response_json, format_type, error_url) + checkout_response_json, format_type, error_url + ) @classmethod def extract_data_from_hold_response(cls, hold_response_json): - position = hold_response_json['holdListPosition'] - placed = cls._extract_date(hold_response_json, 'holdPlacedDate') + position = hold_response_json["holdListPosition"] + placed = cls._extract_date(hold_response_json, "holdPlacedDate") return position, placed @classmethod def extract_expiration_date(cls, data): - return cls._extract_date(data, 'expires') + return cls._extract_date(data, "expires") @classmethod def _extract_date(cls, data, field_name): @@ -680,9 +723,7 @@ def _extract_date(cls, data, field_name): if not field_name in data: return None try: - return strptime_utc( - data[field_name], cls.TIME_FORMAT - ) + return strptime_utc(data[field_name], cls.TIME_FORMAT) except ValueError as e: # Wrong format return None @@ -727,24 +768,23 @@ def patron_activity(self, patron, pin): # It's common enough that it's hardly worth mentioning, but it # could theoretically be the sign of a larger problem. self.log.info( - "Overdrive authentication failed, assuming no loans.", - exc_info=e + "Overdrive authentication failed, assuming no loans.", exc_info=e ) loans = {} holds = {} - for checkout in loans.get('checkouts', []): + for checkout in loans.get("checkouts", []): loan_info = self.process_checkout_data(checkout, self.collection) yield loan_info - for hold in holds.get('holds', []): - overdrive_identifier = hold['reserveId'].lower() - start = self._pd(hold.get('holdPlacedDate')) - end = self._pd(hold.get('holdExpires')) - position = hold.get('holdListPosition') + for hold in holds.get("holds", []): + overdrive_identifier = hold["reserveId"].lower() + start = self._pd(hold.get("holdPlacedDate")) + end = self._pd(hold.get("holdExpires")) + position = hold.get("holdListPosition") if position is not None: position = int(position) - if 'checkout' in hold.get('actions', {}): + if "checkout" in hold.get("actions", {}): # This patron needs to decide whether to check the # book out. By our reckoning, the patron's position is # 0, not whatever position Overdrive had for them. @@ -756,7 +796,7 @@ def patron_activity(self, patron, pin): overdrive_identifier, start_date=start, end_date=end, - hold_position=position + hold_position=position, ) @classmethod @@ -767,25 +807,24 @@ def process_checkout_data(cls, checkout, collection): :return: A LoanInfo object if the book can be fulfilled by the default Library Simplified client, and None otherwise. """ - overdrive_identifier = checkout['reserveId'].lower() - start = cls._pd(checkout.get('checkoutDate')) - end = cls._pd(checkout.get('expires')) + overdrive_identifier = checkout["reserveId"].lower() + start = cls._pd(checkout.get("checkoutDate")) + end = cls._pd(checkout.get("expires")) usable_formats = [] # If a format is already locked in, it will be in formats. - for format in checkout.get('formats', []): - format_type = format.get('formatType') + for format in checkout.get("formats", []): + format_type = format.get("formatType") if format_type in cls.FORMATS: usable_formats.append(format_type) - # If a format hasn't been selected yet, available formats are in actions. - actions = checkout.get('actions', {}) - format_action = actions.get('format', {}) - format_fields = format_action.get('fields', []) + actions = checkout.get("actions", {}) + format_action = actions.get("format", {}) + format_fields = format_action.get("fields", []) for field in format_fields: - if field.get('name', "") == "formatType": + if field.get("name", "") == "formatType": format_options = field.get("options", []) for format_type in format_options: if format_type in cls.FORMATS: @@ -810,9 +849,7 @@ def process_checkout_data(cls, checkout, collection): [overdrive_format] = usable_formats internal_formats = list( - OverdriveRepresentationExtractor.internal_formats( - overdrive_format - ) + OverdriveRepresentationExtractor.internal_formats(overdrive_format) ) if len(internal_formats) == 1: @@ -820,8 +857,7 @@ def process_checkout_data(cls, checkout, collection): # Make it clear that Overdrive will only deliver the content # in one specific media type. locked_to = DeliveryMechanismInfo( - content_type=media_type, - drm_scheme=drm_scheme + content_type=media_type, drm_scheme=drm_scheme ) return LoanInfo( @@ -831,7 +867,7 @@ def process_checkout_data(cls, checkout, collection): overdrive_identifier, start_date=start, end_date=end, - locked_to=locked_to + locked_to=locked_to, ) def default_notification_email_address(self, patron, pin): @@ -854,12 +890,10 @@ def default_notification_email_address(self, patron, pin): # Instead, we will ask _Overdrive_ if this patron has a # preferred email address for notifications. address = None - response = self.patron_request( - patron, pin, self.PATRON_INFORMATION_ENDPOINT - ) + response = self.patron_request(patron, pin, self.PATRON_INFORMATION_ENDPOINT) if response.status_code == 200: data = response.json() - address = data.get('lastHoldEmail') + address = data.get("lastHoldEmail") # Great! Except, it's possible that this address is the # 'trash everything' address, because we _used_ to send @@ -870,7 +904,7 @@ def default_notification_email_address(self, patron, pin): self.log.error( "Unable to get patron information for %s: %s", patron.authorization_identifier, - response.content + response.content, ) return address @@ -889,18 +923,15 @@ def place_hold(self, patron, pin, licensepool, notification_email_address): overdrive_id = licensepool.identifier.identifier form_fields = dict(reserveId=overdrive_id) if notification_email_address: - form_fields['emailAddress'] = notification_email_address + form_fields["emailAddress"] = notification_email_address else: - form_fields['ignoreHoldEmail'] = True + form_fields["ignoreHoldEmail"] = True headers, document = self.fill_out_form(**form_fields) response = self.patron_request( - patron, pin, self.HOLDS_ENDPOINT, headers, - document - ) - return self.process_place_hold_response( - response, patron, pin, licensepool + patron, pin, self.HOLDS_ENDPOINT, headers, document ) + return self.process_place_hold_response(response, patron, pin, licensepool) def process_place_hold_response(self, response, patron, pin, licensepool): """Process the response to a HOLDS_ENDPOINT request. @@ -910,14 +941,13 @@ def process_place_hold_response(self, response, patron, pin, licensepool): :raise: A CirculationException explaining why no hold could be placed. """ + def make_holdinfo(hold_response): # Create a HoldInfo object by combining data passed into # the enclosing method with the data from a hold response # (either creating a new hold or fetching an existing # one). - position, start_date = self.extract_data_from_hold_response( - hold_response - ) + position, start_date = self.extract_data_from_hold_response(hold_response) return HoldInfo( licensepool.collection, licensepool.data_source.name, @@ -925,29 +955,27 @@ def make_holdinfo(hold_response): licensepool.identifier.identifier, start_date=start_date, end_date=None, - hold_position=position + hold_position=position, ) family = response.status_code // 100 if family == 4: error = response.json() - if not error or not 'errorCode' in error: + if not error or not "errorCode" in error: raise CannotHold() - code = error['errorCode'] - if code == 'AlreadyOnWaitList': + code = error["errorCode"] + if code == "AlreadyOnWaitList": # The book is already on hold, so this isn't an exceptional # condition. Refresh the queue info and act as though the # request was successful. - hold = self.get_hold( - patron, pin, licensepool.identifier.identifier - ) + hold = self.get_hold(patron, pin, licensepool.identifier.identifier) return make_holdinfo(hold) - elif code == 'NotWithinRenewalWindow': + elif code == "NotWithinRenewalWindow": # The patron has this book checked out and cannot yet # renew their loan. raise CannotRenew() - elif code == 'PatronExceededHoldLimit': + elif code == "PatronExceededHoldLimit": raise PatronHoldLimitReached() else: raise CannotHold(code) @@ -970,18 +998,17 @@ def release_hold(self, patron, pin, licensepool): any reason. """ url = self.endpoint( - self.HOLD_ENDPOINT, - product_id=licensepool.identifier.identifier + self.HOLD_ENDPOINT, product_id=licensepool.identifier.identifier ) - response = self.patron_request(patron, pin, url, method='DELETE') + response = self.patron_request(patron, pin, url, method="DELETE") if response.status_code // 100 == 2 or response.status_code == 404: return True if not response.content: raise CannotReleaseHold() data = response.json() - if not 'errorCode' in data: + if not "errorCode" in data: raise CannotReleaseHold() - if data['errorCode'] == 'PatronDoesntHaveTitleOnHold': + if data["errorCode"] == "PatronDoesntHaveTitleOnHold": # There was never a hold to begin with, so we're fine. return True raise CannotReleaseHold(response.content) @@ -992,11 +1019,11 @@ def circulation_lookup(self, book): circulation_link = self.endpoint( self.AVAILABILITY_ENDPOINT, collection_token=self.collection_token, - product_id=book_id + product_id=book_id, ) book = dict(id=book_id) else: - circulation_link = book['availability_link'] + circulation_link = book["availability_link"] # Make sure we use v2 of the availability API, # even if Overdrive gave us a link to v1. circulation_link = self.make_link_safe(circulation_link) @@ -1011,7 +1038,8 @@ def update_formats(self, licensepool): info = self.metadata_lookup(licensepool.identifier) metadata = OverdriveRepresentationExtractor.book_info_to_metadata( - info, include_bibliographic=True, include_formats=True) + info, include_bibliographic=True, include_formats=True + ) if not metadata: # No work to be done. return @@ -1034,15 +1062,10 @@ def update_licensepool(self, book_id): """ # Retrieve current circulation information about this book try: - book, (status_code, headers, content) = self.circulation_lookup( - book_id - ) + book, (status_code, headers, content) = self.circulation_lookup(book_id) except Exception as e: status_code = None - self.log.error( - "HTTP exception communicating with Overdrive", - exc_info=e - ) + self.log.error("HTTP exception communicating with Overdrive", exc_info=e) # TODO: If you ask for a book that you know about, and # Overdrive says the book doesn't exist in the collection, @@ -1053,7 +1076,8 @@ def update_licensepool(self, book_id): if status_code not in (200, 404): self.log.error( "Could not get availability for %s: status code %s", - book_id, status_code + book_id, + status_code, ) return None, None, False if isinstance(content, (bytes, str)): @@ -1061,22 +1085,22 @@ def update_licensepool(self, book_id): book.update(content) # Update book_id now that we know we have new data. - book_id = book['id'] + book_id = book["id"] license_pool, is_new = LicensePool.for_foreign_id( - self._db, DataSource.OVERDRIVE, Identifier.OVERDRIVE_ID, book_id, - collection=self.collection + self._db, + DataSource.OVERDRIVE, + Identifier.OVERDRIVE_ID, + book_id, + collection=self.collection, ) if is_new or not license_pool.work: # Either this is the first time we've seen this book or its doesn't # have an associated work. Make sure its identifier has bibliographic coverage. self.overdrive_bibliographic_coverage_provider.ensure_coverage( - license_pool.identifier, - force=True + license_pool.identifier, force=True ) - return self.update_licensepool_with_book_info( - book, license_pool, is_new - ) + return self.update_licensepool_with_book_info(book, license_pool, is_new) # Alias for the CirculationAPI interface def update_availability(self, licensepool): @@ -1087,8 +1111,10 @@ def _edition(self, licensepool): Overdrive metadata for the given LicensePool. """ return Edition.for_foreign_id( - self._db, self.source, licensepool.identifier.type, - licensepool.identifier.identifier + self._db, + self.source, + licensepool.identifier.type, + licensepool.identifier.identifier, ) def update_licensepool_with_book_info(self, book, license_pool, is_new_pool): @@ -1101,9 +1127,7 @@ def update_licensepool_with_book_info(self, book, license_pool, is_new_pool): status. """ extractor = OverdriveRepresentationExtractor(self) - circulation = extractor.book_info_to_circulation( - book - ) + circulation = extractor.book_info_to_circulation(book) license_pool, circulation_changed = circulation.apply( self._db, license_pool.collection ) @@ -1115,7 +1139,6 @@ def update_licensepool_with_book_info(self, book, license_pool, is_new_pool): self.log.info("New Overdrive book discovered: %r", edition) return license_pool, is_new_pool, circulation_changed - @classmethod def get_download_link(self, checkout_response, format_type, error_url): """Extract a download link from the given response. @@ -1140,14 +1163,16 @@ def get_download_link(self, checkout_response, format_type, error_url): else: use_format_type = format_type fetch_manifest = False - for f in checkout_response.get('formats', []): - this_type = f['formatType'] + for f in checkout_response.get("formats", []): + this_type = f["formatType"] available_formats.append(this_type) if this_type == use_format_type: format = f break if not format: - if any(x in set(available_formats) for x in self.INCOMPATIBLE_PLATFORM_FORMATS): + if any( + x in set(available_formats) for x in self.INCOMPATIBLE_PLATFORM_FORMATS + ): # The most likely explanation is that the patron # already had this book delivered to their Kindle. raise FulfilledOnIncompatiblePlatform( @@ -1178,16 +1203,16 @@ def extract_download_link(cls, format, error_url, fetch_manifest=False): a manifest file. """ - format_type = format.get('formatType', '(unknown)') - if not 'linkTemplates' in format: + format_type = format.get("formatType", "(unknown)") + if not "linkTemplates" in format: raise IOError("No linkTemplates for format %s" % format_type) - templates = format['linkTemplates'] - if not 'downloadLink' in templates: + templates = format["linkTemplates"] + if not "downloadLink" in templates: raise IOError("No downloadLink for format %s" % format_type) - download_link_data = templates['downloadLink'] - if not 'href' in download_link_data: + download_link_data = templates["downloadLink"] + if not "href" in download_link_data: raise IOError("No downloadLink href for format %s" % format_type) - download_link = download_link_data['href'] + download_link = download_link_data["href"] if download_link: if fetch_manifest: download_link = cls.make_direct_download_link(download_link) @@ -1210,19 +1235,20 @@ def make_direct_download_link(cls, link): link. """ # Remove any Overdrive Read authentication URL and error URL. - for argument_name in ('odreadauthurl', 'errorpageurl'): + for argument_name in ("odreadauthurl", "errorpageurl"): argument_re = re.compile("%s={%s}&?" % (argument_name, argument_name)) link = argument_re.sub("", link) # Add the contentfile=true argument. - if '?' not in link: - link += '?contentfile=true' - elif link.endswith('&') or link.endswith('?'): - link += 'contentfile=true' + if "?" not in link: + link += "?contentfile=true" + elif link.endswith("&") or link.endswith("?"): + link += "contentfile=true" else: - link += '&contentfile=true' + link += "&contentfile=true" return link + class MockOverdriveResponse(object): def __init__(self, status_code, headers, content): self.status_code = status_code @@ -1239,7 +1265,7 @@ class MockOverdriveAPI(BaseMockOverdriveAPI, OverdriveAPI): token_data = '{"access_token":"foo","token_type":"bearer","expires_in":3600,"scope":"LIB META AVAIL SRCH"}' - collection_token = 'fake token' + collection_token = "fake token" def patron_request(self, patron, pin, *args, **kwargs): response = self._make_request(*args, **kwargs) @@ -1249,8 +1275,8 @@ def patron_request(self, patron, pin, *args, **kwargs): # The last item in the record of the request is keyword arguments. # Stick this information in there to minimize confusion. - original_data[-1]['_patron'] = patron - original_data[-1]['_pin'] = patron + original_data[-1]["_patron"] = patron + original_data[-1]["_pin"] = patron return response @@ -1258,11 +1284,14 @@ class OverdriveCirculationMonitor(CollectionMonitor, TimelineMonitor): """Maintain LicensePools for recently changed Overdrive titles. Create basic Editions for any new LicensePools that show up. """ + SERVICE_NAME = "Overdrive Circulation Monitor" PROTOCOL = ExternalIntegration.OVERDRIVE OVERLAP = datetime.timedelta(minutes=1) - def __init__(self, _db, collection, api_class=OverdriveAPI, analytics_class=Analytics): + def __init__( + self, _db, collection, api_class=OverdriveAPI, analytics_class=Analytics + ): """Constructor.""" super(OverdriveCirculationMonitor, self).__init__(_db, collection) self.api = api_class(_db, collection) @@ -1277,9 +1306,7 @@ def catch_up_from(self, start, cutoff, progress): :progress: A TimestampData representing the time previously covered by this Monitor. """ - overdrive_data_source = DataSource.lookup( - self._db, DataSource.OVERDRIVE - ) + overdrive_data_source = DataSource.lookup(self._db, DataSource.OVERDRIVE) # Ask for changes between the last time covered by the Monitor # and the current time. @@ -1295,7 +1322,10 @@ def catch_up_from(self, start, cutoff, progress): if is_new: for library in self.collection.libraries: self.analytics.collect_event( - library, license_pool, CirculationEvent.DISTRIBUTOR_TITLE_ADD, license_pool.last_checked + library, + license_pool, + CirculationEvent.DISTRIBUTOR_TITLE_ADD, + license_pool.last_checked, ) self._db.commit() @@ -1311,6 +1341,7 @@ class NewTitlesOverdriveCollectionMonitor(OverdriveCirculationMonitor): This catches any new titles that slipped through the cracks of the RecentOverdriveCollectionMonitor. """ + SERVICE_NAME = "Overdrive New Title Monitor" OVERLAP = datetime.timedelta(days=7) DEFAULT_START_TIME = OverdriveCirculationMonitor.NEVER @@ -1326,7 +1357,7 @@ def should_stop(self, start, api_description, is_changed): return False # We should stop if this book was added before our start time. - date_added = api_description.get('date_added') + date_added = api_description.get("date_added") if not date_added: # We don't know when this book was added -- shouldn't happen. return False @@ -1344,7 +1375,9 @@ def should_stop(self, start, api_description, is_changed): start = pytz.utc.localize(start) self.log.info( "Date added: %s, start time: %s, result %s", - date_added, start, date_added < start + date_added, + start, + date_added < start, ) return date_added < start @@ -1353,6 +1386,7 @@ class OverdriveCollectionReaper(IdentifierSweepMonitor): """Check for books that are in the local collection but have left our Overdrive collection. """ + SERVICE_NAME = "Overdrive Collection Reaper" PROTOCOL = ExternalIntegration.OVERDRIVE @@ -1374,7 +1408,7 @@ class RecentOverdriveCollectionMonitor(OverdriveCirculationMonitor): # haven't changed since last time. Overdrive results are not in # strict chronological order, but if you see 100 consecutive books # that haven't changed, you're probably done. - MAXIMUM_CONSECUTIVE_UNCHANGED_BOOKS=100 + MAXIMUM_CONSECUTIVE_UNCHANGED_BOOKS = 100 def __init__(self, *args, **kwargs): super(RecentOverdriveCollectionMonitor, self).__init__(*args, **kwargs) @@ -1385,14 +1419,17 @@ def should_stop(self, start, api_description, is_changed): self.consecutive_unchanged_books = 0 else: self.consecutive_unchanged_books += 1 - if (self.consecutive_unchanged_books >= - self.MAXIMUM_CONSECUTIVE_UNCHANGED_BOOKS): + if ( + self.consecutive_unchanged_books + >= self.MAXIMUM_CONSECUTIVE_UNCHANGED_BOOKS + ): # We're supposed to stop this run after finding a # run of books that have not changed, and we have # in fact seen that many consecutive unchanged # books. - self.log.info("Stopping at %d unchanged books.", - self.consecutive_unchanged_books) + self.log.info( + "Stopping at %d unchanged books.", self.consecutive_unchanged_books + ) return True return False @@ -1401,6 +1438,7 @@ class OverdriveFormatSweep(IdentifierSweepMonitor): """Check the current formats of every Overdrive book in our collection. """ + SERVICE_NAME = "Overdrive Format Sweep" DEFAULT_BATCH_SIZE = 25 PROTOCOL = ExternalIntegration.OVERDRIVE @@ -1419,14 +1457,11 @@ def process_item(self, identifier): class OverdriveAdvantageAccountListScript(Script): - def run(self): """Explain every Overdrive collection and, for each one, all of its Advantage collections. """ - collections = Collection.by_protocol( - self._db, ExternalIntegration.OVERDRIVE - ) + collections = Collection.by_protocol(self._db, ExternalIntegration.OVERDRIVE) for collection in collections: self.explain_main_collection(collection) print() @@ -1440,13 +1475,11 @@ def explain_main_collection(self, collection): print("\n".join(collection.explain())) print("A few of the titles in the main collection:") for i, book in enumerate(api.all_ids()): - print("", book['title']) + print("", book["title"]) if i > 10: break advantage_accounts = list(api.get_advantage_accounts()) - print("%d associated Overdrive Advantage account(s)." % len( - advantage_accounts - )) + print("%d associated Overdrive Advantage account(s)." % len(advantage_accounts)) for advantage_collection in advantage_accounts: self.explain_advantage_collection(advantage_collection) print() @@ -1459,15 +1492,13 @@ def explain_advantage_collection(self, collection): print(" A few of the titles in this Advantage collection:") child_api = OverdriveAPI(self._db, child) for i, book in enumerate(child_api.all_ids()): - print(" ", book['title']) + print(" ", book["title"]) if i > 10: break class OverdriveManifestFulfillmentInfo(FulfillmentInfo): - - def __init__(self, collection, content_link, overdrive_identifier, - scope_string): + def __init__(self, collection, content_link, overdrive_identifier, scope_string): """Constructor. Most of the arguments to the superconstructor can be assumed, @@ -1492,6 +1523,6 @@ def as_response(self): headers = { "Location": self.content_link, "X-Overdrive-Scope": self.scope_string, - "Content-Type": self.content_type or 'text/plain', + "Content-Type": self.content_type or "text/plain", } return flask.Response("", 302, headers) diff --git a/api/problem_details.py b/api/problem_details.py index 9058bf6b4b..8aba62d645 100644 --- a/api/problem_details.py +++ b/api/problem_details.py @@ -1,217 +1,232 @@ -from core.util.problem_detail import ProblemDetail as pd -from core.problem_details import * from flask_babel import lazy_gettext as _ +from core.problem_details import * +from core.util.problem_detail import ProblemDetail as pd + REMOTE_INTEGRATION_FAILED = pd( - "http://librarysimplified.org/terms/problem/remote-integration-failed", - 502, - _("Third-party service failed."), - _("The library could not complete your request because a third-party service has failed."), + "http://librarysimplified.org/terms/problem/remote-integration-failed", + 502, + _("Third-party service failed."), + _( + "The library could not complete your request because a third-party service has failed." + ), ) CANNOT_GENERATE_FEED = pd( - "http://librarysimplified.org/terms/problem/cannot-generate-feed", - 500, - _("Feed should be been pre-cached."), - _("This feed should have been pre-cached. It's too expensive to generate dynamically."), + "http://librarysimplified.org/terms/problem/cannot-generate-feed", + 500, + _("Feed should be been pre-cached."), + _( + "This feed should have been pre-cached. It's too expensive to generate dynamically." + ), ) INVALID_CREDENTIALS = pd( - "http://librarysimplified.org/terms/problem/credentials-invalid", - 401, - _("Invalid credentials"), - _("A valid library card barcode number and PIN are required."), + "http://librarysimplified.org/terms/problem/credentials-invalid", + 401, + _("Invalid credentials"), + _("A valid library card barcode number and PIN are required."), ) EXPIRED_CREDENTIALS = pd( - "http://librarysimplified.org/terms/problem/credentials-expired", - 403, - _("Expired credentials."), - _("Your library card has expired. You need to renew it."), + "http://librarysimplified.org/terms/problem/credentials-expired", + 403, + _("Expired credentials."), + _("Your library card has expired. You need to renew it."), ) BLOCKED_CREDENTIALS = pd( - "http://librarysimplified.org/terms/problem/credentials-suspended", - 403, - _("Suspended credentials."), - _("Your library card has been suspended. Contact your branch library."), + "http://librarysimplified.org/terms/problem/credentials-suspended", + 403, + _("Suspended credentials."), + _("Your library card has been suspended. Contact your branch library."), ) NO_LICENSES = pd( - "http://librarysimplified.org/terms/problem/no-licenses", - 404, - _("No licenses."), - _("The library currently has no licenses for this book."), + "http://librarysimplified.org/terms/problem/no-licenses", + 404, + _("No licenses."), + _("The library currently has no licenses for this book."), ) NO_AVAILABLE_LICENSE = pd( - "http://librarysimplified.org/terms/problem/no-available-license", - 403, - _("No available license."), - _("All licenses for this book are loaned out."), + "http://librarysimplified.org/terms/problem/no-available-license", + 403, + _("No available license."), + _("All licenses for this book are loaned out."), ) NO_ACCEPTABLE_FORMAT = pd( - "http://librarysimplified.org/terms/problem/no-acceptable-format", - 400, - _("No acceptable format."), - _("Could not deliver this book in an acceptable format."), + "http://librarysimplified.org/terms/problem/no-acceptable-format", + 400, + _("No acceptable format."), + _("Could not deliver this book in an acceptable format."), ) ALREADY_CHECKED_OUT = pd( - "http://librarysimplified.org/terms/problem/loan-already-exists", - 400, - _("Already checked out"), - _("You have already checked out this book."), + "http://librarysimplified.org/terms/problem/loan-already-exists", + 400, + _("Already checked out"), + _("You have already checked out this book."), ) -GENERIC_LOAN_LIMIT_MESSAGE = _("You have reached your loan limit. You cannot borrow anything further until you return something.") -SPECIFIC_LOAN_LIMIT_MESSAGE = _("You have reached your loan limit of %(limit)d. You cannot borrow anything further until you return something.") +GENERIC_LOAN_LIMIT_MESSAGE = _( + "You have reached your loan limit. You cannot borrow anything further until you return something." +) +SPECIFIC_LOAN_LIMIT_MESSAGE = _( + "You have reached your loan limit of %(limit)d. You cannot borrow anything further until you return something." +) LOAN_LIMIT_REACHED = pd( - "http://librarysimplified.org/terms/problem/loan-limit-reached", - 403, - _("Loan limit reached."), - GENERIC_LOAN_LIMIT_MESSAGE + "http://librarysimplified.org/terms/problem/loan-limit-reached", + 403, + _("Loan limit reached."), + GENERIC_LOAN_LIMIT_MESSAGE, ) -GENERIC_HOLD_LIMIT_MESSAGE = _("You have reached your hold limit. You cannot place another item on hold until you borrow something or remove a hold.") -SPECIFIC_HOLD_LIMIT_MESSAGE = _("You have reached your hold limit of %(limit)d. You cannot place another item on hold until you borrow something or remove a hold.") +GENERIC_HOLD_LIMIT_MESSAGE = _( + "You have reached your hold limit. You cannot place another item on hold until you borrow something or remove a hold." +) +SPECIFIC_HOLD_LIMIT_MESSAGE = _( + "You have reached your hold limit of %(limit)d. You cannot place another item on hold until you borrow something or remove a hold." +) HOLD_LIMIT_REACHED = pd( - "http://librarysimplified.org/terms/problem/hold-limit-reached", - 403, - _("Limit reached."), - GENERIC_HOLD_LIMIT_MESSAGE + "http://librarysimplified.org/terms/problem/hold-limit-reached", + 403, + _("Limit reached."), + GENERIC_HOLD_LIMIT_MESSAGE, ) OUTSTANDING_FINES = pd( - "http://librarysimplified.org/terms/problem/outstanding-fines", - 403, - _("Outstanding fines."), - _("You must pay your outstanding fines before you can borrow more books."), - ) + "http://librarysimplified.org/terms/problem/outstanding-fines", + 403, + _("Outstanding fines."), + _("You must pay your outstanding fines before you can borrow more books."), +) CHECKOUT_FAILED = pd( - "http://librarysimplified.org/terms/problem/cannot-issue-loan", - 502, - _("Could not issue loan."), - _("Could not issue loan (reason unknown)."), + "http://librarysimplified.org/terms/problem/cannot-issue-loan", + 502, + _("Could not issue loan."), + _("Could not issue loan (reason unknown)."), ) HOLD_FAILED = pd( - "http://librarysimplified.org/terms/problem/cannot-place-hold", - 502, - _("Could not place hold."), - _("Could not place hold (reason unknown)."), + "http://librarysimplified.org/terms/problem/cannot-place-hold", + 502, + _("Could not place hold."), + _("Could not place hold (reason unknown)."), ) RENEW_FAILED = pd( - "http://librarysimplified.org/terms/problem/cannot-renew-loan", - 400, - _("Could not renew loan."), - _("Could not renew loan (reason unknown)."), + "http://librarysimplified.org/terms/problem/cannot-renew-loan", + 400, + _("Could not renew loan."), + _("Could not renew loan (reason unknown)."), ) NOT_FOUND_ON_REMOTE = pd( - "http://librarysimplified.org/terms/problem/not-found-on-remote", - 404, - _("No longer in collection."), - _("This book was recently removed from the collection."), + "http://librarysimplified.org/terms/problem/not-found-on-remote", + 404, + _("No longer in collection."), + _("This book was recently removed from the collection."), ) NO_ACTIVE_LOAN = pd( - "http://librarysimplified.org/terms/problem/no-active-loan", - 400, - _("No active loan."), - _("You can't do this without first borrowing this book."), + "http://librarysimplified.org/terms/problem/no-active-loan", + 400, + _("No active loan."), + _("You can't do this without first borrowing this book."), ) NO_ACTIVE_HOLD = pd( - "http://librarysimplified.org/terms/problem/no-active-hold", - 400, - _("No active hold."), - _("You can't do this without first putting this book on hold."), + "http://librarysimplified.org/terms/problem/no-active-hold", + 400, + _("No active hold."), + _("You can't do this without first putting this book on hold."), ) NO_ACTIVE_LOAN_OR_HOLD = pd( - "http://librarysimplified.org/terms/problem/no-active-loan", - 400, - _("No active loan or hold."), - _("You can't do this without first borrowing this book or putting it on hold."), + "http://librarysimplified.org/terms/problem/no-active-loan", + 400, + _("No active loan or hold."), + _("You can't do this without first borrowing this book or putting it on hold."), ) LOAN_NOT_FOUND = pd( - "http://librarysimplified.org/terms/problem/loan-not-found", - 404, - _("Loan not found."), - _("You don't have a loan with the provided id."), + "http://librarysimplified.org/terms/problem/loan-not-found", + 404, + _("Loan not found."), + _("You don't have a loan with the provided id."), ) HOLD_NOT_FOUND = pd( - "http://librarysimplified.org/terms/problem/hold-not-found", - 404, - _("Hold not found."), - _("You don't have a hold with the provided id."), + "http://librarysimplified.org/terms/problem/hold-not-found", + 404, + _("Hold not found."), + _("You don't have a hold with the provided id."), ) COULD_NOT_MIRROR_TO_REMOTE = pd( - "http://librarysimplified.org/terms/problem/cannot-mirror-to-remote", - 502, - _("Could not mirror local state to remote."), - _("Could not convince a third party to accept the change you made. It's likely to show up again soon."), + "http://librarysimplified.org/terms/problem/cannot-mirror-to-remote", + 502, + _("Could not mirror local state to remote."), + _( + "Could not convince a third party to accept the change you made. It's likely to show up again soon." + ), ) NO_SUCH_LANE = pd( - "http://librarysimplified.org/terms/problem/unknown-lane", - 404, - _("No such lane."), - _("You asked for a nonexistent lane."), + "http://librarysimplified.org/terms/problem/unknown-lane", + 404, + _("No such lane."), + _("You asked for a nonexistent lane."), ) NO_SUCH_LIST = pd( - "http://librarysimplified.org/terms/problem/unknown-list", - 404, - _("No such list."), - _("You asked for a nonexistent list."), + "http://librarysimplified.org/terms/problem/unknown-list", + 404, + _("No such list."), + _("You asked for a nonexistent list."), ) NO_SUCH_COLLECTION = pd( - "http://librarysimplified.org/terms/problem/unknown-collection", - 404, - _("No such collection."), - _("You asked for a nonexistent collection."), + "http://librarysimplified.org/terms/problem/unknown-collection", + 404, + _("No such collection."), + _("You asked for a nonexistent collection."), ) FORBIDDEN_BY_POLICY = pd( - "http://librarysimplified.org/terms/problem/forbidden-by-policy", - 403, - _("Forbidden by policy."), - _("Library policy prevents us from carrying out your request."), + "http://librarysimplified.org/terms/problem/forbidden-by-policy", + 403, + _("Forbidden by policy."), + _("Library policy prevents us from carrying out your request."), ) NOT_AGE_APPROPRIATE = FORBIDDEN_BY_POLICY.detailed( _("Library policy considers this title inappropriate for your patron type."), - status_code=451 + status_code=451, ) CANNOT_FULFILL = pd( - "http://librarysimplified.org/terms/problem/cannot-fulfill-loan", - 400, - _("Could not fulfill loan."), - _("Could not fulfill loan."), + "http://librarysimplified.org/terms/problem/cannot-fulfill-loan", + 400, + _("Could not fulfill loan."), + _("Could not fulfill loan."), ) DELIVERY_CONFLICT = pd( - "http://librarysimplified.org/terms/problem/delivery-mechanism-conflict", - 409, - _("Delivery mechanism conflict."), - _("The delivery mechanism for this book has been locked in and can't be changed."), + "http://librarysimplified.org/terms/problem/delivery-mechanism-conflict", + 409, + _("Delivery mechanism conflict."), + _("The delivery mechanism for this book has been locked in and can't be changed."), ) BAD_DELIVERY_MECHANISM = pd( - "http://librarysimplified.org/terms/problem/bad-delivery-mechanism", - 400, - _("Unsupported delivery mechanism."), - _("You selected a delivery mechanism that's not supported by this book."), + "http://librarysimplified.org/terms/problem/bad-delivery-mechanism", + 400, + _("Unsupported delivery mechanism."), + _("You selected a delivery mechanism that's not supported by this book."), ) CANNOT_RELEASE_HOLD = pd( @@ -225,7 +240,9 @@ "http://librarysimplified.org/terms/problem/invalid-oauth-callback-parameters", status_code=400, title=_("Invalid OAuth callback parameters."), - detail=_("The OAuth callback must contain a code and a state parameter with the OAuth provider name."), + detail=_( + "The OAuth callback must contain a code and a state parameter with the OAuth provider name." + ), ) UNKNOWN_OAUTH_PROVIDER = pd( @@ -337,5 +354,5 @@ "http://librarysimplified.org/terms/problem/decryption-error", status_code=502, title=_("Decryption error"), - detail=_("Failed to decrypt a shared secret retrieved from another computer.") + detail=_("Failed to decrypt a shared secret retrieved from another computer."), ) diff --git a/api/proquest/importer.py b/api/proquest/importer.py index 37af121aea..2201b102bc 100644 --- a/api/proquest/importer.py +++ b/api/proquest/importer.py @@ -450,9 +450,7 @@ def _extract_image_links(self, publication, feed_self_url): :rtype: List[LinkData] """ self._logger.debug( - "Started extracting image links from {0}".format( - encode(publication.images) - ) + "Started extracting image links from {0}".format(encode(publication.images)) ) image_links = [] @@ -819,7 +817,11 @@ def __init__( :type process_removals: bool """ if not isinstance(parser, RWPMManifestParser): - raise ValueError("Argument 'parser' must be an instance of {0}".format(RWPMManifestParser)) + raise ValueError( + "Argument 'parser' must be an instance of {0}".format( + RWPMManifestParser + ) + ) import_class_kwargs["parser"] = parser @@ -896,7 +898,9 @@ def _download_feed_pages(self, feed_pages_directory): try: feed_page_content = json.dumps( - feed, default=str, ensure_ascii=True, + feed, + default=str, + ensure_ascii=True, ) feed_page_file.write(feed_page_content) feed_page_file.flush() diff --git a/api/registry.py b/api/registry.py index cf48fef200..ff22f8ef4b 100644 --- a/api/registry.py +++ b/api/registry.py @@ -1,30 +1,28 @@ +import base64 +import json +import logging + import feedparser from flask_babel import lazy_gettext as _ from html_sanitizer import Sanitizer -import json -import logging from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.session import Session +from api.adobe_vendor_id import AuthdataUtility +from api.config import Configuration +from api.controller import CirculationManager +from api.problem_details import * from core.model import ( + ConfigurationSetting, + ExternalIntegration, create, get_one, get_one_or_create, - ConfigurationSetting, - ExternalIntegration, ) from core.scripts import LibraryInputScript from core.util.http import HTTP -from core.util.problem_detail import ( - ProblemDetail, - JSON_MEDIA_TYPE as PROBLEM_DETAIL_JSON_MEDIA_TYPE, -) -import base64 - -from api.adobe_vendor_id import AuthdataUtility -from api.config import Configuration -from api.controller import CirculationManager -from api.problem_details import * +from core.util.problem_detail import JSON_MEDIA_TYPE as PROBLEM_DETAIL_JSON_MEDIA_TYPE +from core.util.problem_detail import ProblemDetail class RemoteRegistry(object): @@ -37,6 +35,7 @@ class RemoteRegistry(object): DISCOVERY_GOAL and wants to help patrons find their libraries) or it may be a shared ODL collection (which has LICENSE_GOAL). """ + DEFAULT_LIBRARY_REGISTRY_URL = "https://registry.thepalaceproject.org" DEFAULT_LIBRARY_REGISTRY_NAME = "Palace Library Registry" @@ -54,9 +53,7 @@ def for_integration_id(cls, _db, integration_id, goal): :param goal: The ExternalIntegration's .goal must be this goal. """ - integration = get_one(_db, ExternalIntegration, - goal=goal, - id=integration_id) + integration = get_one(_db, ExternalIntegration, goal=goal, id=integration_id) if not integration: return None return cls(integration) @@ -65,8 +62,8 @@ def for_integration_id(cls, _db, integration_id, goal): def for_protocol_and_goal(cls, _db, protocol, goal): """Find all LibraryRegistry objects with the given protocol and goal.""" for i in _db.query(ExternalIntegration).filter( - ExternalIntegration.goal==goal, - ExternalIntegration.protocol==protocol, + ExternalIntegration.goal == goal, + ExternalIntegration.protocol == protocol, ): yield cls(i) @@ -136,7 +133,12 @@ def _extract_catalog_information(cls, response): register_url = link.get("href") break if not register_url: - return REMOTE_INTEGRATION_FAILED.detailed(_("The service at %(url)s did not provide a register link.", url=response.url)) + return REMOTE_INTEGRATION_FAILED.detailed( + _( + "The service at %(url)s did not provide a register link.", + url=response.url, + ) + ) return register_url, vendor_id def fetch_registration_document(self, do_get=HTTP.debuggable_get): @@ -157,9 +159,10 @@ def fetch_registration_document(self, do_get=HTTP.debuggable_get): response = do_get(registration_url) if isinstance(response, ProblemDetail): return response - terms_of_service_link, terms_of_service_html = ( - self._extract_registration_information(response) - ) + ( + terms_of_service_link, + terms_of_service_html, + ) = self._extract_registration_information(response) return terms_of_service_link, terms_of_service_html @classmethod @@ -187,10 +190,9 @@ def _extract_registration_information(cls, response): for link in links: if link.get("rel") != "terms-of-service": continue - url = link.get('href') + url = link.get("href") is_http = any( - [url.startswith(protocol + "://") - for protocol in ("http", "https")] + [url.startswith(protocol + "://") for protocol in ("http", "https")] ) if is_http and not tos_link: tos_link = url @@ -221,7 +223,9 @@ def _extract_links(cls, response): links = feed.get("feed", {}).get("links", []) catalog = None else: - return REMOTE_INTEGRATION_FAILED.detailed(_("The service at %(url)s did not return OPDS.", url=response.url)) + return REMOTE_INTEGRATION_FAILED.detailed( + _("The service at %(url)s did not return OPDS.", url=response.url) + ) return catalog, links @classmethod @@ -240,13 +244,9 @@ def _decode_data_url(cls, url): header, encoded = parts if not header.endswith(";base64"): raise ValueError("data: URL not base64-encoded: %s" % url) - media_type = header[len("data:"):-len(";base64")] - if not any( - media_type.startswith(x) for x in ("text/html", "text/plain") - ): - raise ValueError( - "Unsupported media type in data: URL: %s" % media_type - ) + media_type = header[len("data:") : -len(";base64")] + if not any(media_type.startswith(x) for x in ("text/html", "text/plain")): + raise ValueError("Unsupported media type in data: URL: %s" % media_type) html = base64.b64decode(encoded.encode("utf-8")).decode("utf-8") return Sanitizer().sanitize(html) @@ -327,8 +327,14 @@ def setting(self, key, default_value=None): setting.value = default_value return setting - def push(self, stage, url_for, catalog_url=None, do_get=HTTP.debuggable_get, - do_post=HTTP.debuggable_post): + def push( + self, + stage, + url_for, + catalog_url=None, + do_get=HTTP.debuggable_get, + do_post=HTTP.debuggable_post, + ): """Attempt to register a library with a RemoteRegistry. NOTE: This method is designed to be used in a @@ -371,14 +377,17 @@ def push(self, stage, url_for, catalog_url=None, do_get=HTTP.debuggable_get, # needs to be committed to the database _before_ the push # attempt starts. key_pair = ConfigurationSetting.for_library( - Configuration.KEY_PAIR, self.library).json_value + Configuration.KEY_PAIR, self.library + ).json_value if not key_pair: # TODO: We could create the key pair _here_. The database # session will be committed at the end of this request, # so the push attempt would succeed if repeated. return SHARED_SECRET_DECRYPTION_ERROR.detailed( - _("Library %(library)s has no key pair set.", - library=self.library.short_name) + _( + "Library %(library)s has no key pair set.", + library=self.library.short_name, + ) ) public_key, private_key = key_pair cipher = Configuration.cipher(private_key) @@ -430,8 +439,7 @@ def _create_registration_payload(self, url_for, stage): :return: A dictionary suitable for passing into requests.post. """ auth_document_url = url_for( - "authentication_document", - library_short_name=self.library.short_name + "authentication_document", library_short_name=self.library.short_name ) payload = dict(url=auth_document_url, stage=stage) @@ -439,14 +447,14 @@ def _create_registration_payload(self, url_for, stage): # a problem with the way the library is using an integration. contact = Configuration.configuration_contact_uri(self.library) if contact: - payload['contact'] = contact + payload["contact"] = contact return payload def _create_registration_headers(self): shared_secret = self.setting(ExternalIntegration.PASSWORD).value headers = {} if shared_secret: - headers['Authorization'] = "Bearer %s" % shared_secret + headers["Authorization"] = "Bearer %s" % shared_secret return headers @classmethod @@ -458,17 +466,28 @@ def _send_registration_request(cls, register_url, headers, payload, do_post): """ # Allow 400 and 401 so we can provide a more useful error message. response = do_post( - register_url, headers=headers, payload=payload, timeout=60, + register_url, + headers=headers, + payload=payload, + timeout=60, allowed_response_codes=["2xx", "3xx", "400", "401"], ) if response.status_code in [400, 401]: if response.headers.get("Content-Type") == PROBLEM_DETAIL_JSON_MEDIA_TYPE: problem = json.loads(response.content) return INTEGRATION_ERROR.detailed( - _("Remote service returned: \"%(problem)s\"", problem=problem.get("detail"))) + _( + 'Remote service returned: "%(problem)s"', + problem=problem.get("detail"), + ) + ) else: return INTEGRATION_ERROR.detailed( - _("Remote service returned: \"%(problem)s\"", problem=response.content.decode("utf-8"))) + _( + 'Remote service returned: "%(problem)s"', + problem=response.content.decode("utf-8"), + ) + ) return response @classmethod @@ -476,7 +495,7 @@ def _decrypt_shared_secret(cls, cipher, shared_secret): """Attempt to decrypt an encrypted shared secret. :param cipher: A Cipher object. - + :param shared_secret: A byte string. :return: The decrypted shared secret, as a bytestring, or @@ -506,7 +525,10 @@ def _process_registration_result(self, catalog, cipher, desired_stage): # requests. if not isinstance(catalog, dict): return INTEGRATION_ERROR.detailed( - _("Remote service served %(representation)r, which I can't make sense of as an OPDS document.", representation=catalog) + _( + "Remote service served %(representation)r, which I can't make sense of as an OPDS document.", + representation=catalog, + ) ) metadata = catalog.get("metadata", {}) short_name = metadata.get("short_name") @@ -520,12 +542,10 @@ def _process_registration_result(self, catalog, cipher, desired_stage): break if short_name: - setting = self.setting(ExternalIntegration.USERNAME) - setting.value = short_name + setting = self.setting(ExternalIntegration.USERNAME) + setting.value = short_name if shared_secret: - shared_secret = self._decrypt_shared_secret( - cipher, shared_secret - ) + shared_secret = self._decrypt_shared_secret(cipher, shared_secret) if isinstance(shared_secret, ProblemDetail): return shared_secret @@ -562,14 +582,14 @@ class LibraryRegistrationScript(LibraryInputScript): def arg_parser(cls, _db): parser = LibraryInputScript.arg_parser(_db) parser.add_argument( - '--registry-url', + "--registry-url", help="Register libraries with the given registry.", - default=RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_URL + default=RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_URL, ) parser.add_argument( - '--stage', + "--stage", help="Register these libraries in the 'testing' stage or the 'production' stage.", - choices=(Registration.TESTING_STAGE, Registration.PRODUCTION_STAGE) + choices=(Registration.TESTING_STAGE, Registration.PRODUCTION_STAGE), ) return parser @@ -585,6 +605,7 @@ def do_run(self, cmd_args=None, in_unit_test=False): # Set up an application context so we have access to url_for. from api.app import app + app.manager = CirculationManager(self._db, testing=in_unit_test) base_url = ConfigurationSetting.sitewide( self._db, Configuration.BASE_URL_KEY @@ -594,9 +615,7 @@ def do_run(self, cmd_args=None, in_unit_test=False): for library in parsed.libraries: registration = Registration(registry, library) library_stage = stage or registration.stage_field.value - self.process_library( - registration, library_stage, app.manager.url_for - ) + self.process_library(registration, library_stage, app.manager.url_for) ctx.pop() # For testing purposes, return the application object that was @@ -610,8 +629,7 @@ def process_library(self, registration, stage, url_for): "Registration of library %r" % registration.library.short_name ) logger.info( - "Registering with %s as %s", - registration.registry.integration.url, stage + "Registering with %s as %s", registration.registry.integration.url, stage ) try: result = registration.push(stage, url_for) diff --git a/api/routes.py b/api/routes.py index a1c5d3a91a..94ce1445fd 100644 --- a/api/routes.py +++ b/api/routes.py @@ -1,14 +1,9 @@ -from functools import wraps, update_wrapper import logging import os +from functools import update_wrapper, wraps import flask -from flask import ( - Response, - redirect, - request, - make_response, -) +from flask import Response, make_response, redirect, request from flask_cors.core import get_cors_options, set_cors_headers from werkzeug.exceptions import HTTPException @@ -18,53 +13,50 @@ # we never want werkzeug's merge_slashes feature. app.url_map.merge_slashes = False -from .config import Configuration -from core.app_server import ( - ErrorHandler, - compressible, - returns_problem_detail, -) +from flask_babel import lazy_gettext as _ + +from core.app_server import ErrorHandler, compressible, returns_problem_detail from core.model import ConfigurationSetting from core.util.problem_detail import ProblemDetail + +from .config import Configuration from .controller import CirculationManager from .problem_details import REMOTE_INTEGRATION_FAILED -from flask_babel import lazy_gettext as _ + @app.before_first_request def initialize_circulation_manager(): - if os.environ.get('AUTOINITIALIZE') == "False": + if os.environ.get("AUTOINITIALIZE") == "False": # It's the responsibility of the importing code to set app.manager # appropriately. pass else: - if getattr(app, 'manager', None) is None: + if getattr(app, "manager", None) is None: try: app.manager = CirculationManager(app._db) except Exception: - logging.exception( - "Error instantiating circulation manager!" - ) + logging.exception("Error instantiating circulation manager!") raise # Make sure that any changes to the database (as might happen # on initial setup) are committed before continuing. app.manager._db.commit() + @babel.localeselector def get_locale(): languages = Configuration.localization_languages() return request.accept_languages.best_match(languages) + @app.teardown_request def shutdown_session(exception): - if (hasattr(app, 'manager') - and hasattr(app.manager, '_db') - and app.manager._db - ): + if hasattr(app, "manager") and hasattr(app.manager, "_db") and app.manager._db: if exception: app.manager._db.rollback() else: app.manager._db.commit() + def requires_auth(f): @wraps(f) def decorated(*args, **kwargs): @@ -75,8 +67,10 @@ def decorated(*args, **kwargs): return patron else: return f(*args, **kwargs) + return decorated + def allows_auth(f): """Decorator function for a controller method that supports both authenticated and unauthenticated requests. @@ -84,6 +78,7 @@ def allows_auth(f): NOTE: This decorator might not be necessary; you can probably call BaseCirculationManagerController.request_patron instead. """ + @wraps(f) def decorated(*args, **kwargs): # Try to authenticate a patron. This will set flask.request.patron @@ -93,8 +88,10 @@ def decorated(*args, **kwargs): # Call the decorated function regardless of whether # authentication succeeds. return f(*args, **kwargs) + return decorated + # The allows_patron_web decorator will add Cross-Origin Resource Sharing # (CORS) headers to routes that will be used by the patron web interface. # This is necessary for a JS app on a different domain to make requests. @@ -105,7 +102,7 @@ def decorated(*args, **kwargs): def allows_patron_web(f): # Override Flask's default behavior and intercept the OPTIONS method for # every request so CORS headers can be added. - f.required_methods = getattr(f, 'required_methods', set()) + f.required_methods = getattr(f, "required_methods", set()) f.required_methods.add("OPTIONS") f.provide_automatic_options = False @@ -118,15 +115,18 @@ def wrapped_function(*args, **kwargs): patron_web_domains = app.manager.patron_web_domains if patron_web_domains: options = get_cors_options( - app, dict(origins=patron_web_domains, - supports_credentials=True) + app, dict(origins=patron_web_domains, supports_credentials=True) ) set_cors_headers(resp, options) return resp + return update_wrapper(wrapped_function, f) -h = ErrorHandler(app, app.config['DEBUG']) + +h = ErrorHandler(app, app.config["DEBUG"]) + + @app.errorhandler(Exception) @allows_patron_web def exception_handler(exception): @@ -137,11 +137,13 @@ def exception_handler(exception): return exception return h.handle(exception) + def has_library(f): """Decorator to extract the library short name from the arguments.""" + @wraps(f) def decorated(*args, **kwargs): - if 'library_short_name' in kwargs: + if "library_short_name" in kwargs: library_short_name = kwargs.pop("library_short_name") else: library_short_name = None @@ -150,9 +152,13 @@ def decorated(*args, **kwargs): return library.response else: return f(*args, **kwargs) + return decorated -def has_library_through_external_loan_identifier(parameter_name='external_loan_identifier'): + +def has_library_through_external_loan_identifier( + parameter_name="external_loan_identifier", +): """Decorator to get a library using the loan's external identifier. :param parameter_name: Name of the parameter holding the loan's external identifier @@ -161,6 +167,7 @@ def has_library_through_external_loan_identifier(parameter_name='external_loan_i :return: Decorated function :rtype: Callable """ + def decorator(func): @wraps(func) def wrapper(*args, **kwargs): @@ -169,7 +176,11 @@ def wrapper(*args, **kwargs): else: external_loan_identifier = None - library = app.manager.index_controller.library_through_external_loan_identifier(external_loan_identifier) + library = ( + app.manager.index_controller.library_through_external_loan_identifier( + external_loan_identifier + ) + ) if isinstance(library, ProblemDetail): return library.response @@ -180,28 +191,35 @@ def wrapper(*args, **kwargs): return decorator + def allows_library(f): """Decorator similar to @has_library but if there is no library short name, then don't set the request library. """ + @wraps(f) def decorated(*args, **kwargs): - if 'library_short_name' in kwargs: + if "library_short_name" in kwargs: library_short_name = kwargs.pop("library_short_name") - library = app.manager.index_controller.library_for_request(library_short_name) + library = app.manager.index_controller.library_for_request( + library_short_name + ) if isinstance(library, ProblemDetail): return library.response else: library = None return f(*args, **kwargs) + return decorated + def library_route(path, *args, **kwargs): """Decorator to creates routes that have a library short name in either a subdomain or a url path prefix. If not used with @has_library, the view function must have a library_short_name argument. """ + def decorator(f): # This sets up routes for both the subdomain and the url path prefix. # The order of these determines which one will be used by url_for - @@ -209,11 +227,15 @@ def decorator(f): # We may want to have a configuration option to specify whether to # use a subdomain or a url path prefix. prefix_route = app.route("/" + path, *args, **kwargs)(f) - subdomain_route = app.route(path, subdomain="", *args, **kwargs)(prefix_route) + subdomain_route = app.route( + path, subdomain="", *args, **kwargs + )(prefix_route) default_library_route = app.route(path, *args, **kwargs)(subdomain_route) return default_library_route + return decorator + def library_dir_route(path, *args, **kwargs): """Decorator to create library routes that work with or without a trailing slash.""" @@ -230,15 +252,36 @@ def decorator(f): # by the CORS decorator and won't be valid CORS responses. # Decorate f with four routes, with and without the slash, with a prefix or subdomain - prefix_slash = app.route("/" + path_without_slash + "/", strict_slashes=False, *args, **kwargs)(f) - prefix_no_slash = app.route("/" + path_without_slash, *args, **kwargs)(prefix_slash) - subdomain_slash = app.route(path_without_slash + "/", strict_slashes=False, subdomain="", *args, **kwargs)(prefix_no_slash) - subdomain_no_slash = app.route(path_without_slash, subdomain="", *args, **kwargs)(subdomain_slash) - default_library_slash = app.route(path_without_slash, *args, **kwargs)(subdomain_no_slash) - default_library_no_slash = app.route(path_without_slash + "/", *args, **kwargs)(default_library_slash) + prefix_slash = app.route( + "/" + path_without_slash + "/", + strict_slashes=False, + *args, + **kwargs + )(f) + prefix_no_slash = app.route( + "/" + path_without_slash, *args, **kwargs + )(prefix_slash) + subdomain_slash = app.route( + path_without_slash + "/", + strict_slashes=False, + subdomain="", + *args, + **kwargs + )(prefix_no_slash) + subdomain_no_slash = app.route( + path_without_slash, subdomain="", *args, **kwargs + )(subdomain_slash) + default_library_slash = app.route(path_without_slash, *args, **kwargs)( + subdomain_no_slash + ) + default_library_no_slash = app.route(path_without_slash + "/", *args, **kwargs)( + default_library_slash + ) return default_library_no_slash + return decorator + @library_route("/", strict_slashes=False) @has_library @allows_patron_web @@ -247,21 +290,24 @@ def decorator(f): def index(): return app.manager.index_controller() -@library_route('/authentication_document') + +@library_route("/authentication_document") @has_library @returns_problem_detail @compressible def authentication_document(): return app.manager.index_controller.authentication_document() -@library_route('/public_key_document') + +@library_route("/public_key_document") @returns_problem_detail @compressible def public_key_document(): return app.manager.index_controller.public_key_document() -@library_dir_route('/groups', defaults=dict(lane_identifier=None)) -@library_route('/groups/') + +@library_dir_route("/groups", defaults=dict(lane_identifier=None)) +@library_route("/groups/") @has_library @allows_patron_web @returns_problem_detail @@ -269,7 +315,8 @@ def public_key_document(): def acquisition_groups(lane_identifier): return app.manager.opds_feeds.groups(lane_identifier) -@library_route('/feed/qa/series') + +@library_route("/feed/qa/series") @has_library @allows_patron_web @requires_auth @@ -278,7 +325,8 @@ def acquisition_groups(lane_identifier): def qa_series_feed(): return app.manager.opds_feeds.qa_series_feed() -@library_route('/feed/qa') + +@library_route("/feed/qa") @has_library @allows_patron_web @requires_auth @@ -287,8 +335,9 @@ def qa_series_feed(): def qa_feed(): return app.manager.opds_feeds.qa_feed() -@library_dir_route('/feed', defaults=dict(lane_identifier=None)) -@library_route('/feed/') + +@library_dir_route("/feed", defaults=dict(lane_identifier=None)) +@library_route("/feed/") @has_library @allows_patron_web @returns_problem_detail @@ -296,8 +345,9 @@ def qa_feed(): def feed(lane_identifier): return app.manager.opds_feeds.feed(lane_identifier) -@library_dir_route('/navigation', defaults=dict(lane_identifier=None)) -@library_route('/navigation/') + +@library_dir_route("/navigation", defaults=dict(lane_identifier=None)) +@library_route("/navigation/") @has_library @allows_patron_web @returns_problem_detail @@ -305,7 +355,8 @@ def feed(lane_identifier): def navigation_feed(lane_identifier): return app.manager.opds_feeds.navigation(lane_identifier) -@library_route('/crawlable') + +@library_route("/crawlable") @has_library @allows_patron_web @returns_problem_detail @@ -313,7 +364,8 @@ def navigation_feed(lane_identifier): def crawlable_library_feed(): return app.manager.opds_feeds.crawlable_library_feed() -@library_route('/lists//crawlable') + +@library_route("/lists//crawlable") @has_library @allows_patron_web @returns_problem_detail @@ -321,66 +373,94 @@ def crawlable_library_feed(): def crawlable_list_feed(list_name): return app.manager.opds_feeds.crawlable_list_feed(list_name) -@app.route('/collections//crawlable') + +@app.route("/collections//crawlable") @allows_patron_web @returns_problem_detail @compressible def crawlable_collection_feed(collection_name): return app.manager.opds_feeds.crawlable_collection_feed(collection_name) + @app.route("/collections/") @returns_problem_detail def shared_collection_info(collection_name): return app.manager.shared_collection_controller.info(collection_name) + @app.route("/collections//register", methods=["POST"]) @returns_problem_detail def shared_collection_register(collection_name): return app.manager.shared_collection_controller.register(collection_name) -@app.route("/collections////borrow", - methods=['GET', 'POST'], defaults=dict(hold_id=None)) -@app.route("/collections//holds//borrow", - methods=['GET', 'POST'], defaults=dict(identifier_type=None, identifier=None)) + +@app.route( + "/collections////borrow", + methods=["GET", "POST"], + defaults=dict(hold_id=None), +) +@app.route( + "/collections//holds//borrow", + methods=["GET", "POST"], + defaults=dict(identifier_type=None, identifier=None), +) @returns_problem_detail def shared_collection_borrow(collection_name, identifier_type, identifier, hold_id): - return app.manager.shared_collection_controller.borrow(collection_name, identifier_type, identifier, hold_id) + return app.manager.shared_collection_controller.borrow( + collection_name, identifier_type, identifier, hold_id + ) + @app.route("/collections//loans/") @returns_problem_detail def shared_collection_loan_info(collection_name, loan_id): return app.manager.shared_collection_controller.loan_info(collection_name, loan_id) + @app.route("/collections//loans//revoke") @returns_problem_detail def shared_collection_revoke_loan(collection_name, loan_id): - return app.manager.shared_collection_controller.revoke_loan(collection_name, loan_id) + return app.manager.shared_collection_controller.revoke_loan( + collection_name, loan_id + ) -@app.route("/collections//loans//fulfill", defaults=dict(mechanism_id=None)) + +@app.route( + "/collections//loans//fulfill", + defaults=dict(mechanism_id=None), +) @app.route("/collections//loans//fulfill/") @returns_problem_detail def shared_collection_fulfill(collection_name, loan_id, mechanism_id): - return app.manager.shared_collection_controller.fulfill(collection_name, loan_id, mechanism_id) + return app.manager.shared_collection_controller.fulfill( + collection_name, loan_id, mechanism_id + ) + @app.route("/collections//holds/") @returns_problem_detail def shared_collection_hold_info(collection_name, hold_id): return app.manager.shared_collection_controller.hold_info(collection_name, hold_id) + @app.route("/collections//holds//revoke") @returns_problem_detail def shared_collection_revoke_hold(collection_name, hold_id): - return app.manager.shared_collection_controller.revoke_hold(collection_name, hold_id) + return app.manager.shared_collection_controller.revoke_hold( + collection_name, hold_id + ) + -@library_route('/marc') +@library_route("/marc") @has_library @returns_problem_detail @compressible def marc_page(): return app.manager.marc_records.download_page() -@library_dir_route('/search', defaults=dict(lane_identifier=None)) -@library_route('/search/') + +@library_dir_route("/search", defaults=dict(lane_identifier=None)) +@library_route("/search/") @has_library @allows_patron_web @returns_problem_detail @@ -388,7 +468,8 @@ def marc_page(): def lane_search(lane_identifier): return app.manager.opds_feeds.search(lane_identifier) -@library_dir_route('/patrons/me', methods=['GET', 'PUT']) + +@library_dir_route("/patrons/me", methods=["GET", "PUT"]) @has_library @allows_patron_web @requires_auth @@ -396,7 +477,8 @@ def lane_search(lane_identifier): def patron_profile(): return app.manager.profiles.protocol() -@library_dir_route('/loans', methods=['GET', 'HEAD']) + +@library_dir_route("/loans", methods=["GET", "HEAD"]) @has_library @allows_patron_web @requires_auth @@ -405,7 +487,8 @@ def patron_profile(): def active_loans(): return app.manager.loans.sync() -@library_route('/annotations/', methods=['HEAD', 'GET', 'POST']) + +@library_route("/annotations/", methods=["HEAD", "GET", "POST"]) @has_library @allows_patron_web @requires_auth @@ -414,7 +497,8 @@ def active_loans(): def annotations(): return app.manager.annotations.container() -@library_route('/annotations/', methods=['HEAD', 'GET', 'DELETE']) + +@library_route("/annotations/", methods=["HEAD", "GET", "DELETE"]) @has_library @allows_patron_web @requires_auth @@ -423,7 +507,8 @@ def annotations(): def annotation_detail(annotation_id): return app.manager.annotations.detail(annotation_id) -@library_route('/annotations//', methods=['GET']) + +@library_route("/annotations//", methods=["GET"]) @has_library @allows_patron_web @requires_auth @@ -432,9 +517,14 @@ def annotation_detail(annotation_id): def annotations_for_work(identifier_type, identifier): return app.manager.annotations.container_for_work(identifier_type, identifier) -@library_route('/works///borrow', methods=['GET', 'PUT']) -@library_route('/works///borrow/', - methods=['GET', 'PUT']) + +@library_route( + "/works///borrow", methods=["GET", "PUT"] +) +@library_route( + "/works///borrow/", + methods=["GET", "PUT"], +) @has_library @allows_patron_web @requires_auth @@ -442,16 +532,18 @@ def annotations_for_work(identifier_type, identifier): def borrow(identifier_type, identifier, mechanism_id=None): return app.manager.loans.borrow(identifier_type, identifier, mechanism_id) -@library_route('/works//fulfill') -@library_route('/works//fulfill/') -@library_route('/works//fulfill//') + +@library_route("/works//fulfill") +@library_route("/works//fulfill/") +@library_route("/works//fulfill//") @has_library @allows_patron_web @returns_problem_detail def fulfill(license_pool_id, mechanism_id=None, part=None): return app.manager.loans.fulfill(license_pool_id, mechanism_id, part) -@library_route('/loans//revoke', methods=['GET', 'PUT']) + +@library_route("/loans//revoke", methods=["GET", "PUT"]) @has_library @allows_patron_web @requires_auth @@ -459,7 +551,8 @@ def fulfill(license_pool_id, mechanism_id=None, part=None): def revoke_loan_or_hold(license_pool_id): return app.manager.loans.revoke(license_pool_id) -@library_route('/loans//', methods=['GET', 'DELETE']) + +@library_route("/loans//", methods=["GET", "DELETE"]) @has_library @allows_patron_web @requires_auth @@ -467,27 +560,41 @@ def revoke_loan_or_hold(license_pool_id): def loan_or_hold_detail(identifier_type, identifier): return app.manager.loans.detail(identifier_type, identifier) -@library_dir_route('/works') + +@library_dir_route("/works") @has_library @allows_patron_web @returns_problem_detail @compressible def work(): - return app.manager.urn_lookup.work_lookup('work') + return app.manager.urn_lookup.work_lookup("work") + -@library_dir_route('/works/contributor/', defaults=dict(languages=None, audiences=None)) -@library_dir_route('/works/contributor//', defaults=dict(audiences=None)) -@library_route('/works/contributor///') +@library_dir_route( + "/works/contributor/", + defaults=dict(languages=None, audiences=None), +) +@library_dir_route( + "/works/contributor//", defaults=dict(audiences=None) +) +@library_route("/works/contributor///") @has_library @allows_patron_web @returns_problem_detail @compressible def contributor(contributor_name, languages, audiences): - return app.manager.work_controller.contributor(contributor_name, languages, audiences) + return app.manager.work_controller.contributor( + contributor_name, languages, audiences + ) -@library_dir_route('/works/series/', defaults=dict(languages=None, audiences=None)) -@library_dir_route('/works/series//', defaults=dict(audiences=None)) -@library_route('/works/series///') + +@library_dir_route( + "/works/series/", defaults=dict(languages=None, audiences=None) +) +@library_dir_route( + "/works/series//", defaults=dict(audiences=None) +) +@library_route("/works/series///") @has_library @allows_patron_web @returns_problem_detail @@ -495,7 +602,8 @@ def contributor(contributor_name, languages, audiences): def series(series_name, languages, audiences): return app.manager.work_controller.series(series_name, languages, audiences) -@library_route('/works//') + +@library_route("/works//") @has_library @allows_auth @allows_patron_web @@ -504,7 +612,8 @@ def series(series_name, languages, audiences): def permalink(identifier_type, identifier): return app.manager.work_controller.permalink(identifier_type, identifier) -@library_route('/works///recommendations') + +@library_route("/works///recommendations") @has_library @allows_patron_web @returns_problem_detail @@ -512,7 +621,8 @@ def permalink(identifier_type, identifier): def recommendations(identifier_type, identifier): return app.manager.work_controller.recommendations(identifier_type, identifier) -@library_route('/works///related_books') + +@library_route("/works///related_books") @has_library @allows_patron_web @returns_problem_detail @@ -520,23 +630,30 @@ def recommendations(identifier_type, identifier): def related_books(identifier_type, identifier): return app.manager.work_controller.related(identifier_type, identifier) -@library_route('/works///report', methods=['GET', 'POST']) + +@library_route( + "/works///report", methods=["GET", "POST"] +) @has_library @allows_patron_web @returns_problem_detail def report(identifier_type, identifier): return app.manager.work_controller.report(identifier_type, identifier) -@library_route('/analytics///') + +@library_route("/analytics///") @has_library @allows_auth @allows_patron_web @returns_problem_detail def track_analytics_event(identifier_type, identifier, event_type): - return app.manager.analytics_controller.track_event(identifier_type, identifier, event_type) + return app.manager.analytics_controller.track_event( + identifier_type, identifier, event_type + ) + # Adobe Vendor ID implementation -@library_route('/AdobeAuth/authdata') +@library_route("/AdobeAuth/authdata") @has_library @requires_auth @returns_problem_detail @@ -547,57 +664,72 @@ def adobe_vendor_id_get_token(): ) return app.manager.adobe_vendor_id.create_authdata_handler(flask.request.patron) -@library_route('/AdobeAuth/SignIn', methods=['POST']) + +@library_route("/AdobeAuth/SignIn", methods=["POST"]) @has_library @returns_problem_detail def adobe_vendor_id_signin(): return app.manager.adobe_vendor_id.signin_handler() -@app.route('/AdobeAuth/AccountInfo', methods=['POST']) + +@app.route("/AdobeAuth/AccountInfo", methods=["POST"]) @returns_problem_detail def adobe_vendor_id_accountinfo(): return app.manager.adobe_vendor_id.userinfo_handler() -@app.route('/AdobeAuth/Status') + +@app.route("/AdobeAuth/Status") @returns_problem_detail def adobe_vendor_id_status(): return app.manager.adobe_vendor_id.status_handler() + # DRM Device Management Protocol implementation for ACS. -@library_route('/AdobeAuth/devices', methods=['GET', 'POST']) +@library_route("/AdobeAuth/devices", methods=["GET", "POST"]) @has_library @requires_auth @returns_problem_detail def adobe_drm_devices(): return app.manager.adobe_device_management.device_id_list_handler() -@library_route('/AdobeAuth/devices/', methods=['DELETE']) + +@library_route("/AdobeAuth/devices/", methods=["DELETE"]) @has_library @requires_auth @returns_problem_detail def adobe_drm_device(device_id): return app.manager.adobe_device_management.device_id_handler(device_id) + # Route that redirects to the authentication URL for an OAuth provider -@library_route('/oauth_authenticate') +@library_route("/oauth_authenticate") @has_library @returns_problem_detail def oauth_authenticate(): - return app.manager.oauth_controller.oauth_authentication_redirect(flask.request.args, app.manager._db) + return app.manager.oauth_controller.oauth_authentication_redirect( + flask.request.args, app.manager._db + ) + # Redirect URI for OAuth providers, eg. Clever -@library_route('/oauth_callback') +@library_route("/oauth_callback") @has_library @returns_problem_detail def oauth_callback(): - return app.manager.oauth_controller.oauth_authentication_callback(app.manager._db, flask.request.args) + return app.manager.oauth_controller.oauth_authentication_callback( + app.manager._db, flask.request.args + ) + # Route that redirects to the authentication URL for a SAML provider -@library_route('/saml_authenticate') +@library_route("/saml_authenticate") @has_library @returns_problem_detail def saml_authenticate(): - return app.manager.saml_controller.saml_authentication_redirect(flask.request.args, app.manager._db) + return app.manager.saml_controller.saml_authentication_redirect( + flask.request.args, app.manager._db + ) + # Redirect URI for SAML providers # NOTE: we cannot use @has_library decorator and append a library's name to saml_calback route @@ -609,44 +741,50 @@ def saml_authenticate(): # the URL saved in the SP's metadata configured in this IdP will differ. # Library's name is passed as a part of the relay state and processed in SAMLController.saml_authentication_callback @returns_problem_detail -@app.route("/saml_callback", methods=['POST']) +@app.route("/saml_callback", methods=["POST"]) def saml_callback(): - return app.manager.saml_controller.saml_authentication_callback(request, app.manager._db) + return app.manager.saml_controller.saml_authentication_callback( + request, app.manager._db + ) -@app.route('//lcp/licenses//hint') -@app.route('//lcp/licenses//hint') -@has_library_through_external_loan_identifier(parameter_name='license_id') +@app.route("//lcp/licenses//hint") +@app.route("//lcp/licenses//hint") +@has_library_through_external_loan_identifier(parameter_name="license_id") @requires_auth @returns_problem_detail def lcp_passphrase(collection_name, license_id): return app.manager.lcp_controller.get_lcp_passphrase() -@app.route('//lcp/licenses/') -@has_library_through_external_loan_identifier(parameter_name='license_id') +@app.route("//lcp/licenses/") +@has_library_through_external_loan_identifier(parameter_name="license_id") @requires_auth @returns_problem_detail def lcp_license(collection_name, license_id): return app.manager.lcp_controller.get_lcp_license(collection_name, license_id) + # Loan notifications for ODL distributors, eg. Feedbooks -@library_route('/odl_notify/', methods=['GET', 'POST']) +@library_route("/odl_notify/", methods=["GET", "POST"]) @has_library @returns_problem_detail def odl_notify(loan_id): return app.manager.odl_notification_controller.notify(loan_id) + # Controllers used for operations purposes -@app.route('/heartbeat') +@app.route("/heartbeat") @returns_problem_detail def heartbeat(): return app.manager.heartbeat.heartbeat() -@app.route('/healthcheck.html') + +@app.route("/healthcheck.html") def health_check(): return Response("", 200) + @app.route("/images/") def static_image(filename): return app.manager.static_files.image(filename) diff --git a/api/saml/configuration/model.py b/api/saml/configuration/model.py index 737bf61b45..af1e15d711 100644 --- a/api/saml/configuration/model.py +++ b/api/saml/configuration/model.py @@ -25,6 +25,7 @@ cgi.escape = html.escape + class SAMLConfigurationError(BaseError): """Raised in the case of any configuration errors.""" @@ -79,8 +80,8 @@ class SAMLConfiguration(ConfigurationGrouping): default="true", options=[ ConfigurationOption("true", "Use SAML NameID"), - ConfigurationOption("false", "Do NOT use SAML NameID") - ] + ConfigurationOption("false", "Do NOT use SAML NameID"), + ], ) patron_id_attributes = ConfigurationMetadata( diff --git a/api/saml/metadata/federations/model.py b/api/saml/metadata/federations/model.py index 82f63f614a..83d4adc3d3 100644 --- a/api/saml/metadata/federations/model.py +++ b/api/saml/metadata/federations/model.py @@ -32,7 +32,9 @@ def __init__(self, federation_type, idp_metadata_service_url, certificate=None): """ if not federation_type or not isinstance(federation_type, str): raise ValueError("Argument 'federation_type' must be a non-empty string") - if not idp_metadata_service_url or not isinstance(idp_metadata_service_url, str): + if not idp_metadata_service_url or not isinstance( + idp_metadata_service_url, str + ): raise ValueError( "Argument 'idp_metadata_service_url' must be a non-empty string" ) diff --git a/api/saml/metadata/filter.py b/api/saml/metadata/filter.py index 5e42eb826d..90b72a682d 100644 --- a/api/saml/metadata/filter.py +++ b/api/saml/metadata/filter.py @@ -34,7 +34,11 @@ def __init__(self, dsl_evaluator): :type dsl_evaluator: core.python_expression_dsl.evaluator.DSLEvaluator """ if not isinstance(dsl_evaluator, DSLEvaluator): - raise ValueError("Argument 'dsl_evaluator' must be an instance of {0} class".format(DSLEvaluator)) + raise ValueError( + "Argument 'dsl_evaluator' must be an instance of {0} class".format( + DSLEvaluator + ) + ) self._dsl_evaluator = dsl_evaluator self._logger = logging.getLogger(__name__) diff --git a/api/saml/metadata/model.py b/api/saml/metadata/model.py index 90843a3f34..236fce2244 100644 --- a/api/saml/metadata/model.py +++ b/api/saml/metadata/model.py @@ -1075,9 +1075,7 @@ def __init__(self, name_id, attribute_statement, valid_till=None): elif isinstance(valid_till, datetime.datetime): self._valid_till = valid_till - utc_now() elif isinstance(valid_till, int): - self._valid_till = ( - from_timestamp(valid_till) - utc_now() - ) + self._valid_till = from_timestamp(valid_till) - utc_now() elif isinstance(valid_till, datetime.timedelta): self._valid_till = valid_till else: diff --git a/api/saml/provider.py b/api/saml/provider.py index c4aee3bf47..ec6fe261a6 100644 --- a/api/saml/provider.py +++ b/api/saml/provider.py @@ -471,7 +471,9 @@ def saml_callback(self, db, subject): return patron_data # Convert the PatronData into a Patron object - patron, is_new = patron_data.get_or_create_patron(db, self.library_id, self.analytics) + patron, is_new = patron_data.get_or_create_patron( + db, self.library_id, self.analytics + ) # Create a credential for the Patron with self.get_configuration(db) as configuration: diff --git a/api/selftest.py b/api/selftest.py index 501cb6461a..e84fd1ccf4 100644 --- a/api/selftest.py +++ b/api/selftest.py @@ -1,26 +1,17 @@ import sys + from sqlalchemy.orm.session import Session -from .authenticator import LibraryAuthenticator -from .circulation import CirculationAPI -from .feedbooks import ( - FeedbooksOPDSImporter, - FeedbooksImportMonitor, -) from core.config import IntegrationException -from core.model import ( - ExternalIntegration, - LicensePool, -) -from core.opds_import import ( - OPDSImporter, - OPDSImportMonitor, -) +from core.model import ExternalIntegration, LicensePool +from core.opds_import import OPDSImporter, OPDSImportMonitor from core.scripts import LibraryInputScript -from core.selftest import ( - HasSelfTests as CoreHasSelfTests, - SelfTestResult, -) +from core.selftest import HasSelfTests as CoreHasSelfTests +from core.selftest import SelfTestResult + +from .authenticator import LibraryAuthenticator +from .circulation import CirculationAPI +from .feedbooks import FeedbooksImportMonitor, FeedbooksOPDSImporter class HasSelfTests(CoreHasSelfTests): @@ -45,16 +36,14 @@ def default_patrons(self, collection): yield self.test_failure( "Acquiring test patron credentials.", "Collection is not associated with any libraries.", - "Add the collection to a library that has a patron authentication service." + "Add the collection to a library that has a patron authentication service.", ) for library in collection.libraries: name = library.name task = "Acquiring test patron credentials for library %s" % library.name try: - library_authenticator = LibraryAuthenticator.from_config( - _db, library - ) + library_authenticator = LibraryAuthenticator.from_config(_db, library) patron = password = None auth = library_authenticator.basic_auth_provider if auth: @@ -64,7 +53,7 @@ def default_patrons(self, collection): yield self.test_failure( task, "Library has no test patron configured.", - "You can specify a test patron when you configure the library's patron authentication service." + "You can specify a test patron when you configure the library's patron authentication service.", ) continue @@ -128,11 +117,7 @@ def process_result(self, result): success = "SUCCESS" else: success = "FAILURE" - self.out.write( - " %s %s (%.1fsec)\n" % ( - success, result.name, result.duration - ) - ) + self.out.write(" %s %s (%.1fsec)\n" % (success, result.name, result.duration)) if isinstance(result.result, (bytes, str)): self.out.write(" Result: %s\n" % result.result) if result.exception: @@ -161,9 +146,7 @@ def _no_delivery_mechanisms_test(self): else: title = "[title unknown]" identifier = lp.identifier.identifier - titles.append( - "%s (ID: %s)" % (title, identifier) - ) + titles.append("%s (ID: %s)" % (title, identifier)) if titles: return titles @@ -173,5 +156,5 @@ def _no_delivery_mechanisms_test(self): def _run_self_tests(self): yield self.run_test( "Checking for titles that have no delivery mechanisms.", - self._no_delivery_mechanisms_test + self._no_delivery_mechanisms_test, ) diff --git a/api/shared_collection.py b/api/shared_collection.py index 5d67706dd3..12b60254d1 100644 --- a/api/shared_collection.py +++ b/api/shared_collection.py @@ -1,19 +1,17 @@ +import base64 +import json import logging + import flask -import json from flask_babel import lazy_gettext as _ -from core.model import ( - Collection, - ConfigurationSetting, - IntegrationClient, - get_one, -) -from .circulation_exceptions import * -from .config import Configuration from core.config import CannotLoadConfiguration +from core.model import Collection, ConfigurationSetting, IntegrationClient, get_one from core.util.http import HTTP -import base64 + +from .circulation_exceptions import * +from .config import Configuration + class SharedCollectionAPI(object): """Logic for circulating books to patrons of libraries on other @@ -53,7 +51,8 @@ def __init__(self, _db, api_map=None): except CannotLoadConfiguration as e: self.log.error( "Error loading configuration for %s: %s", - collection.name, str(e) + collection.name, + str(e), ) self.initialization_exceptions[collection.id] = e if api: @@ -65,6 +64,7 @@ def default_api_map(self): API class Y to handle that collection. """ from .odl import ODLAPI + return { ODLAPI.NAME: ODLAPI, } @@ -79,8 +79,11 @@ def api(self, collection): api = self.api_for_collection.get(collection.id) if not api: raise CirculationException( - _("Collection %(collection)s is not a shared collection.", - collection=collection.name)) + _( + "Collection %(collection)s is not a shared collection.", + collection=collection.name, + ) + ) return api def register(self, collection, auth_document_url, do_get=HTTP.get_with_timeout): @@ -89,16 +92,20 @@ def register(self, collection, auth_document_url, do_get=HTTP.get_with_timeout): collection's settings.""" if not auth_document_url: raise InvalidInputException( - _("An authentication document URL is required to register a library.")) + _("An authentication document URL is required to register a library.") + ) auth_response = do_get(auth_document_url, allowed_response_codes=["2xx", "3xx"]) try: auth_document = json.loads(auth_response.content) except ValueError as e: raise RemoteInitiatedServerError( - _("Authentication document at %(auth_document_url)s was not valid JSON.", - auth_document_url=auth_document_url), - _("Remote authentication document")) + _( + "Authentication document at %(auth_document_url)s was not valid JSON.", + auth_document_url=auth_document_url, + ), + _("Remote authentication document"), + ) links = auth_document.get("links") start_url = None @@ -109,25 +116,39 @@ def register(self, collection, auth_document_url, do_get=HTTP.get_with_timeout): if not start_url: raise RemoteInitiatedServerError( - _("Authentication document at %(auth_document_url)s did not contain a start link.", - auth_document_url=auth_document_url), - _("Remote authentication document")) + _( + "Authentication document at %(auth_document_url)s did not contain a start link.", + auth_document_url=auth_document_url, + ), + _("Remote authentication document"), + ) external_library_urls = ConfigurationSetting.for_externalintegration( - BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, collection.external_integration + BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, + collection.external_integration, ).json_value if not external_library_urls or start_url not in external_library_urls: raise AuthorizationFailedException( - _("Your library's URL is not one of the allowed URLs for this collection. Ask the collection administrator to add %(library_url)s to the list of allowed URLs.", - library_url=start_url)) + _( + "Your library's URL is not one of the allowed URLs for this collection. Ask the collection administrator to add %(library_url)s to the list of allowed URLs.", + library_url=start_url, + ) + ) public_key = auth_document.get("public_key") - if not public_key or not public_key.get("type") == "RSA" or not public_key.get("value"): + if ( + not public_key + or not public_key.get("type") == "RSA" + or not public_key.get("value") + ): raise RemoteInitiatedServerError( - _("Authentication document at %(auth_document_url)s did not contain an RSA public key.", - auth_document_url=auth_document_url), - _("Remote authentication document")) + _( + "Authentication document at %(auth_document_url)s did not contain an RSA public key.", + auth_document_url=auth_document_url, + ), + _("Remote authentication document"), + ) public_key = public_key.get("value") encryptor = Configuration.cipher(public_key) @@ -144,9 +165,12 @@ def register(self, collection, auth_document_url, do_get=HTTP.get_with_timeout): def check_client_authorization(self, collection, client): """Verify that an IntegrationClient is whitelisted for access to the collection.""" external_library_urls = ConfigurationSetting.for_externalintegration( - BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, collection.external_integration + BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, + collection.external_integration, ).json_value - if client.url not in [IntegrationClient.normalize_url(url) for url in external_library_urls]: + if client.url not in [ + IntegrationClient.normalize_url(url) for url in external_library_urls + ]: raise AuthorizationFailedException() def borrow(self, collection, client, pool, hold=None): @@ -197,18 +221,26 @@ class BaseSharedCollectionAPI(object): SETTINGS = [ { "key": EXTERNAL_LIBRARY_URLS, - "label": _("URLs for libraries on other circulation managers that use this collection"), - "description": _("A URL should include the library's short name (e.g. https://circulation.librarysimplified.org/NYNYPL/), even if it is the only library on the circulation manager."), + "label": _( + "URLs for libraries on other circulation managers that use this collection" + ), + "description": _( + "A URL should include the library's short name (e.g. https://circulation.librarysimplified.org/NYNYPL/), even if it is the only library on the circulation manager." + ), "type": "list", "format": "url", }, { "key": Collection.EBOOK_LOAN_DURATION_KEY, - "label": _("Ebook Loan Duration for libraries on other circulation managers (in Days)"), + "label": _( + "Ebook Loan Duration for libraries on other circulation managers (in Days)" + ), "default": Collection.STANDARD_DEFAULT_LOAN_PERIOD, - "description": _("When a patron from another library borrows an ebook from this collection, the circulation manager will ask for a loan that lasts this number of days. This must be equal to or less than the maximum loan duration negotiated with the distributor."), + "description": _( + "When a patron from another library borrows an ebook from this collection, the circulation manager will ask for a loan that lasts this number of days. This must be equal to or less than the maximum loan duration negotiated with the distributor." + ), "type": "number", - } + }, ] def checkout_to_external_library(self, client, pool, hold=None): diff --git a/api/simple_authentication.py b/api/simple_authentication.py index dd7237320b..1591a41baa 100644 --- a/api/simple_authentication.py +++ b/api/simple_authentication.py @@ -1,15 +1,10 @@ from flask_babel import lazy_gettext as _ -from .authenticator import ( - BasicAuthenticationProvider, - PatronData, -) +from core.model import Patron -from .config import ( - CannotLoadConfiguration, -) +from .authenticator import BasicAuthenticationProvider, PatronData +from .config import CannotLoadConfiguration -from core.model import Patron class SimpleAuthenticationProvider(BasicAuthenticationProvider): """An authentication provider that authenticates a single patron. @@ -17,39 +12,54 @@ class SimpleAuthenticationProvider(BasicAuthenticationProvider): This serves only one purpose: to set up a working circulation manager before connecting it to an ILS. """ + NAME = "Simple Authentication Provider" - DESCRIPTION = _(""" + DESCRIPTION = _( + """ An internal authentication service that authenticates a single patron. This is useful for testing a circulation manager before connecting - it to an ILS.""") + it to an ILS.""" + ) - ADDITIONAL_TEST_IDENTIFIERS = 'additional_test_identifiers' + ADDITIONAL_TEST_IDENTIFIERS = "additional_test_identifiers" - TEST_NEIGHBORHOOD = 'neighborhood' + TEST_NEIGHBORHOOD = "neighborhood" basic_settings = list(BasicAuthenticationProvider.SETTINGS) for i, setting in enumerate(basic_settings): - if setting['key'] == BasicAuthenticationProvider.TEST_IDENTIFIER: + if setting["key"] == BasicAuthenticationProvider.TEST_IDENTIFIER: s = dict(**setting) - s['description'] = BasicAuthenticationProvider.TEST_IDENTIFIER_DESCRIPTION_FOR_REQUIRED_PASSWORD + s[ + "description" + ] = ( + BasicAuthenticationProvider.TEST_IDENTIFIER_DESCRIPTION_FOR_REQUIRED_PASSWORD + ) basic_settings[i] = s - elif setting['key'] == BasicAuthenticationProvider.TEST_PASSWORD: + elif setting["key"] == BasicAuthenticationProvider.TEST_PASSWORD: s = dict(**setting) - s['required'] = True - s['description'] = BasicAuthenticationProvider.TEST_PASSWORD_DESCRIPTION_REQUIRED + s["required"] = True + s[ + "description" + ] = BasicAuthenticationProvider.TEST_PASSWORD_DESCRIPTION_REQUIRED basic_settings[i] = s SETTINGS = basic_settings + [ - { "key": ADDITIONAL_TEST_IDENTIFIERS, - "label": _("Additional test identifiers"), - "type": "list", - "description": _("Identifiers for additional patrons to use in testing. The identifiers will all use the same test password as the first identifier."), + { + "key": ADDITIONAL_TEST_IDENTIFIERS, + "label": _("Additional test identifiers"), + "type": "list", + "description": _( + "Identifiers for additional patrons to use in testing. The identifiers will all use the same test password as the first identifier." + ), + }, + { + "key": TEST_NEIGHBORHOOD, + "label": _("Test neighborhood"), + "description": _( + "For analytics purposes, all patrons will be 'from' this neighborhood." + ), }, - { "key": TEST_NEIGHBORHOOD, - "label": _("Test neighborhood"), - "description": _("For analytics purposes, all patrons will be 'from' this neighborhood."), - } ] def __init__(self, library, integration, analytics=None): @@ -60,17 +70,19 @@ def __init__(self, library, integration, analytics=None): self.test_password = integration.setting(self.TEST_PASSWORD).value test_identifier = integration.setting(self.TEST_IDENTIFIER).value if not (test_identifier and self.test_password): - raise CannotLoadConfiguration( - "Test identifier and password not set." - ) + raise CannotLoadConfiguration("Test identifier and password not set.") self.test_identifiers = [test_identifier, test_identifier + "_username"] - additional_identifiers = integration.setting(self.ADDITIONAL_TEST_IDENTIFIERS).json_value + additional_identifiers = integration.setting( + self.ADDITIONAL_TEST_IDENTIFIERS + ).json_value if additional_identifiers: for identifier in additional_identifiers: self.test_identifiers += [identifier, identifier + "_username"] - self.test_neighborhood = integration.setting(self.TEST_NEIGHBORHOOD).value or None + self.test_neighborhood = ( + integration.setting(self.TEST_NEIGHBORHOOD).value or None + ) def remote_authenticate(self, username, password): "Fake 'remote' authentication." @@ -99,8 +111,8 @@ def generate_patrondata(cls, authorization_identifier, neighborhood=None): permanent_id=identifier + "_id", username=username, personal_name=personal_name, - authorization_expires = None, - fines = None, + authorization_expires=None, + fines=None, neighborhood=neighborhood, ) return patrondata @@ -113,17 +125,21 @@ def valid_patron(self, username, password): the given dictionary? """ if self.collects_password: - password_match = (password==self.test_password) + password_match = password == self.test_password else: - password_match = (password in (None, '')) + password_match = password in (None, "") return password_match and username in self.test_identifiers def _remote_patron_lookup(self, patron_or_patrondata): if not patron_or_patrondata: return None - if ((isinstance(patron_or_patrondata, PatronData) - or isinstance(patron_or_patrondata, Patron)) - and patron_or_patrondata.authorization_identifier in self.test_identifiers): - return self.generate_patrondata(patron_or_patrondata.authorization_identifier) + if ( + isinstance(patron_or_patrondata, PatronData) + or isinstance(patron_or_patrondata, Patron) + ) and patron_or_patrondata.authorization_identifier in self.test_identifiers: + return self.generate_patrondata( + patron_or_patrondata.authorization_identifier + ) + AuthenticationProvider = SimpleAuthenticationProvider diff --git a/api/sip/__init__.py b/api/sip/__init__.py index 7f1d2ed754..86d8f20efc 100644 --- a/api/sip/__init__.py +++ b/api/sip/__init__.py @@ -1,15 +1,15 @@ +import json from datetime import datetime + from flask_babel import lazy_gettext as _ -from api.authenticator import ( - BasicAuthenticationProvider, - PatronData, -) + +from api.authenticator import BasicAuthenticationProvider, PatronData from api.sip.client import SIPClient -from core.util.http import RemoteIntegrationException -from core.util import MoneyUtility -from core.model import ExternalIntegration -import json from api.sip.dialect import Dialect as Sip2Dialect +from core.model import ExternalIntegration +from core.util import MoneyUtility +from core.util.http import RemoteIntegrationException + class SIP2AuthenticationProvider(BasicAuthenticationProvider): @@ -30,79 +30,107 @@ class SIP2AuthenticationProvider(BasicAuthenticationProvider): PATRON_STATUS_BLOCK = "patron status block" SETTINGS = [ - { "key": ExternalIntegration.URL, "label": _("Server"), "required": True }, - { "key": PORT, "label": _("Port"), "required": True , "type": "number" }, - { "key": ExternalIntegration.USERNAME, "label": _("Login User ID") }, - { "key": ExternalIntegration.PASSWORD, "label": _("Login Password") }, - { "key": LOCATION_CODE, "label": _("Location Code") }, + {"key": ExternalIntegration.URL, "label": _("Server"), "required": True}, + {"key": PORT, "label": _("Port"), "required": True, "type": "number"}, + {"key": ExternalIntegration.USERNAME, "label": _("Login User ID")}, + {"key": ExternalIntegration.PASSWORD, "label": _("Login Password")}, + {"key": LOCATION_CODE, "label": _("Location Code")}, { "key": ENCODING, "label": _("Data encoding"), "default": DEFAULT_ENCODING, - "description": _("By default, SIP2 servers encode outgoing data using the Code Page 850 encoding, but some ILSes allow some other encoding to be used, usually UTF-8."), + "description": _( + "By default, SIP2 servers encode outgoing data using the Code Page 850 encoding, but some ILSes allow some other encoding to be used, usually UTF-8." + ), }, - { "key": USE_SSL, "label": _("Connect over SSL?"), - "description": _("Some SIP2 servers require or allow clients to connect securely over SSL. Other servers don't support SSL, and require clients to use an ordinary socket connection."), - "type": "select", - "options": [ - { "key": "true", "label": _("Connect to the SIP2 server over SSL")}, - { "key": "false", "label": _("Connect to the SIP2 server over an ordinary socket connection")}, - ], - "default": "false", - "required": True, + { + "key": USE_SSL, + "label": _("Connect over SSL?"), + "description": _( + "Some SIP2 servers require or allow clients to connect securely over SSL. Other servers don't support SSL, and require clients to use an ordinary socket connection." + ), + "type": "select", + "options": [ + {"key": "true", "label": _("Connect to the SIP2 server over SSL")}, + { + "key": "false", + "label": _( + "Connect to the SIP2 server over an ordinary socket connection" + ), + }, + ], + "default": "false", + "required": True, }, - { "key": ILS, "label": _("ILS"), - "description": _("Some ILS require specific SIP2 settings. If the ILS you are using is in the list please pick it otherwise select 'Generic ILS'."), - "type": "select", - "options": [ - {"key": Sip2Dialect.GENERIC_ILS, "label": _("Generic ILS")}, - {"key": Sip2Dialect.AG_VERSO, "label": _("Auto-Graphics VERSO")}, - ], - "default": Sip2Dialect.GENERIC_ILS, - "required": True, + { + "key": ILS, + "label": _("ILS"), + "description": _( + "Some ILS require specific SIP2 settings. If the ILS you are using is in the list please pick it otherwise select 'Generic ILS'." + ), + "type": "select", + "options": [ + {"key": Sip2Dialect.GENERIC_ILS, "label": _("Generic ILS")}, + {"key": Sip2Dialect.AG_VERSO, "label": _("Auto-Graphics VERSO")}, + ], + "default": Sip2Dialect.GENERIC_ILS, + "required": True, }, - { "key": SSL_CERTIFICATE, "label": _("SSL Certificate"), - "description": _('The SSL certificate used to securely connect to an SSL-enabled SIP2 server. Not all SSL-enabled SIP2 servers require a custom certificate, but some do. This should be a string beginning with -----BEGIN CERTIFICATE----- and ending with -----END CERTIFICATE-----'), - "type": "textarea", + { + "key": SSL_CERTIFICATE, + "label": _("SSL Certificate"), + "description": _( + "The SSL certificate used to securely connect to an SSL-enabled SIP2 server. Not all SSL-enabled SIP2 servers require a custom certificate, but some do. This should be a string beginning with -----BEGIN CERTIFICATE----- and ending with -----END CERTIFICATE-----" + ), + "type": "textarea", }, { - "key": SSL_KEY, "label": _("SSL Key"), - "description" : _('The private key, if any, used to sign the SSL certificate above. If present, this should be a string beginning with -----BEGIN PRIVATE KEY----- and ending with -----END PRIVATE KEY-----'), - "type": "textarea", + "key": SSL_KEY, + "label": _("SSL Key"), + "description": _( + "The private key, if any, used to sign the SSL certificate above. If present, this should be a string beginning with -----BEGIN PRIVATE KEY----- and ending with -----END PRIVATE KEY-----" + ), + "type": "textarea", }, - { "key": FIELD_SEPARATOR, "label": _("Field Separator"), - "default": "|", "required": True, + { + "key": FIELD_SEPARATOR, + "label": _("Field Separator"), + "default": "|", + "required": True, }, - { "key": PATRON_STATUS_BLOCK, - "label": _("SIP2 Patron Status Block"), - "description": _( - "Block patrons from borrowing based on the status of the SIP2 patron status field."), - "type": "select", - "options": [ - {"key": "true", "label": _("Block based on patron status field")}, - {"key": "false", "label": _("No blocks based on patron status field")}, - ], - "default": "true", + { + "key": PATRON_STATUS_BLOCK, + "label": _("SIP2 Patron Status Block"), + "description": _( + "Block patrons from borrowing based on the status of the SIP2 patron status field." + ), + "type": "select", + "options": [ + {"key": "true", "label": _("Block based on patron status field")}, + {"key": "false", "label": _("No blocks based on patron status field")}, + ], + "default": "true", }, ] + BasicAuthenticationProvider.SETTINGS # Map the reasons why SIP2 might report a patron is blocked to the # protocol-independent block reason used by PatronData. SPECIFIC_BLOCK_REASONS = { - SIPClient.CARD_REPORTED_LOST : PatronData.CARD_REPORTED_LOST, - SIPClient.EXCESSIVE_FINES : PatronData.EXCESSIVE_FINES, - SIPClient.EXCESSIVE_FEES : PatronData.EXCESSIVE_FEES, - SIPClient.TOO_MANY_ITEMS_BILLED : PatronData.TOO_MANY_ITEMS_BILLED, - SIPClient.CHARGE_PRIVILEGES_DENIED : PatronData.NO_BORROWING_PRIVILEGES, - SIPClient.TOO_MANY_ITEMS_CHARGED : PatronData.TOO_MANY_LOANS, - SIPClient.TOO_MANY_ITEMS_OVERDUE : PatronData.TOO_MANY_OVERDUE, - SIPClient.TOO_MANY_RENEWALS : PatronData.TOO_MANY_RENEWALS, - SIPClient.TOO_MANY_LOST : PatronData.TOO_MANY_LOST, - SIPClient.RECALL_OVERDUE : PatronData.RECALL_OVERDUE, + SIPClient.CARD_REPORTED_LOST: PatronData.CARD_REPORTED_LOST, + SIPClient.EXCESSIVE_FINES: PatronData.EXCESSIVE_FINES, + SIPClient.EXCESSIVE_FEES: PatronData.EXCESSIVE_FEES, + SIPClient.TOO_MANY_ITEMS_BILLED: PatronData.TOO_MANY_ITEMS_BILLED, + SIPClient.CHARGE_PRIVILEGES_DENIED: PatronData.NO_BORROWING_PRIVILEGES, + SIPClient.TOO_MANY_ITEMS_CHARGED: PatronData.TOO_MANY_LOANS, + SIPClient.TOO_MANY_ITEMS_OVERDUE: PatronData.TOO_MANY_OVERDUE, + SIPClient.TOO_MANY_RENEWALS: PatronData.TOO_MANY_RENEWALS, + SIPClient.TOO_MANY_LOST: PatronData.TOO_MANY_LOST, + SIPClient.RECALL_OVERDUE: PatronData.RECALL_OVERDUE, } - def __init__(self, library, integration, analytics=None, - client=SIPClient, connect=True): + def __init__( + self, library, integration, analytics=None, client=SIPClient, connect=True + ): """An object capable of communicating with a SIP server. :param server: Hostname of the SIP server. @@ -151,7 +179,7 @@ def __init__(self, library, integration, analytics=None, self.encoding = integration.setting(self.ENCODING).value_or_default( self.DEFAULT_ENCODING ) - self.field_separator = integration.setting(self.FIELD_SEPARATOR).value or '|' + self.field_separator = integration.setting(self.FIELD_SEPARATOR).value or "|" self.use_ssl = integration.setting(self.USE_SSL).json_value self.ssl_cert = integration.setting(self.SSL_CERTIFICATE).value self.ssl_key = integration.setting(self.SSL_KEY).value @@ -159,7 +187,9 @@ def __init__(self, library, integration, analytics=None, self.client = client patron_status_block = integration.setting(self.PATRON_STATUS_BLOCK).json_value if patron_status_block is None or patron_status_block: - self.fields_that_deny_borrowing = SIPClient.PATRON_STATUS_FIELDS_THAT_DENY_BORROWING_PRIVILEGES + self.fields_that_deny_borrowing = ( + SIPClient.PATRON_STATUS_FIELDS_THAT_DENY_BORROWING_PRIVILEGES + ) else: self.fields_that_deny_borrowing = [] @@ -186,7 +216,7 @@ def _client(self): ssl_cert=self.ssl_cert, ssl_key=self.ssl_key, encoding=self.encoding.lower(), - dialect=self.dialect + dialect=self.dialect, ) def patron_information(self, username, password): @@ -200,9 +230,7 @@ def patron_information(self, username, password): return info except IOError as e: - raise RemoteIntegrationException( - self.server or 'unknown server', str(e) - ) + raise RemoteIntegrationException(self.server or "unknown server", str(e)) def _remote_patron_lookup(self, patron_or_patrondata): info = self.patron_information( @@ -230,43 +258,46 @@ def makeConnection(sip): return sip.connection sip = self._client - connection = self.run_test( - ("Test Connection"), - makeConnection, - sip - ) + connection = self.run_test(("Test Connection"), makeConnection, sip) yield connection if not connection.success: return login = self.run_test( - ("Test Login with username '%s' and password '%s'" % (self.login_user_id, self.login_password)), - sip.login + ( + "Test Login with username '%s' and password '%s'" + % (self.login_user_id, self.login_password) + ), + sip.login, ) yield login # Log in was successful so test patron's test credentials if login.success: - results = [r for r in super(SIP2AuthenticationProvider, self)._run_self_tests(_db)] + results = [ + r for r in super(SIP2AuthenticationProvider, self)._run_self_tests(_db) + ] for result in results: yield result if results[0].success: + def raw_patron_information(): - info = sip.patron_information(self.test_username, self.test_password) + info = sip.patron_information( + self.test_username, self.test_password + ) return json.dumps(info, indent=1) yield self.run_test( "Patron information request", sip.patron_information_request, self.test_username, - patron_password=self.test_password + patron_password=self.test_password, ) yield self.run_test( - ("Raw test patron information"), - raw_patron_information + ("Raw test patron information"), raw_patron_information ) def info_to_patrondata(self, info, validate_password=True): @@ -275,12 +306,12 @@ def info_to_patrondata(self, info, validate_password=True): SIPClient.patron_information() to an abstract, authenticator-independent PatronData object. """ - if info.get('valid_patron', 'N') == 'N': + if info.get("valid_patron", "N") == "N": # The patron could not be identified as a patron of this # library. Don't return any data. return None - if info.get('valid_patron_password') == 'N' and validate_password: + if info.get("valid_patron_password") == "N" and validate_password: # The patron did not authenticate correctly. Don't # return any data. return None @@ -290,22 +321,25 @@ def info_to_patrondata(self, info, validate_password=True): # authenticated," rather than "you didn't provide a # password so we didn't check." patrondata = PatronData() - if 'sipserver_internal_id' in info: - patrondata.permanent_id = info['sipserver_internal_id'] - if 'patron_identifier' in info: - patrondata.authorization_identifier = info['patron_identifier'] - if 'email_address' in info: - patrondata.email_address = info['email_address'] - if 'personal_name' in info: - patrondata.personal_name = info['personal_name'] - if 'fee_amount' in info: - fines = info['fee_amount'] + if "sipserver_internal_id" in info: + patrondata.permanent_id = info["sipserver_internal_id"] + if "patron_identifier" in info: + patrondata.authorization_identifier = info["patron_identifier"] + if "email_address" in info: + patrondata.email_address = info["email_address"] + if "personal_name" in info: + patrondata.personal_name = info["personal_name"] + if "fee_amount" in info: + fines = info["fee_amount"] else: - fines = '0' + fines = "0" patrondata.fines = MoneyUtility.parse(fines) - if 'sipserver_patron_class' in info: - patrondata.external_type = info['sipserver_patron_class'] - for expire_field in ['sipserver_patron_expiration', 'polaris_patron_expiration']: + if "sipserver_patron_class" in info: + patrondata.external_type = info["sipserver_patron_class"] + for expire_field in [ + "sipserver_patron_expiration", + "polaris_patron_expiration", + ]: if expire_field in info: value = info.get(expire_field) value = self.parse_date(value) @@ -316,15 +350,14 @@ def info_to_patrondata(self, info, validate_password=True): # A True value in most (but not all) subfields of the # patron_status field will prohibit the patron from borrowing # books. - status = info['patron_status_parsed'] + status = info["patron_status_parsed"] block_reason = PatronData.NO_VALUE for field in self.fields_that_deny_borrowing: if status.get(field) is True: block_reason = self.SPECIFIC_BLOCK_REASONS.get( field, PatronData.UNKNOWN_BLOCK ) - if block_reason not in (PatronData.NO_VALUE, - PatronData.UNKNOWN_BLOCK): + if block_reason not in (PatronData.NO_VALUE, PatronData.UNKNOWN_BLOCK): # Even if there are multiple problems with this # patron's account, we can now present a specific # error message. There's no need to look through @@ -335,8 +368,8 @@ def info_to_patrondata(self, info, validate_password=True): # If we can tell by looking at the SIP2 message that the # patron has excessive fines, we can use that as the reason # they're blocked. - if 'fee_limit' in info: - fee_limit = MoneyUtility.parse(info['fee_limit']).amount + if "fee_limit" in info: + fee_limit = MoneyUtility.parse(info["fee_limit"]).amount if fee_limit and patrondata.fines > fee_limit: patrondata.block_reason = PatronData.EXCESSIVE_FINES @@ -357,4 +390,5 @@ def parse_date(cls, value): # NOTE: It's not necessary to implement remote_patron_lookup # because authentication gets patron data as a side effect. + AuthenticationProvider = SIP2AuthenticationProvider diff --git a/api/sip/client.py b/api/sip/client.py index 71bf14be87..120eb3869c 100644 --- a/api/sip/client.py +++ b/api/sip/client.py @@ -32,6 +32,7 @@ import socket import ssl import tempfile + from api.sip.dialect import GenericILS from core.util.datetime_helpers import utc_now @@ -40,6 +41,7 @@ # fields in a way that makes it easy to reliably parse response # documents. + class fixed(object): """A fixed-width field in a SIP2 response.""" @@ -58,34 +60,38 @@ def consume(self, data, in_progress): :return: The original input string, after the value of this field has been removed. """ - value = data[:self.length] + value = data[: self.length] in_progress[self.internal_name] = value - return data[self.length:] + return data[self.length :] @classmethod def _add(cls, internal_name, *args, **kwargs): obj = cls(internal_name, *args, **kwargs) setattr(cls, internal_name, obj) -fixed._add('patron_status', 14) -fixed._add('language', 3) -fixed._add('transaction_date', 18) -fixed._add('hold_items_count', 4) -fixed._add('overdue_items_count', 4) -fixed._add('charged_items_count', 4) -fixed._add('fine_items_count', 4) -fixed._add('recall_items_count', 4) -fixed._add('unavailable_holds_count', 4) -fixed._add('login_ok', 1) -fixed._add('end_session', 1) + +fixed._add("patron_status", 14) +fixed._add("language", 3) +fixed._add("transaction_date", 18) +fixed._add("hold_items_count", 4) +fixed._add("overdue_items_count", 4) +fixed._add("charged_items_count", 4) +fixed._add("fine_items_count", 4) +fixed._add("recall_items_count", 4) +fixed._add("unavailable_holds_count", 4) +fixed._add("login_ok", 1) +fixed._add("end_session", 1) + class named(object): """A variable-length field in a SIP2 response.""" - def __init__(self, internal_name, sip_code, required=False, - length=None, allow_multiple=False): + + def __init__( + self, internal_name, sip_code, required=False, length=None, allow_multiple=False + ): self.sip_code = sip_code self.internal_name = internal_name - self.req=required + self.req = required self.length = length self.allow_multiple = allow_multiple @@ -100,8 +106,9 @@ def required(self): To check whether a specific field actually is required, check `field.req`. """ - return named(self.internal_name, self.sip_code, True, - self.length, self.allow_multiple) + return named( + self.internal_name, self.sip_code, True, self.length, self.allow_multiple + ) def consume(self, value, in_progress): """Process the given value for this field. @@ -117,10 +124,12 @@ def consume(self, value, in_progress): if self.length and len(value) != self.length: self.log.warn( "Expected string of length %d for field %s, but got %r", - self.length, self.sip_code, value + self.length, + self.sip_code, + value, ) if self.allow_multiple: - in_progress.setdefault(self.internal_name,[]).append(value) + in_progress.setdefault(self.internal_name, []).append(value) else: in_progress[self.internal_name] = value @@ -129,6 +138,7 @@ def _add(cls, internal_name, *args, **kwargs): obj = cls(internal_name, *args, **kwargs) setattr(cls, internal_name, obj) + named._add("institution_id", "AO") named._add("patron_identifier", "AA") named._add("personal_name", "AE") @@ -158,33 +168,35 @@ def _add(cls, internal_name, *args, **kwargs): # SIP extensions defined by Georgia Public Library Service's SIP # server, used by Evergreen and Koha. -named._add('sipserver_patron_expiration', 'PA') -named._add('sipserver_patron_class', 'PC') -named._add('sipserver_internet_privileges', 'PI') -named._add('sipserver_internal_id', 'XI') +named._add("sipserver_patron_expiration", "PA") +named._add("sipserver_patron_class", "PC") +named._add("sipserver_internet_privileges", "PI") +named._add("sipserver_internal_id", "XI") # SIP extensions defined by Polaris. -named._add('polaris_patron_birthdate', 'BC') -named._add('polaris_postal_code', 'PZ') -named._add('polaris_patron_expiration', 'PX') -named._add('polaris_patron_expired', 'PY') +named._add("polaris_patron_birthdate", "BC") +named._add("polaris_postal_code", "PZ") +named._add("polaris_patron_expiration", "PX") +named._add("polaris_patron_expired", "PY") # A potential problem: Polaris defines PA to refer to something else. + class RequestResend(IOError): """There was an error transmitting a message and the server has requested that it be resent. """ + class Constants(object): UNKNOWN_LANGUAGE = "000" ENGLISH = "001" # By default, SIP2 messages are encoded using Code Page 850. - DEFAULT_ENCODING = 'cp850' + DEFAULT_ENCODING = "cp850" # SIP2 messages are terminated with the \r character. - TERMINATOR_CHAR = '\r' + TERMINATOR_CHAR = "\r" class SIPClient(Constants): @@ -196,20 +208,20 @@ class SIPClient(Constants): # These are the subfield names associated with the 'patron status' # field as specified in the SIP2 spec. - CHARGE_PRIVILEGES_DENIED = 'charge privileges denied' - RENEWAL_PRIVILEGES_DENIED = 'renewal privileges denied' - RECALL_PRIVILEGES_DENIED = 'recall privileges denied' - HOLD_PRIVILEGES_DENIED = 'hold privileges denied' - CARD_REPORTED_LOST = 'card reported lost' - TOO_MANY_ITEMS_CHARGED = 'too many items charged' - TOO_MANY_ITEMS_OVERDUE = 'too many items overdue' - TOO_MANY_RENEWALS = 'too many renewals' - TOO_MANY_RETURN_CLAIMS = 'too many claims of items returned' - TOO_MANY_LOST= 'too many items lost' - EXCESSIVE_FINES = 'excessive outstanding fines' - EXCESSIVE_FEES = 'excessive outstanding fees' - RECALL_OVERDUE = 'recall overdue' - TOO_MANY_ITEMS_BILLED = 'too many items billed' + CHARGE_PRIVILEGES_DENIED = "charge privileges denied" + RENEWAL_PRIVILEGES_DENIED = "renewal privileges denied" + RECALL_PRIVILEGES_DENIED = "recall privileges denied" + HOLD_PRIVILEGES_DENIED = "hold privileges denied" + CARD_REPORTED_LOST = "card reported lost" + TOO_MANY_ITEMS_CHARGED = "too many items charged" + TOO_MANY_ITEMS_OVERDUE = "too many items overdue" + TOO_MANY_RENEWALS = "too many renewals" + TOO_MANY_RETURN_CLAIMS = "too many claims of items returned" + TOO_MANY_LOST = "too many items lost" + EXCESSIVE_FINES = "excessive outstanding fines" + EXCESSIVE_FEES = "excessive outstanding fees" + RECALL_OVERDUE = "recall overdue" + TOO_MANY_ITEMS_BILLED = "too many items billed" # All the flags, in the order they're used in the 'patron status' # field. @@ -227,7 +239,7 @@ class SIPClient(Constants): EXCESSIVE_FINES, EXCESSIVE_FEES, RECALL_OVERDUE, - TOO_MANY_ITEMS_BILLED + TOO_MANY_ITEMS_BILLED, ] # Some, but not all, of these fields, imply that a patron has lost @@ -241,13 +253,23 @@ class SIPClient(Constants): EXCESSIVE_FINES, EXCESSIVE_FEES, RECALL_OVERDUE, - TOO_MANY_ITEMS_BILLED + TOO_MANY_ITEMS_BILLED, ] - def __init__(self, target_server, target_port, login_user_id=None, - login_password=None, location_code=None, institution_id='', separator=None, - use_ssl=False, ssl_cert=None, ssl_key=None, - encoding=Constants.DEFAULT_ENCODING, dialect=GenericILS + def __init__( + self, + target_server, + target_port, + login_user_id=None, + login_password=None, + location_code=None, + institution_id="", + separator=None, + use_ssl=False, + ssl_cert=None, + ssl_key=None, + encoding=Constants.DEFAULT_ENCODING, + dialect=GenericILS, ): """Initialize a client for (but do not connect to) a SIP2 server. @@ -268,7 +290,7 @@ def __init__(self, target_server, target_port, login_user_id=None, self.target_port = int(target_port) self.location_code = location_code self.institution_id = institution_id - self.separator = separator or '|' + self.separator = separator or "|" self.use_ssl = use_ssl or ssl_cert or ssl_key self.ssl_cert = ssl_cert @@ -277,8 +299,8 @@ def __init__(self, target_server, target_port, login_user_id=None, # Turn the separator string into a regular expression that splits # field name/field value pairs on the separator string. - if self.separator in '|.^$*+?{}()[]\\': - escaped = '\\' + self.separator + if self.separator in "|.^$*+?{}()[]\\": + escaped = "\\" + self.separator else: escaped = self.separator self.separator_re = re.compile(escaped + "([A-Z][A-Z])") @@ -288,7 +310,7 @@ def __init__(self, target_server, target_port, login_user_id=None, self.login_user_id = login_user_id if login_user_id: if not login_password: - login_password = '' + login_password = "" # We need to log in before using this server. self.must_log_in = True else: @@ -301,27 +323,33 @@ def login(self): """Log in to the SIP server if required.""" if self.must_log_in: response = self.make_request( - self.login_message, self.login_response_parser, - self.login_user_id, self.login_password, self.location_code + self.login_message, + self.login_response_parser, + self.login_user_id, + self.login_password, + self.location_code, ) - if response['login_ok'] != '1': + if response["login_ok"] != "1": raise IOError("Error logging in: %r" % response) return response def patron_information(self, *args, **kwargs): - """Get information about a patron. - """ + """Get information about a patron.""" return self.make_request( - self.patron_information_request, self.patron_information_parser, - *args, **kwargs + self.patron_information_request, + self.patron_information_parser, + *args, + **kwargs ) def end_session(self, *args, **kwargs): """Send end session message.""" if self.dialect.sendEndSession: return self.make_request( - self.end_session_message, self.end_session_response_parser, - *args, **kwargs + self.end_session_message, + self.end_session_response_parser, + *args, + **kwargs ) else: return None @@ -341,9 +369,8 @@ def connect(self): self.connection.connect((self.target_server, self.target_port)) except socket.error as message: raise IOError( - "Could not connect to %s:%s - %s" % ( - self.target_server, self.target_port, message - ) + "Could not connect to %s:%s - %s" + % (self.target_server, self.target_port, message) ) # Since this is a new socket connection, reset the message count @@ -377,8 +404,7 @@ def make_secure_connection(self): os.close(fd) connection = self.make_insecure_connection() connection = ssl.wrap_socket( - connection, certfile=tmp_ssl_cert_path, - keyfile=tmp_ssl_key_path + connection, certfile=tmp_ssl_cert_path, keyfile=tmp_ssl_key_path ) # Now that the connection has been established, the temporary @@ -415,7 +441,7 @@ def make_request(self, message_creator, parser, *args, **kwargs): if retries >= self.MAXIMUM_RETRIES: # Only retry MAXIMUM_RETRIES times in case we we are sending # a message the ILS doesn't like, so we don't retry forever - raise IOError('Maximum SIP retries reached') + raise IOError("Maximum SIP retries reached") self.send(message_with_checksum) response = self.read_message() try: @@ -430,13 +456,24 @@ def make_request(self, message_creator, parser, *args, **kwargs): retries += 1 return parsed - def login_message(self, login_user_id, login_password, location_code="", - uid_algorithm="0", - pwd_algorithm="0"): + def login_message( + self, + login_user_id, + login_password, + location_code="", + uid_algorithm="0", + pwd_algorithm="0", + ): """Generate a message for logging in to a SIP server.""" - message = ("93" + uid_algorithm + pwd_algorithm - + "CN" + login_user_id + self.separator - + "CO" + login_password + message = ( + "93" + + uid_algorithm + + pwd_algorithm + + "CN" + + login_user_id + + self.separator + + "CO" + + login_password ) if location_code: message = message + self.separator + "CP" + location_code @@ -444,15 +481,13 @@ def login_message(self, login_user_id, login_password, location_code="", def login_response_parser(self, message): """Parse the response from a login message.""" - return self.parse_response( - message, - 94, - fixed.login_ok - ) + return self.parse_response(message, 94, fixed.login_ok) def end_session_message( - self, patron_identifier, patron_password="", - terminal_password="", + self, + patron_identifier, + patron_password="", + terminal_password="", ): """ This message will be sent when a patron has completed all of their @@ -472,10 +507,17 @@ def end_session_message( code = "35" timestamp = self.now() - message = (code + timestamp + - "AO" + self.institution_id + self.separator + - "AA" + patron_identifier + self.separator + - "AC" + terminal_password + message = ( + code + + timestamp + + "AO" + + self.institution_id + + self.separator + + "AA" + + patron_identifier + + self.separator + + "AC" + + terminal_password ) if patron_password: message += self.separator + "AD" + patron_password @@ -491,13 +533,16 @@ def end_session_response_parser(self, message): named.institution_id.required, named.patron_identifier.required, named.screen_message, - named.print_line + named.print_line, ) def patron_information_request( - self, patron_identifier, patron_password="", - terminal_password="", - language=None, summary=None + self, + patron_identifier, + patron_password="", + terminal_password="", + language=None, + summary=None, ): """ A superset of patron status request. @@ -520,10 +565,19 @@ def patron_information_request( timestamp = self.now() summary = summary or self.summary() - message = (code + language + timestamp + summary - + "AO" + self.institution_id + self.separator + - "AA" + patron_identifier + self.separator + - "AC" + terminal_password + message = ( + code + + language + + timestamp + + summary + + "AO" + + self.institution_id + + self.separator + + "AA" + + patron_identifier + + self.separator + + "AC" + + terminal_password ) if patron_password: message += self.separator + "AD" + patron_password @@ -608,22 +662,21 @@ def patron_information_parser(self, data): named.phone_number, named.screen_message, named.print_line, - # Add common extension fields. named.sipserver_patron_expiration, named.polaris_patron_expiration, named.sipserver_patron_class, named.sipserver_internet_privileges, - named.sipserver_internal_id + named.sipserver_internal_id, ) # As a convenience, parse the patron_status field from a # 14-character string into a dictionary of booleans. try: - parsed = self.parse_patron_status(response.get('patron_status')) + parsed = self.parse_patron_status(response.get("patron_status")) except ValueError as e: parsed = {} - response['patron_status_parsed'] = parsed + response["patron_status_parsed"] = parsed return response def parse_response(self, data, expect_status_code, *fields): @@ -678,7 +731,7 @@ def parse_response(self, data, expect_status_code, *fields): # field object, and process it. while i < len(split): sip_code = split[i] - value = split[i+1] + value = split[i + 1] if sip_code == named.sequence_number.sip_code: # Sequence number is special in two ways. First, it # indicates the end of the message. Second, it doesn't @@ -703,8 +756,7 @@ def parse_response(self, data, expect_status_code, *fields): # If a named field is required and never showed up, sound the alarm. for field in required_fields_not_seen: self.log.error( - "Expected required field %s but did not find it.", - field.sip_code + "Expected required field %s but did not find it.", field.sip_code ) return parsed @@ -713,14 +765,12 @@ def consume_status_code(self, data, expected, in_progress): given response string, and verify that it's as expected. """ status_code = data[:2] - in_progress['_status'] = status_code + in_progress["_status"] = status_code if status_code != expected: - if status_code == '96': # Request SC Resend + if status_code == "96": # Request SC Resend raise RequestResend() else: - raise IOError( - "Unexpected status code %s: %s" % (status_code, data) - ) + raise IOError("Unexpected status code %s: %s" % (status_code, data)) return data[2:] @classmethod @@ -729,15 +779,12 @@ def parse_patron_status(cls, status_string): :return: A 14-element dictionary mapping flag names to boolean values. """ - if (not isinstance(status_string, (bytes, str)) - or len(status_string) != 14): - raise ValueError( - "Patron status must be a 14-character string." - ) + if not isinstance(status_string, (bytes, str)) or len(status_string) != 14: + raise ValueError("Patron status must be a 14-character string.") status = {} for i, field in enumerate(cls.PATRON_STATUS_FIELDS): # ' ' means false, 'Y' means true. - value = status_string[i] != ' ' + value = status_string[i] != " " status[field] = value return status @@ -746,31 +793,39 @@ def now(self): now = utc_now() return datetime.datetime.strftime(now, "%Y%m%d0000%H%M%S") - def summary(self, hold_items=False, overdue_items=False, - charged_items=False, fine_items=False, recall_items=False, - unavailable_holds=False): + def summary( + self, + hold_items=False, + overdue_items=False, + charged_items=False, + fine_items=False, + recall_items=False, + unavailable_holds=False, + ): """Generate the SIP summary field: a 10-character query string for requesting detailed information about a patron's relationship with items. """ summary = "" for item in ( - hold_items, overdue_items, - charged_items, fine_items, recall_items, - unavailable_holds + hold_items, + overdue_items, + charged_items, + fine_items, + recall_items, + unavailable_holds, ): if item: summary += "Y" else: summary += " " # The last four spaces are always empty. - summary += ' ' - if summary.count('Y') > 1: + summary += " " + if summary.count("Y") > 1: # This violates the spec but in my tests it seemed to # work, so we'll allow it. self.log.warn( - 'Summary requested too many kinds of detailed information: %s' % - summary + "Summary requested too many kinds of detailed information: %s" % summary ) return summary @@ -786,7 +841,7 @@ def do_send(self, data): """ self.connection.send(data) - def read_message(self, max_size=1024*1024): + def read_message(self, max_size=1024 * 1024): """Read a SIP2 message from the socket connection. A SIP2 message ends with a \\r character. @@ -834,7 +889,7 @@ def append_checksum(self, text, include_sequence_number=True): check = 0 for each in text: check = check + ord(each) - check = check + ord('\0') + check = check + ord("\0") check = (check ^ 0xFFFF) + 1 checksum = "%4.4X" % (check) @@ -854,8 +909,8 @@ class MockSIPClient(SIPClient): def __init__(self, **kwargs): # Override any settings that might cause us to actually # make requests. - kwargs['target_server'] = None - kwargs['target_port'] = None + kwargs["target_server"] = None + kwargs["target_port"] = None super(MockSIPClient, self).__init__(**kwargs) self.read_count = 0 @@ -882,7 +937,7 @@ def do_send(self, data): self.write_count += 1 self.requests.append(data) - def read_message(self, max_size=1024*1024): + def read_message(self, max_size=1024 * 1024): """Read a response message off the queue.""" self.read_count += 1 response = self.responses[0] @@ -901,6 +956,7 @@ class MockSIPClientFactory(object): every simulated server interaction, making it impossible to queue responses or look at the results. """ + def __init__(self): self.client = None diff --git a/api/sip/dialect.py b/api/sip/dialect.py index a13dbc96ab..27025b092e 100644 --- a/api/sip/dialect.py +++ b/api/sip/dialect.py @@ -1,10 +1,9 @@ class Dialect: - """Describe a SIP2 dialect. - """ + """Describe a SIP2 dialect.""" # Constants for each class - GENERIC_ILS = 'GenericILS' - AG_VERSO = 'AutoGraphicsVerso' + GENERIC_ILS = "GenericILS" + AG_VERSO = "AutoGraphicsVerso" # Settings defined in each class sendEndSession = None @@ -17,8 +16,10 @@ def load_dialect(dialect): else: return GenericILS + class GenericILS(Dialect): sendEndSession = True + class AutoGraphicsVerso(Dialect): sendEndSession = False diff --git a/api/testing.py b/api/testing.py index 1903ccb95b..08b88742a7 100644 --- a/api/testing.py +++ b/api/testing.py @@ -4,33 +4,22 @@ import logging from collections import defaultdict -from core.testing import DatabaseTest - +from api.adobe_vendor_id import AuthdataUtility +from api.circulation import BaseCirculationAPI, CirculationAPI, HoldInfo, LoanInfo +from api.config import Configuration, temp_config +from api.shared_collection import SharedCollectionAPI from core.model import ( ConfigurationSetting, DataSource, ExternalIntegration, + Hold, Identifier, Library, Loan, - Hold, Session, ) -from api.circulation import ( - BaseCirculationAPI, - CirculationAPI, - LoanInfo, - HoldInfo, -) -from api.shared_collection import ( - SharedCollectionAPI, -) -from api.config import ( - Configuration, - temp_config, -) +from core.testing import DatabaseTest -from api.adobe_vendor_id import AuthdataUtility class VendorIDTest(DatabaseTest): """A DatabaseTest that knows how to set up an Adobe Vendor ID @@ -56,21 +45,21 @@ def initialize_adobe(self, vendor_id_library, short_token_libraries=[]): # The first library acts as an Adobe Vendor ID server. self.adobe_vendor_id = self._external_integration( ExternalIntegration.ADOBE_VENDOR_ID, - ExternalIntegration.DRM_GOAL, username=self.TEST_VENDOR_ID, - libraries=[vendor_id_library] + ExternalIntegration.DRM_GOAL, + username=self.TEST_VENDOR_ID, + libraries=[vendor_id_library], ) # The other libraries will share a registry integration. self.registry = self._external_integration( ExternalIntegration.OPDS_REGISTRATION, ExternalIntegration.DISCOVERY_GOAL, - libraries=short_token_libraries + libraries=short_token_libraries, ) # The integration knows which Adobe Vendor ID server it # gets its Adobe IDs from. self.registry.set_setting( - AuthdataUtility.VENDOR_ID_KEY, - self.adobe_vendor_id.username + AuthdataUtility.VENDOR_ID_KEY, self.adobe_vendor_id.username ) # As we give libraries their Short Client Token settings, @@ -108,7 +97,6 @@ def initialize_adobe(self, vendor_id_library, short_token_libraries=[]): class MonitorTest(DatabaseTest): - @property def ts(self): """Make the timestamp used by run() when calling run_once(). @@ -122,7 +110,7 @@ class AnnouncementTest(object): """A test that needs to create announcements.""" # Create raw data to be used in tests. - format = '%Y-%m-%d' + format = "%Y-%m-%d" today = datetime.date.today() yesterday = (today - datetime.timedelta(days=1)).strftime(format) tomorrow = (today + datetime.timedelta(days=1)).strftime(format) @@ -132,23 +120,20 @@ class AnnouncementTest(object): # This announcement is active. active = dict( - id="active", - start=today, - finish=tomorrow, - content="A sample announcement." + id="active", start=today, finish=tomorrow, content="A sample announcement." ) # This announcement expired yesterday. expired = dict(active) - expired['id'] = 'expired' - expired['start'] = a_week_ago - expired['finish'] = yesterday + expired["id"] = "expired" + expired["start"] = a_week_ago + expired["finish"] = yesterday # This announcement should be displayed starting tomorrow. forthcoming = dict(active) - forthcoming['id'] = 'forthcoming' - forthcoming['start'] = tomorrow - forthcoming['finish'] = in_a_week + forthcoming["id"] = "forthcoming" + forthcoming["start"] = tomorrow + forthcoming["finish"] = in_a_week class MockRemoteAPI(BaseCirculationAPI): @@ -159,34 +144,37 @@ def __init__(self, set_delivery_mechanism_at, can_revoke_hold_when_reserved): self.log = logging.getLogger("Mock remote API") self.availability_updated_for = [] - def checkout( - self, patron_obj, patron_password, licensepool, - delivery_mechanism - ): + def checkout(self, patron_obj, patron_password, licensepool, delivery_mechanism): # Should be a LoanInfo. - return self._return_or_raise('checkout') + return self._return_or_raise("checkout") def update_availability(self, licensepool): """Simply record the fact that update_availability was called.""" self.availability_updated_for.append(licensepool) - def place_hold(self, patron, pin, licensepool, - hold_notification_email=None): + def place_hold(self, patron, pin, licensepool, hold_notification_email=None): # Should be a HoldInfo. - return self._return_or_raise('hold') - - def fulfill(self, patron, pin, licensepool, internal_format=None, - part=None, fulfill_part_url=None): + return self._return_or_raise("hold") + + def fulfill( + self, + patron, + pin, + licensepool, + internal_format=None, + part=None, + fulfill_part_url=None, + ): # Should be a FulfillmentInfo. - return self._return_or_raise('fulfill') + return self._return_or_raise("fulfill") def checkin(self, patron, pin, licensepool): # Return value is not checked. - return self._return_or_raise('checkin') + return self._return_or_raise("checkin") def release_hold(self, patron, pin, licensepool): # Return value is not checked. - return self._return_or_raise('release_hold') + return self._return_or_raise("release_hold") def internal_format(self, delivery_mechanism): return delivery_mechanism @@ -195,19 +183,19 @@ def update_loan(self, loan, status_doc): self.availability_updated_for.append(loan.license_pool) def queue_checkout(self, response): - self._queue('checkout', response) + self._queue("checkout", response) def queue_hold(self, response): - self._queue('hold', response) + self._queue("hold", response) def queue_fulfill(self, response): - self._queue('fulfill', response) + self._queue("fulfill", response) def queue_checkin(self, response): - self._queue('checkin', response) + self._queue("checkin", response) def queue_release_hold(self, response): - self._queue('release_hold', response) + self._queue("release_hold", response) def _queue(self, k, v): self.responses[k].append(v) @@ -220,8 +208,8 @@ def _return_or_raise(self, k): raise v return v -class MockCirculationAPI(CirculationAPI): +class MockCirculationAPI(CirculationAPI): def __init__(self, *args, **kwargs): super(MockCirculationAPI, self).__init__(*args, **kwargs) self.responses = defaultdict(list) @@ -230,10 +218,10 @@ def __init__(self, *args, **kwargs): self.remotes = {} def local_loans(self, patron): - return self._db.query(Loan).filter(Loan.patron==patron) + return self._db.query(Loan).filter(Loan.patron == patron) def local_holds(self, patron): - return self._db.query(Hold).filter(Hold.patron==patron) + return self._db.query(Hold).filter(Hold.patron == patron) def add_remote_loan(self, *args, **kwargs): self.remote_loans.append(LoanInfo(*args, **kwargs)) @@ -246,19 +234,19 @@ def patron_activity(self, patron, pin): return self.remote_loans, self.remote_holds, True def queue_checkout(self, licensepool, response): - self._queue('checkout', licensepool, response) + self._queue("checkout", licensepool, response) def queue_hold(self, licensepool, response): - self._queue('hold', licensepool, response) + self._queue("hold", licensepool, response) def queue_fulfill(self, licensepool, response): - self._queue('fulfill', licensepool, response) + self._queue("fulfill", licensepool, response) def queue_checkin(self, licensepool, response): - self._queue('checkin', licensepool, response) + self._queue("checkin", licensepool, response) def queue_release_hold(self, licensepool, response): - self._queue('release_hold', licensepool, response) + self._queue("release_hold", licensepool, response) def _queue(self, method, licensepool, response): mock = self.api_for_license_pool(licensepool) @@ -279,6 +267,7 @@ def api_for_license_pool(self, licensepool): self.remotes[source] = remote return self.remotes[source] + class MockSharedCollectionAPI(SharedCollectionAPI): def __init__(self, *args, **kwargs): super(MockSharedCollectionAPI, self).__init__(*args, **kwargs) @@ -296,32 +285,39 @@ def _return_or_raise(self, k): return v def queue_register(self, response): - self._queue('register', response) + self._queue("register", response) def register(self, collection, url): - return self._return_or_raise('register') + return self._return_or_raise("register") def queue_borrow(self, response): - self._queue('borrow', response) + self._queue("borrow", response) def borrow(self, collection, client, pool, hold=None): - return self._return_or_raise('borrow') + return self._return_or_raise("borrow") def queue_revoke_loan(self, response): - self._queue('revoke-loan', response) + self._queue("revoke-loan", response) def revoke_loan(self, collection, client, loan): - return self._return_or_raise('revoke-loan') + return self._return_or_raise("revoke-loan") def queue_fulfill(self, response): - self._queue('fulfill', response) - - def fulfill(self, patron, pin, licensepool, internal_format=None, - part=None, fulfill_part_url=None): - return self._return_or_raise('fulfill') + self._queue("fulfill", response) + + def fulfill( + self, + patron, + pin, + licensepool, + internal_format=None, + part=None, + fulfill_part_url=None, + ): + return self._return_or_raise("fulfill") def queue_revoke_hold(self, response): - self._queue('revoke-hold', response) + self._queue("revoke-hold", response) def revoke_hold(self, collection, client, hold): - return self._return_or_raise('revoke-hold') + return self._return_or_raise("revoke-hold") diff --git a/api/util/patron.py b/api/util/patron.py index af9cc98ef1..174be438f8 100644 --- a/api/util/patron.py +++ b/api/util/patron.py @@ -1,12 +1,15 @@ import datetime + import dateutil from money import Money -from api.config import Configuration + from api.circulation_exceptions import * +from api.config import Configuration from core.model.patron import Patron from core.util import MoneyUtility from core.util.datetime_helpers import utc_now + class PatronUtility(object): """Apply circulation-specific logic to Patron model objects.""" @@ -69,6 +72,7 @@ def assert_borrowing_privileges(cls, patron): raise OutstandingFines() from api.authenticator import PatronData + if patron.block_reason is not None: if patron.block_reason is PatronData.EXCESSIVE_FINES: # The authentication mechanism itself may know that @@ -105,9 +109,9 @@ def authorization_is_active(cls, patron): # less likely that a patron's authorization will expire before # they think it should. now_local = datetime.datetime.now(tz=dateutil.tz.tzlocal()) - if (patron.authorization_expires - and cls._to_date(patron.authorization_expires) - < cls._to_date(now_local)): + if patron.authorization_expires and cls._to_date( + patron.authorization_expires + ) < cls._to_date(now_local): return False return True diff --git a/api/util/url.py b/api/util/url.py index 238f15c77a..346904c345 100644 --- a/api/util/url.py +++ b/api/util/url.py @@ -1,4 +1,4 @@ -from urllib.parse import urlparse, ParseResult, urlencode +from urllib.parse import ParseResult, urlencode, urlparse class URLUtility(object): @@ -24,7 +24,7 @@ def build_url(base_url, query_parameters): result.path, result.params, urlencode(query_parameters), - result.fragment + result.fragment, ) return result.geturl() diff --git a/api/web_publication_manifest.py b/api/web_publication_manifest.py index 342f819b90..1c8d1f5597 100644 --- a/api/web_publication_manifest.py +++ b/api/web_publication_manifest.py @@ -1,17 +1,16 @@ """Vendor-specific variants of the standard Web Publication Manifest classes. """ -from core.model import ( - DeliveryMechanism, - Representation, -) +from core.model import DeliveryMechanism, Representation from core.util.web_publication_manifest import AudiobookManifest + class SpineItem(object): """Metadata about a piece of playable audio from an audiobook.""" - def __init__(self, title, duration, part, sequence, - media_type=Representation.MP3_MEDIA_TYPE): + def __init__( + self, title, duration, part, sequence, media_type=Representation.MP3_MEDIA_TYPE + ): """Constructor. :param title: The title of this spine item. @@ -40,14 +39,21 @@ class FindawayManifest(AudiobookManifest): # This URI prefix makes it clear when we are using a term coined # by Findaway in a JSON-LD document. - FINDAWAY_EXTENSION_CONTEXT = "http://librarysimplified.org/terms/third-parties/findaway.com/" + FINDAWAY_EXTENSION_CONTEXT = ( + "http://librarysimplified.org/terms/third-parties/findaway.com/" + ) MEDIA_TYPE = DeliveryMechanism.FINDAWAY_DRM def __init__( - self, license_pool, accountId=None, checkoutId=None, - fulfillmentId=None, licenseId=None, sessionKey=None, - spine_items=[] + self, + license_pool, + accountId=None, + checkoutId=None, + fulfillmentId=None, + licenseId=None, + sessionKey=None, + spine_items=[], ): """Create a FindawayManifest object from raw data. @@ -78,7 +84,7 @@ def __init__( context_with_extension = [ "http://readium.org/webpub/default.jsonld", - {"findaway" : self.FINDAWAY_EXTENSION_CONTEXT}, + {"findaway": self.FINDAWAY_EXTENSION_CONTEXT}, ] super(FindawayManifest, self).__init__(context=context_with_extension) @@ -89,40 +95,37 @@ def __init__( # Add Findaway-specific DRM information as an 'encrypted' object # within the metadata object. - encrypted = dict( - scheme='http://librarysimplified.org/terms/drm/scheme/FAE' - ) - self.metadata['encrypted'] = encrypted + encrypted = dict(scheme="http://librarysimplified.org/terms/drm/scheme/FAE") + self.metadata["encrypted"] = encrypted for findaway_extension, value in [ - ('accountId', accountId), - ('checkoutId', checkoutId), - ('fulfillmentId', fulfillmentId), - ('licenseId', licenseId), - ('sessionKey', sessionKey) + ("accountId", accountId), + ("checkoutId", checkoutId), + ("fulfillmentId", fulfillmentId), + ("licenseId", licenseId), + ("sessionKey", sessionKey), ]: if not value: continue - output_key = 'findaway:' + findaway_extension + output_key = "findaway:" + findaway_extension encrypted[output_key] = value # Add the SpineItems as reading order items. None of them will # have working 'href' fields -- it's just to give the client a # picture of the structure of the timeline. - part_key = 'findaway:part' - sequence_key = 'findaway:sequence' + part_key = "findaway:part" + sequence_key = "findaway:sequence" total_duration = 0 spine_items.sort(key=SpineItem.sort_key) for item in spine_items: - kwargs = { - part_key: item.part, - sequence_key: item.sequence - } + kwargs = {part_key: item.part, sequence_key: item.sequence} self.add_reading_order( - href=None, title=item.title, duration=item.duration, - type=item.media_type, **kwargs + href=None, + title=item.title, + duration=item.duration, + type=item.media_type, + **kwargs ) total_duration += item.duration if spine_items: - self.metadata['duration'] = total_duration - + self.metadata["duration"] = total_duration diff --git a/app.py b/app.py index d30512d756..3d126739a9 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,7 @@ -from api import app import sys +from api import app + url = None if len(sys.argv) > 1: url = sys.argv[1] diff --git a/bin/axis_monitor b/bin/axis_monitor index b77d144687..f71a7d664d 100755 --- a/bin/axis_monitor +++ b/bin/axis_monitor @@ -2,9 +2,11 @@ """Monitor the Axis 360 collection by asking about recently changed books.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.axis import Axis360CirculationMonitor +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(Axis360CirculationMonitor).run() diff --git a/bin/axis_reaper b/bin/axis_reaper index f21fa8fb04..af6ffceb42 100755 --- a/bin/axis_reaper +++ b/bin/axis_reaper @@ -2,9 +2,11 @@ """Monitor the Axis collection by looking for books that have been removed.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.axis import AxisCollectionReaper +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(AxisCollectionReaper).run() diff --git a/bin/bibliotheca_circulation_sweep b/bin/bibliotheca_circulation_sweep index d139a989e4..5f68cb1d19 100755 --- a/bin/bibliotheca_circulation_sweep +++ b/bin/bibliotheca_circulation_sweep @@ -2,9 +2,11 @@ """Sweep through our Bibliotheca collections verifying circulation stats.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.bibliotheca import BibliothecaCirculationSweep +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(BibliothecaCirculationSweep).run() diff --git a/bin/bibliotheca_monitor b/bin/bibliotheca_monitor index 728dc321a2..f5bc25015e 100755 --- a/bin/bibliotheca_monitor +++ b/bin/bibliotheca_monitor @@ -2,9 +2,11 @@ """Monitor the Bibliotheca collections by asking about recently changed events.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.bibliotheca import BibliothecaEventMonitor +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(BibliothecaEventMonitor).run() diff --git a/bin/bibliotheca_purchase_monitor b/bin/bibliotheca_purchase_monitor index a21411d106..a34a656af3 100755 --- a/bin/bibliotheca_purchase_monitor +++ b/bin/bibliotheca_purchase_monitor @@ -3,8 +3,13 @@ that happened many years in the past.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from api.bibliotheca import BibliothecaPurchaseMonitor, RunBibliothecaPurchaseMonitorScript +from api.bibliotheca import ( + BibliothecaPurchaseMonitor, + RunBibliothecaPurchaseMonitorScript, +) + RunBibliothecaPurchaseMonitorScript(BibliothecaPurchaseMonitor).run() diff --git a/bin/cache_marc_files b/bin/cache_marc_files index e25603f06e..89203087ae 100755 --- a/bin/cache_marc_files +++ b/bin/cache_marc_files @@ -2,10 +2,10 @@ """Refresh and store the MARC files for lanes.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from scripts import ( - CacheMARCFiles -) +from scripts import CacheMARCFiles + CacheMARCFiles().run() diff --git a/bin/cache_opds_blocks b/bin/cache_opds_blocks index 9cd6dc05c9..6de4efda6a 100755 --- a/bin/cache_opds_blocks +++ b/bin/cache_opds_blocks @@ -2,10 +2,10 @@ """Refresh the top-level OPDS groups.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from scripts import ( - CacheOPDSGroupFeedPerLane -) +from scripts import CacheOPDSGroupFeedPerLane + CacheOPDSGroupFeedPerLane().run() diff --git a/bin/cache_opds_lane_facets b/bin/cache_opds_lane_facets index 34af79bf28..2f1819a5f1 100755 --- a/bin/cache_opds_lane_facets +++ b/bin/cache_opds_lane_facets @@ -2,10 +2,10 @@ """Refresh the OPDS lane facets.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from scripts import ( - CacheFacetListsPerLane -) +from scripts import CacheFacetListsPerLane + CacheFacetListsPerLane().run() diff --git a/bin/configuration/configure_collection b/bin/configuration/configure_collection index fd56129212..a2663e4735 100755 --- a/bin/configuration/configure_collection +++ b/bin/configuration/configure_collection @@ -2,10 +2,10 @@ """Configure a collection's settings.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import ( - ConfigureCollectionScript -) +from core.scripts import ConfigureCollectionScript + ConfigureCollectionScript().run() diff --git a/bin/configuration/configure_integration b/bin/configuration/configure_integration index ef2f90d8aa..4c18e0a258 100755 --- a/bin/configuration/configure_integration +++ b/bin/configuration/configure_integration @@ -2,10 +2,10 @@ """Configure an integration's settings.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import ( - ConfigureIntegrationScript -) +from core.scripts import ConfigureIntegrationScript + ConfigureIntegrationScript().run() diff --git a/bin/configuration/configure_lane b/bin/configuration/configure_lane index 2e57d9d8d3..eb6f730387 100755 --- a/bin/configuration/configure_lane +++ b/bin/configuration/configure_lane @@ -2,10 +2,10 @@ """Configure a lane's settings.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import ( - ConfigureLaneScript -) +from core.scripts import ConfigureLaneScript + ConfigureLaneScript().run() diff --git a/bin/configuration/configure_library b/bin/configuration/configure_library index 6ba238761f..7130dada79 100755 --- a/bin/configuration/configure_library +++ b/bin/configuration/configure_library @@ -2,10 +2,10 @@ """Configure a library's settings.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import ( - ConfigureLibraryScript -) +from core.scripts import ConfigureLibraryScript + ConfigureLibraryScript().run() diff --git a/bin/configuration/configure_site_setting b/bin/configuration/configure_site_setting index be2943459e..83c562efb5 100755 --- a/bin/configuration/configure_site_setting +++ b/bin/configuration/configure_site_setting @@ -2,11 +2,11 @@ """View or configure the site-wide settings.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import ( - ConfigureSiteScript -) from api.config import Configuration +from core.scripts import ConfigureSiteScript + ConfigureSiteScript(config=Configuration).run() diff --git a/bin/configuration/register_library b/bin/configuration/register_library index f2715a147c..72f1ae78e2 100755 --- a/bin/configuration/register_library +++ b/bin/configuration/register_library @@ -2,8 +2,10 @@ """Push the configurations of one or more libraries to a library registry.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from api.registry import LibraryRegistrationScript + LibraryRegistrationScript().run() diff --git a/bin/configuration/short_client_token_library_configuration b/bin/configuration/short_client_token_library_configuration index 60829744be..72dd18b9b3 100755 --- a/bin/configuration/short_client_token_library_configuration +++ b/bin/configuration/short_client_token_library_configuration @@ -4,6 +4,7 @@ to sign Short Client Tokens. """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) diff --git a/bin/configuration/show_collections b/bin/configuration/show_collections index bf476658e4..a9cca3d21a 100755 --- a/bin/configuration/show_collections +++ b/bin/configuration/show_collections @@ -2,10 +2,10 @@ """Show a collection or the full list of collections.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import ( - ShowCollectionsScript -) +from core.scripts import ShowCollectionsScript + ShowCollectionsScript().run() diff --git a/bin/configuration/show_integrations b/bin/configuration/show_integrations index 2151c713eb..4643f6afbd 100755 --- a/bin/configuration/show_integrations +++ b/bin/configuration/show_integrations @@ -2,10 +2,10 @@ """Show an integration or the full list of integrations.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import ( - ShowIntegrationsScript -) +from core.scripts import ShowIntegrationsScript + ShowIntegrationsScript().run() diff --git a/bin/configuration/show_lanes b/bin/configuration/show_lanes index 0b560ac8f1..f5217e5edf 100755 --- a/bin/configuration/show_lanes +++ b/bin/configuration/show_lanes @@ -2,10 +2,10 @@ """Show a lane or the full list of lanes.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import ( - ShowLanesScript -) +from core.scripts import ShowLanesScript + ShowLanesScript().run() diff --git a/bin/configuration/show_libraries b/bin/configuration/show_libraries index d5c7a92cd7..ec16e48654 100755 --- a/bin/configuration/show_libraries +++ b/bin/configuration/show_libraries @@ -2,10 +2,10 @@ """Show a library or the full list of libraries.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import ( - ShowLibrariesScript -) +from core.scripts import ShowLibrariesScript + ShowLibrariesScript().run() diff --git a/bin/configuration/vendor_id_library_configuration b/bin/configuration/vendor_id_library_configuration index 706e7c2e64..53296fc7f0 100755 --- a/bin/configuration/vendor_id_library_configuration +++ b/bin/configuration/vendor_id_library_configuration @@ -4,6 +4,7 @@ Tokens to this circulation manager's Adobe Vendor ID implementation. """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) diff --git a/bin/custom_list_entry_update_license_pool b/bin/custom_list_entry_update_license_pool index d6a5324a80..40bffb7237 100755 --- a/bin/custom_list_entry_update_license_pool +++ b/bin/custom_list_entry_update_license_pool @@ -6,9 +6,11 @@ newly acquired licenses for books that were already on a custom list. """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from core.monitor import CustomListEntryLicensePoolUpdateMonitor from core.scripts import RunMonitorScript + RunMonitorScript(CustomListEntryLicensePoolUpdateMonitor).run() diff --git a/bin/database_reaper b/bin/database_reaper index 27e5d681aa..678a1576fe 100755 --- a/bin/database_reaper +++ b/bin/database_reaper @@ -4,8 +4,10 @@ from the database. """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import RunReaperMonitorsScript + RunReaperMonitorsScript().run() diff --git a/bin/enki_import b/bin/enki_import index 3a7bd81db0..4442fa46f0 100755 --- a/bin/enki_import +++ b/bin/enki_import @@ -2,9 +2,11 @@ """monitor the Enki collection by asking about recently changed books.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.enki import EnkiImport +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(EnkiImport).run() diff --git a/bin/enki_reaper b/bin/enki_reaper index eee2045cf9..ad9da8ac4d 100755 --- a/bin/enki_reaper +++ b/bin/enki_reaper @@ -2,9 +2,11 @@ """Monitor the Enki collection by looking for books with lost licenses.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunMonitorScript from api.enki import EnkiCollectionReaper +from core.scripts import RunMonitorScript + RunMonitorScript(EnkiCollectionReaper).run() diff --git a/bin/feedbooks_import_monitor b/bin/feedbooks_import_monitor index a8fc866284..9973426452 100755 --- a/bin/feedbooks_import_monitor +++ b/bin/feedbooks_import_monitor @@ -3,15 +3,14 @@ Feedbooks collections.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import OPDSImportScript +from api.feedbooks import FeedbooksImportMonitor, FeedbooksOPDSImporter from core.model import ExternalIntegration -from api.feedbooks import ( - FeedbooksOPDSImporter, - FeedbooksImportMonitor -) +from core.scripts import OPDSImportScript + OPDSImportScript( importer_class=FeedbooksOPDSImporter, monitor_class=FeedbooksImportMonitor, diff --git a/bin/informational/axis-raw-bibliographic b/bin/informational/axis-raw-bibliographic index f6837df2d2..e6b475a9f0 100755 --- a/bin/informational/axis-raw-bibliographic +++ b/bin/informational/axis-raw-bibliographic @@ -7,9 +7,9 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from api.axis import Axis360API # noqa: E402 -from core.model import (Collection, ExternalIntegration) # noqa: E402 -from core.scripts import IdentifierInputScript # noqa: E402 +from api.axis import Axis360API # noqa: E402 +from core.model import Collection, ExternalIntegration # noqa: E402 +from core.scripts import IdentifierInputScript # noqa: E402 class Axis360RawBibliographicScript(IdentifierInputScript): diff --git a/bin/informational/axis-raw-patron-status b/bin/informational/axis-raw-patron-status index bbb7ea4534..7eb15c9473 100755 --- a/bin/informational/axis-raw-patron-status +++ b/bin/informational/axis-raw-patron-status @@ -7,9 +7,9 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from api.axis import Axis360API # noqa: E402 -from core.model import (Collection, ExternalIntegration) # noqa: E402 -from core.scripts import Script # noqa: E402 +from api.axis import Axis360API # noqa: E402 +from core.model import Collection, ExternalIntegration # noqa: E402 +from core.scripts import Script # noqa: E402 class Axis360RawPatronActivityScript(Script): diff --git a/bin/informational/bibliotheca-raw-bibliographic b/bin/informational/bibliotheca-raw-bibliographic index 9a29f58b87..617b5f395b 100755 --- a/bin/informational/bibliotheca-raw-bibliographic +++ b/bin/informational/bibliotheca-raw-bibliographic @@ -7,9 +7,9 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from api.bibliotheca import BibliothecaAPI # noqa: E402 -from core.model import (Collection, ExternalIntegration) # noqa: E402 -from core.scripts import IdentifierInputScript # noqa: E402 +from api.bibliotheca import BibliothecaAPI # noqa: E402 +from core.model import Collection, ExternalIntegration # noqa: E402 +from core.scripts import IdentifierInputScript # noqa: E402 class BibliothecaRawBibliographicScript(IdentifierInputScript): diff --git a/bin/informational/bibliotheca-raw-patron-status b/bin/informational/bibliotheca-raw-patron-status index 4443cee813..1067fa28c9 100755 --- a/bin/informational/bibliotheca-raw-patron-status +++ b/bin/informational/bibliotheca-raw-patron-status @@ -7,9 +7,9 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from api.bibliotheca import BibliothecaAPI # noqa: E402 -from core.model import (Collection, ExternalIntegration, Patron) # noqa: E402 -from core.scripts import Script # noqa: E402 +from api.bibliotheca import BibliothecaAPI # noqa: E402 +from core.model import Collection, ExternalIntegration, Patron # noqa: E402 +from core.scripts import Script # noqa: E402 class BibliothecaRawPatronStatusScript(Script): diff --git a/bin/informational/disappearing_books b/bin/informational/disappearing_books index ba4596153f..80810313ad 100755 --- a/bin/informational/disappearing_books +++ b/bin/informational/disappearing_books @@ -7,6 +7,6 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from scripts import DisappearingBookReportScript # noqa: E402 +from scripts import DisappearingBookReportScript # noqa: E402 DisappearingBookReportScript().run() diff --git a/bin/informational/explain b/bin/informational/explain index 97d27b8e5d..eeb0f07d8c 100755 --- a/bin/informational/explain +++ b/bin/informational/explain @@ -7,6 +7,6 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import Explain # noqa: E402 +from core.scripts import Explain # noqa: E402 Explain().run() diff --git a/bin/informational/language_list b/bin/informational/language_list index 97a6cb26c6..3dbd72cc85 100755 --- a/bin/informational/language_list +++ b/bin/informational/language_list @@ -7,6 +7,6 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from scripts import LanguageListScript # noqa: E402 +from scripts import LanguageListScript # noqa: E402 LanguageListScript().run() diff --git a/bin/informational/list_collection_metadata_identifiers b/bin/informational/list_collection_metadata_identifiers index 0b77ee65dd..2ccd7bc96a 100755 --- a/bin/informational/list_collection_metadata_identifiers +++ b/bin/informational/list_collection_metadata_identifiers @@ -6,6 +6,6 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import ListCollectionMetadataIdentifiersScript # noqa: E402 +from core.scripts import ListCollectionMetadataIdentifiersScript # noqa: E402 ListCollectionMetadataIdentifiersScript().run() diff --git a/bin/informational/overdrive-advantage-list b/bin/informational/overdrive-advantage-list index 39c9563c17..f4bc9b92c6 100755 --- a/bin/informational/overdrive-advantage-list +++ b/bin/informational/overdrive-advantage-list @@ -6,6 +6,6 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from api.overdrive import OverdriveAdvantageAccountListScript # noqa: E402 +from api.overdrive import OverdriveAdvantageAccountListScript # noqa: E402 OverdriveAdvantageAccountListScript().run() diff --git a/bin/informational/overdrive-raw-bibliographic b/bin/informational/overdrive-raw-bibliographic index 086082b8aa..3592dbccd1 100755 --- a/bin/informational/overdrive-raw-bibliographic +++ b/bin/informational/overdrive-raw-bibliographic @@ -7,9 +7,9 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from api.overdrive import OverdriveAPI # noqa: E402 -from core.model import (Collection, ExternalIntegration) # noqa: E402 -from core.scripts import IdentifierInputScript # noqa: E402 +from api.overdrive import OverdriveAPI # noqa: E402 +from core.model import Collection, ExternalIntegration # noqa: E402 +from core.scripts import IdentifierInputScript # noqa: E402 class OverdriveRawBibliographicScript(IdentifierInputScript): diff --git a/bin/informational/overdrive-raw-circulation b/bin/informational/overdrive-raw-circulation index 23f9074d96..82d8a33f75 100755 --- a/bin/informational/overdrive-raw-circulation +++ b/bin/informational/overdrive-raw-circulation @@ -7,9 +7,9 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from api.overdrive import OverdriveAPI # noqa: E402 -from core.scripts import IdentifierInputScript # noqa: E402 -from core.model import (Collection, ExternalIntegration) # noqa: E402 +from api.overdrive import OverdriveAPI # noqa: E402 +from core.model import Collection, ExternalIntegration # noqa: E402 +from core.scripts import IdentifierInputScript # noqa: E402 class OverdriveRawCirculationScript(IdentifierInputScript): diff --git a/bin/informational/patron_information b/bin/informational/patron_information index a5f6bcca2c..fc8ac7b341 100755 --- a/bin/informational/patron_information +++ b/bin/informational/patron_information @@ -9,7 +9,7 @@ package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from api.authenticator import LibraryAuthenticator, PatronData # noqa: E402 -from core.scripts import LibraryInputScript # noqa: E402 +from core.scripts import LibraryInputScript # noqa: E402 class PatronInformationScript(LibraryInputScript): diff --git a/bin/informational/run-self-tests b/bin/informational/run-self-tests index dedec5067f..059c4c8eb0 100755 --- a/bin/informational/run-self-tests +++ b/bin/informational/run-self-tests @@ -7,6 +7,6 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from api.selftest import RunSelfTestsScript # noqa: E402 +from api.selftest import RunSelfTestsScript # noqa: E402 RunSelfTestsScript().run() diff --git a/bin/local_analytics_export b/bin/local_analytics_export index 29909e5e8c..caa2b65687 100755 --- a/bin/local_analytics_export +++ b/bin/local_analytics_export @@ -2,6 +2,7 @@ """Export circulation events for a date range to a CSV file.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) diff --git a/bin/marc_record_coverage b/bin/marc_record_coverage index bc99d4e654..45a6c79920 100755 --- a/bin/marc_record_coverage +++ b/bin/marc_record_coverage @@ -2,6 +2,7 @@ """Make sure all presentation-ready works have up-to-date MARC records.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) diff --git a/bin/metadata_upload_coverage b/bin/metadata_upload_coverage index d66b12ccb3..17af78f2ad 100755 --- a/bin/metadata_upload_coverage +++ b/bin/metadata_upload_coverage @@ -2,6 +2,7 @@ """Upload information to the metadata wrangler.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) diff --git a/bin/metadata_wrangler_auxiliary_metadata b/bin/metadata_wrangler_auxiliary_metadata index 76e05e6b20..62c06cd22f 100755 --- a/bin/metadata_wrangler_auxiliary_metadata +++ b/bin/metadata_wrangler_auxiliary_metadata @@ -2,10 +2,12 @@ """Monitor metadata requests from the Metadata Wrangler remote collection.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.metadata_wrangler import MWAuxiliaryMetadataMonitor +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(MWAuxiliaryMetadataMonitor).run() diff --git a/bin/metadata_wrangler_collection_reaper b/bin/metadata_wrangler_collection_reaper index 07f5b89f27..4518184a38 100755 --- a/bin/metadata_wrangler_collection_reaper +++ b/bin/metadata_wrangler_collection_reaper @@ -2,6 +2,7 @@ """Remove unlicensed items from the remote metadata wrangler Collection.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) diff --git a/bin/metadata_wrangler_collection_registrar b/bin/metadata_wrangler_collection_registrar index 3f6afa9df9..c91ec08f77 100755 --- a/bin/metadata_wrangler_collection_registrar +++ b/bin/metadata_wrangler_collection_registrar @@ -6,6 +6,7 @@ metadata_wrangler_collection_sync is a deprecated name for the same script. """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) diff --git a/bin/metadata_wrangler_collection_updates b/bin/metadata_wrangler_collection_updates index a0d5bf3465..ebfac5c240 100755 --- a/bin/metadata_wrangler_collection_updates +++ b/bin/metadata_wrangler_collection_updates @@ -2,9 +2,11 @@ """Monitor bibliographic updates to the Metadata Wrangler remote collection.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.metadata_wrangler import MWCollectionUpdateMonitor +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(MWCollectionUpdateMonitor).run() diff --git a/bin/novelist_update b/bin/novelist_update index de6f80fb66..537a88896a 100755 --- a/bin/novelist_update +++ b/bin/novelist_update @@ -2,10 +2,10 @@ """Get all ISBNs for all collections in a library and send to NoveList.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from scripts import ( - NovelistSnapshotScript -) +from scripts import NovelistSnapshotScript + NovelistSnapshotScript().run() diff --git a/bin/odilo_monitor_recent b/bin/odilo_monitor_recent index fccd862cfc..42846ae61a 100755 --- a/bin/odilo_monitor_recent +++ b/bin/odilo_monitor_recent @@ -7,7 +7,7 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.odilo import OdiloCirculationMonitor +from core.scripts import RunCollectionMonitorScript RunCollectionMonitorScript(OdiloCirculationMonitor).run() diff --git a/bin/odl2_reaper b/bin/odl2_reaper index a2103064a9..be16eb47c8 100755 --- a/bin/odl2_reaper +++ b/bin/odl2_reaper @@ -2,9 +2,11 @@ """Remove all expired licenses from ODL 2.x collections.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.odl2 import ODL2ExpiredItemsReaper +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(ODL2ExpiredItemsReaper).run() diff --git a/bin/odl_hold_reaper b/bin/odl_hold_reaper index 0331dd30b6..053c31c3f6 100755 --- a/bin/odl_hold_reaper +++ b/bin/odl_hold_reaper @@ -2,9 +2,11 @@ """Check for ODL holds that have expired and delete them.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.odl import ODLHoldReaper +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(ODLHoldReaper).run() diff --git a/bin/odl_reaper b/bin/odl_reaper index ee09a1d3bc..58841f6962 100755 --- a/bin/odl_reaper +++ b/bin/odl_reaper @@ -2,9 +2,11 @@ """Remove all expired licenses from ODL 1.x collections.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.odl import ODLExpiredItemsReaper +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(ODLExpiredItemsReaper).run() diff --git a/bin/opds_entry_coverage b/bin/opds_entry_coverage index 60dc6ecb19..6126b34900 100755 --- a/bin/opds_entry_coverage +++ b/bin/opds_entry_coverage @@ -2,6 +2,7 @@ """Make sure all presentation-ready works have up-to-date OPDS entries.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) diff --git a/bin/opds_for_distributors_import_monitor b/bin/opds_for_distributors_import_monitor index ce3654ed45..1ee1beb70a 100755 --- a/bin/opds_for_distributors_import_monitor +++ b/bin/opds_for_distributors_import_monitor @@ -3,8 +3,10 @@ OPDS import collections that have authentication.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from scripts import OPDSForDistributorsImportScript + OPDSForDistributorsImportScript().run() diff --git a/bin/opds_for_distributors_reaper_monitor b/bin/opds_for_distributors_reaper_monitor index a2260af0f6..c638b3aa37 100755 --- a/bin/opds_for_distributors_reaper_monitor +++ b/bin/opds_for_distributors_reaper_monitor @@ -3,8 +3,10 @@ have been removed from OPDS for distributors collections.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from scripts import OPDSForDistributorsReaperScript + OPDSForDistributorsReaperScript().run() diff --git a/bin/overdrive_format_sweep b/bin/overdrive_format_sweep index 386a2d0623..3db171d94a 100755 --- a/bin/overdrive_format_sweep +++ b/bin/overdrive_format_sweep @@ -2,9 +2,11 @@ """Sweep through our Overdrive collections updating delivery mechanisms.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.overdrive import OverdriveFormatSweep +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(OverdriveFormatSweep).run() diff --git a/bin/overdrive_monitor_recent b/bin/overdrive_monitor_recent index e371dc9297..22089190f0 100755 --- a/bin/overdrive_monitor_recent +++ b/bin/overdrive_monitor_recent @@ -2,9 +2,11 @@ """Monitor the Overdrive collections by going through the recently changed list.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.overdrive import RecentOverdriveCollectionMonitor +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(RecentOverdriveCollectionMonitor).run() diff --git a/bin/overdrive_monitor_search b/bin/overdrive_monitor_search index c0e88fa51e..def32fa926 100755 --- a/bin/overdrive_monitor_search +++ b/bin/overdrive_monitor_search @@ -2,9 +2,11 @@ """Monitor the Overdrive collections using Overdrive's search feature""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.overdrive import OverdriveCirculationMonitor +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(OverdriveCirculationMonitor).run() diff --git a/bin/overdrive_new_titles b/bin/overdrive_new_titles index 4796e61632..a133fdcc60 100755 --- a/bin/overdrive_new_titles +++ b/bin/overdrive_new_titles @@ -2,9 +2,11 @@ """Look for new titles added to Overdrive collections which slipped through the cracks.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.overdrive import NewTitlesOverdriveCollectionMonitor +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(NewTitlesOverdriveCollectionMonitor).run() diff --git a/bin/overdrive_reaper b/bin/overdrive_reaper index 92ef34e58d..d7fbf539d1 100755 --- a/bin/overdrive_reaper +++ b/bin/overdrive_reaper @@ -2,9 +2,11 @@ """Monitor the Overdrive collections by looking for books with lost licenses.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionMonitorScript from api.overdrive import OverdriveCollectionReaper +from core.scripts import RunCollectionMonitorScript + RunCollectionMonitorScript(OverdriveCollectionReaper).run() diff --git a/bin/repair/add_classification b/bin/repair/add_classification index ac2e2ecd23..b7b8830163 100755 --- a/bin/repair/add_classification +++ b/bin/repair/add_classification @@ -9,8 +9,10 @@ bin/repair/add_classification --identifier-type="Bibliotheca ID" --subject-type= """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import AddClassificationScript + AddClassificationScript().run() diff --git a/bin/repair/adobe_account_id_reset b/bin/repair/adobe_account_id_reset index c0152a9553..47671ad4e2 100755 --- a/bin/repair/adobe_account_id_reset +++ b/bin/repair/adobe_account_id_reset @@ -2,8 +2,10 @@ """Reset the Adobe account IDs for one or more patrons.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from scripts import AdobeAccountIDResetScript + AdobeAccountIDResetScript().run() diff --git a/bin/repair/availability b/bin/repair/availability index a07c293a40..1115b89f9d 100755 --- a/bin/repair/availability +++ b/bin/repair/availability @@ -2,9 +2,11 @@ """Refresh availability information for one or more specific books.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from scripts import AvailabilityRefreshScript + AvailabilityRefreshScript().run() diff --git a/bin/repair/axis_bibliographic_coverage b/bin/repair/axis_bibliographic_coverage index 305c83158f..4d114a5115 100755 --- a/bin/repair/axis_bibliographic_coverage +++ b/bin/repair/axis_bibliographic_coverage @@ -2,9 +2,11 @@ """Make sure all Axis 360 books have bibliographic coverage.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCoverageProviderScript from core.axis import Axis360BibliographicCoverageProvider +from core.scripts import RunCoverageProviderScript + RunCoverageProviderScript(Axis360BibliographicCoverageProvider).run() diff --git a/bin/repair/bibliotheca_bibliographic_coverage b/bin/repair/bibliotheca_bibliographic_coverage index e90b20d72e..c6d3717488 100755 --- a/bin/repair/bibliotheca_bibliographic_coverage +++ b/bin/repair/bibliotheca_bibliographic_coverage @@ -2,9 +2,12 @@ """Make sure all Bibliotheca books have bibliographic coverage.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCollectionCoverageProviderScript from bibliotheca import BibliothecaBibliographicCoverageProvider + +from core.scripts import RunCollectionCoverageProviderScript + RunCollectionCoverageProviderScript(BibliothecaBibliographicCoverageProvider).run() diff --git a/bin/repair/create_edition_for_identifiers_missing_edition b/bin/repair/create_edition_for_identifiers_missing_edition index 1d3b2d5458..d343341377 100755 --- a/bin/repair/create_edition_for_identifiers_missing_edition +++ b/bin/repair/create_edition_for_identifiers_missing_edition @@ -2,8 +2,10 @@ """Use the metadata wrangler to create Editions for identifiers that lack them.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from scripts import CreateWorksForIdentifiersScript + CreateWorksForIdentifiersScript().run() diff --git a/bin/repair/mirror_resources b/bin/repair/mirror_resources index 8a4154a64e..3613da8de8 100755 --- a/bin/repair/mirror_resources +++ b/bin/repair/mirror_resources @@ -2,8 +2,10 @@ """Mirror resources that haven't been mirrored yet.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import MirrorResourcesScript + MirrorResourcesScript().run() diff --git a/bin/repair/opds_entries b/bin/repair/opds_entries index 8fc5ab666b..c095e9ad5c 100755 --- a/bin/repair/opds_entries +++ b/bin/repair/opds_entries @@ -2,11 +2,11 @@ """Ensure that all presentation-ready works have an up-to-date OPDS feed.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.monitor import ( - OPDSEntryCacheMonitor, -) +from core.monitor import OPDSEntryCacheMonitor from core.scripts import RunMonitorScript + RunMonitorScript(OPDSEntryCacheMonitor).run() diff --git a/bin/repair/overdrive_bibliographic_coverage b/bin/repair/overdrive_bibliographic_coverage index 2298854375..9c4218134e 100755 --- a/bin/repair/overdrive_bibliographic_coverage +++ b/bin/repair/overdrive_bibliographic_coverage @@ -2,9 +2,11 @@ """Make sure all Overdrive books have bibliographic coverage.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from core.scripts import RunCoverageProviderScript from core.overdrive import OverdriveBibliographicCoverageProvider +from core.scripts import RunCoverageProviderScript + RunCoverageProviderScript(OverdriveBibliographicCoverageProvider).run() diff --git a/bin/repair/permanent_work_id b/bin/repair/permanent_work_id index da82945bf2..8f100e24cb 100755 --- a/bin/repair/permanent_work_id +++ b/bin/repair/permanent_work_id @@ -2,9 +2,11 @@ """Recalculate all permanent work IDs.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from core.monitor import PermanentWorkIDRefreshMonitor from core.scripts import RunMonitorScript + RunMonitorScript(PermanentWorkIDRefreshMonitor).run() diff --git a/bin/repair/reset_lanes b/bin/repair/reset_lanes index 2375cdca93..3c46b10a97 100755 --- a/bin/repair/reset_lanes +++ b/bin/repair/reset_lanes @@ -2,8 +2,10 @@ """Regenerate the lanes for a library.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from scripts import LaneResetScript + LaneResetScript().run() diff --git a/bin/repair/search_index b/bin/repair/search_index index b61a9a6f75..81f6930b4e 100755 --- a/bin/repair/search_index +++ b/bin/repair/search_index @@ -2,8 +2,10 @@ """Delete the search index for all works and recreate it.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import RebuildSearchIndexScript + RebuildSearchIndexScript().run() diff --git a/bin/repair/set_sort_author_where_missing b/bin/repair/set_sort_author_where_missing index 62cfdcbdc1..4696554c75 100755 --- a/bin/repair/set_sort_author_where_missing +++ b/bin/repair/set_sort_author_where_missing @@ -7,8 +7,10 @@ regularly. """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from scripts import FillInAuthorScript + FillInAuthorScript().run() diff --git a/bin/repair/where_are_my_books b/bin/repair/where_are_my_books index 5ba27beaff..4b61b58716 100755 --- a/bin/repair/where_are_my_books +++ b/bin/repair/where_are_my_books @@ -2,8 +2,10 @@ """Try to figure out why Works are not in the system following an import.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import WhereAreMyBooksScript + WhereAreMyBooksScript().run() diff --git a/bin/repair/work_classification b/bin/repair/work_classification index ac75de69c3..9576ddcb23 100755 --- a/bin/repair/work_classification +++ b/bin/repair/work_classification @@ -2,8 +2,10 @@ """Classify or reclassify works.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import WorkClassificationScript + WorkClassificationScript().run() diff --git a/bin/repair/work_consolidation b/bin/repair/work_consolidation index ddee9c1e06..238e31cfa9 100755 --- a/bin/repair/work_consolidation +++ b/bin/repair/work_consolidation @@ -2,8 +2,10 @@ """Recalculate works for certain license pools.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import WorkConsolidationScript + WorkConsolidationScript(force=False).run() diff --git a/bin/repair/work_opds b/bin/repair/work_opds index 99297b77bc..8a54f07758 100755 --- a/bin/repair/work_opds +++ b/bin/repair/work_opds @@ -2,8 +2,10 @@ """Refresh a work's OPDS feeds.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import WorkOPDSScript + WorkOPDSScript().run() diff --git a/bin/repair/work_recalculate_presentation b/bin/repair/work_recalculate_presentation index 98b417d7c3..45bacedc68 100755 --- a/bin/repair/work_recalculate_presentation +++ b/bin/repair/work_recalculate_presentation @@ -2,8 +2,10 @@ """Recalculate works presentation.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import WorkPresentationScript + WorkPresentationScript().run() diff --git a/bin/search_index_clear b/bin/search_index_clear index c220b513c8..a9a0794c5c 100755 --- a/bin/search_index_clear +++ b/bin/search_index_clear @@ -6,8 +6,10 @@ providing automatic recovery from bugs and major metadata changes. """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import SearchIndexCoverageRemover + SearchIndexCoverageRemover().run() diff --git a/bin/search_index_refresh b/bin/search_index_refresh index dca1f10e16..f0dfb2a865 100755 --- a/bin/search_index_refresh +++ b/bin/search_index_refresh @@ -4,9 +4,11 @@ out of date. """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from core.external_search import SearchIndexCoverageProvider from core.scripts import RunWorkCoverageProviderScript + RunWorkCoverageProviderScript(SearchIndexCoverageProvider).run() diff --git a/bin/shared_odl_import_monitor b/bin/shared_odl_import_monitor index e13e239951..ff4dc6988d 100755 --- a/bin/shared_odl_import_monitor +++ b/bin/shared_odl_import_monitor @@ -3,8 +3,10 @@ shared ODL collections.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from scripts import SharedODLImportScript + SharedODLImportScript().run() diff --git a/bin/subjects_assign b/bin/subjects_assign index d0665b39a8..661f90de0c 100755 --- a/bin/subjects_assign +++ b/bin/subjects_assign @@ -2,8 +2,10 @@ """Assign subjects to genres.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import SubjectAssignmentScript + SubjectAssignmentScript().run() diff --git a/bin/update_custom_list_size b/bin/update_custom_list_size index cc268c3278..cd8519c0dd 100755 --- a/bin/update_custom_list_size +++ b/bin/update_custom_list_size @@ -2,8 +2,10 @@ """Update the cached sizes of all custom lists.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import UpdateCustomListSizeScript + UpdateCustomListSizeScript().run() diff --git a/bin/update_lane_size b/bin/update_lane_size index d1a3d2708c..9d8a2ec485 100755 --- a/bin/update_lane_size +++ b/bin/update_lane_size @@ -2,8 +2,10 @@ """Update the cached sizes of all lanes.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import UpdateLaneSizeScript + UpdateLaneSizeScript().run() diff --git a/bin/update_nyt_best_seller_lists b/bin/update_nyt_best_seller_lists index 927e45b8e7..48fe7efd94 100755 --- a/bin/update_nyt_best_seller_lists +++ b/bin/update_nyt_best_seller_lists @@ -2,12 +2,11 @@ """Bring in the entire history of all NYT best-seller lists.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from scripts import ( - NYTBestSellerListsScript, -) +from scripts import NYTBestSellerListsScript include_history = ('history' in sys.argv) diff --git a/bin/update_staff_picks b/bin/update_staff_picks index 4759c393fc..a678aad520 100755 --- a/bin/update_staff_picks +++ b/bin/update_staff_picks @@ -2,8 +2,10 @@ """Update the staff picks list from a Google Drive spreadsheet.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from scripts import UpdateStaffPicksScript + UpdateStaffPicksScript().run() diff --git a/bin/util/compile_translations b/bin/util/compile_translations index 13b2c5084e..26aa77575a 100755 --- a/bin/util/compile_translations +++ b/bin/util/compile_translations @@ -2,6 +2,7 @@ """Compile translations.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) diff --git a/bin/util/custom_list_from_subject b/bin/util/custom_list_from_subject index 417f7f9d7f..6304852317 100755 --- a/bin/util/custom_list_from_subject +++ b/bin/util/custom_list_from_subject @@ -2,12 +2,13 @@ """Maintain a CustomList containing all books classified under certain subjects.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) from core.external_list import ClassificationBasedMembershipManager -from core.scripts import CustomListManagementScript from core.model import DataSource +from core.scripts import CustomListManagementScript if len(sys.argv) < 6: print "Usage: %s [SHORT_NAME] [HUMAN_READABLE_NAME] [PRIMARY_LANGUAGE] [DESCRIPTION] [SUBJECT] [subject2] ..." % sys.argv[0] diff --git a/bin/util/generate_short_token b/bin/util/generate_short_token index 9289ee7ec9..17b758c04e 100755 --- a/bin/util/generate_short_token +++ b/bin/util/generate_short_token @@ -5,6 +5,7 @@ Generate client short tokens from cli. """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) diff --git a/bin/util/initialize_instance b/bin/util/initialize_instance index aeb8d24aa3..a44b62760a 100755 --- a/bin/util/initialize_instance +++ b/bin/util/initialize_instance @@ -2,6 +2,7 @@ """Initialize an instance of the Circulation Manager""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) diff --git a/bin/work_classification b/bin/work_classification index f212a47654..78be34f72c 100755 --- a/bin/work_classification +++ b/bin/work_classification @@ -3,9 +3,11 @@ """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from core.coverage import WorkClassificationCoverageProvider from core.scripts import RunWorkCoverageProviderScript + RunWorkCoverageProviderScript(WorkClassificationCoverageProvider).run() diff --git a/bin/work_classify_unchecked_subjects b/bin/work_classify_unchecked_subjects index 61f8467a1b..3e7db22655 100755 --- a/bin/work_classify_unchecked_subjects +++ b/bin/work_classify_unchecked_subjects @@ -2,8 +2,10 @@ """(Re)calculate the presentation of works associated with unchecked subjects.""" import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from core.scripts import ReclassifyWorksForUncheckedSubjectsScript + ReclassifyWorksForUncheckedSubjectsScript().run() diff --git a/bin/work_presentation_editions b/bin/work_presentation_editions index fbbf3b1d41..288fc5c2e9 100755 --- a/bin/work_presentation_editions +++ b/bin/work_presentation_editions @@ -3,9 +3,11 @@ """ import os import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from core.coverage import WorkPresentationEditionCoverageProvider from core.scripts import RunWorkCoverageProviderScript + RunWorkCoverageProviderScript(WorkPresentationEditionCoverageProvider).run() diff --git a/docs/source/conf.py b/docs/source/conf.py index 03fae5f64b..8f23bde809 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -8,26 +8,30 @@ # -- Path setup -------------------------------------------------------------- +import datetime + # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os import sys -import datetime -sys.path.insert(0, os.path.abspath('../../')) + +sys.path.insert(0, os.path.abspath("../../")) # -- Project information ----------------------------------------------------- year = datetime.datetime.now().year -project = 'Library Simplified Circulation Manager' -copyright = '%s, The New York Public Library, Astor, Lenox, and Tilden Foundations' % year -author = 'Library Simplified' +project = "Library Simplified Circulation Manager" +copyright = ( + "%s, The New York Public Library, Astor, Lenox, and Tilden Foundations" % year +) +author = "Library Simplified" # The short X.Y version -version = '' +version = "" # The full version, including alpha/beta/rc tags -release = '' +release = "" # -- General configuration --------------------------------------------------- @@ -39,23 +43,19 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.viewcode', - 'sphinx.ext.githubpages' -] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.viewcode", "sphinx.ext.githubpages"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -78,7 +78,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -89,7 +89,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['../build/html/_static'] +html_static_path = ["../build/html/_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -105,7 +105,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'LibrarySimplifiedCirculationManagerdoc' +htmlhelp_basename = "LibrarySimplifiedCirculationManagerdoc" # -- Options for LaTeX output ------------------------------------------------ @@ -114,15 +114,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -132,9 +129,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'LibrarySimplifiedCirculationManager.tex', - 'Library Simplified Circulation Manager Documentation', - 'Library Simplified', 'manual'), + ( + master_doc, + "LibrarySimplifiedCirculationManager.tex", + "Library Simplified Circulation Manager Documentation", + "Library Simplified", + "manual", + ), ] @@ -143,9 +144,13 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'librarysimplifiedcirculationmanager', - 'Library Simplified Circulation Manager Documentation', - [author], 1) + ( + master_doc, + "librarysimplifiedcirculationmanager", + "Library Simplified Circulation Manager Documentation", + [author], + 1, + ) ] @@ -155,10 +160,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'LibrarySimplifiedCirculationManager', - 'Library Simplified Circulation Manager Documentation', - author, 'LibrarySimplifiedCirculationManager', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "LibrarySimplifiedCirculationManager", + "Library Simplified Circulation Manager Documentation", + author, + "LibrarySimplifiedCirculationManager", + "One line description of project.", + "Miscellaneous", + ), ] @@ -177,7 +187,7 @@ # epub_uid = '' # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- diff --git a/integration_tests/benchmark_feed_queries.py b/integration_tests/benchmark_feed_queries.py index ceac5b9630..ba17da69cb 100644 --- a/integration_tests/benchmark_feed_queries.py +++ b/integration_tests/benchmark_feed_queries.py @@ -1,16 +1,15 @@ # encoding: utf-8 -from pdb import set_trace import random import time -import numpy +from pdb import set_trace from threading import Thread -from urllib.parse import urlencode, quote +from urllib.parse import quote, urlencode -import random +import numpy import requests -class QueryTimingThread(Thread): +class QueryTimingThread(Thread): def __init__(self, urls): Thread.__init__(self) self.urls = urls @@ -21,7 +20,7 @@ def run(self): for url in self.urls: a = time.time() exception = self.do_query(url) - self.elapsed.append(time.time()-a) + self.elapsed.append(time.time() - a) if exception: self.exceptions.append((url, exception)) @@ -50,6 +49,7 @@ def report(self): print("Exception: %s: %s" % (url, e)) print("") + size = 50 pages = 10 thread_count = 10 @@ -57,229 +57,150 @@ def report(self): queries = [ { - 'language': 'eng', - 'category': 'Adult Fiction', - 'params': { - 'order': 'author', - 'available': 'now', - 'collection': 'full' - } + "language": "eng", + "category": "Adult Fiction", + "params": {"order": "author", "available": "now", "collection": "full"}, }, { - 'language': 'eng', - 'category': 'Adult Fiction', - 'params': { - 'order': 'title', - 'available': 'all', - 'collection': 'main' - } + "language": "eng", + "category": "Adult Fiction", + "params": {"order": "title", "available": "all", "collection": "main"}, }, { - 'language': 'eng', - 'category': 'Adult Nonfiction', - 'params': { - 'order': 'author', - 'available': 'now', - 'collection': 'main' - } + "language": "eng", + "category": "Adult Nonfiction", + "params": {"order": "author", "available": "now", "collection": "main"}, }, { - 'language': 'eng', - 'category': 'Adult Nonfiction', - 'params': { - 'order': 'title', - 'available': 'all', - 'collection': 'featured' - } + "language": "eng", + "category": "Adult Nonfiction", + "params": {"order": "title", "available": "all", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'English Best Sellers', - 'params': { - 'order': 'author', - 'available': 'all', - 'collection': 'featured' - } + "language": "eng", + "category": "English Best Sellers", + "params": {"order": "author", "available": "all", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'Young Adult Fiction', - 'params': { - 'order': 'added', - 'available': 'all', - 'collection': 'main' - } + "language": "eng", + "category": "Young Adult Fiction", + "params": {"order": "added", "available": "all", "collection": "main"}, }, { - 'language': 'eng', - 'category': 'Children and Middle Grade', - 'params': { - 'order': 'author', - 'available': 'now', - 'collection': 'featured' - } + "language": "eng", + "category": "Children and Middle Grade", + "params": {"order": "author", "available": "now", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'Adventure', - 'params': { - 'order': 'author', - 'available': 'main', - 'collection': 'featured' - } + "language": "eng", + "category": "Adventure", + "params": {"order": "author", "available": "main", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'Classics', - 'params': { - 'order': 'title', - 'available': 'now', - 'collection': 'full' - } + "language": "eng", + "category": "Classics", + "params": {"order": "title", "available": "now", "collection": "full"}, }, { - 'language': 'eng', - 'category': 'Police Procedural', - 'params': { - 'order': 'title', - 'available': 'now', - 'collection': 'featured' - } + "language": "eng", + "category": "Police Procedural", + "params": {"order": "title", "available": "now", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'Biography & Memoir', - 'params': { - 'order': 'author', - 'available': 'always', - 'collection': 'main' - } + "language": "eng", + "category": "Biography & Memoir", + "params": {"order": "author", "available": "always", "collection": "main"}, }, { - 'language': 'eng', - 'category': 'Business', - 'params': { - 'order': 'added', - 'available': 'now', - 'collection': 'full' - } + "language": "eng", + "category": "Business", + "params": {"order": "added", "available": "now", "collection": "full"}, }, { - 'language': 'eng', - 'category': 'Parenting & Family', - 'params': { - 'order': 'author', - 'available': 'all', - 'collection': 'featured' - } + "language": "eng", + "category": "Parenting & Family", + "params": {"order": "author", "available": "all", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'Cooking', - 'params': { - 'order': 'title', - 'available': 'all', - 'collection': 'featured' - } + "language": "eng", + "category": "Cooking", + "params": {"order": "title", "available": "all", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'Latin American History', - 'params': { - 'order': 'author', - 'available': 'all', - 'collection': 'main' - } + "language": "eng", + "category": "Latin American History", + "params": {"order": "author", "available": "all", "collection": "main"}, }, { - 'language': 'eng', - 'category': 'Pets', - 'params': { - 'order': 'title', - 'available': 'now', - 'collection': 'featured' - } + "language": "eng", + "category": "Pets", + "params": {"order": "title", "available": "now", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'Photography', - 'params': { - 'order': 'author', - 'available': 'now', - 'collection': 'featured' - } + "language": "eng", + "category": "Photography", + "params": {"order": "author", "available": "now", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'Music', - 'params': { - 'order': 'added', - 'available': 'now', - 'collection': 'featured' - } + "language": "eng", + "category": "Music", + "params": {"order": "added", "available": "now", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'Life Strategies', - 'params': { - 'order': 'title', - 'available': 'all', - 'collection': 'main' - } + "language": "eng", + "category": "Life Strategies", + "params": {"order": "title", "available": "all", "collection": "main"}, }, { - 'language': 'eng', - 'category': 'Buddhism', - 'params': { - 'order': 'author', - 'available': 'all', - 'collection': 'featured' - } + "language": "eng", + "category": "Buddhism", + "params": {"order": "author", "available": "all", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'Computers', - 'params': { - 'order': 'added', - 'available': 'now', - 'collection': 'featured' - } + "language": "eng", + "category": "Computers", + "params": {"order": "added", "available": "now", "collection": "featured"}, }, { - 'language': 'eng', - 'category': 'Self Help', - 'params': { - 'order': 'author', - 'available': 'all', - 'collection': 'full' - } + "language": "eng", + "category": "Self Help", + "params": {"order": "author", "available": "all", "collection": "full"}, }, { - 'language': 'eng', - 'category': 'True Crime', - 'params': { - 'order': 'title', - 'available': 'all', - 'collection': 'full' - } - } + "language": "eng", + "category": "True Crime", + "params": {"order": "title", "available": "all", "collection": "full"}, + }, ] + def urls_from_query(query, pages, size): urls = [] for i in range(pages): if i > 0: - query['params']['after'] = i * size - url = quote("%s/feed/%s/%s?%s" % ( - base_url, query['language'], query['category'], urlencode(query['params'])), safe=':/?=&') + query["params"]["after"] = i * size + url = quote( + "%s/feed/%s/%s?%s" + % ( + base_url, + query["language"], + query["category"], + urlencode(query["params"]), + ), + safe=":/?=&", + ) urls.append(url) return urls -threads = [QueryTimingThread(urls=urls_from_query(random.choice(queries), pages, size)) for i in range(thread_count)] + +threads = [ + QueryTimingThread(urls=urls_from_query(random.choice(queries), pages, size)) + for i in range(thread_count) +] for t in threads: t.start() for t in threads: t.join() for t in threads: - t.report() \ No newline at end of file + t.report() diff --git a/integration_tests/test_borrow.py b/integration_tests/test_borrow.py index 22b22972f4..a1ea015c8f 100644 --- a/integration_tests/test_borrow.py +++ b/integration_tests/test_borrow.py @@ -1,40 +1,51 @@ +import os +import feedparser import requests from requests.auth import HTTPBasicAuth -import feedparser -import os from . import CirculationIntegrationTest -class TestBorrow(CirculationIntegrationTest): +class TestBorrow(CirculationIntegrationTest): def test_borrow(self): - if 'TEST_IDENTIFIER' in os.environ: - overdrive_id = os.environ['TEST_IDENTIFIER'] + if "TEST_IDENTIFIER" in os.environ: + overdrive_id = os.environ["TEST_IDENTIFIER"] else: # Fifty Shades of Grey has a large number of copies available overdrive_id = "82cdd641-857a-45ca-8775-34eede35b238" borrow_url = "%sworks/Overdrive/%s/borrow" % (self.url, overdrive_id) - borrow_response = requests.get(borrow_url, auth=HTTPBasicAuth(self.test_username, self.test_password)) + borrow_response = requests.get( + borrow_url, auth=HTTPBasicAuth(self.test_username, self.test_password) + ) # it's possible we already have the book borrowed, if a previous test didn't revoke it assert borrow_response.status_code in [200, 201] feed = feedparser.parse(borrow_response.text) - entries = feed['entries'] + entries = feed["entries"] eq_(1, len(entries)) entry = entries[0] - links = entry['links'] - fulfill_links = [link for link in links if link.rel == "http://opds-spec.org/acquisition"] + links = entry["links"] + fulfill_links = [ + link for link in links if link.rel == "http://opds-spec.org/acquisition" + ] assert len(fulfill_links) > 0 fulfill_url = fulfill_links[0].href - fulfill_response = requests.get(fulfill_url, auth=HTTPBasicAuth(self.test_username, self.test_password)) + fulfill_response = requests.get( + fulfill_url, auth=HTTPBasicAuth(self.test_username, self.test_password) + ) eq_(200, fulfill_response.status_code) - - revoke_links = [link for link in links if link.rel == "http://librarysimplified.org/terms/rel/revoke"] + revoke_links = [ + link + for link in links + if link.rel == "http://librarysimplified.org/terms/rel/revoke" + ] eq_(1, len(revoke_links)) revoke_url = revoke_links[0].href - revoke_response = requests.get(revoke_url, auth=HTTPBasicAuth(self.test_username, self.test_password)) + revoke_response = requests.get( + revoke_url, auth=HTTPBasicAuth(self.test_username, self.test_password) + ) eq_(200, revoke_response.status_code) diff --git a/integration_tests/test_circulation.py b/integration_tests/test_circulation.py index 9f3ee25f69..74ada23e50 100644 --- a/integration_tests/test_circulation.py +++ b/integration_tests/test_circulation.py @@ -1,34 +1,34 @@ #!/usr/bin/env python -import random -import sys import os +import random import sys + bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) +from axis import Axis360API +from circulation_exceptions import * +from overdrive import OverdriveAPI +from threem import ThreeMAPI + +from circulation import CirculationAPI from core.model import ( - get_one_or_create, - production_session, DataSource, Identifier, - LicensePool, Patron, - ) -from threem import ThreeMAPI -from overdrive import OverdriveAPI -from axis import Axis360API - -from circulation import CirculationAPI -from circulation_exceptions import * + get_one_or_create, + production_session, +) barcode, pin, borrow_urn, hold_urn = sys.argv[1:5] -email = os.environ.get('DEFAULT_NOTIFICATION_EMAIL_ADDRESS', 'test@librarysimplified.org') +email = os.environ.get( + "DEFAULT_NOTIFICATION_EMAIL_ADDRESS", "test@librarysimplified.org" +) _db = production_session() -patron, ignore = get_one_or_create( - _db, Patron, authorization_identifier=barcode) +patron, ignore = get_one_or_create(_db, Patron, authorization_identifier=barcode) borrow_identifier = Identifier.parse_urn(_db, borrow_urn, True)[0] hold_identifier = Identifier.parse_urn(_db, hold_urn, True)[0] @@ -50,14 +50,13 @@ else: axis = None -circulation = CirculationAPI(_db, overdrive=overdrive, threem=threem, - axis=axis) +circulation = CirculationAPI(_db, overdrive=overdrive, threem=threem, axis=axis) activity = circulation.patron_activity(patron, pin) -print('-' * 80) +print("-" * 80) for i in activity: print(i) -print('-' * 80) +print("-" * 80) licensepool = borrow_pool mechanism = licensepool.delivery_mechanisms[0] @@ -100,10 +99,10 @@ print(" Exception as expected") activity = circulation.patron_activity(patron, pin) -print('-' * 80) +print("-" * 80) for i in activity: print(i) -print('-' * 80) +print("-" * 80) print("Revoke loan") print(circulation.revoke_loan(patron, pin, licensepool)) @@ -114,4 +113,3 @@ print(circulation.release_hold(patron, pin, licensepool)) print("Release nonexistent hold.") print(circulation.release_hold(patron, pin, licensepool)) - diff --git a/integration_tests/test_feed.py b/integration_tests/test_feed.py index 4e64cb0e0f..add67b26f8 100644 --- a/integration_tests/test_feed.py +++ b/integration_tests/test_feed.py @@ -1,42 +1,45 @@ - +import os from urllib.request import urlopen + import feedparser -import os from . import CirculationIntegrationTest -class TestFeed(CirculationIntegrationTest): +class TestFeed(CirculationIntegrationTest): def test_grouped_feed(self): feed_url = self.url feed = urlopen(feed_url).read() feed = feedparser.parse(str(feed)) - entries = feed['entries'] + entries = feed["entries"] assert len(entries) > 20 # spot-check an entry entry = entries[5] - assert len(entry.get('title')) > 0 - assert len(entry.get('author')) > 0 - links = entry.get('links') + assert len(entry.get("title")) > 0 + assert len(entry.get("author")) > 0 + links = entry.get("links") assert len(links) > 0 # books on the first page should be available to borrow - borrow_links = [link for link in links if link.rel == "http://opds-spec.org/acquisition/borrow"] + borrow_links = [ + link + for link in links + if link.rel == "http://opds-spec.org/acquisition/borrow" + ] eq_(1, len(borrow_links)) def test_genre_feed(self): - if 'TEST_FEED_PATH' in os.environ: - path = os.environ['TEST_FEED_PATH'] + if "TEST_FEED_PATH" in os.environ: + path = os.environ["TEST_FEED_PATH"] else: path = "eng/Romance" feed_url = "%sfeed/%s" % (self.url, path) feed = urlopen(feed_url).read() feed = feedparser.parse(str(feed)) - entries = feed['entries'] + entries = feed["entries"] assert len(entries) > 20 # spot-check an entry entry = entries[5] - assert len(entry.get('title')) > 0 - assert len(entry.get('author')) > 0 - links = entry.get('links') + assert len(entry.get("title")) > 0 + assert len(entry.get("author")) > 0 + links = entry.get("links") assert len(links) > 0 - diff --git a/integration_tests/test_hold.py b/integration_tests/test_hold.py index df75ce5811..515059fffd 100644 --- a/integration_tests/test_hold.py +++ b/integration_tests/test_hold.py @@ -1,41 +1,49 @@ +import os +import feedparser import requests from requests.auth import HTTPBasicAuth -import feedparser -import os from . import CirculationIntegrationTest -class TestHold(CirculationIntegrationTest): +class TestHold(CirculationIntegrationTest): def test_hold(self): - if 'TEST_IDENTIFIER' in os.environ: - overdrive_id = os.environ['TEST_IDENTIFIER'] + if "TEST_IDENTIFIER" in os.environ: + overdrive_id = os.environ["TEST_IDENTIFIER"] else: # Yes Please has a large hold queue overdrive_id = "0abe1ed3-f117-4b7c-a6b0-857a2e7d227b" borrow_url = "%sworks/Overdrive/%s/borrow" % (self.url, overdrive_id) - borrow_response = requests.get(borrow_url, auth=HTTPBasicAuth(self.test_username, self.test_password)) + borrow_response = requests.get( + borrow_url, auth=HTTPBasicAuth(self.test_username, self.test_password) + ) # it's possible we already have the book on hold, if a previous test didn't revoke it assert borrow_response.status_code in [200, 201] feed = feedparser.parse(borrow_response.text) - entries = feed['entries'] + entries = feed["entries"] eq_(1, len(entries)) entry = entries[0] - availability = entry['opds_availability'] - eq_("reserved", availability['status']) + availability = entry["opds_availability"] + eq_("reserved", availability["status"]) - links = entry['links'] - fulfill_links = [link for link in links if link.rel == "http://opds-spec.org/acquisition"] + links = entry["links"] + fulfill_links = [ + link for link in links if link.rel == "http://opds-spec.org/acquisition" + ] eq_(0, len(fulfill_links)) - revoke_links = [link for link in links if link.rel == "http://librarysimplified.org/terms/rel/revoke"] + revoke_links = [ + link + for link in links + if link.rel == "http://librarysimplified.org/terms/rel/revoke" + ] eq_(1, len(revoke_links)) revoke_url = revoke_links[0].href - revoke_response = requests.get(revoke_url, auth=HTTPBasicAuth(self.test_username, self.test_password)) + revoke_response = requests.get( + revoke_url, auth=HTTPBasicAuth(self.test_username, self.test_password) + ) eq_(200, revoke_response.status_code) - - diff --git a/integration_tests/test_search.py b/integration_tests/test_search.py index 652d09b1ae..6c7daeedd5 100644 --- a/integration_tests/test_search.py +++ b/integration_tests/test_search.py @@ -26,22 +26,16 @@ # # $ nosetests integration_tests/test_search.py -from functools import wraps +import json import logging -import urllib.parse - import os import re -import json -from core.model import ( - production_session, - Library, -) +import urllib.parse +from functools import wraps + +from core.external_search import ExternalSearchIndex, Filter from core.lane import Pagination -from core.external_search import ( - ExternalSearchIndex, - Filter -) +from core.model import Library, production_session from core.util.personal_names import ( display_name_to_sort_name, sort_name_to_display_name, @@ -60,6 +54,7 @@ # self.book_by_someone_else, # match across fields (no 'biography') # ] + def known_to_fail(f): @wraps(f) def decorated(*args, **kwargs): @@ -71,10 +66,13 @@ def decorated(*args, **kwargs): return SearchTest.unexpected_successes.append(f) raise Exception("Expected this test to fail, and it didn't! Congratulations?") + return decorated + class Searcher(object): """A class that knows how to perform searches.""" + def __init__(self, library, index): self.library = library self.filter = Filter(collections=self.library) @@ -82,8 +80,7 @@ def __init__(self, library, index): def query(self, query, pagination): return self.index.query_works( - query, filter=self.filter, pagination=pagination, - debug=True + query, filter=self.filter, pagination=pagination, debug=True ) @@ -142,7 +139,7 @@ def format(self, result): _type=result.meta.doc_type, _id=result.meta.id, _index=result.meta.index, - source=source + source=source, ) def _field(self, field, result=None): @@ -164,16 +161,14 @@ def assert_ratio(self, matches, hits, threshold): if actual < threshold: # This test is going to fail. Log some useful information. logging.info( - "Need %d%% matches, got %d%%" % ( - threshold*100, actual*100 - ) + "Need %d%% matches, got %d%%" % (threshold * 100, actual * 100) ) for hit in hits: logging.info(repr(hit)) assert actual >= threshold def _match_scalar(self, value, expect, inclusive=False, case_sensitive=False): - if hasattr(expect, 'search'): + if hasattr(expect, "search"): if expect and value is not None: success = expect.search(value) else: @@ -183,9 +178,9 @@ def _match_scalar(self, value, expect, inclusive=False, case_sensitive=False): if value and not case_sensitive: value = value.lower() if inclusive: - success = (expect in value) + success = expect in value else: - success = (value == expect) + success = value == expect expect_str = expect return success, expect_str @@ -193,8 +188,8 @@ def _match_subject(self, subject, result): """Is the given result classified under the given subject?""" values = [] expect_str = subject - for classification in (result.classifications or []): - value = classification['term'].lower() + for classification in result.classifications or []: + value = classification["term"].lower() values.append(value) success, expect_str = self._match_scalar(value, subject) if success: @@ -205,8 +200,8 @@ def _match_genre(self, subject, result): """Is the given result classified under the given genre?""" values = [] expect_str = subject - for genre in (result.genres or []): - value = genre['name'].lower() + for genre in result.genres or []: + value = genre["name"].lower() values.append(value) success, expect_str = self._match_scalar(value, subject) if success: @@ -234,10 +229,11 @@ def _match_author(self, author, result): if not contributor.role in Filter.AUTHOR_MATCH_ROLES: continue names = [ - contributor[field].lower() for field in ['display_name', 'sort_name'] + contributor[field].lower() + for field in ["display_name", "sort_name"] if contributor[field] ] - if hasattr(author, 'match'): + if hasattr(author, "match"): match = any(author.search(name) for name in names) else: match = any(author == name for name in names) @@ -251,16 +247,16 @@ def match_result(self, result): for field, expect in list(self.kwargs.items()): fields = None - if field == 'subject': + if field == "subject": success, value, expect_str = self._match_subject(expect, result) - elif field == 'genre': + elif field == "genre": success, value, expect_str = self._match_genre(expect, result) - elif field == 'target_age': + elif field == "target_age": success, value, expect_str = self._match_target_age(expect, result) - elif field == 'author': + elif field == "author": success, value, expect_str = self._match_author(expect, result) - elif field == 'title_or_subtitle': - fields = ['title', 'subtitle'] + elif field == "title_or_subtitle": + fields = ["title", "subtitle"] else: fields = [field] if fields: @@ -294,10 +290,11 @@ def evaluate(self, hits): class Common(Evaluator): - """It must be common for the results to match certain criteria. - """ - def __init__(self, threshold=0.5, minimum=None, first_must_match=True, - negate=False, **kwargs): + """It must be common for the results to match certain criteria.""" + + def __init__( + self, threshold=0.5, minimum=None, first_must_match=True, negate=False, **kwargs + ): """Constructor :param threshold: A proportion of the search results must @@ -318,13 +315,15 @@ def __init__(self, threshold=0.5, minimum=None, first_must_match=True, def evaluate_first(self, hit): if self.first_must_match: success, actual, expected = self.match_result(hit) - if hasattr(actual, 'match'): + if hasattr(actual, "match"): actual = actual.pattern if (not success) or (self.negate and success): if self.negate: if actual == expected: logging.info( - "First result matched and shouldn't have. %s == %s", expected, actual + "First result matched and shouldn't have. %s == %s", + expected, + actual, ) assert actual != expected else: @@ -341,23 +340,21 @@ def evaluate_hits(self, hits): if self.threshold is not None: self.assert_ratio( [x[1:] for x in successes], - [x[1:] for x in successes+failures], - self.threshold + [x[1:] for x in successes + failures], + self.threshold, ) if self.minimum is not None: overall_success = len(successes) >= self.minimum if not overall_success: - logging.info( - "Need %d matches, got %d" % (self.minimum, len(successes)) - ) - for i in (successes+failures): + logging.info("Need %d matches, got %d" % (self.minimum, len(successes))) + for i in successes + failures: if i in successes: - template = 'Y (%s == %s)' + template = "Y (%s == %s)" else: - template = 'N (%s != %s)' + template = "N (%s != %s)" vars = [] for display in i[1:]: - if hasattr(display, 'match'): + if hasattr(display, "match"): display = display.pattern vars.append(display) logging.info(template % tuple(vars)) @@ -366,8 +363,9 @@ def evaluate_hits(self, hits): class Uncommon(Common): """The given match must seldom or never happen.""" + def __init__(self, threshold=1, **kwargs): - kwargs['negate'] = True + kwargs["negate"] = True super(Uncommon, self).__init__(threshold=threshold, **kwargs) @@ -375,7 +373,7 @@ class FirstMatch(Common): """The first result must match certain criteria.""" def __init__(self, **kwargs): - threshold = kwargs.pop('threshold', None) + threshold = kwargs.pop("threshold", None) super(FirstMatch, self).__init__( threshold=threshold, first_must_match=True, **kwargs ) @@ -391,6 +389,7 @@ def __init__(self, **kwargs): class SpecificGenre(Common): pass + class SpecificAuthor(FirstMatch): """The first result must be by a specific author. @@ -405,15 +404,17 @@ def __init__(self, author, accept_title=None, threshold=0): self.accept_title = None def author_role(self, expect_author, result): - if hasattr(expect_author, 'match'): + if hasattr(expect_author, "match"): + def match(author): - return ( - expect_author.search(author.display_name) - or expect_author.search(author.sort_name) - ) + return expect_author.search( + author.display_name + ) or expect_author.search(author.sort_name) + else: expect_author_sort = display_name_to_sort_name(expect_author) expect_author_display = sort_name_to_display_name(expect_author) + def match(author): return ( contributor.display_name == expect_author @@ -421,6 +422,7 @@ def match(author): or contributor.sort_name == expect_author_sort or contributor.display_name == expect_author_display ) + for contributor in result.contributors or []: if match(contributor): return contributor.role @@ -428,24 +430,26 @@ def match(author): return None def evaluate_first(self, first): - expect = self.original_kwargs['author'] + expect = self.original_kwargs["author"] if self.author_role(expect, first) is not None: return True - title = self._field('title', first) - subtitle = self._field('subtitle', first) - if self.accept_title and (self.accept_title in title or self.accept_title in subtitle): + title = self._field("title", first) + subtitle = self._field("subtitle", first) + if self.accept_title and ( + self.accept_title in title or self.accept_title in subtitle + ): return True # We have failed. - if hasattr(expect, 'match'): + if hasattr(expect, "match"): expect = expect.pattern eq_(expect, first.contributors) def evaluate_hits(self, hits): last_role = None last_title = None - author = self.original_kwargs['author'] + author = self.original_kwargs["author"] authors = [hit.contributors for hit in hits] author_matches = [] for hit in hits: @@ -470,8 +474,8 @@ def evaluate(self, results): self.assert_ratio(successes, diagnostics, self.threshold) def evaluate_one(self, result): - expect_author = self.kwargs.get('author') - expect_series = self.kwargs.get('series') + expect_author = self.kwargs.get("author") + expect_series = self.kwargs.get("series") # Ideally a series match happens in the .series, but sometimes # it happens in the .title. @@ -489,13 +493,18 @@ def evaluate_one(self, result): # a matching title by a different author is not part of the # series. if expect_author: - author_match, match, details = self._match_author( - expect_author, result - ) + author_match, match, details = self._match_author(expect_author, result) else: author_match = True - actual = (actual_series, actual_title, result.author, - result.sort_author, series_match, title_match, author_match) + actual = ( + actual_series, + actual_title, + result.author, + result.sort_author, + series_match, + title_match, + author_match, + ) return (series_match or title_match) and author_match, actual @@ -520,10 +529,12 @@ def search(self, query, evaluators=None, limit=10): for e in evaluators: e.evaluate(hits) + class VariantSearchTest(SearchTest): """A test suite that runs different searches but evaluates the results against the same evaluator every time. """ + EVALUATOR = None def search(self, query): @@ -537,24 +548,21 @@ def test_junk(self): # Test one long string self.search( "rguhriregiuh43pn5rtsadpfnsadfausdfhaspdiufnhwe42uhdsaipfh", - ReturnsNothing() + ReturnsNothing(), ) def test_multi_word_junk(self): # Test several short strings self.search( "rguhriregiuh 43pn5rts adpfnsadfaus dfhaspdiufnhwe4 2uhdsaipfh", - ReturnsNothing() + ReturnsNothing(), ) def test_wordlike_junk(self): # To a human eye this is obviously gibberish, but it's close # enough to English words that it might pick up a few results # on a fuzzy match. - self.search( - "asdfza oiagher ofnalqk", - ReturnsNothing() - ) + self.search("asdfza oiagher ofnalqk", ReturnsNothing()) class TestTitleMatch(SearchTest): @@ -569,15 +577,19 @@ def test_simple_title_match_bookshop(self): self.search("the bookshop", FirstMatch(title="The Bookshop")) def test_simple_title_match_house(self): - self.search("A house for Mr. biswas", FirstMatch(title="A House for Mr. Biswas")) + self.search( + "A house for Mr. biswas", FirstMatch(title="A House for Mr. Biswas") + ) def test_simple_title_match_clique(self): self.search("clique", FirstMatch(title="The Clique")) def test_simple_title_match_assassin(self): - self.search("blind assassin", FirstMatch( - title=re.compile("^(the )?blind assassin$"), - author="Margaret Atwood") + self.search( + "blind assassin", + FirstMatch( + title=re.compile("^(the )?blind assassin$"), author="Margaret Atwood" + ), ) def test_simple_title_match_dry(self): @@ -590,10 +602,7 @@ def test_simple_title_match_goldfinch(self): # This book is available as both "The Goldfinch" and "Goldfinch" self.search( "goldfinch", - FirstMatch( - title=re.compile("^(the )?goldfinch$"), - author="Donna Tartt" - ) + FirstMatch(title=re.compile("^(the )?goldfinch$"), author="Donna Tartt"), ) def test_simple_title_match_beach(self): @@ -601,15 +610,11 @@ def test_simple_title_match_beach(self): def test_simple_title_match_testing(self): self.search( - "The testing", - FirstMatch(title="The testing", author='Joelle Charbonneau') + "The testing", FirstMatch(title="The testing", author="Joelle Charbonneau") ) def test_simple_title_twentysomething(self): - self.search( - "Twentysomething", - FirstMatch(title="Twentysomething") - ) + self.search("Twentysomething", FirstMatch(title="Twentysomething")) def test_simple_title_match_bell_jar(self): # NOTE: this works on ES6. On ES1, the top result is the Sparknotes for @@ -618,51 +623,41 @@ def test_simple_title_match_bell_jar(self): self.search("bell jar", FirstMatch(author="Sylvia Plath")) def test_simple_title_match_androids(self): - self.search("Do androids dream of electric sheep", - FirstMatch(title="Do Androids Dream of Electric Sheep?")) + self.search( + "Do androids dream of electric sheep", + FirstMatch(title="Do Androids Dream of Electric Sheep?"), + ) def test_genius_foods(self): # In addition to an exact title match, we also check that # food-related books show up in the search results. self.search( - "genius foods", [ + "genius foods", + [ FirstMatch(title="Genius Foods"), - Common( - genre=re.compile("(cook|diet)"), - threshold=0.2 - ) - ] + Common(genre=re.compile("(cook|diet)"), threshold=0.2), + ], ) - def test_it(self): # The book "It" is correctly prioritized over books whose titles contain # the word "it." - self.search( - "It", - FirstMatch(title="It") - ) + self.search("It", FirstMatch(title="It")) def test_girl_on_the_train(self): # There's a different book called "The Girl in the Train". - self.search( - "girl on the train", - FirstMatch(title="The Girl On The Train") - ) + self.search("girl on the train", FirstMatch(title="The Girl On The Train")) + class TestPosessives(SearchTest): """Test searches for book titles that contain posessives.""" def test_washington_partial(self): - self.search( - "washington war", - AtLeastOne(title="George Washington's War") - ) + self.search("washington war", AtLeastOne(title="George Washington's War")) def test_washington_full_no_apostrophe(self): self.search( - "george washingtons war", - FirstMatch(title="George Washington's War") + "george washingtons war", FirstMatch(title="George Washington's War") ) @known_to_fail @@ -675,22 +670,15 @@ def test_washington_partial_apostrophe(self): # # Since most people don't type the apostrophe, the tradeoff is # worth it. - self.search( - "washington's war", - FirstMatch(title="George Washington's War") - ) + self.search("washington's war", FirstMatch(title="George Washington's War")) def test_washington_full_apostrophe(self): self.search( - "george washington's war", - FirstMatch(title="George Washington's War") + "george washington's war", FirstMatch(title="George Washington's War") ) def test_bankers(self): - self.search( - "bankers wife", - FirstMatch(title="The Banker's Wife") - ) + self.search("bankers wife", FirstMatch(title="The Banker's Wife")) def test_brother(self): # The entire posessive is omitted. @@ -727,6 +715,7 @@ def test_police_women_extra_space(self): FirstMatch(title="The Policewomen's Bureau"), ) + class TestSynonyms(SearchTest): # Test synonyms that could be (but currently aren't) defined in # the search index. @@ -750,6 +739,7 @@ def test_ampersand_is_and(self): FirstMatch(title="Master & Apprentice (Star Wars)"), ) + class TestUnownedTitle(SearchTest): # These are title searches for books not owned by NYPL. # Because of this we check that _similar_ books are returned. @@ -761,27 +751,24 @@ def test_boy_saved_baseball(self): # The target title ("The Boy who Saved Baseball") isn't in the # collection, but, ideally, most of the top results should # still be about baseball. - self.search( - "boy saved baseball", - Common(subject=re.compile("baseball")) - ) + self.search("boy saved baseball", Common(subject=re.compile("baseball"))) def test_save_cat(self): # This specific book isn't in the collection, but there's a # book with a very similar title, which is the first result. self.search( "Save the Cat", - [Common(title=re.compile("save the cat"), threshold=0.1), - Common(title=re.compile("(save|cat)"), threshold=0.6)] + [ + Common(title=re.compile("save the cat"), threshold=0.1), + Common(title=re.compile("(save|cat)"), threshold=0.6), + ], ) def test_minecraft_zombie(self): # We don't have this specific title, but there's no shortage of # Minecraft books. self.search( - "Diary of a minecraft zombie", - Common(summary=re.compile("minecraft", re.I)) - + "Diary of a minecraft zombie", Common(summary=re.compile("minecraft", re.I)) ) def test_pie(self): @@ -794,19 +781,21 @@ def test_divorce(self): # Despite the 'children', we are not looking for children's # books. We're looking for books for grown-ups about divorce. self.search( - "The truth about children and divorce", [ + "The truth about children and divorce", + [ Common(audience="adult"), AtLeastOne(subject=re.compile("divorce")), - ] + ], ) def test_patterns_of_fashion(self): # This specific title isn't in the collection, but the results # should still be reasonably relevant. self.search( - "Patterns of fashion", [ + "Patterns of fashion", + [ Common(subject=re.compile("crafts"), first_must_match=False), - ] + ], ) def test_unowned_partial_title_rosetta_stone(self): @@ -816,10 +805,7 @@ def test_unowned_partial_title_rosetta_stone(self): # less relevant to the user, but still reasonable. Instead, the first result # is a memoir by an author whose first name is "Rosetta." - self.search( - "Rosetta stone", - FirstMatch(title=re.compile("(rosetta|stone)")) - ) + self.search("Rosetta stone", FirstMatch(title=re.compile("(rosetta|stone)"))) @known_to_fail def test_unowned_misspelled_partial_title_cosmetics(self): @@ -829,9 +815,10 @@ def test_unowned_misspelled_partial_title_cosmetics(self): # something to do with cosmetics; instead, they're about # comets. self.search( - "Cometics counter", [ + "Cometics counter", + [ AtLeastOne(title=re.compile("cosmetics")), - ] + ], ) def test_title_match_with_genre_name(self): @@ -843,8 +830,7 @@ def test_title_match_with_genre_name(self): # search query but doesn't say "spy" anywhere. self.search( "My life as a spy", - Common(title_or_subtitle=re.compile("life|spy"), - threshold=0.5) + Common(title_or_subtitle=re.compile("life|spy"), threshold=0.5), ) @known_to_fail @@ -862,50 +848,33 @@ class TestMisspelledTitleSearch(SearchTest): @known_to_fail def test_allegiant(self): # A very bad misspelling. - self.search( - "alliagent", - FirstMatch(title="Allegiant") - ) + self.search("alliagent", FirstMatch(title="Allegiant")) def test_marriage_lie(self): - self.search( - "Marriage liez", - FirstMatch(title="The Marriage Lie") - ) + self.search("Marriage liez", FirstMatch(title="The Marriage Lie")) def test_invisible_emmie(self): # One word in the title is slightly misspelled. - self.search( - "Ivisible emmie", FirstMatch(title="Invisible Emmie") - ) + self.search("Ivisible emmie", FirstMatch(title="Invisible Emmie")) def test_karamazov(self): # Extremely uncommon proper noun, slightly misspelled - self.search( - "Brothers karamzov", - FirstMatch(title="The Brothers Karamazov") - ) + self.search("Brothers karamzov", FirstMatch(title="The Brothers Karamazov")) def test_restless_wave(self): # One common word in the title is slightly misspelled. - self.search( - "He restless wave", - FirstMatch(title="The Restless Wave") - ) + self.search("He restless wave", FirstMatch(title="The Restless Wave")) def test_kingdom_of_the_blind(self): # The first word, which is a fairly common word, is slightly misspelled. - self.search( - "Kngdom of the blind", - FirstMatch(title="Kingdom of the Blind") - ) + self.search("Kngdom of the blind", FirstMatch(title="Kingdom of the Blind")) def test_seven_husbands(self): # Two words--1) a common word which is spelled as a different word # ("if" instead of "of"), and 2) a proper noun--are misspelled. self.search( "The seven husbands if evyln hugo", - FirstMatch(title="The Seven Husbands of Evelyn Hugo") + FirstMatch(title="The Seven Husbands of Evelyn Hugo"), ) def test_nightingale(self): @@ -913,68 +882,48 @@ def test_nightingale(self): # # This might fail because a book by Florence Nightingale is # seen as a better match. - self.search( - "The nightenale", - FirstMatch(title="The Nightingale") - ) + self.search("The nightenale", FirstMatch(title="The Nightingale")) @known_to_fail def test_memoirs_geisha(self): # The desired work shows up on the first page, but it should # be first. - self.search( - "Memoire of a ghesia", - FirstMatch(title="Memoirs of a Geisha") - ) + self.search("Memoire of a ghesia", FirstMatch(title="Memoirs of a Geisha")) def test_healthyish(self): # Misspelling of the title, which is a neologism. - self.search( - "healtylish", FirstMatch(title="Healthyish") - ) + self.search("healtylish", FirstMatch(title="Healthyish")) def test_zodiac(self): # Uncommon word, slightly misspelled. - self.search( - "Zodiaf", FirstMatch(title="Zodiac") - ) + self.search("Zodiaf", FirstMatch(title="Zodiac")) def test_for_whom_the_bell_tolls(self): # A relatively common word is spelled as a different, more common word. self.search( - "For whom the bell tools", - FirstMatch(title="For Whom the Bell Tolls") + "For whom the bell tools", FirstMatch(title="For Whom the Bell Tolls") ) @known_to_fail def test_came_to_baghdad(self): # An extremely common word is spelled as a different word. - self.search( - "They cane to baghdad", - FirstMatch(title="They Came To Baghdad") - ) + self.search("They cane to baghdad", FirstMatch(title="They Came To Baghdad")) def test_genghis_khan(self): - self.search( - "Ghangiz Khan", - AtLeastOne(title=re.compile("Genghis Khan", re.I)) - ) + self.search("Ghangiz Khan", AtLeastOne(title=re.compile("Genghis Khan", re.I))) def test_guernsey(self): # One word, which is a place name, is misspelled. self.search( "The gurnsey literary and potato peel society", - FirstMatch(title="The Guernsey Literary & Potato Peel Society") + FirstMatch(title="The Guernsey Literary & Potato Peel Society"), ) def test_british_spelling_color_of_our_sky(self): # Note to pedants: the title of the book as published is # "The Color of Our Sky". - self.search( - "The colour of our sky", - FirstMatch(title="The Color of Our Sky") - ) + self.search("The colour of our sky", FirstMatch(title="The Color of Our Sky")) class TestPartialTitleSearch(SearchTest): @@ -989,29 +938,22 @@ def test_i_funnyest(self): def test_future_home(self): # The search query only contains half of the title. - self.search( - "Future home of", - FirstMatch(title="Future Home Of the Living God") - ) + self.search("Future home of", FirstMatch(title="Future Home Of the Living God")) def test_fundamentals_of_supervision(self): # A word from the middle of the title is missing. self.search( "fundamentals of supervision", - FirstMatch(title="Fundamentals of Library Supervision") + FirstMatch(title="Fundamentals of Library Supervision"), ) def test_hurin(self): # A single word is so unusual that it can identify the book # we're looking for. - for query in ( - "Hurin", "Húrin" - ): + for query in ("Hurin", "Húrin"): self.search( query, - FirstMatch( - title="The Children of Húrin", author=re.compile("tolkien") - ) + FirstMatch(title="The Children of Húrin", author=re.compile("tolkien")), ) @known_to_fail @@ -1025,46 +967,34 @@ def test_open_wide(self): "Open wide a radical", FirstMatch( title="Open Wide", - subtitle="a radically real guide to deep love, rocking relationships, and soulful sex" - ) + subtitle="a radically real guide to deep love, rocking relationships, and soulful sex", + ), ) def test_how_to_win_friends(self): # The search query only contains half of the title. self.search( "How to win friends", - FirstMatch(title="How to Win Friends and Influence People") + FirstMatch(title="How to Win Friends and Influence People"), ) def test_wash_your_face_1(self): # The search query is missing the last word of the title. - self.search( - "Girl wash your", - FirstMatch(title="Girl, Wash Your Face") - ) + self.search("Girl wash your", FirstMatch(title="Girl, Wash Your Face")) def test_wash_your_face_2(self): # The search query is missing the first word of the title. - self.search( - "Wash your face", - FirstMatch(title="Girl, Wash Your Face") - ) + self.search("Wash your face", FirstMatch(title="Girl, Wash Your Face")) def test_theresa(self): # The search results correctly prioritize books with titles containing # "Theresa" over books by authors with the first name "Theresa." - self.search( - "Theresa", - FirstMatch(title=re.compile("Theresa", re.I)) - ) + self.search("Theresa", FirstMatch(title=re.compile("Theresa", re.I))) def test_prime_of_miss_jean_brodie(self): - # The search query only has the first and last words from the title, and - # the last word is misspelled. - self.search( - "Prime brody", - FirstMatch(title="The Prime of Miss Jean Brodie") - ) + # The search query only has the first and last words from the title, and + # the last word is misspelled. + self.search("Prime brody", FirstMatch(title="The Prime of Miss Jean Brodie")) class TestTitleGenreConflict(SearchTest): @@ -1075,22 +1005,17 @@ class TestTitleGenreConflict(SearchTest): def test_drama(self): # The title of the book is the name of a genre, and another # genre has been added to the search term to clarify it. - self.search( - "drama comic", - FirstMatch(title="Drama", author="Raina Telgemeier") - ) + self.search("drama comic", FirstMatch(title="Drama", author="Raina Telgemeier")) def test_title_match_with_genre_name_romance(self): # The title contains the name of a genre. Despite this, # an exact title match should show up first. - self.search( - "modern romance", FirstMatch(title="Modern Romance") - ) + self.search("modern romance", FirstMatch(title="Modern Romance")) def test_modern_romance_with_author(self): self.search( "modern romance aziz ansari", - FirstMatch(title="Modern Romance", author="Aziz Ansari") + FirstMatch(title="Modern Romance", author="Aziz Ansari"), ) def test_partial_title_match_with_genre_name_education(self): @@ -1101,8 +1026,7 @@ def test_partial_title_match_with_genre_name_education(self): def test_title_match_with_genre_name_law(self): self.search( - "law of the mountain man", - FirstMatch(title="Law of the Mountain Man") + "law of the mountain man", FirstMatch(title="Law of the Mountain Man") ) @known_to_fail @@ -1114,24 +1038,18 @@ def test_law_of_the_mountain_man_with_author(self): [ FirstMatch(title="Law of the Mountain Man"), Common(author="William Johnstone"), - ] + ], ) def test_spy(self): - self.search( - "spying on whales", - FirstMatch(title="Spying on Whales") - ) + self.search("spying on whales", FirstMatch(title="Spying on Whales")) def test_dance(self): # This works because of the stopword index. # # Otherwise "Dance of the Dragons" looks like an equally good # result. - self.search( - "dance with dragons", - FirstMatch(title="A Dance With Dragons") - ) + self.search("dance with dragons", FirstMatch(title="A Dance With Dragons")) class TestTitleAuthorConflict(SearchTest): @@ -1141,34 +1059,22 @@ class TestTitleAuthorConflict(SearchTest): def test_lord_jim(self): # The book "Lord Jim" is correctly prioritized over books whose authors' names # contain "Lord" or "Jim." - self.search( - "Lord Jim", - FirstMatch(title="Lord Jim") - ) + self.search("Lord Jim", FirstMatch(title="Lord Jim")) def test_wilder(self): # The book "Wilder" is correctly prioritized over books by authors with the # last name "Wilder." - self.search( - "Wilder", - FirstMatch(title="Wilder") - ) + self.search("Wilder", FirstMatch(title="Wilder")) def test_alice(self): # The book "Alice" is correctly prioritized over books by authors with the # first name "Alice." - self.search( - "Alice", - FirstMatch(title="Alice") - ) + self.search("Alice", FirstMatch(title="Alice")) def test_alex_and_eliza(self): # The book "Alex and Eliza" is correctly prioritized over books by authors with the # first names "Alex" or "Eliza." - self.search( - "Alex and Eliza", - FirstMatch(title="Alex and Eliza") - ) + self.search("Alex and Eliza", FirstMatch(title="Alex and Eliza")) def test_disney(self): # The majority of the search results will be about Walt Disney and/or the @@ -1180,9 +1086,11 @@ def test_disney(self): # It's an unusual situation so I think this is all right. self.search( "disney", - [ Common(title=re.compile("disney"), first_must_match=False), - AtLeastOne(title=re.compile("walt disney")), - AtLeastOne(author="Disney Book Group") ] + [ + Common(title=re.compile("disney"), first_must_match=False), + AtLeastOne(title=re.compile("walt disney")), + AtLeastOne(author="Disney Book Group"), + ], ) def test_bridge(self): @@ -1190,8 +1098,7 @@ def test_bridge(self): # title over books by authors whose names contain "Luis" or # "Rey." self.search( - "the bridge of san luis rey", - FirstMatch(title="The Bridge of San Luis Rey") + "the bridge of san luis rey", FirstMatch(title="The Bridge of San Luis Rey") ) @@ -1200,26 +1107,18 @@ class TestTitleAudienceConflict(SearchTest): # name of an audience or target age. def test_title_match_with_audience_name_children(self): - self.search( - "Children blood", - FirstMatch(title="Children of Blood and Bone") - ) + self.search("Children blood", FirstMatch(title="Children of Blood and Bone")) def test_title_match_with_audience_name_kids(self): - self.search( - "just kids", - FirstMatch(title="Just Kids") - ) + self.search("just kids", FirstMatch(title="Just Kids")) def test_tales_of_a_fourth_grade_nothing(self): self.search( - "fourth grade nothing", - FirstMatch(title="Tales of a Fourth Grade Nothing") + "fourth grade nothing", FirstMatch(title="Tales of a Fourth Grade Nothing") ) class TestMixedTitleAuthorMatch(SearchTest): - @known_to_fail def test_centos_caen(self): # 'centos' shows up in the subtitle. 'caen' is the name @@ -1227,37 +1126,29 @@ def test_centos_caen(self): # # NOTE: The work we're looking for shows up on the first page # but it can't beat out title matches like "CentOS Bible" - self.search( - "centos caen", - FirstMatch(title="fedora linux toolbox") - ) + self.search("centos caen", FirstMatch(title="fedora linux toolbox")) def test_fallen_baldacci(self): self.search( - "fallen baldacci", - FirstMatch(author="David Baldacci", title="The Fallen") + "fallen baldacci", FirstMatch(author="David Baldacci", title="The Fallen") ) def test_dragons(self): # Full title, full but misspelled author self.search( "Michael conolley Nine Dragons", - FirstMatch(title="Nine Dragons", author="Michael Connelly") + FirstMatch(title="Nine Dragons", author="Michael Connelly"), ) def test_dostoyevsky(self): # Full title, partial author self.search( - "Crime and punishment Dostoyevsky", - FirstMatch(title="Crime and Punishment") + "Crime and punishment Dostoyevsky", FirstMatch(title="Crime and Punishment") ) def test_dostoyevsky_partial_title(self): # Partial title, partial author - self.search( - "punishment Dostoyevsky", - FirstMatch(title="Crime and Punishment") - ) + self.search("punishment Dostoyevsky", FirstMatch(title="Crime and Punishment")) @known_to_fail def test_sparks(self): @@ -1267,38 +1158,37 @@ def test_sparks(self): # Breath" self.search( "Every breath by nicholis sparks", - FirstMatch(title="Every Breath", author="Nicholas Sparks") + FirstMatch(title="Every Breath", author="Nicholas Sparks"), ) def test_grisham(self): # Full title, author name misspelled self.search( "The reckoning john grisham", - FirstMatch(title="The Reckoning", author="John Grisham") + FirstMatch(title="The Reckoning", author="John Grisham"), ) def test_singh(self): self.search( "Nalini singh archangel", - [ Common(author="Nalini Singh", threshold=0.9), - Common(title=re.compile("archangel")) ] + [ + Common(author="Nalini Singh", threshold=0.9), + Common(title=re.compile("archangel")), + ], ) def test_sebald_1(self): # This title isn't in the collection, but the author's other # books should still come up. self.search( - "Sebald after", - SpecificAuthor("W. G. Sebald", accept_title="Sebald") + "Sebald after", SpecificAuthor("W. G. Sebald", accept_title="Sebald") ) def test_sebald_2(self): # Specifying the full title gets rid of the book about # this author, probably because "Nature" is the name of a genre. - self.search( - "Sebald after nature", - SpecificAuthor("W. G. Sebald") - ) + self.search("Sebald after nature", SpecificAuthor("W. G. Sebald")) + # Classes that test many different variant searches for a specific # title. @@ -1382,17 +1272,13 @@ class TestSubtitleMatch(SearchTest): def test_shame_stereotypes(self): # "Sister Citizen" has both search terms in its # subtitle. - self.search( - "shame stereotypes", FirstMatch(title="Sister Citizen") - ) + self.search("shame stereotypes", FirstMatch(title="Sister Citizen")) def test_garden_wiser(self): - self.search( - "garden wiser", FirstMatch(title="Gardening for a Lifetime") - ) + self.search("garden wiser", FirstMatch(title="Gardening for a Lifetime")) -class TestAuthorMatch(SearchTest): +class TestAuthorMatch(SearchTest): def test_kelly_link(self): # There is one obvious right answer. self.search("kelly link", SpecificAuthor("Kelly Link")) @@ -1404,8 +1290,10 @@ def test_stephen_king(self): # majority of search results should be books _by_ this author. self.search( "stephen king", - [ SpecificAuthor("Stephen King", accept_title="Stephen King"), - Common(author="Stephen King", threshold=0.7) ] + [ + SpecificAuthor("Stephen King", accept_title="Stephen King"), + Common(author="Stephen King", threshold=0.7), + ], ) def test_fleming(self): @@ -1413,8 +1301,10 @@ def test_fleming(self): # results, but the overwhelming majority of the results should be books by him. self.search( "ian fleming", - [ SpecificAuthor("Ian Fleming", accept_title="Ian Fleming"), - Common(author="Ian Fleming", threshold=0.9) ] + [ + SpecificAuthor("Ian Fleming", accept_title="Ian Fleming"), + Common(author="Ian Fleming", threshold=0.9), + ], ) def test_plato(self): @@ -1422,8 +1312,7 @@ def test_plato(self): # but there should also be some _by_ him. self.search( "plato", - [ SpecificAuthor("Plato", accept_title="Plato"), - AtLeastOne(author="Plato") ] + [SpecificAuthor("Plato", accept_title="Plato"), AtLeastOne(author="Plato")], ) def test_byron(self): @@ -1432,10 +1321,11 @@ def test_byron(self): # # TODO: Books about Byron are consistently prioritized above books by him. self.search( - "Byron", [ + "Byron", + [ AtLeastOne(title=re.compile("byron"), genre=re.compile("biography")), - AtLeastOne(author=re.compile("byron")) - ] + AtLeastOne(author=re.compile("byron")), + ], ) def test_hemingway(self): @@ -1444,107 +1334,81 @@ def test_hemingway(self): # The majority of the search results should be _by_ this author, # but there should also be at least one _about_ him. self.search( - "Hemingway", [ - AtLeastOne(title=re.compile("hemingway"), genre=re.compile("biography")), - AtLeastOne(author="Ernest Hemingway") - ] + "Hemingway", + [ + AtLeastOne( + title=re.compile("hemingway"), genre=re.compile("biography") + ), + AtLeastOne(author="Ernest Hemingway"), + ], ) def test_lagercrantz(self): # The search query contains only the author's last name. # There are several people with this name, and there's no # information that would let us prefer one over the other. - self.search( - "Lagercrantz", SpecificAuthor(re.compile("Lagercrantz")) - ) + self.search("Lagercrantz", SpecificAuthor(re.compile("Lagercrantz"))) def test_burger(self): # The author is correctly prioritized above books whose titles contain # the word "burger." - self.search( - "wolfgang burger", SpecificAuthor("Wolfgang Burger") - ) + self.search("wolfgang burger", SpecificAuthor("Wolfgang Burger")) def test_chase(self): # The author is correctly prioritized above the book "Emma." - self.search( - "Emma chase", SpecificAuthor("Emma Chase") - ) + self.search("Emma chase", SpecificAuthor("Emma Chase")) @known_to_fail def test_deirdre_martin(self): # The author's first name is misspelled in the search query. # # The search results are books about characters named Diedre. - self.search( - "deidre martin", SpecificAuthor("Deirdre Martin") - ) + self.search("deidre martin", SpecificAuthor("Deirdre Martin")) def test_wharton(self): self.search( "edith wharton", - SpecificAuthor("Edith Wharton", accept_title="Edith Wharton") + SpecificAuthor("Edith Wharton", accept_title="Edith Wharton"), ) def test_wharton_misspelled(self): # The author's last name is misspelled in the search query. - self.search( - "edith warton", Common(author="Edith Wharton") - ) + self.search("edith warton", Common(author="Edith Wharton")) def test_danielle_steel(self): # The author's last name is slightly misspelled in the search query. - self.search( - "danielle steele", - SpecificAuthor("Danielle Steel", threshold=1) - ) + self.search("danielle steele", SpecificAuthor("Danielle Steel", threshold=1)) def test_primary_author_with_coauthors(self): # This person is sometimes credited as primary author with # other authors, and sometimes as just a regular co-author. - self.search( - "steven peterman", - SpecificAuthor("Steven Peterman") - ) + self.search("steven peterman", SpecificAuthor("Steven Peterman")) def test_primary_author_with_coauthors_2(self): - self.search( - "jack cohen", - SpecificAuthor("Jack Cohen") - ) + self.search("jack cohen", SpecificAuthor("Jack Cohen")) def test_only_as_coauthor(self): # This person is inevitably credited co-equal with another # author. - self.search( - "stan berenstain", - SpecificAuthor("Stan Berenstain") - ) + self.search("stan berenstain", SpecificAuthor("Stan Berenstain")) def test_narrator(self): # This person is narrator for a lot of Stephen King # audiobooks. Searching for their name may bring up people # with similar names and authorship roles, but they'll show up # pretty frequently. - self.search( - "will patton", - Common(author="Will Patton") - ) + self.search("will patton", Common(author="Will Patton")) def test_unknown_display_name(self): # In NYPL's dataset, we know the sort name for this author but # not the display name. - self.search( - "emma craigie", - SpecificAuthor("Craigie, Emma") - ) + self.search("emma craigie", SpecificAuthor("Craigie, Emma")) def test_nabokov_misspelled(self): # Only the last name is provided in the search query, # and it's misspelled. self.search( - "Nabokof", - SpecificAuthor("Vladimir Nabokov", accept_title="Nabokov") + "Nabokof", SpecificAuthor("Vladimir Nabokov", accept_title="Nabokov") ) def test_ba_paris(self): @@ -1553,22 +1417,15 @@ def test_ba_paris(self): # NOTE: These results are always very good, but sometimes the # first result is a title match with stopword removed: # "Escalier B, Paris 12". - self.search( - "b a paris", SpecificAuthor("B. A. Paris") - ) + self.search("b a paris", SpecificAuthor("B. A. Paris")) def test_griffiths(self): # The search query gives the author's sort name. - self.search( - "Griffiths elly", SpecificAuthor("Elly Griffiths") - ) + self.search("Griffiths elly", SpecificAuthor("Elly Griffiths")) def test_christian_kracht(self): # The author's name contains a genre name. - self.search( - "christian kracht", - FirstMatch(author="Christian Kracht") - ) + self.search("christian kracht", FirstMatch(author="Christian Kracht")) def test_dan_gutman(self): self.search("gutman, dan", Common(author="Dan Gutman")) @@ -1576,9 +1433,7 @@ def test_dan_gutman(self): def test_dan_gutman_with_series(self): self.search( "gutman, dan the weird school", - SpecificSeries( - series="My Weird School", author="Dan Gutman" - ) + SpecificSeries(series="My Weird School", author="Dan Gutman"), ) def test_steve_berry(self): @@ -1591,23 +1446,22 @@ def test_steve_berry(self): def test_thomas_python(self): # All the terms are correctly spelled words, but the patron # clearly means something else. - self.search( - "thomas python", - Common(author="Thomas Pynchon") - ) + self.search("thomas python", Common(author="Thomas Pynchon")) def test_betty_neels_audiobooks(self): # Even though there are no audiobooks, all of the search # results should still be books by this author. self.search( "Betty neels audiobooks", - Common(author="Betty Neels", genre="romance", threshold=1) + Common(author="Betty Neels", genre="romance", threshold=1), ) + # Classes that test many different variant searches for a specific # author. # + class TestTimothyZahn(VariantSearchTest): # Test ways of searching for author Timothy Zahn. EVALUATOR = SpecificAuthor("Timothy Zahn") @@ -1627,17 +1481,18 @@ class TestRainaTelgemeier(VariantSearchTest): EVALUATOR = SpecificAuthor("Raina Telgemeier") def test_correct_spelling(self): - self.search('raina telgemeier') + self.search("raina telgemeier") def test_minor_misspelling(self): - self.search('raina telegmeier') + self.search("raina telegmeier") @known_to_fail def test_misspelling_1(self): - self.search('raina telemger') + self.search("raina telemger") def test_misspelling_2(self): - self.search('raina telgemerier') + self.search("raina telgemerier") + class TestHenningMankell(VariantSearchTest): # A few tests of searches for author Henning Mankell @@ -1707,28 +1562,21 @@ class TestPublisherMatch(SearchTest): # imprint. def test_harlequin_romance(self): - self.search( - "harlequin romance", Common(publisher="harlequin", genre="Romance") - ) + self.search("harlequin romance", Common(publisher="harlequin", genre="Romance")) def test_harlequin_historical(self): self.search( "harlequin historical", # We may get some "harlequin historical classic", which is fine. - Common(imprint=re.compile("harlequin historical"), genre="Romance") + Common(imprint=re.compile("harlequin historical"), genre="Romance"), ) def test_princeton_review(self): - self.search( - "princeton review", - Common(imprint="princeton review") - ) + self.search("princeton review", Common(imprint="princeton review")) @known_to_fail def test_wizards(self): - self.search( - "wizards coast", Common(publisher="wizards of the coast") - ) + self.search("wizards coast", Common(publisher="wizards of the coast")) # We don't want to boost publisher/imprint matches _too_ highly # because publishers and imprints are often single words that @@ -1739,15 +1587,16 @@ def test_penguin(self): # matches in other fields over exact imprint matches. self.search( "penguin", - [Common(title=re.compile("penguin", re.I)), - Uncommon(imprint="Penguin")] + [Common(title=re.compile("penguin", re.I)), Uncommon(imprint="Penguin")], ) def test_vintage(self): self.search( "vintage", - [Common(title=re.compile("vintage", re.I)), - Uncommon(imprint="Vintage", threshold=0.5)] + [ + Common(title=re.compile("vintage", re.I)), + Uncommon(imprint="Vintage", threshold=0.5), + ], ) def test_plympton(self): @@ -1756,8 +1605,10 @@ def test_plympton(self): # publisher. self.search( "plympton", - [Common(author=re.compile("plimpton", re.I)), - Uncommon(publisher="Plympton")] + [ + Common(author=re.compile("plimpton", re.I)), + Uncommon(publisher="Plympton"), + ], ) @known_to_fail @@ -1769,9 +1620,7 @@ def test_scholastic(self): # it's tough to know that "scholastic" is probably a publisher # search, where "penguin" is probably a topic search and # "plympton" is probably a misspelled author search. - self.search( - "scholastic", Common(publisher="scholastic inc.") - ) + self.search("scholastic", Common(publisher="scholastic inc.")) class TestGenreMatch(SearchTest): @@ -1803,7 +1652,7 @@ def test_iain_banks_sf(self): self.search( # Genre and author "iain banks science fiction", - Common(genre=self.any_sf, author="Iain M. Banks") + Common(genre=self.any_sf, author="Iain M. Banks"), ) @known_to_fail @@ -1812,14 +1661,12 @@ def test_christian(self): # classified under other genres. self.search( "christian", - Common(genre=re.compile("(christian|religion)"), - first_must_match=False) + Common(genre=re.compile("(christian|religion)"), first_must_match=False), ) def test_christian_authors(self): self.search( - "christian authors", - Common(genre=re.compile("(christian|religion)")) + "christian authors", Common(genre=re.compile("(christian|religion)")) ) @known_to_fail @@ -1831,7 +1678,7 @@ def test_christian_lust(self): # so bad. self.search( "lust christian", - Common(genre=re.compile("(christian|religion|religious fiction)")) + Common(genre=re.compile("(christian|religion|religious fiction)")), ) @known_to_fail @@ -1842,10 +1689,8 @@ def test_christian_fiction(self): "christian fiction", [ Common(fiction="fiction"), - Common(genre=re.compile( - "(christian|religion|religious fiction)") - ) - ] + Common(genre=re.compile("(christian|religion|religious fiction)")), + ], ) @known_to_fail @@ -1853,16 +1698,10 @@ def test_graphic_novel(self): # NOTE: This fails for a spurious reason. Many of the results # have "Graphic Novel" in the title but are not classified as # such. - self.search( - "Graphic novel", - Common(genre="Comics & Graphic Novels") - ) + self.search("Graphic novel", Common(genre="Comics & Graphic Novels")) def test_horror(self): - self.search( - "Best horror story", - Common(genre=re.compile("horror")) - ) + self.search("Best horror story", Common(genre=re.compile("horror"))) @known_to_fail def test_scary_stories(self): @@ -1879,23 +1718,24 @@ def test_percy_jackson_graphic_novel(self): self.search( "Percy jackson graphic novel", - [Common(author="Rick Riordan"), - AtLeastOne(genre="Comics & Graphic Novels", author="Rick Riordan") - ] + [ + Common(author="Rick Riordan"), + AtLeastOne(genre="Comics & Graphic Novels", author="Rick Riordan"), + ], ) - def test_gossip_girl_manga(self): # A "Gossip Girl" manga series does exist, but it's not in # NYPL's collection. Instead, the results should focus on # the "Gossip Girl" series. self.search( - "Gossip girl Manga", [ + "Gossip girl Manga", + [ SpecificSeries( series="Gossip Girl", author=re.compile("cecily von ziegesar"), ), - ] + ], ) @known_to_fail @@ -1907,7 +1747,7 @@ def test_clique(self): # Genre and title self.search( "The clique graphic novel", - Common(genre="Comics & Graphic Novels", title="The Clique") + Common(genre="Comics & Graphic Novels", title="The Clique"), ) def test_spy(self): @@ -1915,24 +1755,18 @@ def test_spy(self): # fine, since people don't really think of "Spy" as a genre, # and people who do type in "spy" looking for spy books will # find them. - self.search( - "Spy", - Common(title=re.compile("(spy|spies)", re.I)) - ) + self.search("Spy", Common(title=re.compile("(spy|spies)", re.I))) def test_espionage(self): self.search( "Espionage", Common( genre=re.compile("(espionage|history|crime|thriller)"), - ) + ), ) def test_food(self): - self.search( - "food", - Common(genre=re.compile("(cook|diet)")) - ) + self.search("food", Common(genre=re.compile("(cook|diet)"))) def test_mystery(self): self.search("mystery", Common(genre="Mystery")) @@ -1942,41 +1776,38 @@ def test_agatha_christie_mystery(self): # Agatha Christie. self.search( "agatha christie mystery", - [ SpecificGenre(genre="Mystery", author="Agatha Christie"), - Common(author="Agatha Christie", threshold=1) ] + [ + SpecificGenre(genre="Mystery", author="Agatha Christie"), + Common(author="Agatha Christie", threshold=1), + ], ) def test_british_mystery(self): # Genre and keyword self.search( "British mysteries", - Common(genre="Mystery", summary=re.compile("british|london|england|scotland")) + Common( + genre="Mystery", summary=re.compile("british|london|england|scotland") + ), ) def test_finance(self): # Keyword self.search( "Finance", - Common( - genre=re.compile("(business|finance)"), first_must_match=False - ) + Common(genre=re.compile("(business|finance)"), first_must_match=False), ) def test_constitution(self): # Keyword self.search( "Constitution", - Common( - genre=re.compile("(politic|history)"), first_must_match=False - ) + Common(genre=re.compile("(politic|history)"), first_must_match=False), ) def test_deep_poems(self): # This appears to be a search for poems which are deep. - self.search( - "deep poems", - Common(genre="Poetry") - ) + self.search("deep poems", Common(genre="Poetry")) class TestSubjectMatch(SearchTest): @@ -1987,8 +1818,8 @@ def test_alien_misspelled(self): "allien", Common( subject=re.compile("(alien|extraterrestrial|science fiction)"), - first_must_match=False - ) + first_must_match=False, + ), ) def test_alien_misspelled_2(self): @@ -1996,8 +1827,8 @@ def test_alien_misspelled_2(self): "aluens", Common( subject=re.compile("(alien|extraterrestrial|science fiction)"), - first_must_match=False - ) + first_must_match=False, + ), ) @known_to_fail @@ -2008,10 +1839,7 @@ def test_anime_genre(self): # # So we get a few title matches for "Anime" and then go into # books about animals. - self.search( - "anime", - Common(subject=re.compile("(manga|anime)")) - ) + self.search("anime", Common(subject=re.compile("(manga|anime)"))) def test_astrophysics(self): # Keyword @@ -2019,17 +1847,14 @@ def test_astrophysics(self): "Astrophysics", Common( genre="Science", - subject=re.compile("(astrophysics|astronomy|physics|space|science)") - ) + subject=re.compile("(astrophysics|astronomy|physics|space|science)"), + ), ) def test_anxiety(self): self.search( "anxiety", - Common( - genre=re.compile("(psychology|self-help)"), - first_must_match=False - ) + Common(genre=re.compile("(psychology|self-help)"), first_must_match=False), ) def test_beauty_hacks(self): @@ -2037,8 +1862,10 @@ def test_beauty_hacks(self): # type of book; ideally, the search results would return at least one relevant # one. Instead, all of the top results are either books about computer hacking # or romance novels. - self.search("beauty hacks", - AtLeastOne(subject=re.compile("(self-help|style|grooming|personal)"))) + self.search( + "beauty hacks", + AtLeastOne(subject=re.compile("(self-help|style|grooming|personal)")), + ) def test_character_classification(self): # Although we check a book's description, it's very difficult @@ -2049,16 +1876,13 @@ def test_character_classification(self): # one word of overlap with the subject matter classification. self.search( "Gastner, Sheriff (Fictitious character)", - SpecificSeries(series="Bill Gastner Mystery") + SpecificSeries(series="Bill Gastner Mystery"), ) def test_college_essay(self): self.search( "College essay", - Common( - genre=re.compile("study aids"), - subject=re.compile("college") - ) + Common(genre=re.compile("study aids"), subject=re.compile("college")), ) @known_to_fail @@ -2070,7 +1894,7 @@ def test_da_vinci(self): # "Da Vinci Code" territory. Maybe that's fine, though. self.search( "Da Vinci", - Common(genre=re.compile("(biography|art)"), first_must_match=False) + Common(genre=re.compile("(biography|art)"), first_must_match=False), ) @known_to_fail @@ -2082,8 +1906,8 @@ def test_da_vinci_missing_space(self): Common( genre=re.compile("(biography|art)"), first_must_match=False, - threshold=0.3 - ) + threshold=0.3, + ), ) @known_to_fail @@ -2092,10 +1916,7 @@ def test_dirtbike(self): # (two words) renders more relevant results, but still not # enough for the test to pass. self.search( - "dirtbike", - Common( - subject=re.compile("(bik|bicycle|sports|nature|travel)") - ) + "dirtbike", Common(subject=re.compile("(bik|bicycle|sports|nature|travel)")) ) def test_greek_romance(self): @@ -2103,19 +1924,17 @@ def test_greek_romance(self): # something like "Essays on the Greek Romances." self.search( "Greek romance", - [Common(genre="Romance", first_must_match=False), - AtLeastOne(title=re.compile("greek romance"))] + [ + Common(genre="Romance", first_must_match=False), + AtLeastOne(title=re.compile("greek romance")), + ], ) def test_ice_cream(self): # There are a lot of books about making ice cream. The search results # correctly present those before looking for non-cooking "artisan" books. self.search( - "Artisan ice cream", - Common( - genre=re.compile("cook"), - threshold=0.9 - ) + "Artisan ice cream", Common(genre=re.compile("cook"), threshold=0.9) ) def test_information_technology(self): @@ -2124,25 +1943,19 @@ def test_information_technology(self): "information technology", Common( subject=re.compile("(information technology|computer)"), - first_must_match=False - ) + first_must_match=False, + ), ) def test_louis_xiii(self): # There aren't very many books in the collection about Louis # XIII, but he is the basis for the king in "The Three # Musketeers", so that's not a bad answer. - self.search( - "Louis xiii", - AtLeastOne(title="The Three Musketeers") - ) + self.search("Louis xiii", AtLeastOne(title="The Three Musketeers")) def test_managerial_skills(self): self.search( - "managerial skills", - Common( - subject=re.compile("(business|management)") - ) + "managerial skills", Common(subject=re.compile("(business|management)")) ) def test_manga(self): @@ -2153,74 +1966,60 @@ def test_manga(self): [ Common(title=re.compile("manga")), Common(subject=re.compile("(manga|art|comic)")), - ] + ], ) def test_meditation(self): - self.search( - "Meditation", - Common( - genre=re.compile("(self-help|mind|spirit)") - ) - ) + self.search("Meditation", Common(genre=re.compile("(self-help|mind|spirit)"))) def test_music_theory(self): # Keywords self.search( - "music theory", Common( - genre="Music", - subject=re.compile("(music theory|musical theory)") - ) + "music theory", + Common(genre="Music", subject=re.compile("(music theory|musical theory)")), ) def test_native_american(self): # Keyword self.search( - "Native american", [ + "Native american", + [ Common( genre=re.compile("history"), subject=re.compile("(america|u.s.)"), - first_must_match=False + first_must_match=False, ) - ] + ], ) def test_native_american_misspelled(self): # Keyword, misspelled self.search( - "Native amerixan", [ + "Native amerixan", + [ Common( genre=re.compile("history"), subject=re.compile("(america|u.s.)"), first_must_match=False, - threshold=0.4 + threshold=0.4, ) - ] + ], ) - def test_ninjas(self): self.search("ninjas", Common(title=re.compile("ninja"))) def test_ninjas_misspelled(self): # NOTE: The first result is "Ningyo", which does look a # lot like "Ningas"... - self.search( - "ningas", Common(title=re.compile("ninja"), first_must_match=False) - ) + self.search("ningas", Common(title=re.compile("ninja"), first_must_match=False)) def test_pattern_making(self): - self.search( - "Pattern making", - AtLeastOne(subject=re.compile("crafts")) - ) + self.search("Pattern making", AtLeastOne(subject=re.compile("crafts"))) def test_plant_based(self): self.search( - "Plant based", - Common( - subject=re.compile("(cooking|food|nutrition|health)") - ) + "Plant based", Common(subject=re.compile("(cooking|food|nutrition|health)")) ) def test_prank(self): @@ -2233,8 +2032,10 @@ def test_presentations(self): self.search( "presentations", Common( - subject=re.compile("(language arts|business presentations|business|management)") - ) + subject=re.compile( + "(language arts|business presentations|business|management)" + ) + ), ) def test_python_programming(self): @@ -2246,32 +2047,35 @@ def test_python_programming(self): # Most works will show up because of a title match -- verify that we're talking about # Python as a programming language. Common( - title=re.compile("python", re.I), subject=re.compile("(computer technology|programming)", re.I), threshold=0.8, - first_must_match=False + title=re.compile("python", re.I), + subject=re.compile("(computer technology|programming)", re.I), + threshold=0.8, + first_must_match=False, ) - ] + ], ) def test_sewing(self): self.search( "Sewing", - [ FirstMatch(title=re.compile("sewing")), - Common(title=re.compile("sewing")), - ] + [ + FirstMatch(title=re.compile("sewing")), + Common(title=re.compile("sewing")), + ], ) def test_supervising(self): # Keyword - self.search( - "supervising", Common(genre="Business", first_must_match=False) - ) + self.search("supervising", Common(genre="Business", first_must_match=False)) def test_tennis(self): # We will get sports books with "Tennis" in the title. self.search( "tennis", - Common(title=re.compile("Tennis", re.I), - genre=re.compile("(Sports|Games)", re.I)) + Common( + title=re.compile("Tennis", re.I), + genre=re.compile("(Sports|Games)", re.I), + ), ) @known_to_fail @@ -2283,15 +2087,12 @@ def test_texas_fair(self): # TODO: "books about" really skews the results here -- lots of # title matches. self.search( - "books about texas like the fair", - Common(title=re.compile("texas")) + "books about texas like the fair", Common(title=re.compile("texas")) ) def test_witches(self): - self.search( - "witches", - Common(subject=re.compile('witch')) - ) + self.search("witches", Common(subject=re.compile("witch"))) + class TestFuzzyConfounders(SearchTest): """Test searches on very distinct terms that are near each other in @@ -2302,48 +2103,56 @@ class TestFuzzyConfounders(SearchTest): def test_amulet(self): self.search( "amulet", - [Common(title_or_subtitle=re.compile("amulet")), - Uncommon(title_or_subtitle=re.compile("hamlet|harlem|tablet")) - ] + [ + Common(title_or_subtitle=re.compile("amulet")), + Uncommon(title_or_subtitle=re.compile("hamlet|harlem|tablet")), + ], ) def test_hamlet(self): self.search( "Hamlet", - [Common(title_or_subtitle="Hamlet"), - Uncommon(title_or_subtitle=re.compile("amulet|harlem|tablet")) - ] + [ + Common(title_or_subtitle="Hamlet"), + Uncommon(title_or_subtitle=re.compile("amulet|harlem|tablet")), + ], ) def test_harlem(self): self.search( "harlem", - [Common(title_or_subtitle=re.compile("harlem")), - Uncommon(title_or_subtitle=re.compile("amulet|hamlet|tablet")) - ] + [ + Common(title_or_subtitle=re.compile("harlem")), + Uncommon(title_or_subtitle=re.compile("amulet|hamlet|tablet")), + ], ) def test_tablet(self): self.search( "tablet", - [Common(title_or_subtitle=re.compile("tablet")), - Uncommon(title_or_subtitle=re.compile("amulet|hamlet|harlem")) - ] + [ + Common(title_or_subtitle=re.compile("tablet")), + Uncommon(title_or_subtitle=re.compile("amulet|hamlet|harlem")), + ], ) # baseball / basketball def test_baseball(self): self.search( "baseball", - [Common(title=re.compile("baseball")), - Uncommon(title=re.compile("basketball"))] + [ + Common(title=re.compile("baseball")), + Uncommon(title=re.compile("basketball")), + ], ) def test_basketball(self): self.search( "basketball", - [Common(title=re.compile("basketball")), - Uncommon(title=re.compile("baseball"))] + [ + Common(title=re.compile("basketball")), + Uncommon(title=re.compile("baseball")), + ], ) # car / war @@ -2352,15 +2161,15 @@ def test_car(self): "car", # There is a book called "Car Wars", so we can't # completely prohibit 'war' from showing up. - [Common(title=re.compile("car")), - Uncommon(title=re.compile("war"), threshold=0.1)] + [ + Common(title=re.compile("car")), + Uncommon(title=re.compile("war"), threshold=0.1), + ], ) def test_war(self): self.search( - "war", - [Common(title=re.compile("war")), - Uncommon(title=re.compile("car"))] + "war", [Common(title=re.compile("war")), Uncommon(title=re.compile("car"))] ) @@ -2385,21 +2194,14 @@ def test_toronto(self): class TestSeriesMatch(SearchTest): - @known_to_fail def test_dinosaur_cove(self): # NYPL's collection doesn't have any books in this series . - self.search( - "dinosaur cove", - SpecificSeries(series="Dinosaur Cove") - ) + self.search("dinosaur cove", SpecificSeries(series="Dinosaur Cove")) def test_poldi(self): # NYPL's collection only has one book from this series. - self.search( - "Auntie poldi", - FirstMatch(series="Auntie Poldi") - ) + self.search("Auntie poldi", FirstMatch(series="Auntie Poldi")) def test_39_clues(self): # We have many books in this series. @@ -2407,10 +2209,7 @@ def test_39_clues(self): def test_maggie_hope(self): # We have many books in this series. - self.search( - "Maggie hope", - SpecificSeries(series="Maggie Hope", threshold=0.9) - ) + self.search("Maggie hope", SpecificSeries(series="Maggie Hope", threshold=0.9)) def test_game_of_thrones(self): # People often search for the name of the TV show, but the @@ -2420,9 +2219,10 @@ def test_game_of_thrones(self): # find that. self.search( "game of thrones", - [Common(title=re.compile("Game of Thrones", re.I)), - AtLeastOne(series="a song of ice and fire") - ] + [ + Common(title=re.compile("Game of Thrones", re.I)), + AtLeastOne(series="a song of ice and fire"), + ], ) def test_harry_potter(self): @@ -2436,15 +2236,12 @@ def test_harry_potter(self): "Harry potter", SpecificSeries( series="Harry Potter", threshold=0.9, first_must_match=False - ) + ), ) def test_maisie_dobbs(self): # Misspelled proper noun - self.search( - "maise dobbs", - SpecificSeries(series="Maisie Dobbs", threshold=0.5) - ) + self.search("maise dobbs", SpecificSeries(series="Maisie Dobbs", threshold=0.5)) def test_gossip_girl(self): self.search( @@ -2468,33 +2265,33 @@ def test_gossip_girl_misspelled(self): def test_magic(self): # This book isn't in the collection, but the results include other books from # the same series. - self.search( - "Frogs and french kisses", - AtLeastOne(series="Magic in Manhattan") - ) + self.search("Frogs and french kisses", AtLeastOne(series="Magic in Manhattan")) def test_goosebumps(self): self.search( "goosebumps", SpecificSeries( - series="Goosebumps", author="R. L. Stine", - ) + series="Goosebumps", + author="R. L. Stine", + ), ) def test_goosebump_singular(self): self.search( "goosebump", SpecificSeries( - series="Goosebumps", author="R. L. Stine", - ) + series="Goosebumps", + author="R. L. Stine", + ), ) def test_goosebumps_misspelled(self): self.search( "goosebump", SpecificSeries( - series="Goosebumps", author="R. L. Stine", - ) + series="Goosebumps", + author="R. L. Stine", + ), ) def test_severance(self): @@ -2502,34 +2299,22 @@ def test_severance(self): # # Searching for 'severance' alone is going to get title # matches, which is as it should be. - self.search( - "severance trilogy", - AtLeastOne(series="The Severance Trilogy") - ) + self.search("severance trilogy", AtLeastOne(series="The Severance Trilogy")) def test_severance_misspelled(self): # Slightly misspelled - self.search( - "severence trilogy", - AtLeastOne(series="The Severance Trilogy") - ) + self.search("severence trilogy", AtLeastOne(series="The Severance Trilogy")) def test_hunger_games(self): - self.search( - "the hunger games", SpecificSeries(series="The Hunger Games") - ) + self.search("the hunger games", SpecificSeries(series="The Hunger Games")) def test_hunger_games_misspelled(self): - self.search( - "The hinger games", - SpecificSeries(series="The Hunger Games") - ) + self.search("The hinger games", SpecificSeries(series="The Hunger Games")) def test_mockingjay(self): self.search( "The hunger games mockingjay", - [FirstMatch(title="Mockingjay"), - SpecificSeries(series="The Hunger Games")] + [FirstMatch(title="Mockingjay"), SpecificSeries(series="The Hunger Games")], ) def test_i_funny(self): @@ -2546,20 +2331,22 @@ def test_foundation(self): "Isaac asimov foundation", [ FirstMatch(title="Foundation"), - SpecificSeries(series="Foundation", author="Isaac Asimov") - ] + SpecificSeries(series="Foundation", author="Isaac Asimov"), + ], ) def test_dark_tower(self): # There exist two completely unrelated books called "The Dark # Tower"--it's fine for one of those to be the first result. self.search( - "The dark tower", [ + "The dark tower", + [ SpecificSeries( series="The Dark Tower", - author="Stephen King", first_must_match=False + author="Stephen King", + first_must_match=False, ) - ] + ], ) def test_science_comics(self): @@ -2571,8 +2358,9 @@ def test_science_comics(self): # of two genres. self.search( "Science comics", - [FirstMatch(title=re.compile("^science comics")), - ] + [ + FirstMatch(title=re.compile("^science comics")), + ], ) def test_who_is(self): @@ -2591,8 +2379,7 @@ def test_who_was(self): def test_wimpy_kid_misspelled(self): # Series name contains the wrong stopword ('the' vs 'a') self.search( - "dairy of the wimpy kid", - SpecificSeries(series="Diary of a Wimpy Kid") + "dairy of the wimpy kid", SpecificSeries(series="Diary of a Wimpy Kid") ) @@ -2604,8 +2391,8 @@ def test_39_clues_specific_title(self): "39 clues maze of bones", [ FirstMatch(title="The Maze of Bones"), - SpecificSeries(series="the 39 clues") - ] + SpecificSeries(series="the 39 clues"), + ], ) def test_harry_potter_specific_title(self): @@ -2615,10 +2402,11 @@ def test_harry_potter_specific_title(self): # same series, but this doesn't happen much compared to other, # similar tests. We get more partial title matches. self.search( - "chamber of secrets", [ + "chamber of secrets", + [ FirstMatch(title="Harry Potter and the Chamber of Secrets"), - SpecificSeries(series="Harry Potter", threshold=0.2) - ] + SpecificSeries(series="Harry Potter", threshold=0.2), + ], ) @known_to_fail @@ -2632,10 +2420,8 @@ def test_wimpy_kid_specific_title(self): "dairy of the wimpy kid dog days", [ FirstMatch(title="Dog Days", author="Jeff Kinney"), - SpecificSeries( - series="Diary of a Wimpy Kid", author="Jeff Kinney" - ) - ] + SpecificSeries(series="Diary of a Wimpy Kid", author="Jeff Kinney"), + ], ) @known_to_fail @@ -2644,7 +2430,7 @@ def test_foundation_specific_title_by_number(self): # and we don't search it, so there's no way to make this work. self.search( "Isaac Asimov foundation book 1", - FirstMatch(series="Foundation", title="Foundation") + FirstMatch(series="Foundation", title="Foundation"), ) @known_to_fail @@ -2657,7 +2443,7 @@ def test_survivors_specific_title(self): [ Common(series="Survivors"), FirstMatch(title="The Empty City"), - ] + ], ) @@ -2667,7 +2453,7 @@ def test_survivors_specific_title(self): class TestISurvived(VariantSearchTest): # Test different ways of spelling "I Survived" # .series is not set for these books so we check the title. - EVALUATOR = Common(title=re.compile('^i survived ')) + EVALUATOR = Common(title=re.compile("^i survived ")) def test_correct_spelling(self): self.search("i survived") @@ -2691,35 +2477,35 @@ class TestDorkDiaries(VariantSearchTest): EVALUATOR = SpecificAuthor(re.compile("Rachel .* Russell", re.I)) def test_correct_spelling(self): - self.search('dork diaries') + self.search("dork diaries") def test_misspelling_and_number(self): self.search("dork diarys #11") @known_to_fail def test_misspelling_with_punctuation(self): - self.search('doke diaries.') + self.search("doke diaries.") def test_singular(self): self.search("dork diary") def test_misspelling_1(self): - self.search('dork diarys') + self.search("dork diarys") @known_to_fail def test_misspelling_2(self): - self.search('doke dirares') + self.search("doke dirares") @known_to_fail def test_misspelling_3(self): - self.search('doke dares') + self.search("doke dares") @known_to_fail def test_misspelling_4(self): - self.search('doke dires') + self.search("doke dires") def test_misspelling_5(self): - self.search('dork diareis') + self.search("dork diareis") class TestMyLittlePony(VariantSearchTest): @@ -2760,8 +2546,7 @@ def test_language_spanish(self): @known_to_fail def test_author_with_language(self): self.search( - "Pablo escobar spanish", - FirstMatch(author="Pablo Escobar", language="spa") + "Pablo escobar spanish", FirstMatch(author="Pablo Escobar", language="spa") ) def test_gatos(self): @@ -2769,10 +2554,7 @@ def test_gatos(self): # since that's where the word would be used. # # However, 'gatos' also shows up in English, e.g. in place names. - self.search( - "gatos", - Common(language="spa", threshold=0.7) - ) + self.search("gatos", Common(language="spa", threshold=0.7)) class TestAwardSearch(SearchTest): @@ -2788,29 +2570,22 @@ def test_hugo(self): Common(summary=re.compile("hugo award")), Uncommon(author="Victor Hugo"), Uncommon(series=re.compile("hugo")), - ] + ], ) def test_nebula(self): - self.search( - "nebula award", - Common(summary=re.compile("nebula award")) - ) + self.search("nebula award", Common(summary=re.compile("nebula award"))) def test_nebula_no_award(self): # This one does great -- the award is the most common # use of the word "nebula". - self.search( - "nebula", - Common(summary=re.compile("nebula award")) - ) + self.search("nebula", Common(summary=re.compile("nebula award"))) def test_world_fantasy(self): # This award contains the name of a genre. self.search( "world fantasy award", - Common(summary=re.compile("world fantasy award"), - first_must_match=False) + Common(summary=re.compile("world fantasy award"), first_must_match=False), ) @known_to_fail @@ -2819,25 +2594,23 @@ def test_tiptree_award(self): # books -- we want the award winners. self.search( "tiptree award", - [Common(summary=re.compile("tiptree award")), - Uncommon(author=re.compile("james tiptree"))], + [ + Common(summary=re.compile("tiptree award")), + Uncommon(author=re.compile("james tiptree")), + ], ) @known_to_fail def test_newberry(self): # Tends to get author matches. - self.search( - "newbery", - Common(summary=re.compile("newbery medal")) - ) + self.search("newbery", Common(summary=re.compile("newbery medal"))) @known_to_fail def test_man_booker(self): # This gets author and title matches. self.search( "man booker prize", - Common(summary=re.compile("man booker prize"), - first_must_match=False) + Common(summary=re.compile("man booker prize"), first_must_match=False), ) def test_award_winning(self): @@ -2849,7 +2622,7 @@ def test_award_winning(self): [ Common(summary=re.compile("award"), threshold=0.5), Uncommon(title=re.compile("award"), threshold=0.5), - ] + ], ) @known_to_fail @@ -2863,8 +2636,8 @@ def test_staff_picks(self): "staff picks", [ Uncommon(author=re.compile("(staff|picks)")), - Uncommon(title=re.compile("(staff|picks)")) - ] + Uncommon(title=re.compile("(staff|picks)")), + ], ) @@ -2877,7 +2650,7 @@ def test_3_little_pigs(self): [ AtLeastOne(title=re.compile("three little pigs")), Common(title=re.compile("pig")), - ] + ], ) @known_to_fail @@ -2889,54 +2662,36 @@ def test_3_little_pigs_more_precise(self): FirstMatch(title="Three Little Pigs"), ) - def test_batman(self): - self.search( - "batman book", - Common(title=re.compile("batman")) - ) + self.search("batman book", Common(title=re.compile("batman"))) @known_to_fail def test_batman_two_words(self): # Patron is searching for 'batman' but treats it as two words. - self.search( - "bat man book", - Common(title=re.compile("batman")) - ) + self.search("bat man book", Common(title=re.compile("batman"))) def test_christian_grey(self): # This search uses a character name to stand in for a series. self.search( - "christian grey", - FirstMatch(author=re.compile("E.\s*L.\s*James", re.I)) + "christian grey", FirstMatch(author=re.compile("E.\s*L.\s*James", re.I)) ) def test_spiderman_hyphenated(self): - self.search( - "spider-man", Common(title=re.compile("spider-man")) - ) + self.search("spider-man", Common(title=re.compile("spider-man"))) @known_to_fail def test_spiderman_one_word(self): # NOTE: There are some Spider-Man titles but not as many as # with the hyphen. - self.search( - "spiderman", Common(title=re.compile("spider-man")) - ) + self.search("spiderman", Common(title=re.compile("spider-man"))) @known_to_fail def test_spiderman_run_on(self): # NOTE: This gets no results at all. - self.search( - "spidermanbook", Common(title=re.compile("spider-man")) - ) - + self.search("spidermanbook", Common(title=re.compile("spider-man"))) def test_teen_titans(self): - self.search( - "teen titans", - Common(title=re.compile("^teen titans")), limit=5 - ) + self.search("teen titans", Common(title=re.compile("^teen titans")), limit=5) @known_to_fail def test_teen_titans_girls(self): @@ -2945,8 +2700,7 @@ def test_teen_titans_girls(self): # _similar_ results to 'teen titans' and not go off # on tangents because of the 'girls' part. self.search( - "teen titans girls", - Common(title=re.compile("^teen titans")), limit=5 + "teen titans girls", Common(title=re.compile("^teen titans")), limit=5 ) def test_thrawn(self): @@ -2959,10 +2713,11 @@ def test_thrawn(self): [ FirstMatch(title="Thrawn"), Common( - author="Timothy Zahn", series=re.compile("star wars", re.I), - threshold=0.9 + author="Timothy Zahn", + series=re.compile("star wars", re.I), + threshold=0.9, ), - ] + ], ) @@ -2972,11 +2727,11 @@ class TestAgeRangeRestriction(SearchTest): def all_children(self, q): # Verify that this search finds nothing but books for children. - self.search(q, Common(audience='Children', threshold=1)) + self.search(q, Common(audience="Children", threshold=1)) def mostly_adult(self, q): # Verify that this search finds mostly books for grown-ups. - self.search(q, Common(audience='Adult', first_must_match=False)) + self.search(q, Common(audience="Adult", first_must_match=False)) def test_black(self): self.all_children("black age 3-5") @@ -2994,26 +2749,20 @@ def test_panda(self): def test_chapter_books(self): # Chapter books are a book format aimed at a specific # age range. - self.search( - "chapter books", Common(target_age=(6, 10)) - ) + self.search("chapter books", Common(target_age=(6, 10))) def test_chapter_books_misspelled_1(self): # NOTE: We don't do fuzzy matching on things that would become # filter terms. When this works, it's because of fuzzy title # matches and description matches. - self.search( - "chapter bookd", Common(target_age=(6, 10)) - ) + self.search("chapter bookd", Common(target_age=(6, 10))) @known_to_fail def test_chapter_books_misspelled_2(self): # This fails for a similar reason as misspelled_1, though it # actually does a little better -- only the first result is # bad. - self.search( - "chaptr books", Common(target_age=(6, 10)) - ) + self.search("chaptr books", Common(target_age=(6, 10))) @known_to_fail def test_grade_and_subject(self): @@ -3022,10 +2771,7 @@ def test_grade_and_subject(self): # digits. self.search( "Seventh grade science", - [ - Common(target_age=(12, 13)), - Common(genre="Science") - ] + [Common(target_age=(12, 13)), Common(genre="Science")], ) @@ -3036,10 +2782,7 @@ def test_black_and_the_blue(self): # This is a real book title that is almost entirely stopwords. # Putting in a few words of the title will find that specific # title even if most of the words are stopwords. - self.search( - "the black and", - FirstMatch(title="The Black and the Blue") - ) + self.search("the black and", FirstMatch(title="The Black and the Blue")) @known_to_fail def test_the_real(self): @@ -3049,18 +2792,14 @@ def test_the_real(self): # NOTE: These results are very good, but the first result is # "Tiger: The Real Story", which is a subtitle match. A title match # should be better. - self.search( - "the real", - Common(title=re.compile("The Real", re.I)) - ) + self.search("the real", Common(title=re.compile("The Real", re.I))) def test_nothing_but_stopwords(self): # If we always stripped stopwords, this would match nothing, # but we get the best results we can manage -- e.g. # "History of Florence and of the Affairs of Italy" self.search( - "and of the", - Common(title_or_subtitle=re.compile("and of the", re.I)) + "and of the", Common(title_or_subtitle=re.compile("and of the", re.I)) ) @@ -3070,19 +2809,17 @@ def test_nothing_but_stopwords(self): index = ExternalSearchIndex(_db) SearchTest.searcher = Searcher(library, index) + def teardown_module(): failures = SearchTest.expected_failures if failures: - logging.info( - "%d tests were expected to fail, and did.", len(failures) - ) + logging.info("%d tests were expected to fail, and did.", len(failures)) successes = SearchTest.unexpected_successes if successes: - logging.info( - "%d tests passed unexepectedly:", len(successes) - ) + logging.info("%d tests passed unexepectedly:", len(successes)) for success in successes: logging.info( "Line #%d: %s", - success.__code__.co_firstlineno, success.__name__, + success.__code__.co_firstlineno, + success.__name__, ) diff --git a/migration/20150929-1-set_delivery_mechanism_for_3m_books.py b/migration/20150929-1-set_delivery_mechanism_for_3m_books.py deleted file mode 100644 index 6c92f4cbe2..0000000000 --- a/migration/20150929-1-set_delivery_mechanism_for_3m_books.py +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env python -"""Look up and set the delivery mechanism for all 3M books.""" -import os -import sys -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) -from core.monitor import IdentifierSweepMonitor -from core.model import ( - Identifier -) -from core.opds_import import SimplifiedOPDSLookup -from threem import ThreeMAPI -from core.scripts import RunMonitorScript - -class SetDeliveryMechanismMonitor(IdentifierSweepMonitor): - - def __init__(self, _db, interval_seconds=None): - super(SetDeliveryMechanismMonitor, self).__init__( - _db, "20150929 migration - Set delivery mechanism for 3M books", - interval_seconds, batch_size=10) - self.api = ThreeMAPI(_db) - self.metadata_client = SimplifiedOPDSLookup( - "http://metadata.alpha.librarysimplified.org/" - ) - - def identifier_query(self): - return self._db.query(Identifier).filter( - Identifier.type==Identifier.THREEM_ID - ) - - def process_identifier(self, identifier): - metadata = self.api.bibliographic_lookup(identifier) - license_pool = identifier.licensed_through - for format in metadata.formats: - print "%s: %s - %s" % (identifier.identifier, format.content_type, format.drm_scheme) - mech = license_pool.set_delivery_mechanism( - format.content_type, - format.drm_scheme, - format.link - ) - -RunMonitorScript(SetDeliveryMechanismMonitor).run() diff --git a/migration/20150929-2-set_delivery_mechanism_for_axis_books.py b/migration/20150929-2-set_delivery_mechanism_for_axis_books.py deleted file mode 100644 index c69cc229db..0000000000 --- a/migration/20150929-2-set_delivery_mechanism_for_axis_books.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python -"""Look up and set the delivery mechanism for all 3M books.""" -import os -import sys -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) -from core.monitor import IdentifierSweepMonitor -from core.model import ( - Identifier -) -from core.opds_import import SimplifiedOPDSLookup -from axis import Axis360API, BibliographicParser -from core.scripts import RunMonitorScript - -class SetDeliveryMechanismMonitor(IdentifierSweepMonitor): - - def __init__(self, _db, interval_seconds=None): - super(SetDeliveryMechanismMonitor, self).__init__( - _db, "20150929 migration - Set delivery mechanism for Axis books", - interval_seconds, batch_size=10) - self.api = Axis360API(_db) - - def identifier_query(self): - return self._db.query(Identifier).filter( - Identifier.type==Identifier.AXIS_360_ID - ) - - def process_identifier(self, identifier): - - availability = self.api.availability(title_ids=[identifier.identifier]) - status_code = availability.status_code - content = availability.content - result = list(BibliographicParser().process_all(content)) - if len(result) == 1: - [(metadata, circulation)] = result - license_pool = identifier.licensed_through - for format in metadata.formats: - print "%s: %s - %s" % (identifier.identifier, format.content_type, format.drm_scheme) - mech = license_pool.set_delivery_mechanism( - format.content_type, - format.drm_scheme, - format.link - ) - else: - print "Book not in collection: %s" % identifier.identifier - -RunMonitorScript(SetDeliveryMechanismMonitor).run() diff --git a/migration/20150929-3-set_delivery_mechanism_for_overdrive_books.py b/migration/20150929-3-set_delivery_mechanism_for_overdrive_books.py deleted file mode 100644 index a08363548a..0000000000 --- a/migration/20150929-3-set_delivery_mechanism_for_overdrive_books.py +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env python -"""Look up and set the delivery mechanism for all 3M books.""" -import os -import sys -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) -from core.monitor import IdentifierSweepMonitor -from core.model import ( - Identifier -) -from core.opds_import import SimplifiedOPDSLookup -from overdrive import OverdriveAPI, OverdriveRepresentationExtractor -from core.scripts import RunMonitorScript - -class SetDeliveryMechanismMonitor(IdentifierSweepMonitor): - - def __init__(self, _db, interval_seconds=None): - super(SetDeliveryMechanismMonitor, self).__init__( - _db, "20150929 migration - Set delivery mechanism for Overdrive books", - interval_seconds, batch_size=10) - self.api = OverdriveAPI(_db) - - def identifier_query(self): - return self._db.query(Identifier).filter( - Identifier.type==Identifier.OVERDRIVE_ID - ) - - def process_identifier(self, identifier): - - content = self.api.metadata_lookup(identifier) - metadata = OverdriveRepresentationExtractor.book_info_to_metadata(content) - if not metadata: - return - license_pool = identifier.licensed_through - for format in metadata.formats: - print "%s: %s - %s" % (identifier.identifier, format.content_type, format.drm_scheme) - mech = license_pool.set_delivery_mechanism( - format.content_type, - format.drm_scheme, - format.link - ) - -RunMonitorScript(SetDeliveryMechanismMonitor).run() diff --git a/migration/20150929-4-set_delivery_mechanism_for_gutenberg_books.py b/migration/20150929-4-set_delivery_mechanism_for_gutenberg_books.py deleted file mode 100644 index 636e0d3d97..0000000000 --- a/migration/20150929-4-set_delivery_mechanism_for_gutenberg_books.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python -"""Look up and set the delivery mechanism for all 3M books.""" -import os -import sys -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) -from core.monitor import IdentifierSweepMonitor -from core.model import ( - Identifier, - Representation, - DeliveryMechanism, -) -from core.opds_import import SimplifiedOPDSLookup -from core.scripts import RunMonitorScript - -class SetDeliveryMechanismMonitor(IdentifierSweepMonitor): - - def __init__(self, _db, interval_seconds=None): - super(SetDeliveryMechanismMonitor, self).__init__( - _db, "20150929 migration - Set delivery mechanism for Gutenberg books", - interval_seconds, batch_size=10) - - def identifier_query(self): - return self._db.query(Identifier).filter( - Identifier.type==Identifier.GUTENBERG_ID - ) - - def process_identifier(self, identifier): - - license_pool = identifier.licensed_through - if not license_pool: - print "No license pool for %s!" % identifier.identifier - return - edition = license_pool.edition - if edition: - best = edition.best_open_access_link - if best: - print edition.id, edition.title, best.url - edition.license_pool.set_delivery_mechanism( - Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM, - best - ) - else: - print "Edition but no link for %s/%s!" % ( - identifier.identifier, edition.title) - else: - print "No edition for %s!" % identifier.identifier - -RunMonitorScript(SetDeliveryMechanismMonitor).run() diff --git a/migration/20150930-fix-empty-author-strings.py b/migration/20150930-fix-empty-author-strings.py deleted file mode 100644 index bcdb680395..0000000000 --- a/migration/20150930-fix-empty-author-strings.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python -"""Try to fix the contributors for books that currently have none. -""" - -from pdb import set_trace -import os -import sys -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) -from core.model import ( - production_session, - DataSource, - Work, - Edition, -) -from overdrive import ( - OverdriveAPI, - OverdriveRepresentationExtractor -) -from threem import ThreeMAPI -from core.opds_import import SimplifiedOPDSLookup -lookup = SimplifiedOPDSLookup("http://metadata.alpha.librarysimplified.org/") - -_db = production_session() -overdrive = OverdriveAPI(_db) -threem = ThreeMAPI(_db) - -q = _db.query(Edition).join(Edition.data_source).filter(DataSource.name.in_([DataSource.OVERDRIVE])).filter(Edition.author=='') -print "Fixing %s books." % q.count() -for edition in q: - if edition.data_source.name==DataSource.OVERDRIVE: - data = overdrive.metadata_lookup(edition.primary_identifier) - metadata = OverdriveRepresentationExtractor.book_info_to_metadata(data) - else: - metadata = threem.bibliographic_lookup(edition.primary_identifier) - metadata.update_contributions(_db, edition, metadata_client=lookup, - replace_contributions=True) - if edition.work: - edition.work.calculate_presentation() - else: - edition.calculate_presentation() - - for c in edition.contributions: - print "%s = %s (%s)" % ( - c.role, c.contributor.display_name, c.contributor.name - ) - print edition.author, edition.sort_author - _db.commit() diff --git a/migration/20150930-fix-wikidata-ids-treated-as-names.py b/migration/20150930-fix-wikidata-ids-treated-as-names.py deleted file mode 100644 index efd90257a4..0000000000 --- a/migration/20150930-fix-wikidata-ids-treated-as-names.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python -"""Recalculate the display information about all contributors -mistakenly given Wikidata IDs as 'names'. -""" - -from pdb import set_trace -import os -import sys -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.model import ( - production_session, - Contributor, -) - -_db = production_session() -from sqlalchemy.sql import text -contributors = _db.query(Contributor).filter( - text("contributors.display_name ~ '^Q[0-9]'") -).order_by(Contributor.id) -print contributors.count() -for contributor in contributors: - display_name, family_name = contributor.default_names() - print "%s/%s: %s => %s, %s => %s" % ( - contributor.id, - contributor.name, - contributor.display_name, display_name, - contributor.family_name, family_name - ) - contributor.display_name = display_name - contributor.wikipedia_name = None - contributor.family_name = family_name - for contribution in contributor.contributions: - edition = contribution.edition - if edition.work: - edition.work.calculate_presentation() - else: - edition.calculate_presentation() - _db.commit() - diff --git a/migration/20160201-add-coverage-records-for-metadata-wrangler.sql b/migration/20160201-add-coverage-records-for-metadata-wrangler.sql deleted file mode 100644 index b6b86f85e8..0000000000 --- a/migration/20160201-add-coverage-records-for-metadata-wrangler.sql +++ /dev/null @@ -1 +0,0 @@ -insert into coveragerecords (identifier_id, data_source_id, date, exception) select i.id, datasources.id, w.presentation_ready_attempt, NULL from datasources, identifiers i join editions e on e.primary_identifier_id=i.id join works w on e.work_id=w.id where w.presentation_ready = true and datasources.name='Library Simplified metadata wrangler' and e.title is not null and e.sort_author is not null; diff --git a/migration/20160331-add-alias-for-new-index.py b/migration/20160331-add-alias-for-new-index.py deleted file mode 100755 index 2daace2021..0000000000 --- a/migration/20160331-add-alias-for-new-index.py +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env python -""" -Add an alias to a new search index. - -Process for creating and switching to the new index: -Deploy the new code. -Run this migration, which creates a new index ("-v2") and an alias ("-current") -based on the current index name. -Run `bin/repair/search_index `. -Change the config file to point to the alias instead of the old index name. -Restart the application. - -The old index can be dropped when we're confident the new index works. -""" -import os -import sys -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) -from core.scripts import Script -from core.external_search import ExternalSearchIndex -from api.config import Configuration - - -class AddSearchIndexAlias(Script): - - def do_run(self): - integration = Configuration.integration( - Configuration.ELASTICSEARCH_INTEGRATION, - ) - old_index = integration.get( - Configuration.ELASTICSEARCH_INDEX_KEY, - ) - new_index = old_index + "-v2" - alias = old_index + "-current" - - search_index_client = ExternalSearchIndex(works_index=new_index) - search_index_client.indices.put_alias( - index=new_index, - name=alias - ) - -AddSearchIndexAlias().run() diff --git a/migration/20160513-3m-books-are-never-open-access.sql b/migration/20160513-3m-books-are-never-open-access.sql deleted file mode 100644 index 74cda61feb..0000000000 --- a/migration/20160513-3m-books-are-never-open-access.sql +++ /dev/null @@ -1 +0,0 @@ -update licensepools set open_access=false where open_access is null and data_source_id=(select id from datasources where name='3M'); diff --git a/migration/20160620-remove-license-pools-from-series-cachedfeeds.sql b/migration/20160620-remove-license-pools-from-series-cachedfeeds.sql deleted file mode 100644 index ae0a8a8187..0000000000 --- a/migration/20160620-remove-license-pools-from-series-cachedfeeds.sql +++ /dev/null @@ -1 +0,0 @@ -update cachedfeeds set license_pool_id=NULL where type='series'; diff --git a/migration/20160722-fix-missing-hyperlinks.py b/migration/20160722-fix-missing-hyperlinks.py deleted file mode 100755 index 2388d4d651..0000000000 --- a/migration/20160722-fix-missing-hyperlinks.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python -"""Find open-access LicensePools that do not have a Hyperlink -with an open access rel. If they have a delivery mechanism with -a resource, create a Hyperlink for the resource and identifier. -""" - -import os -import sys -import logging -from pdb import set_trace -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.model import ( - production_session, - Hyperlink, - LicensePool, - Resource, - get_one_or_create, -) - -_db = production_session() - -open_access_pools = _db.query(LicensePool).filter(LicensePool.open_access==True) - -pools_with_open_access_links = _db.query(LicensePool).join( - Hyperlink, - Hyperlink.identifier_id==LicensePool.identifier_id - ).filter( - Hyperlink.rel==Hyperlink.OPEN_ACCESS_DOWNLOAD - ).filter( - LicensePool.open_access==True - ) - -pool_ids_with_open_access_links = [pool.id for pool in pools_with_open_access_links] - -open_access_pools_without_open_access_links = open_access_pools.filter(~LicensePool.id.in_(pool_ids_with_open_access_links)) - -print "Found %d open access pools without open access links" % open_access_pools_without_open_access_links.count() - -fixed = 0 -no_identifier = 0 -no_resource = 0 - -for pool in open_access_pools_without_open_access_links: - - if not pool.identifier: - no_identifier += 1 - continue - - # Do we have a resource for this pool? - if pool.delivery_mechanisms and pool.delivery_mechanisms[0].resource: - resource = pool.delivery_mechanisms[0].resource - identifier = pool.identifier - - link, is_new = get_one_or_create( - _db, Hyperlink, identifier=identifier, - resource=resource, license_pool=pool, - data_source=pool.data_source, rel=Hyperlink.OPEN_ACCESS_DOWNLOAD, - ) - - if not is_new: - print "Expected to create a new open access link for pool %s but one already existed" % pool - else: - fixed += 1 - pool.presentation_edition.set_open_access_link() - - if not fixed % 20: - _db.commit() - else: - no_resource += 1 - -_db.commit() -print "Fixed %d pools" % fixed -print "%d pools with no resource were not fixed" % no_resource -print "%d pools with no identifier were not fixed" % no_identifier diff --git a/migration/20161102-adobe-id-is-delegated-patron-identifier.py b/migration/20161102-adobe-id-is-delegated-patron-identifier.py deleted file mode 100755 index dc88aff631..0000000000 --- a/migration/20161102-adobe-id-is-delegated-patron-identifier.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python -"""For every patron with a credential containing an Adobe ID, make -sure they also get a DelegatedPatronIdentifier containing the same -Adobe ID. This makes sure that they don't suddenly change Adobe IDs -when they start using a client that employs the new JWT-based authdata -system. -""" - -import os -import sys -from pdb import set_trace -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) -from core.model import ( - production_session, - Patron -) -from api.adobe_vendor_id import AuthdataUtility - -_db = production_session() -authdata = AuthdataUtility.from_config() -if not authdata: - print "Adobe IDs not configured, doing nothing." - -count = 0 -qu = _db.query(Patron) -print "Processing %d patrons." % qu.count() -for patron in qu: - credential, delegated_identifier = authdata.migrate_adobe_id(patron) - count += 1 - if not (count % 100): - print count - _db.commit() - if credential is None or delegated_identifier is None: - # This patron did not have an Adobe ID in the first place. - # Do nothing. - continue - output = "%s -> %s -> %s" % ( - patron.authorization_identifier, - credential.credential, - delegated_identifier.delegated_identifier - ) - print output -_db.commit() diff --git a/migration/20170131-remove-empty-cachedfeeds.sql b/migration/20170131-remove-empty-cachedfeeds.sql deleted file mode 100644 index c856ebeacb..0000000000 --- a/migration/20170131-remove-empty-cachedfeeds.sql +++ /dev/null @@ -1 +0,0 @@ -DELETE FROM cachedfeeds WHERE content IS NULL and timestamp IS NULL; diff --git a/migration/20170224-2-copy-collection-configuration-into-database.py b/migration/20170224-2-copy-collection-configuration-into-database.py deleted file mode 100755 index abfe1efd8c..0000000000 --- a/migration/20170224-2-copy-collection-configuration-into-database.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env python -"""Copy the collection configuration information from the JSON configuration -into Collection objects. -""" - -import os -import sys -import uuid -from pdb import set_trace -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) -from api.config import Configuration -from core.model import ( - get_one_or_create, - production_session, - DataSource, - Library, - Collection, -) - -# We're going directly against the configuration object, rather than -# using shortcuts like OverdriveAPI.from_environment, because this -# script may be running against a branch where the implementation of -# those shortcuts goes against the database. - -_db = production_session() - -def copy_library_registry_information(_db, library): - config = Configuration.integration("Adobe Vendor ID") - if not config: - print("No Adobe Vendor ID configuration, not setting short name or secret.") - return - library.short_name = config.get("library_short_name") - library.library_registry_short_name = config.get("library_short_name") - library.library_registry_shared_secret = config.get("authdata_secret") - - -def convert_overdrive(_db, library): - config = Configuration.integration('Overdrive') - if not config: - print("No Overdrive configuration, not creating a Collection for it.") - return - print("Creating Collection object for Overdrive collection.") - username = config.get('client_key') - password = config.get('client_secret') - library_id = config.get('library_id') - website_id = config.get('website_id') - - collection, ignore = get_one_or_create( - _db, Collection, - protocol=Collection.OVERDRIVE, - name="Overdrive" - ) - library.collections.append(collection) - collection.external_integration.username = username - collection.external_integration.password = password - collection.external_account_id = library_id - collection.external_integration.set_setting("website_id", website_id) - -def convert_bibliotheca(_db, library): - config = Configuration.integration('3M') - if not config: - print("No Bibliotheca configuration, not creating a Collection for it.") - return - print("Creating Collection object for Bibliotheca collection.") - username = config.get('account_id') - password = config.get('account_key') - library_id = config.get('library_id') - collection, ignore = get_one_or_create( - _db, Collection, - protocol=Collection.BIBLIOTHECA, - name="Bibliotheca" - ) - library.collections.append(collection) - collection.external_integration.username = username - collection.external_integration.password = password - collection.external_account_id = library_id - -def convert_axis(_db, library): - config = Configuration.integration('Axis 360') - if not config: - print("No Axis 360 configuration, not creating a Collection for it.") - return - print("Creating Collection object for Axis 360 collection.") - username = config.get('username') - password = config.get('password') - library_id = config.get('library_id') - # This is not technically a URL, it's "production" or "staging", - # but it's converted into a URL internally. - url = config.get('server') - collection, ignore = get_one_or_create( - _db, Collection, - protocol=Collection.AXIS_360, - name="Axis 360" - ) - library.collections.append(collection) - collection.external_integration.username = username - collection.external_integration.password = password - collection.external_account_id = library_id - collection.external_integration.url = url - -def convert_one_click(_db, library): - config = Configuration.integration('OneClick') - if not config: - print("No OneClick configuration, not creating a Collection for it.") - return - print("Creating Collection object for OneClick collection.") - basic_token = config.get('basic_token') - library_id = config.get('library_id') - url = config.get('url') - ebook_loan_length = config.get('ebook_loan_length') - eaudio_loan_length = config.get('eaudio_loan_length') - - collection, ignore = get_one_or_create( - _db, Collection, - protocol=Collection.ONECLICK, - name="OneClick" - ) - library.collections.append(collection) - collection.external_integration.password = basic_token - collection.external_account_id = library_id - collection.external_integration.url = url - collection.external_integration.set_setting("ebook_loan_length", ebook_loan_length) - collection.external_integration.set_setting("eaudio_loan_length", eaudio_loan_length) - -def convert_content_server(_db, library): - config = Configuration.integration("Content Server") - if not config: - print("No content server configuration, not creating a Collection for it.") - return - url = config.get('url') - collection, ignore = get_one_or_create( - _db, Collection, - protocol=Collection.OPDS_IMPORT, - name="Open Access Content Server" - ) - collection.external_integration.setting("data_source").value = DataSource.OA_CONTENT_SERVER - library.collections.append(collection) - -# This is the point in the migration where we first create a Library -# for this system. -library = get_one_or_create( - _db, Library, - create_method_kwargs=dict( - name="Default Library", - short_name="default", - uuid=str(uuid.uuid4()) - ) -) - -copy_library_registry_information(_db, library) -convert_overdrive(_db, library) -convert_bibliotheca(_db, library) -convert_axis(_db, library) -convert_one_click(_db, library) -convert_content_server(_db, library) -_db.commit() diff --git a/migration/20170303-4-use-current-alias.py b/migration/20170303-4-use-current-alias.py deleted file mode 100644 index 87847f7887..0000000000 --- a/migration/20170303-4-use-current-alias.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python -"""Create a -current alias for the index being used""" - -import os -import sys -from pdb import set_trace -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from api.config import Configuration as C -from core.external_search import ExternalSearchIndex - -C.load() -config_index = C.integration(C.ELASTICSEARCH_INTEGRATION).get(C.ELASTICSEARCH_INDEX_KEY) -if not config_index: - print "No action taken. Elasticsearch not configured." - sys.exit() - -search = ExternalSearchIndex() -update_required_text = ( - "\n\tConfiguration update required for given alias \"%s\".\n" - "\t============================================\n" - "\tReplace Elasticsearch configuration \"works_index\" value with alias.\n" - "\te.g. \"works_index\" : \"%s\" ===> \"works_index\" : \"%s\"\n\n" -) - -misplaced_alias_text = ( - "\n\tExpected Elasticsearch alias \"%s\" is being used with\n" - "\tindex \"%s\" instead of configured index \"%s\".\n" - "\t============================================\n" -) - - -alias_not_used = search.works_alias == search.works_index -if config_index == search.works_index: - # The configuration doesn't have the alias in its configuration. - current_alias = search.base_index_name(config_index)+search.CURRENT_ALIAS_SUFFIX - - if alias_not_used: - # The current_alias wasn't set during initialization, indicating - # that it's connected to a different alias. (If it didn't exist - # already, it would have been created on this search client.) - indices = ','.join(search.indices.get_alias(name=current_alias).keys()) - manual_steps_text = ( - "\tMANUAL STEPS:\n" - "\t 1. Replace Elasticsearch configuration \"works_index\" value with alias.\n" - "\t e.g. \"works_index\" : \"%s\" ===> \"works_index\" : \"%s\"\n\n" - "\t 2. Confirm alias \"%s\" is pointing to the preferred index.\n\n" - ) - print ( - (misplaced_alias_text + manual_steps_text) % - (current_alias, indices, config_index, config_index, current_alias, current_alias)) - else: - # Initialization found or created an alias, but the configuration - # file itself needs to be updated. - print ( - update_required_text % - (search.works_alias, config_index, search.works_alias)) - -elif 'error' not in search.indices.get_alias(name=config_index, ignore=[404]): - # The configuration has an alias instead of an index. - if config_index == search.works_alias: - print "No action needed. Elasticsearch alias '%s' is properly named and configured." % config_index - print "Works are being uploaded to Elasticsearch index '%s'" % search.works_index - else: - # The alias doesn't use the naming convention we expect. Try to create one - # that does. - index = search.indices.get_alias(name=config_index).keys()[0] - current_alias = search.base_index_name(index)+search.CURRENT_ALIAS_SUFFIX - current_alias_index = ','.join(search.indices.get_alias(name=current_alias).keys()) - - if (current_alias_index != search.works_index or alias_not_used): - # An alias with the proper naming convention exists elsewhere. - # It will have to be manually removed or replaced. - manual_steps_text = ( - "\tEITHER:\n\t Remove -current alias \"%s\" from index \"%s\". " - "\n\t Place it on \"%s\" instead.\n" - "\tOR:\n\t Use -current alias \"%s\" in the configuration file" - "\n\t if \"%s\" is the preferred index.\n\n" - ) - print ( - (misplaced_alias_text + manual_steps_text) % - (current_alias, current_alias_index, index, current_alias, - current_alias_index, index, current_alias, current_alias_index)) - else: - # ExternalSearchIndex.setup_current_alias() already does this, - # so it shouldn't need to happen here. - response = search.indices.put_alias( - index=search.works_index, name=current_alias - ) - print (update_required_text % (current_alias, config_index, current_alias)) -else: - # A catchall just in case. This shouldn't happen. - print "\n\tSomething unexpected happened. Weird!" - print "\t - Given index (in configuration file): \t\"%s\"" % config_index - print "\t - Elasticsearch index (in use): \t\t\"%s\"" % search.works_index - print "\t - Elasticsearch alias (in use): \t\t\"%s\"" % search.works_alias - print "\n\tThe configured index should be manually set to the Elasticsearch alias\n" - print "\tand this migration should be run again." diff --git a/migration/20170713-11-move-configuration-links-into-db.py b/migration/20170713-11-move-configuration-links-into-db.py deleted file mode 100755 index 2ecfafccb0..0000000000 --- a/migration/20170713-11-move-configuration-links-into-db.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env python -"""Move links from the Configuration file into the database as ConfigurationSettings -for the default Library. -""" -import os -import sys - -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.model import ( - ConfigurationSetting, - Library, - get_one_or_create, - production_session, -) -from api.config import Configuration - -Configuration.load() -_db = production_session() -library = Library.default(_db) - -for rel, value in ( - ("terms-of-service", Configuration.get('links', {}).get('terms_of_service', None)), - ("privacy-policy", Configuration.get('links', {}).get('privacy_policy', None)), - ("copyright", Configuration.get('links', {}).get('copyright', None)), - ("about", Configuration.get('links', {}).get('about', None)), - ("license", Configuration.get('links', {}).get('license', None)), -): - if value: - ConfigurationSetting.for_library(rel, library).value = value - -_db.commit() -_db.close() - diff --git a/migration/20170713-13-move-authentication-configuration-to-external-integrations.py b/migration/20170713-13-move-authentication-configuration-to-external-integrations.py deleted file mode 100755 index 1c101ed96e..0000000000 --- a/migration/20170713-13-move-authentication-configuration-to-external-integrations.py +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python -"""Move authentication integration details from the Configuration file -into the database as ExternalIntegrations -""" -import os -import sys -import json -import logging - -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.model import ( - ConfigurationSetting, - ExternalIntegration, - get_one_or_create, - production_session, - Library, - create, -) - -from api.config import Configuration -from api.millenium_patron import MilleniumPatronAPI -from api.sip import SIP2AuthenticationProvider -from api.authenticator import ( - BasicAuthenticationProvider, - OAuthAuthenticationProvider, -) - -log = logging.getLogger(name="Circulation manager authentication configuration import") - -def log_import(service_name): - log.info("Importing configuration for %s" % service_name) - -def make_patron_auth_integration(_db, provider): - integration, ignore = get_one_or_create( - _db, ExternalIntegration, protocol=provider.get('module'), - goal=ExternalIntegration.PATRON_AUTH_GOAL - ) - - # If any of the common Basic Auth-type settings were provided, set them - # as ConfigurationSettings on the ExternalIntegration. - test_identifier = provider.get('test_username') - test_password = provider.get('test_password') - if test_identifier: - integration.setting(BasicAuthenticationProvider.TEST_IDENTIFIER).value = test_identifier - if test_password: - integration.setting(BasicAuthenticationProvider.TEST_PASSWORD).value = test_password - identifier_re = provider.get('identifier_regular_expression') - password_re = provider.get('password_regular_expression') - if identifier_re: - integration.setting(BasicAuthenticationProvider.IDENTIFIER_REGULAR_EXPRESSION).value = identifier_re - if password_re: - integration.setting(BasicAuthenticationProvider.PASSWORD_REGULAR_EXPRESSION).value = password_re - - return integration - -def convert_millenium(_db, integration, provider): - - # Cross-check MilleniumPatronAPI.__init__ to see how these values - # are pulled from the ExternalIntegration. - integration.url = provider.get('url') - auth_mode = provider.get('auth_mode') - blacklist = provider.get('authorization_identifier_blacklist') - if blacklist: - integration.setting(MilleniumPatronAPI.IDENTIFIER_BLACKLIST - ).value = json.dumps(blacklist) - if auth_mode: - integration.setting(MilleniumPatronAPI.AUTHENTICATION_MODE - ).value = auth_mode - -def convert_sip(_db, integration, provider): - # Cross-check SIP2AuthenticationProvider.__init__ to see how these values - # are pulled from the ExternalIntegration. - integration.url = provider.get('server') - integration.username = provider.get('login_user_id') - integration.password = provider.get('login_password') - SAP = SIP2AuthenticationProvider - port = provider.get('port') - if port: - integration.setting(SAP.PORT).value = port - location_code = provider.get('location_code') - if location_code: - integration.setting(SAP.LOCATION_CODE).value = location_code - field_separator = provider.get('field_separator') - if field_separator: - integration.setting(SAP.FIELD_SEPARATOR).value = field_separator - -def convert_firstbook(_db, integration, provider): - # Cross-check FirstBookAuthenticationAPI.__init__ to see how these values - # are pulled from the ExternalIntegration. - integration.url = provider.get('url') - integration.password = provider.get('key') - -def convert_clever(_db, integration, provider): - # Cross-check OAuthAuthenticationProvider.from_config to see how - # these values are pulled from the ExternalIntegration. - integration.username = provider.get('client_id') - integration.password = provider.get('client_secret') - expiration_days = provider.get('token_expiration_days') - if expiration_days: - integration.setting(OAuthAuthenticationProvider.TOKEN_EXPIRATION_DAYS - ).value = expiration_days - -Configuration.load() -if not Configuration.instance: - # No need to import configuration if there isn't any. - sys.exit() - -_db = production_session() -try: - integrations = [] - auth_conf = Configuration.policy('authentication') - if not auth_conf: - sys.exit() - - bearer_token_signing_secret = auth_conf.get('bearer_token_signing_secret') - secret_setting = ConfigurationSetting.sitewide( - _db, OAuthAuthenticationProvider.BEARER_TOKEN_SIGNING_SECRET - ) - if bearer_token_signing_secret: - secret_setting.value = bearer_token_signing_secret - - for provider in auth_conf.get('providers'): - integration = make_patron_auth_integration(_db, provider) - module = provider.get('module') - if module == 'api.millenium_patron': - convert_millenium(_db, integration, provider) - elif module == 'api.firstbook': - convert_firstbook(_db, integration, provider) - elif module == 'api.clever': - convert_clever(_db, integration, provider) - elif module == 'api.sip': - convert_sip(_db, integration, provider) - else: - log.warn("I don't know how to convert a provider of type %s. Conversion is probably incomplete." % module) - integrations.append(integration) - - # Add each integration to each library. - library = Library.default(_db) - for library in _db.query(Library): - for integration in integrations: - if integration not in library.integrations: - library.integrations.append(integration) - - print "Sitewide bearer token signing secret: %s" % secret_setting.value - for library in _db.query(Library): - print "\n".join(library.explain(include_secrets=True)) -finally: - _db.commit() - _db.close() diff --git a/migration/20170713-15-move-library-configuration-to-configurationsettings.py b/migration/20170713-15-move-library-configuration-to-configurationsettings.py deleted file mode 100755 index e7e014980f..0000000000 --- a/migration/20170713-15-move-library-configuration-to-configurationsettings.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python -"""Move per-library settings from the Configuration file -into the database as ConfigurationSettings. -""" -import os -import sys -import json -import logging - -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.model import ( - ConfigurationSetting, - ExternalIntegration, - get_one_or_create, - production_session, - Library, - create, -) - -from api.config import Configuration - -_db = production_session() -try: - Configuration.load() - integrations = [] - - # Get or create a secret key and make it a sitewide setting. - secret_key = Configuration.get('secret_key') - if not secret_key: - secret_key = os.urandom(24).encode('hex') - - secret_setting = ConfigurationSetting.sitewide( - _db, Configuration.SECRET_KEY - ) - secret_setting.value = secret_key - - libraries = _db.query(Library).all() - - # Copy default email address into each library. - key = 'default_notification_email_address' - value = Configuration.get(key) - if value: - for library in libraries: - ConfigurationSetting.for_library(key, library).value = value - - # Copy maximum fines into each library. - for key in ['max_outstanding_fines', 'minimum_featured_quality', - 'featured_lane_size']: - value = Configuration.policy(key) - if value: - for library in libraries: - ConfigurationSetting.for_library(key, library).value = value - - # Convert the string hold_policy into the boolean allow_holds. - hold_policy = Configuration.policy('holds') - if hold_policy == 'hide': - for library in libraries: - library.setting("allow_holds").value = "False" - - # Install the language policies used to configure the lanes. - language_policy = Configuration.policy('languages') - if language_policy: - for variable in [Configuration.LARGE_COLLECTION_LANGUAGES, - Configuration.SMALL_COLLECTION_LANGUAGES, - Configuration.TINY_COLLECTION_LANGUAGES]: - value = language_policy.get(variable) - if value: - for library in libraries: - library.setting(variable).value = json.dumps(value) - - # Copy facet configuration - facet_policy = Configuration.policy("facets", default={}) - enabled = facet_policy.get("enabled", {}) - default = facet_policy.get("default", {}) - for library in libraries: - for k, v in enabled.items(): - library.enabled_facets_setting(k).value = json.dumps(v) - for k, v in default.items(): - library.default_facet_setting(k).value = v - - # Copy external type regular expression into each authentication - # mechanism for each library. - key = 'external_type_regular_expression' - value = Configuration.policy(key) - if value: - for library in libraries: - for integration in library.integrations: - if integration.goal != ExternalIntegration.PATRON_AUTH_GOAL: - continue - ConfigurationSetting.for_library_and_externalintegration( - _db, key, library, integration).value = value - -finally: - _db.commit() - _db.close() diff --git a/migration/20170713-19-move-third-party-config-to-external-integrations.py b/migration/20170713-19-move-third-party-config-to-external-integrations.py deleted file mode 100755 index 96e64e752c..0000000000 --- a/migration/20170713-19-move-third-party-config-to-external-integrations.py +++ /dev/null @@ -1,175 +0,0 @@ -#!/usr/bin/env python -"""Move integration details from the Configuration file into the -database as ExternalIntegrations -""" -import os -import sys -import json -import logging - -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.model import ( - ConfigurationSetting, - ExternalIntegration as EI, - Library, - get_one_or_create, - production_session, -) - -from api.adobe_vendor_id import AuthdataUtility -from api.config import Configuration - -log = logging.getLogger(name="Circulation manager configuration import") - -def log_import(integration_or_setting): - log.info("CREATED: %r" % integration_or_setting) - -try: - Configuration.load() - _db = production_session() - LIBRARIES = _db.query(Library).all() - - # Import Circulation Manager base url. - circ_manager_conf = Configuration.integration('Circulation Manager') - if circ_manager_conf: - url = circ_manager_conf.get('url') - if url: - setting = ConfigurationSetting.sitewide(_db, Configuration.BASE_URL_KEY) - setting.value = unicode(url) - log_import(setting) - - # Import Metadata Wrangler configuration. - metadata_wrangler_conf = Configuration.integration('Metadata Wrangler') - - if metadata_wrangler_conf: - integration = EI(protocol=EI.METADATA_WRANGLER, goal=EI.METADATA_GOAL) - _db.add(integration) - - integration.url = metadata_wrangler_conf.get('url') - integration.username = metadata_wrangler_conf.get('client_id') - integration.password = metadata_wrangler_conf.get('client_secret') - - log_import(integration) - - # Import NoveList Select configuration. - novelist = Configuration.integration('NoveList Select') - if novelist: - integration = EI(protocol=EI.NOVELIST, goal=EI.METADATA_GOAL) - _db.add(integration) - - integration.username = novelist.get('profile') - integration.password = novelist.get('password') - - integration.libraries.extend(LIBRARIES) - log_import(integration) - - # Import NYT configuration. - nyt_conf = Configuration.integration('New York Times') - if nyt_conf: - integration = EI(protocol=EI.NYT, goal=EI.METADATA_GOAL) - _db.add(integration) - - integration.password = nyt_conf.get('best_sellers_api_key') - - log_import(integration) - - # Import Adobe Vendor ID configuration. - adobe_conf = Configuration.integration('Adobe Vendor ID') - if adobe_conf: - vendor_id = adobe_conf.get('vendor_id') - node_value = adobe_conf.get('node_value') - other_libraries = adobe_conf.get('other_libraries') - - if node_value: - node_library = Library.default(_db) - integration = EI(protocol=EI.ADOBE_VENDOR_ID, goal=EI.DRM_GOAL) - _db.add(integration) - - integration.username = vendor_id - integration.password = node_value - - if other_libraries: - other_libraries = unicode(json.dumps(other_libraries)) - integration.set_setting('other_libraries', other_libraries) - integration.libraries.append(node_library) - log_import(integration) - - # Import short client token configuration. - integration = EI(protocol='Short Client Token', goal=EI.DRM_GOAL) - _db.add(integration) - integration.set_setting( - AuthdataUtility.VENDOR_ID_KEY, vendor_id - ) - - for library in LIBRARIES: - short_name = library.library_registry_short_name - short_name = short_name or adobe_conf.get('library_short_name') - if short_name: - ConfigurationSetting.for_library_and_externalintegration( - _db, EI.USERNAME, library, integration - ).value = short_name - - shared_secret = library.library_registry_shared_secret - shared_secret = shared_secret or adobe_conf.get('authdata_secret') - ConfigurationSetting.for_library_and_externalintegration( - _db, EI.PASSWORD, library, integration - ).value = shared_secret - - library_url = adobe_conf.get('library_uri') - ConfigurationSetting.for_library( - Configuration.WEBSITE_URL, library).value = library_url - - integration.libraries.append(library) - - # Import Google OAuth configuration. - google_oauth_conf = Configuration.integration('Google OAuth') - if google_oauth_conf: - integration = EI(protocol=EI.GOOGLE_OAUTH, goal=EI.ADMIN_AUTH_GOAL) - _db.add(integration) - - integration.url = google_oauth_conf.get("web", {}).get("auth_uri") - integration.username = google_oauth_conf.get("web", {}).get("client_id") - integration.password = google_oauth_conf.get("web", {}).get("client_secret") - - auth_domain = Configuration.policy('admin_authentication_domain') - if auth_domain: - integration.set_setting('domains', json.dumps([auth_domain])) - - log_import(integration) - - # Import Patron Web Client configuration. - patron_web_client_conf = Configuration.integration('Patron Web Client', {}) - patron_web_client_url = patron_web_client_conf.get('url') - if patron_web_client_url: - setting = ConfigurationSetting.sitewide( - _db, Configuration.PATRON_WEB_CLIENT_URL) - setting.value = patron_web_client_url - log_import(setting) - - # Import analytics configuration. - policies = Configuration.get("policies", {}) - analytics_modules = policies.get("analytics", ["core.local_analytics_provider"]) - - if "api.google_analytics_provider" in analytics_modules: - google_analytics_conf = Configuration.integration("Google Analytics Provider", {}) - tracking_id = google_analytics_conf.get("tracking_id") - - integration = EI(protocol="api.google_analytics_provider", goal=EI.ANALYTICS_GOAL) - _db.add(integration) - integration.url = "http://www.google-analytics.com/collect" - - for library in LIBRARIES: - ConfigurationSetting.for_library_and_externalintegration( - _db, "tracking_id", library, integration).value = tracking_id - library.integrations += [integration] - - if "core.local_analytics_provider" in analytics_modules: - integration = EI(protocol="core.local_analytics_provider", goal=EI.ANALYTICS_GOAL) - _db.add(integration) - -finally: - _db.commit() - _db.close() diff --git a/migration/20170713-9-change-coverage-sync-to-import.sql b/migration/20170713-9-change-coverage-sync-to-import.sql deleted file mode 100644 index 0671e8e206..0000000000 --- a/migration/20170713-9-change-coverage-sync-to-import.sql +++ /dev/null @@ -1,6 +0,0 @@ -UPDATE coveragerecords -SET operation = 'import' -WHERE - data_source_id in ( - select id from datasources where name = 'Library Simplified metadata wrangler' - ) and operation = 'sync'; diff --git a/migration/20170714-add-collection-id-to-licensepools.py b/migration/20170714-add-collection-id-to-licensepools.py deleted file mode 100755 index 4ba04e7279..0000000000 --- a/migration/20170714-add-collection-id-to-licensepools.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -"""Associate LicensePools to their Collections""" -import os -import sys -import logging - -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.model import ( - Collection, - DataSource, - LicensePool, - production_session, -) -from core.util import fast_query_count - -_db = production_session() -log = logging.getLogger('Migration [20170714]') - -def log_change(count, collection): - log.info('UPDATED: %d LicensePools given Collection %r' % ( - int(count), collection)) - -try: - collections = _db.query(Collection).all() - collections_by_data_source = dict([(collection.data_source, collection) for collection in collections]) - - base_query = _db.query(LicensePool).filter(LicensePool.collection_id==None) - for data_source, collection in collections_by_data_source.items(): - # Find LicensePools with the matching DataSource. - qu = base_query.filter(LicensePool.data_source==data_source) - qu.update({LicensePool.collection_id : collection.id}) - log_change(fast_query_count(qu), collection) - _db.commit() - - # Some LicensePools may be associated with the a duplicate or - # outdated Bibliotheca DataSource. Find them. - bibliotheca = DataSource.lookup(_db, DataSource.BIBLIOTHECA) - old_sources = _db.query(DataSource.id).filter( - DataSource.name.in_(['3M', 'Bibliotecha'])).subquery() - threem_qu = base_query.filter(LicensePool.data_source_id.in_(old_sources)) - - # Associate these LicensePools with the Bibliotheca Collection. - bibliotheca_collection = collections_by_data_source.get(bibliotheca) - if bibliotheca_collection: - result = threem_qu.update( - {LicensePool.collection_id : bibliotheca_collection.id}, - synchronize_session='fetch' - ) - - # If something changed, log it. - threem_count = fast_query_count(threem_qu) - if threem_count: - log_change(threem_count, bibliotheca_collection) - - remaining = fast_query_count(base_query) - if remaining > 0: - log.warning('No Collection found for %d LicensePools', remaining) - - source_ids = _db.query(LicensePool.data_source_id)\ - .filter(LicensePool.collection_id==None).subquery() - sources = _db.query(DataSource).filter(DataSource.id.in_(source_ids)) - names = ', '.join(["%s" % source.name for source in sources]) - - log.warning('Remaining LicensePools have DataSources: %s', names) -except Exception as e: - _db.close() - raise e -finally: - _db.commit() - _db.close() diff --git a/migration/20170728-move-adobe-vendor-id-and-sct.sql b/migration/20170728-move-adobe-vendor-id-and-sct.sql deleted file mode 100644 index 275a95dced..0000000000 --- a/migration/20170728-move-adobe-vendor-id-and-sct.sql +++ /dev/null @@ -1,38 +0,0 @@ -DO $$ -DECLARE registry_id int; -DECLARE sct_id int; - -BEGIN - -- Create a new registry integration, and store its id. - INSERT INTO externalintegrations (goal, protocol, name) - VALUES ('discovery', 'OPDS Registration', 'Library Simplified Registry') - RETURNING id into registry_id; - - -- Create the registry's url setting. - INSERT INTO configurationsettings (key, value, external_integration_id) - VALUES ('url', 'https://registry.librarysimplified.org', registry_id); - - -- Find the short client token integration and store its id. - SELECT id INTO sct_id FROM externalintegrations - WHERE protocol='Short Client Token' AND GOAL='drm' LIMIT 1; - - -- Move the vendor id from the short client token to the registry. - UPDATE configurationsettings SET external_integration_id=registry_id - WHERE key='vendor_id' and external_integration_id=sct_id; - - -- Move usernames from the short client token to the registry. - UPDATE configurationsettings SET external_integration_id=registry_id - WHERE key='username' and external_integration_id=sct_id; - - -- Move passwords from the short client token to the registry. - UPDATE configurationsettings SET external_integration_id=registry_id - WHERE key='password' and external_integration_id=sct_id; - - -- Move libraries from the short client token to the registry. - UPDATE externalintegrations_libraries SET externalintegration_id=registry_id - WHERE externalintegration_id=sct_id; - - -- Drop the short client token integration. - DELETE FROM externalintegrations WHERE id=sct_id; - -END $$; \ No newline at end of file diff --git a/migration/20171026-set-rbdigital-identifier-for-patrons-with-active-loans.sql b/migration/20171026-set-rbdigital-identifier-for-patrons-with-active-loans.sql deleted file mode 100644 index 1301752a0e..0000000000 --- a/migration/20171026-set-rbdigital-identifier-for-patrons-with-active-loans.sql +++ /dev/null @@ -1,20 +0,0 @@ --- Set patron's authorization identifier as their RBdigital identifier -insert into credentials(data_source_id, patron_id, type, credential) select - datasources.id, - patrons.id, - 'Identifier Sent To Remote Service', - patrons.authorization_identifier -from patrons join datasources on datasources.name='RBdigital' - --- If they don't already have a credential -where patrons.id not in ( - select patron_id from credentials join datasources on credentials.data_source_id=datasources.id and datasources.name='RBdigital' where type='Identifier Sent To Remote Service' -) and ( - -- And they have an active RBdigital loan or hold. - patrons.id in ( - select patron_id from loans join licensepools on loans.license_pool_id=licensepools.id join datasources on licensepools.data_source_id=datasources.id and datasources.name='RBdigital' - ) or patrons.id in ( - select patron_id from holds join licensepools on holds.license_pool_id=licensepools.id join datasources on licensepools.data_source_id=datasources.id and datasources.name='RBdigital' - ) -) -; diff --git a/migration/20180122-make-lanes-if-none-currently.py b/migration/20180122-make-lanes-if-none-currently.py deleted file mode 100755 index bfd075811e..0000000000 --- a/migration/20180122-make-lanes-if-none-currently.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env python -"""Make sure every library has some lanes.""" -import os -import sys -import logging - -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.model import ( - Library, - production_session, -) -from api.lanes import create_default_lanes - -_db = production_session() -for library in _db.query(Library): - num_lanes = len(library.lanes) - if num_lanes: - logging.info( - "%s has %d lanes, not doing anything.", - library.name, num_lanes - ) - else: - logging.warn( - "%s has no lanes, creating some.", - library.name - ) - try: - create_default_lanes(_db, library) - except Exception, e: - logging.error( - "Could not create default lanes; suggest you try resetting them manually.", - exc_info=e - ) diff --git a/migration/20180201-set-library-identifier-from-patron-restriction.py b/migration/20180201-set-library-identifier-from-patron-restriction.py deleted file mode 100755 index 701ab09dd9..0000000000 --- a/migration/20180201-set-library-identifier-from-patron-restriction.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python -"""Migrate patron restriction to library identifier.""" -import os -import sys - -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.model import ( - Library, - ExternalIntegration, - ConfigurationSetting, - production_session, -) -from api.authenticator import AuthenticationProvider - -try: - _db = production_session() - for library in _db.query(Library): - for integration in library.integrations: - if integration.goal == ExternalIntegration.PATRON_AUTH_GOAL: - # Get old patron restriction. - patron_restriction = ConfigurationSetting.for_library_and_externalintegration( - _db, 'patron_identifier_restriction', library, integration) - - # Get new settings. - library_identifier_field = ConfigurationSetting.for_library_and_externalintegration( - _db, AuthenticationProvider.LIBRARY_IDENTIFIER_FIELD, library, integration) - library_identifier_restriction_type = ConfigurationSetting.for_library_and_externalintegration( - _db, AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE, library, integration) - library_identifier_restriction = ConfigurationSetting.for_library_and_externalintegration( - _db, AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION, library, integration) - - # Set new settings. - if not patron_restriction.value: - library_identifier_restriction_type.value = AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE - elif patron_restriction.value.startswith("^"): - library_identifier_restriction_type.value = AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX - library_identifier_field.value = 'barcode' - library_identifier_restriction.value = patron_restriction.value - else: - library_identifier_restriction_type.value = AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX - library_identifier_field.value = 'barcode' - library_identifier_restriction.value = patron_restriction.value - - # Old patron restriction no longer needed. - _db.delete(patron_restriction) -finally: - _db.commit() - _db.close() \ No newline at end of file diff --git a/migration/20180424-2-make-google-oauth-domains-a-per-library-setting.sql b/migration/20180424-2-make-google-oauth-domains-a-per-library-setting.sql deleted file mode 100644 index 9bb8f872ca..0000000000 --- a/migration/20180424-2-make-google-oauth-domains-a-per-library-setting.sql +++ /dev/null @@ -1,9 +0,0 @@ -INSERT INTO externalintegrations_libraries (externalintegration_id, library_id) -SELECT e.id, l.id -FROM libraries as l join externalintegrations as e on e.protocol = 'Google OAuth'; - -INSERT INTO configurationsettings (external_integration_id, library_id, key, value) -SELECT cs.external_integration_id, l.id, cs.key, cs.value -FROM libraries as l join configurationsettings as cs on cs.key = 'domains'; - -DELETE FROM configurationsettings WHERE key = 'domains' and library_id is null; \ No newline at end of file diff --git a/migration/20180427-delete-borrow-links-for-open-access-books-from-shared-odl-collection.py b/migration/20180427-delete-borrow-links-for-open-access-books-from-shared-odl-collection.py deleted file mode 100755 index 5d98b2195d..0000000000 --- a/migration/20180427-delete-borrow-links-for-open-access-books-from-shared-odl-collection.py +++ /dev/null @@ -1,67 +0,0 @@ -#!/usr/bin/env python -"""Due to a bug in version 2.2.0, borrow links for open-access books in a -shared ODL collection were imported. This migration delete the links and -their associated resources and representations.""" -import os -import sys - -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from sqlalchemy.orm import aliased -from sqlalchemy import and_ - -from core.model import ( - Collection, - Hyperlink, - LicensePool, - Representation, - Resource, - production_session, -) -from api.odl import SharedODLAPI - -try: - _db = production_session() - for collection in Collection.by_protocol(_db, SharedODLAPI.NAME): - borrow_link = aliased(Hyperlink) - open_link = aliased(Hyperlink) - - pools = _db.query( - LicensePool - ).join( - borrow_link, - LicensePool.identifier_id==borrow_link.identifier_id, - ).join( - open_link, - LicensePool.identifier_id==open_link.identifier_id, - ).join( - Resource, - borrow_link.resource_id==Resource.id, - ).join( - Representation, - Resource.representation_id==Representation.id, - ).filter( - and_( - LicensePool.collection_id==collection.id, - borrow_link.rel==Hyperlink.BORROW, - open_link.rel==Hyperlink.OPEN_ACCESS_DOWNLOAD, - Representation.media_type=='application/atom+xml;type=entry;profile=opds-catalog', - ) - ) - - print "Deleting hyperlinks for %i license pools" % pools.count() - for pool in pools: - for link in pool.identifier.links: - if link.rel == Hyperlink.BORROW: - resource = link.resource - representation = resource.representation - - _db.delete(representation) - _db.delete(link) - _db.delete(resource) - -finally: - _db.commit() - _db.close() diff --git a/migration/20180604-feedbooks-language-is-external-account-id.sql b/migration/20180604-feedbooks-language-is-external-account-id.sql deleted file mode 100644 index c893daf6a7..0000000000 --- a/migration/20180604-feedbooks-language-is-external-account-id.sql +++ /dev/null @@ -1,23 +0,0 @@ --- For each FeedBooks language, copy the value of the 'language' setting --- to the collection's external_account_id. -update collections set external_account_id='en' where id in ( - select c.id from collections c join externalintegrations e on c.external_integration_id=e.id join configurationsettings cs on cs.external_integration_id=e.id where e.protocol='FeedBooks' and cs.key='language' and cs.value='en' -); -update collections set external_account_id='fr' where id in ( - select c.id from collections c join externalintegrations e on c.external_integration_id=e.id join configurationsettings cs on cs.external_integration_id=e.id where e.protocol='FeedBooks' and cs.key='language' and cs.value='fr' -); -update collections set external_account_id='de' where id in ( - select c.id from collections c join externalintegrations e on c.external_integration_id=e.id join configurationsettings cs on cs.external_integration_id=e.id where e.protocol='FeedBooks' and cs.key='language' and cs.value='de' -); -update collections set external_account_id='it' where id in ( - select c.id from collections c join externalintegrations e on c.external_integration_id=e.id join configurationsettings cs on cs.external_integration_id=e.id where e.protocol='FeedBooks' and cs.key='language' and cs.value='it' -); -update collections set external_account_id='es' where id in ( - select c.id from collections c join externalintegrations e on c.external_integration_id=e.id join configurationsettings cs on cs.external_integration_id=e.id where e.protocol='FeedBooks' and cs.key='language' and cs.value='es' -); - --- Delete all FeedBooks language settings. -delete from configurationsettings where id in ( - select cs.id from configurationsettings cs join externalintegrations e on cs.external_integration_id=e.id where cs.key='language' and e.protocol='FeedBooks' -); - diff --git a/migration/20180619-lanes-have-no-media-type.sql b/migration/20180619-lanes-have-no-media-type.sql deleted file mode 100644 index 2a4be24ae4..0000000000 --- a/migration/20180619-lanes-have-no-media-type.sql +++ /dev/null @@ -1,7 +0,0 @@ --- The lanes.media column is no longer set by default -- the same lanes are --- present for both ebooks and audiobooks, and an EntryPoint is used to --- filter them. --- --- We're not removing lanes.media altogether because it might be useful --- as a way of dividing up sublanes for other entry points. -update lanes set media=null; diff --git a/migration/20180820-oneclick-to-rbdigital.sql b/migration/20180820-oneclick-to-rbdigital.sql deleted file mode 100644 index 33e8af4d27..0000000000 --- a/migration/20180820-oneclick-to-rbdigital.sql +++ /dev/null @@ -1 +0,0 @@ -update timestamps set service='RBDigital CirculationMonitor' where service='OneClick CirculationMonitor'; diff --git a/migration/20180921-recalculate-work.py b/migration/20180921-recalculate-work.py deleted file mode 100755 index 2f49d39aa8..0000000000 --- a/migration/20180921-recalculate-work.py +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env python -import os -import sys -import logging -from pdb import set_trace -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.model import ( - Edition, - production_session, - LicensePool, - Work, -) - -_db = production_session() - -works = _db.query(Work).filter(Work.fiction == None).order_by(Work.id) -logging.info("Processing %d works with no fiction status.", works.count()) -a = 0 -for work in works: - logging.info("Processing %s", work.title) - work.set_presentation_ready_based_on_content() - if not a % 10: - _db.commit() -_db.commit() - - -license_pools = _db.query(LicensePool).join(LicensePool.presentation_edition).filter(LicensePool.work == None).filter(Edition.title != None).filter(Edition.author == "[Unknown]").order_by(LicensePool.id) -logging.info("Processing %d license pools with no work and no known author.", license_pools.count()) -for license_pool in license_pools: - logging.info("Processing %s", license_pool.presentation_edition.title) - try: - license_pool.calculate_work() - except Exception, e: - logging.error("That didn't work.", exc_info=e) - _db.commit() diff --git a/migration/20181016-recalculate-presentation-for-audiobooks-believed-to-be-books.py b/migration/20181016-recalculate-presentation-for-audiobooks-believed-to-be-books.py deleted file mode 100755 index c69ef5533f..0000000000 --- a/migration/20181016-recalculate-presentation-for-audiobooks-believed-to-be-books.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python -import os -import sys -from sqlalchemy.sql import select -from sqlalchemy.sql.expression import ( - join, - and_, -) -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) -from core.model import ( - dump_query, - production_session, - LicensePool, - DataSource, - Edition, - PresentationCalculationPolicy, -) - -# Find all books where the edition associated with the LicensePool has a -# different medium from the presentation edition. -_db = production_session() - -# Find all the LicensePools that aren't books. -subq = select([LicensePool.id]).select_from( - join(LicensePool, Edition, - and_(LicensePool.data_source_id==Edition.data_source_id, - LicensePool.identifier_id==Edition.primary_identifier_id) - ) -).where(Edition.medium != Edition.BOOK_MEDIUM) - -# Of those LicensePools, find every LicensePool whose presentation -# edition says it _is_ a book. -qu = _db.query(LicensePool).join( - Edition, LicensePool.presentation_edition_id==Edition.id -).filter(LicensePool.id.in_(subq)).filter(Edition.medium == Edition.BOOK_MEDIUM) - -print "Recalculating presentation edition for %d LicensePools." % qu.count() - -for lp in qu: - # Recalculate that LicensePool's presentation edition, and then its - # work presentation. - lp.set_presentation_edition() - policy = PresentationCalculationPolicy( - regenerate_opds_entries=True, update_search_index=True - ) - work, is_new = lp.calculate_work() - work.calculate_presentation(policy) - print "New medium: %s" % lp.presentation_edition.medium - _db.commit() diff --git a/migration/20181113-remove-mp3-findaway-delivery-mechanism.sql b/migration/20181113-remove-mp3-findaway-delivery-mechanism.sql deleted file mode 100644 index 178d1a02ba..0000000000 --- a/migration/20181113-remove-mp3-findaway-delivery-mechanism.sql +++ /dev/null @@ -1,12 +0,0 @@ --- If an audio/mpeg+findaway delivery mechanism was created again after the old --- one was removed, change over all rows in licensepooldeliveries that use it. -update licensepooldeliveries set delivery_mechanism_id=( - select id from deliverymechanisms where content_type is null and drm_scheme='application/vnd.librarysimplified.findaway.license+json' -) where delivery_mechanism_id = ( - select id from deliverymechanisms where content_type='audio/mpeg' and drm_scheme='application/vnd.librarysimplified.findaway.license+json' -); - --- Then delete it. -delete from deliverymechanisms where content_type='audio/mpeg' and drm_scheme='application/vnd.librarysimplified.findaway.license+json'; - -update deliverymechanisms set default_client_can_fulfill='t' where content_type is null and drm_scheme='application/vnd.librarysimplified.findaway.license+json'; diff --git a/migration/20190215-update-enki-collection-integration.sql b/migration/20190215-update-enki-collection-integration.sql deleted file mode 100644 index 14dbbb5a20..0000000000 --- a/migration/20190215-update-enki-collection-integration.sql +++ /dev/null @@ -1,13 +0,0 @@ --- Moving the "Library ID" Enki integration setting to be associated with each library --- and not just the collection. -insert into configurationsettings (external_integration_id, library_id, key, value) -select externalintegration_id, library_id, 'enki_library_id', external_account_id -from collections join externalintegrations_libraries as el -on collections.external_integration_id=el.externalintegration_id -join externalintegrations as e on e.id=el.externalintegration_id where e.protocol='Enki'; - - --- Remove external_account_id values for all Enki collections -update collections -set external_account_id=null -where external_integration_id in (select id from externalintegrations where protocol='Enki'); diff --git a/migration/20190220-fix-same-book-in-two-collections-from-same-source.sql b/migration/20190220-fix-same-book-in-two-collections-from-same-source.sql deleted file mode 100644 index 25a09b6982..0000000000 --- a/migration/20190220-fix-same-book-in-two-collections-from-same-source.sql +++ /dev/null @@ -1,3 +0,0 @@ -update licensepools as lp set presentation_edition_id = lp2.presentation_edition_id from licensepools as lp2 where lp.identifier_id = lp2.identifier_id and lp.id != lp2.id and lp.data_source_id = lp2.data_source_id and lp.presentation_edition_id is null and lp2.presentation_edition_id is not null; - -update licensepools as lp set work_id = lp2.work_id from licensepools as lp2 where lp.identifier_id = lp2.identifier_id and lp.id != lp2.id and lp.data_source_id = lp2.data_source_id and lp.work_id is null and lp2.work_id is not null; \ No newline at end of file diff --git a/migration/20190328-2-remove-odl-consolidated-licenses.sql b/migration/20190328-2-remove-odl-consolidated-licenses.sql deleted file mode 100644 index d68d4adbe4..0000000000 --- a/migration/20190328-2-remove-odl-consolidated-licenses.sql +++ /dev/null @@ -1,21 +0,0 @@ --- Temporarily wipe out all the licenses in ODL collections and delete their coverage records. Another migration will run a full import and add individual licenses. - -update licensepools set licenses_owned = 0 from collections join externalintegrations on collections.external_integration_id = externalintegrations.id where licensepools.collection_id = collections.id and externalintegrations.protocol = 'ODL with Consolidated Copies'; -update licensepools set licenses_available = 0 from collections join externalintegrations on collections.external_integration_id = externalintegrations.id where licensepools.collection_id = collections.id and externalintegrations.protocol = 'ODL with Consolidated Copies'; - -delete from coveragerecords where coveragerecords.identifier_id in (select identifiers.id from identifiers join licensepools on licensepools.identifier_id = identifiers.id join collections on licensepools.collection_id = collections.id join externalintegrations on collections.external_integration_id = externalintegrations.id where externalintegrations.protocol = 'ODL With Consolidated Copies') and coveragerecords.operation = 'import'; - -update externalintegrations set protocol = 'ODL' where protocol = 'ODL with Consolidated Copies'; - --- Update the URL for any collections using the DPLA Exchange. - -update collections set external_account_id = 'https://www.feedbooks.com/harvest/odl' where external_account_id = 'https://www.feedbooks.com/library/last_update.atom'; - --- Upload any loans with old URLs. - -update loans set external_identifier = replace(external_identifier, 'https://loan.feedbooks.net/loan/get/', 'https://license.feedbooks.net/loan/status/'); - --- Delete configuration settings that are no longer needed. - -delete from configurationsettings where key = 'consolidated_copies_url'; -delete from configurationsettings where key = 'consolidated_loan_url'; diff --git a/migration/20190328-3-odl-full-reimport.py b/migration/20190328-3-odl-full-reimport.py deleted file mode 100755 index 2945a52513..0000000000 --- a/migration/20190328-3-odl-full-reimport.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -"""Reimport ODL collections to get individual license data. -""" -import os -import sys -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) -from scripts import ODLImportScript -ODLImportScript().run() diff --git a/migration/20190515-reformat-existing-geographic-data.py b/migration/20190515-reformat-existing-geographic-data.py deleted file mode 100755 index d756d3a128..0000000000 --- a/migration/20190515-reformat-existing-geographic-data.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python - -from sqlalchemy.sql import * -import json -import os -import sys -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) -from core.model import ( - Library, - production_session -) -from core.model.configuration import ConfigurationSetting -_db = production_session() -area_settings = _db.query(ConfigurationSetting).filter( - or_(ConfigurationSetting.key == "service_area", ConfigurationSetting.key == "focus_area")).filter( - ConfigurationSetting._value != None - ).filter( - ConfigurationSetting._value != "" - ).all() -def format(value): - result = [] - try: - value = json.loads(value) - if type(value) is list: - for x in value: - result += format(x) - elif type(value) is not dict: - result += json.loads(value) - except: - result.append(value) - return result -def fix(value): - result = format(value) - formatted_info = None - if result: - formatted_info = json.dumps({"US": result}) - return formatted_info -expect = json.dumps({"US": ["Waterford, CT"]}) -assert fix("Waterford, CT") == expect -assert fix(json.dumps("Waterford, CT")) == expect -assert fix(json.dumps(["Waterford, CT"])) == expect -# If the value is already in the correct format, fix() shouldn't return anything; -# there's no need to update the setting. -assert fix(expect) == None -for setting in area_settings: - library = _db.query(Library).filter(Library.id == setting.library_id).first() - formatted_info = fix(setting._value) - if formatted_info: - print "Changing %r to %s" % (setting._value, formatted_info) - ConfigurationSetting.for_library_and_externalintegration(_db, setting.key, library, None).value = formatted_info - else: - print "Leaving %s alone" % (setting._value) - -_db.commit() -_db.close() diff --git a/migration/20201112-add-incommon-saml-federation-metadata.py b/migration/20201112-add-incommon-saml-federation-metadata.py index 9b820dada4..b7064bcf84 100755 --- a/migration/20201112-add-incommon-saml-federation-metadata.py +++ b/migration/20201112-add-incommon-saml-federation-metadata.py @@ -10,21 +10,22 @@ package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -from core.model import ( - production_session -) from api.saml.metadata.federations import incommon from api.saml.metadata.federations.model import SAMLFederation +from core.model import production_session with closing(production_session()) as db: - incommon_federation = db.query(SAMLFederation).filter( - SAMLFederation.type == incommon.FEDERATION_TYPE).one_or_none() + incommon_federation = ( + db.query(SAMLFederation) + .filter(SAMLFederation.type == incommon.FEDERATION_TYPE) + .one_or_none() + ) if not incommon_federation: incommon_federation = SAMLFederation( incommon.FEDERATION_TYPE, incommon.IDP_METADATA_SERVICE_URL, - incommon.CERTIFICATE + incommon.CERTIFICATE, ) db.add(incommon_federation) diff --git a/scripts.py b/scripts.py index 808bc82764..2308faa0e1 100644 --- a/scripts.py +++ b/scripts.py @@ -5,35 +5,21 @@ import os import sys import time +from datetime import datetime, timedelta +from enum import Enum from io import StringIO -from datetime import ( - datetime, - timedelta, -) -from enum import Enum -from sqlalchemy import ( - or_, -) +from sqlalchemy import or_ -from api.adobe_vendor_id import ( - AuthdataUtility, -) +from api.adobe_vendor_id import AuthdataUtility from api.authenticator import LibraryAuthenticator -from api.bibliotheca import ( - BibliothecaCirculationSweep -) -from api.config import ( - CannotLoadConfiguration, - Configuration, -) +from api.bibliotheca import BibliothecaCirculationSweep +from api.config import CannotLoadConfiguration, Configuration from api.controller import CirculationManager from api.lanes import create_default_lanes from api.local_analytics_exporter import LocalAnalyticsExporter from api.marc import LibraryAnnotator as MARCLibraryAnnotator -from api.novelist import ( - NoveListAPI -) +from api.novelist import NoveListAPI from api.nyt import NYTBestSellerAPI from api.odl import ( ODLImporter, @@ -47,26 +33,19 @@ OPDSForDistributorsImportMonitor, OPDSForDistributorsReaperMonitor, ) -from api.overdrive import ( - OverdriveAPI, -) +from api.overdrive import OverdriveAPI from core.entrypoint import EntryPoint from core.external_list import CustomListFromCSV from core.external_search import ExternalSearchIndex -from core.lane import Lane -from core.lane import ( - Pagination, - Facets, - FeaturedFacets, -) +from core.lane import Facets, FeaturedFacets, Lane, Pagination from core.marc import MARCExporter from core.metadata_layer import ( CirculationData, FormatData, - ReplacementPolicy, LinkData, + MARCExtractor, + ReplacementPolicy, ) -from core.metadata_layer import MARCExtractor from core.mirror import MirrorUploader from core.model import ( CachedMARCFile, @@ -78,8 +57,8 @@ DataSource, DeliveryMechanism, Edition, + EditionConstants, ExternalIntegration, - get_one, Hold, Hyperlink, Identifier, @@ -92,31 +71,25 @@ Subject, Timestamp, Work, - EditionConstants, + get_one, ) from core.model.configuration import ExternalIntegrationLink -from core.opds import ( - AcquisitionFeed, -) -from core.opds_import import ( - MetadataWranglerOPDSLookup, - OPDSImporter, -) -from core.scripts import OPDSImportScript, CollectionType +from core.opds import AcquisitionFeed +from core.opds_import import MetadataWranglerOPDSLookup, OPDSImporter from core.scripts import ( - Script as CoreScript, + CollectionType, DatabaseMigrationInitializationScript, IdentifierInputScript, LaneSweeperScript, LibraryInputScript, + OPDSImportScript, PatronInputScript, - TimestampScript, ) +from core.scripts import Script as CoreScript +from core.scripts import TimestampScript from core.util import LanguageCodes -from core.util.opds_writer import ( - OPDSFeed, -) from core.util.datetime_helpers import utc_now +from core.util.opds_writer import OPDSFeed class Script(CoreScript): @@ -124,14 +97,15 @@ def load_config(self): if not Configuration.instance: Configuration.load(self._db) + class CreateWorksForIdentifiersScript(Script): """Do the bare minimum to associate each Identifier with an Edition with title and author, so that we can calculate a permanent work ID. """ - to_check = [Identifier.OVERDRIVE_ID, Identifier.THREEM_ID, - Identifier.GUTENBERG_ID] + + to_check = [Identifier.OVERDRIVE_ID, Identifier.THREEM_ID, Identifier.GUTENBERG_ID] BATCH_SIZE = 100 name = "Create works for identifiers" @@ -153,18 +127,25 @@ def run(self): Edition.title == None, Edition.sort_author == None, ) - edition_missing_title_or_author = self._db.query(Identifier).join( - Identifier.primarily_identifies).filter( - either_title_or_author_missing) + edition_missing_title_or_author = ( + self._db.query(Identifier) + .join(Identifier.primarily_identifies) + .filter(either_title_or_author_missing) + ) - no_edition = self._db.query(Identifier).filter( - Identifier.primarily_identifies==None).filter( - Identifier.type.in_(self.to_check)) + no_edition = ( + self._db.query(Identifier) + .filter(Identifier.primarily_identifies == None) + .filter(Identifier.type.in_(self.to_check)) + ) for q, descr in ( - (edition_missing_title_or_author, - "identifiers whose edition is missing title or author"), - (no_edition, "identifiers with no edition")): + ( + edition_missing_title_or_author, + "identifiers whose edition is missing title or author", + ), + (no_edition, "identifiers with no edition"), + ): batch = [] self.log.debug("Trying to fix %d %s", q.count(), descr) for i in q: @@ -179,18 +160,20 @@ def process_batch(self, batch): if response.status_code != 200: raise Exception(response.text) - content_type = response.headers['content-type'] + content_type = response.headers["content-type"] if content_type != OPDSFeed.ACQUISITION_FEED_TYPE: raise Exception("Wrong media type: %s" % content_type) importer = OPDSImporter( - self._db, response.text, - overwrite_rels=[Hyperlink.DESCRIPTION, Hyperlink.IMAGE]) + self._db, + response.text, + overwrite_rels=[Hyperlink.DESCRIPTION, Hyperlink.IMAGE], + ) imported, messages_by_id = importer.import_from_feed() - self.log.info("%d successes, %d failures.", - len(imported), len(messages_by_id)) + self.log.info("%d successes, %d failures.", len(imported), len(messages_by_id)) self._db.commit() + class MetadataCalculationScript(Script): """Force calculate_presentation() to be called on some set of Editions. @@ -221,8 +204,12 @@ def run(self): def checkpoint(): self._db.commit() - self.log.info("%d successes, %d failures, %d new works.", - success, failure, also_created_work) + self.log.info( + "%d successes, %d failures, %d new works.", + success, + failure, + also_created_work, + ) i = 0 for edition in q: @@ -230,7 +217,8 @@ def checkpoint(): if edition.sort_author: success += 1 work, is_new = edition.license_pool.calculate_work( - search_index_client=search_index_client) + search_index_client=search_index_client + ) if work: work.calculate_presentation() if is_new: @@ -242,6 +230,7 @@ def checkpoint(): checkpoint() checkpoint() + class FillInAuthorScript(MetadataCalculationScript): """Fill in Edition.sort_author for Editions that have a list of Contributors, but no .sort_author. @@ -253,9 +242,13 @@ class FillInAuthorScript(MetadataCalculationScript): name = "Fill in missing authors" def q(self): - return self._db.query(Edition).join( - Edition.contributions).join(Contribution.contributor).filter( - Edition.sort_author==None) + return ( + self._db.query(Edition) + .join(Edition.contributions) + .join(Contribution.contributor) + .filter(Edition.sort_author == None) + ) + class UpdateStaffPicksScript(Script): @@ -264,17 +257,16 @@ class UpdateStaffPicksScript(Script): def run(self): inp = self.open() tag_fields = { - 'tags': Subject.NYPL_APPEAL, + "tags": Subject.NYPL_APPEAL, } integ = Configuration.integration(Configuration.STAFF_PICKS_INTEGRATION) fields = integ.get(Configuration.LIST_FIELDS, {}) importer = CustomListFromCSV( - DataSource.LIBRARY_STAFF, CustomList.STAFF_PICKS_NAME, - **fields + DataSource.LIBRARY_STAFF, CustomList.STAFF_PICKS_NAME, **fields ) - reader = csv.DictReader(inp, dialect='excel-tab') + reader = csv.DictReader(inp, dialect="excel-tab") importer.to_customlist(self._db, reader) self._db.commit() @@ -282,21 +274,21 @@ def open(self): if len(sys.argv) > 1: return open(sys.argv[1]) - url = Configuration.integration_url( - Configuration.STAFF_PICKS_INTEGRATION, True - ) - if not url.startswith('https://') or url.startswith('http://'): + url = Configuration.integration_url(Configuration.STAFF_PICKS_INTEGRATION, True) + if not url.startswith("https://") or url.startswith("http://"): url = self.DEFAULT_URL_TEMPLATE % url self.log.info("Retrieving %s", url) representation, cached = Representation.get( - self._db, url, do_get=Representation.browser_http_get, - accept="text/csv", max_age=timedelta(days=1)) + self._db, + url, + do_get=Representation.browser_http_get, + accept="text/csv", + max_age=timedelta(days=1), + ) if representation.status_code != 200: - raise ValueError("Unexpected status code %s" % - representation.status_code) + raise ValueError("Unexpected status code %s" % representation.status_code) if not representation.media_type.startswith("text/csv"): - raise ValueError("Unexpected media type %s" % - representation.media_type) + raise ValueError("Unexpected media type %s" % representation.media_type) return StringIO(representation.content) @@ -308,26 +300,27 @@ class CacheRepresentationPerLane(TimestampScript, LaneSweeperScript): def arg_parser(cls, _db): parser = LaneSweeperScript.arg_parser(_db) parser.add_argument( - '--language', - help='Process only lanes that include books in this language.', - action='append' + "--language", + help="Process only lanes that include books in this language.", + action="append", ) parser.add_argument( - '--max-depth', - help='Stop processing lanes once you reach this depth.', + "--max-depth", + help="Stop processing lanes once you reach this depth.", type=int, - default=None + default=None, ) parser.add_argument( - '--min-depth', - help='Start processing lanes once you reach this depth.', + "--min-depth", + help="Start processing lanes once you reach this depth.", type=int, - default=1 + default=1, ) return parser - def __init__(self, _db=None, cmd_args=None, testing=False, manager=None, - *args, **kwargs): + def __init__( + self, _db=None, cmd_args=None, testing=False, manager=None, *args, **kwargs + ): """Constructor. :param _db: A database connection. :param cmd_args: A mock set of command-line arguments, to use instead @@ -347,9 +340,12 @@ def __init__(self, _db=None, cmd_args=None, testing=False, manager=None, if not manager: manager = CirculationManager(self._db, testing=testing) from api.app import app + app.manager = manager self.app = app - self.base_url = ConfigurationSetting.sitewide(self._db, Configuration.BASE_URL_KEY).value + self.base_url = ConfigurationSetting.sitewide( + self._db, Configuration.BASE_URL_KEY + ).value def parse_args(self, cmd_args=None): parser = self.arg_parser(self._db) @@ -417,7 +413,7 @@ def process_library(self, library): ctx.pop() end = time.time() self.log.info( - "Processed library %s in %.2fsec", library.short_name, end-begin + "Processed library %s in %.2fsec", library.short_name, end - begin ) def process_lane(self, lane): @@ -434,8 +430,7 @@ def process_lane(self, lane): if pagination: extra_description += " Pagination: %s." % pagination.query_string self.log.info( - "Generating feed for %s.%s", lane.full_identifier, - extra_description + "Generating feed for %s.%s", lane.full_identifier, extra_description ) a = time.time() feed = self.do_generate(lane, facets, pagination) @@ -443,8 +438,7 @@ def process_lane(self, lane): if feed: cached_feeds.append(feed) self.log.info( - "Took %.2f sec to make %d bytes.", (b-a), - len(feed.data) + "Took %.2f sec to make %d bytes.", (b - a), len(feed.data) ) total_size = sum(len(x.data) for x in cached_feeds) return cached_feeds @@ -477,55 +471,58 @@ class CacheFacetListsPerLane(CacheRepresentationPerLane): def arg_parser(cls, _db): parser = CacheRepresentationPerLane.arg_parser(_db) available = Facets.DEFAULT_ENABLED_FACETS[Facets.ORDER_FACET_GROUP_NAME] - order_help = 'Generate feeds for this ordering. Possible values: %s.' % ( + order_help = "Generate feeds for this ordering. Possible values: %s." % ( ", ".join(available) ) parser.add_argument( - '--order', + "--order", help=order_help, - action='append', + action="append", default=[], ) available = Facets.DEFAULT_ENABLED_FACETS[Facets.AVAILABILITY_FACET_GROUP_NAME] - availability_help = 'Generate feeds for this availability setting. Possible values: %s.' % ( - ", ".join(available) + availability_help = ( + "Generate feeds for this availability setting. Possible values: %s." + % (", ".join(available)) ) parser.add_argument( - '--availability', + "--availability", help=availability_help, - action='append', + action="append", default=[], ) available = Facets.DEFAULT_ENABLED_FACETS[Facets.COLLECTION_FACET_GROUP_NAME] - collection_help = 'Generate feeds for this collection within each lane. Possible values: %s.' % ( - ", ".join(available) + collection_help = ( + "Generate feeds for this collection within each lane. Possible values: %s." + % (", ".join(available)) ) parser.add_argument( - '--collection', + "--collection", help=collection_help, - action='append', + action="append", default=[], ) available = [x.INTERNAL_NAME for x in EntryPoint.ENTRY_POINTS] - entrypoint_help = 'Generate feeds for this entry point within each lane. Possible values: %s.' % ( - ", ".join(available) + entrypoint_help = ( + "Generate feeds for this entry point within each lane. Possible values: %s." + % (", ".join(available)) ) parser.add_argument( - '--entrypoint', + "--entrypoint", help=entrypoint_help, - action='append', + action="append", default=[], ) default_pages = 2 parser.add_argument( - '--pages', + "--pages", help="Number of pages to cache for each facet. Default: %d" % default_pages, type=int, - default=default_pages + default=default_pages, ) return parser @@ -553,9 +550,7 @@ def facets(self, lane): allowed_orders = library.enabled_facets(Facets.ORDER_FACET_GROUP_NAME) chosen_orders = self.orders or [default_order] - allowed_entrypoint_names = [ - x.INTERNAL_NAME for x in library.entrypoints - ] + allowed_entrypoint_names = [x.INTERNAL_NAME for x in library.entrypoints] default_entrypoint_name = None if allowed_entrypoint_names: default_entrypoint_name = allowed_entrypoint_names[0] @@ -570,15 +565,11 @@ def facets(self, lane): ) chosen_availabilities = self.availabilities or [default_availability] - default_collection = library.default_facet( - Facets.COLLECTION_FACET_GROUP_NAME - ) - allowed_collections = library.enabled_facets( - Facets.COLLECTION_FACET_GROUP_NAME - ) + default_collection = library.default_facet(Facets.COLLECTION_FACET_GROUP_NAME) + allowed_collections = library.enabled_facets(Facets.COLLECTION_FACET_GROUP_NAME) chosen_collections = self.collections or [default_collection] - top_level = (lane.parent is None) + top_level = lane.parent is None for entrypoint_name in chosen_entrypoints: entrypoint = EntryPoint.BY_INTERNAL_NAME.get(entrypoint_name) if not entrypoint: @@ -593,21 +584,27 @@ def facets(self, lane): continue for availability in chosen_availabilities: if availability not in allowed_availabilities: - logging.warn("Ignoring unsupported availability %s" % availability) + logging.warn( + "Ignoring unsupported availability %s" % availability + ) continue for collection in chosen_collections: if collection not in allowed_collections: - logging.warn("Ignoring unsupported collection %s" % collection) + logging.warn( + "Ignoring unsupported collection %s" % collection + ) continue facets = Facets( - library=library, collection=collection, + library=library, + collection=collection, availability=availability, entrypoint=entrypoint, entrypoint_is_default=( - top_level and - entrypoint.INTERNAL_NAME == default_entrypoint_name + top_level + and entrypoint.INTERNAL_NAME == default_entrypoint_name ), - order=order, order_ascending=True + order=order, + order_ascending=True, ) yield facets @@ -630,9 +627,14 @@ def do_generate(self, lane, facets, pagination, feed_class=None): url = annotator.feed_url(lane, facets=facets, pagination=pagination) feed_class = feed_class or AcquisitionFeed return feed_class.page( - _db=self._db, title=title, url=url, worklist=lane, - annotator=annotator, facets=facets, pagination=pagination, - max_age=0 + _db=self._db, + title=title, + url=url, + worklist=lane, + annotator=annotator, + facets=facets, + pagination=pagination, + max_age=0, ) @@ -658,8 +660,13 @@ def do_generate(self, lane, facets, pagination, feed_class=None): # there's no need to consider the case of a lane with no sublanes, # unlike the corresponding code in OPDSFeedController.groups() return feed_class.groups( - _db=self._db, title=title, url=url, worklist=lane, - annotator=annotator, max_age=0, facets=facets + _db=self._db, + title=title, + url=url, + worklist=lane, + annotator=annotator, + max_age=0, + facets=facets, ) def facets(self, lane): @@ -668,7 +675,7 @@ def facets(self, lane): This is the only way grouped feeds are ever generated, so there is no way to override this. """ - top_level = (lane.parent is None) + top_level = lane.parent is None library = lane.get_library(self._db) # If the WorkList has explicitly defined EntryPoints, we want to @@ -687,12 +694,11 @@ def facets(self, lane): minimum_featured_quality=library.minimum_featured_quality, uses_customlists=lane.uses_customlists, entrypoint=entrypoint, - entrypoint_is_default=( - top_level and entrypoint is default_entrypoint - ) + entrypoint_is_default=(top_level and entrypoint is default_entrypoint), ) yield facets + class CacheMARCFiles(LaneSweeperScript): """Generate and cache MARC files for each input library.""" @@ -702,15 +708,16 @@ class CacheMARCFiles(LaneSweeperScript): def arg_parser(cls, _db): parser = LaneSweeperScript.arg_parser(_db) parser.add_argument( - '--max-depth', - help='Stop processing lanes once you reach this depth.', + "--max-depth", + help="Stop processing lanes once you reach this depth.", type=int, default=0, ) parser.add_argument( - '--force', + "--force", help="Generate new MARC files even if MARC files have already been generated recently enough", - dest='force', action='store_true', + dest="force", + action="store_true", ) return parser @@ -727,9 +734,12 @@ def parse_args(self, cmd_args=None): def should_process_library(self, library): integration = ExternalIntegration.lookup( - self._db, ExternalIntegration.MARC_EXPORT, - ExternalIntegration.CATALOG_GOAL, library) - return (integration is not None) + self._db, + ExternalIntegration.MARC_EXPORT, + ExternalIntegration.CATALOG_GOAL, + library, + ) + return integration is not None def process_library(self, library): if self.should_process_library(library): @@ -761,29 +771,40 @@ def process_lane(self, lane, exporter=None): update_frequency = MARCExporter.DEFAULT_UPDATE_FREQUENCY last_update = None - files_q = self._db.query(CachedMARCFile).filter( - CachedMARCFile.library==library - ).filter( - CachedMARCFile.lane==(lane if isinstance(lane, Lane) else None), - ).order_by(CachedMARCFile.end_time.desc()) + files_q = ( + self._db.query(CachedMARCFile) + .filter(CachedMARCFile.library == library) + .filter( + CachedMARCFile.lane == (lane if isinstance(lane, Lane) else None), + ) + .order_by(CachedMARCFile.end_time.desc()) + ) if files_q.count() > 0: last_update = files_q.first().end_time - if not self.force and last_update and (last_update > utc_now() - timedelta(days=update_frequency)): - self.log.info("Skipping lane %s because last update was less than %d days ago" % (lane.display_name, update_frequency)) + if ( + not self.force + and last_update + and (last_update > utc_now() - timedelta(days=update_frequency)) + ): + self.log.info( + "Skipping lane %s because last update was less than %d days ago" + % (lane.display_name, update_frequency) + ) return # To find the storage integration for the exporter, first find the # external integration link associated with the exporter's external # integration. integration_link = get_one( - self._db, ExternalIntegrationLink, + self._db, + ExternalIntegrationLink, external_integration_id=exporter.integration.id, - purpose=ExternalIntegrationLink.MARC + purpose=ExternalIntegrationLink.MARC, ) # Then use the "other" integration value to find the storage integration. - storage_integration = get_one(self._db, ExternalIntegration, - id=integration_link.other_integration_id + storage_integration = get_one( + self._db, ExternalIntegration, id=integration_link.other_integration_id ) if not storage_integration: @@ -791,9 +812,7 @@ def process_lane(self, lane, exporter=None): return # First update the file with ALL the records. - records = exporter.records( - lane, annotator, storage_integration - ) + records = exporter.records(lane, annotator, storage_integration) # Then create a new file with changes since the last update. start_time = None @@ -807,14 +826,13 @@ def process_lane(self, lane, exporter=None): class AdobeAccountIDResetScript(PatronInputScript): - @classmethod def arg_parser(cls, _db): parser = super(AdobeAccountIDResetScript, cls).arg_parser(_db) parser.add_argument( - '--delete', + "--delete", help="Actually delete credentials as opposed to showing what would happen.", - action='store_true' + action="store_true", ) return parser @@ -826,9 +844,7 @@ def do_run(self, *args, **kwargs): self.log.info( "This is a dry run. Nothing will actually change in the database." ) - self.log.info( - "Run with --delete to change the database." - ) + self.log.info("Run with --delete to change the database.") if patrons and self.delete: self.log.warn( @@ -837,7 +853,7 @@ def do_run(self, *args, **kwargs): They will be unable to fulfill any existing loans that involve Adobe-encrypted files. Sleeping for five seconds to give you a chance to back out. You'll get another chance to back out before the database session is committed.""", - len(patrons) + len(patrons), ) time.sleep(5) self.process_patrons(patrons) @@ -853,34 +869,34 @@ def process_patron(self, patron): """ self.log.info( 'Processing patron "%s"', - patron.authorization_identifier or patron.username - or patron.external_identifier + patron.authorization_identifier + or patron.username + or patron.external_identifier, ) for credential in AuthdataUtility.adobe_relevant_credentials(patron): self.log.info( - ' Deleting "%s" credential "%s"', - credential.type, credential.credential + ' Deleting "%s" credential "%s"', credential.type, credential.credential ) if self.delete: self._db.delete(credential) + class AvailabilityRefreshScript(IdentifierInputScript): """Refresh the availability information for a LicensePool, direct from the license source. """ + def do_run(self): args = self.parse_command_line(self._db) if not args.identifiers: - raise Exception( - "You must specify at least one identifier to refresh." - ) + raise Exception("You must specify at least one identifier to refresh.") # We don't know exactly how big to make these batches, but 10 is # always safe. start = 0 size = 10 while start < len(args.identifiers): - batch = args.identifiers[start:start+size] + batch = args.identifiers[start : start + size] self.refresh_availability(batch) self._db.commit() start += size @@ -888,14 +904,14 @@ def do_run(self): def refresh_availability(self, identifiers): provider = None identifier = identifiers[0] - if identifier.type==Identifier.THREEM_ID: + if identifier.type == Identifier.THREEM_ID: sweeper = BibliothecaCirculationSweep(self._db) sweeper.process_batch(identifiers) - elif identifier.type==Identifier.OVERDRIVE_ID: + elif identifier.type == Identifier.OVERDRIVE_ID: api = OverdriveAPI(self._db) for identifier in identifiers: api.update_licensepool(identifier.identifier) - elif identifier.type==Identifier.AXIS_360_ID: + elif identifier.type == Identifier.AXIS_360_ID: provider = Axis360BibliographicCoverageProvider(self._db) provider.process_batch(identifiers) else: @@ -990,7 +1006,9 @@ def run(self, *args, **kwargs): if results is None: super(InstanceInitializationScript, self).run(*args, **kwargs) else: - self.log.error("I think this site has already been initialized; doing nothing.") + self.log.error( + "I think this site has already been initialized; doing nothing." + ) def do_run(self, ignore_search=False): # Creates a "-current" alias on the Elasticsearch client. @@ -1004,8 +1022,10 @@ def do_run(self, ignore_search=False): # Set a timestamp that represents the new database's version. db_init_script = DatabaseMigrationInitializationScript(_db=self._db) existing = get_one( - self._db, Timestamp, service=db_init_script.name, - service_type=Timestamp.SCRIPT_TYPE + self._db, + Timestamp, + service=db_init_script.name, + service_type=Timestamp.SCRIPT_TYPE, ) if existing: # No need to run the script. We already have a timestamp. @@ -1033,26 +1053,26 @@ def do_run(self): now = utc_now() # Reap loans and holds that we know have expired. - for obj, what in ((Loan, 'loans'), (Hold, 'holds')): + for obj, what in ((Loan, "loans"), (Hold, "holds")): qu = self._db.query(obj).filter(obj.end < now) self._reap(qu, "expired %s" % what) for obj, what, max_age in ( - (Loan, 'loans', timedelta(days=90)), - (Hold, 'holds', timedelta(days=365)), + (Loan, "loans", timedelta(days=90)), + (Hold, "holds", timedelta(days=365)), ): # Reap loans and holds which have no end date and are very # old. It's very likely these loans and holds have expired # and we simply don't have the information. older_than = now - max_age - qu = self._db.query(obj).join(obj.license_pool).filter( - obj.end == None).filter( - obj.start < older_than).filter( - LicensePool.open_access == False - ) - explain = "%s older than %s" % ( - what, older_than.strftime("%Y-%m-%d") + qu = ( + self._db.query(obj) + .join(obj.license_pool) + .filter(obj.end == None) + .filter(obj.start < older_than) + .filter(LicensePool.open_access == False) ) + explain = "%s older than %s" % (what, older_than.strftime("%Y-%m-%d")) self._reap(qu, explain) def _reap(self, qu, what): @@ -1080,20 +1100,23 @@ class DisappearingBookReportScript(Script): """ def do_run(self): - qu = self._db.query(LicensePool).filter( - LicensePool.open_access==False).filter( - LicensePool.suppressed==False).filter( - LicensePool.licenses_owned<=0).order_by( - LicensePool.availability_time.desc()) - first_row = ["Identifier", - "Title", - "Author", - "First seen", - "Last seen (best guess)", - "Current licenses owned", - "Current licenses available", - "Changes in number of licenses", - "Changes in title availability", + qu = ( + self._db.query(LicensePool) + .filter(LicensePool.open_access == False) + .filter(LicensePool.suppressed == False) + .filter(LicensePool.licenses_owned <= 0) + .order_by(LicensePool.availability_time.desc()) + ) + first_row = [ + "Identifier", + "Title", + "Author", + "First seen", + "Last seen (best guess)", + "Current licenses owned", + "Current licenses available", + "Changes in number of licenses", + "Changes in title availability", ] print("\t".join(first_row)) @@ -1139,12 +1162,13 @@ def investigate(self, licensepool): # Now we look for relevant circulation events. First, an event # where the title was explicitly removed is pretty clearly # a 'last seen'. - base_query = self._db.query(CirculationEvent).filter( - CirculationEvent.license_pool==licensepool).order_by( - CirculationEvent.start.desc() - ) + base_query = ( + self._db.query(CirculationEvent) + .filter(CirculationEvent.license_pool == licensepool) + .order_by(CirculationEvent.start.desc()) + ) title_removal_events = base_query.filter( - CirculationEvent.type==CirculationEvent.DISTRIBUTOR_TITLE_REMOVE + CirculationEvent.type == CirculationEvent.DISTRIBUTOR_TITLE_REMOVE ) if title_removal_events.count(): candidate = title_removal_events[-1].start @@ -1154,12 +1178,13 @@ def investigate(self, licensepool): # Also look for an event where the title went from a nonzero # number of licenses to a zero number of licenses. That's a # good 'last seen'. - license_removal_events = base_query.filter( - CirculationEvent.type==CirculationEvent.DISTRIBUTOR_LICENSE_REMOVE, - ).filter( - CirculationEvent.old_value>0).filter( - CirculationEvent.new_value<=0 + license_removal_events = ( + base_query.filter( + CirculationEvent.type == CirculationEvent.DISTRIBUTOR_LICENSE_REMOVE, ) + .filter(CirculationEvent.old_value > 0) + .filter(CirculationEvent.new_value <= 0) + ) if license_removal_events.count(): candidate = license_removal_events[-1].start if not last_seen or candidate > last_seen: @@ -1182,27 +1207,29 @@ def explain(self, licensepool): if licensepool.availability_time: first_seen = licensepool.availability_time.strftime(self.format) else: - first_seen = '' + first_seen = "" data.append(first_seen) if last_seen: last_seen = last_seen.strftime(self.format) else: - last_seen = '' + last_seen = "" data.append(last_seen) data.append(licensepool.licenses_owned) data.append(licensepool.licenses_available) license_removals = [] for event in license_removal_events: - description ="%s: %s→%s" % ( - event.start.strftime(self.format), event.old_value, - event.new_value + description = "%s: %s→%s" % ( + event.start.strftime(self.format), + event.old_value, + event.new_value, ) license_removals.append(description) data.append(", ".join(license_removals)) - title_removals = [event.start.strftime(self.format) - for event in title_removal_events] + title_removals = [ + event.start.strftime(self.format) for event in title_removal_events + ] data.append(", ".join(title_removals)) print("\t".join([str(x).encode("utf8") for x in data])) @@ -1221,9 +1248,9 @@ def do_run(self): self.data_source = DataSource.lookup(self._db, DataSource.NYT) # For every best-seller list... names = self.api.list_of_lists() - for l in sorted(names['results'], key=lambda x: x['list_name_encoded']): + for l in sorted(names["results"], key=lambda x: x["list_name_encoded"]): - name = l['list_name_encoded'] + name = l["list_name_encoded"] self.log.info("Handling list %s" % name) best = self.api.best_seller_list(l) @@ -1234,10 +1261,10 @@ def do_run(self): # Mirror the list to the database. customlist = best.to_customlist(self._db) - self.log.info( - "Now %s entries in the list.", len(customlist.entries)) + self.log.info("Now %s entries in the list.", len(customlist.entries)) self._db.commit() + class OPDSForDistributorsImportScript(OPDSImportScript): """Import all books from the OPDS feed associated with a collection that requires authentication.""" @@ -1246,6 +1273,7 @@ class OPDSForDistributorsImportScript(OPDSImportScript): MONITOR_CLASS = OPDSForDistributorsImportMonitor PROTOCOL = OPDSForDistributorsImporter.NAME + class OPDSForDistributorsReaperScript(OPDSImportScript): """Get all books from the OPDS feed associated with a collection to find out if any have been removed.""" @@ -1266,59 +1294,62 @@ class DirectoryImportScript(TimestampScript): def arg_parser(cls, _db): parser = argparse.ArgumentParser() parser.add_argument( - '--collection-name', - help='Titles will be imported into a collection with this name. The collection will be created if it does not already exist.', - required=True + "--collection-name", + help="Titles will be imported into a collection with this name. The collection will be created if it does not already exist.", + required=True, ) parser.add_argument( - '--collection-type', - help='Collection type. Valid values are: OPEN_ACCESS (default), PROTECTED_ACCESS, LCP.', + "--collection-type", + help="Collection type. Valid values are: OPEN_ACCESS (default), PROTECTED_ACCESS, LCP.", type=CollectionType, choices=list(CollectionType), - default=CollectionType.OPEN_ACCESS + default=CollectionType.OPEN_ACCESS, ) parser.add_argument( - '--data-source-name', - help='All data associated with this import activity will be recorded as originating with this data source. The data source will be created if it does not already exist.', - required=True + "--data-source-name", + help="All data associated with this import activity will be recorded as originating with this data source. The data source will be created if it does not already exist.", + required=True, ) parser.add_argument( - '--metadata-file', - help='Path to a file containing MARC or ONIX 3.0 metadata for every title in the collection', - required=True + "--metadata-file", + help="Path to a file containing MARC or ONIX 3.0 metadata for every title in the collection", + required=True, ) parser.add_argument( - '--metadata-format', + "--metadata-format", help='Format of the metadata file ("marc" or "onix")', - default='marc', + default="marc", ) parser.add_argument( - '--cover-directory', - help='Directory containing a full-size cover image for every title in the collection.', + "--cover-directory", + help="Directory containing a full-size cover image for every title in the collection.", ) parser.add_argument( - '--ebook-directory', - help='Directory containing an EPUB or PDF file for every title in the collection.', - required=True + "--ebook-directory", + help="Directory containing an EPUB or PDF file for every title in the collection.", + required=True, ) RS = RightsStatus rights_uris = ", ".join(RS.OPEN_ACCESS) parser.add_argument( - '--rights-uri', - help="A URI explaining the rights status of the works being uploaded. Acceptable values: %s" % rights_uris, - required=True + "--rights-uri", + help="A URI explaining the rights status of the works being uploaded. Acceptable values: %s" + % rights_uris, + required=True, ) parser.add_argument( - '--dry-run', + "--dry-run", help="Show what would be imported, but don't actually do the import.", - action='store_true', + action="store_true", ) parser.add_argument( - '--default-medium-type', - help='Default medium type used in the case when it\'s not explicitly specified in a metadata file. ' - 'Valid values are: {0}.'.format(', '.join(EditionConstants.FULFILLABLE_MEDIA)), + "--default-medium-type", + help="Default medium type used in the case when it's not explicitly specified in a metadata file. " + "Valid values are: {0}.".format( + ", ".join(EditionConstants.FULFILLABLE_MEDIA) + ), type=str, - choices=EditionConstants.FULFILLABLE_MEDIA + choices=EditionConstants.FULFILLABLE_MEDIA, ) return parser @@ -1347,28 +1378,30 @@ def do_run(self, cmd_args=None): ebook_directory=ebook_directory, rights_uri=rights_uri, dry_run=dry_run, - default_medium_type=default_medium_type + default_medium_type=default_medium_type, ) def run_with_arguments( - self, - collection_name, - collection_type, - data_source_name, - metadata_file, - metadata_format, - cover_directory, - ebook_directory, - rights_uri, - dry_run, - default_medium_type=None + self, + collection_name, + collection_type, + data_source_name, + metadata_file, + metadata_format, + cover_directory, + ebook_directory, + rights_uri, + dry_run, + default_medium_type=None, ): if dry_run: self.log.warn( "This is a dry run. No files will be uploaded and nothing will change in the database." ) - collection, mirrors = self.load_collection(collection_name, collection_type, data_source_name) + collection, mirrors = self.load_collection( + collection_name, collection_type, data_source_name + ) if not collection or not mirrors: return @@ -1378,10 +1411,15 @@ def run_with_arguments( if dry_run: mirrors = None - self_hosted_collection = collection_type in (CollectionType.OPEN_ACCESS, CollectionType.PROTECTED_ACCESS) + self_hosted_collection = collection_type in ( + CollectionType.OPEN_ACCESS, + CollectionType.PROTECTED_ACCESS, + ) replacement_policy = ReplacementPolicy.from_license_source(self._db) replacement_policy.mirrors = mirrors - metadata_records = self.load_metadata(metadata_file, metadata_format, data_source_name, default_medium_type) + metadata_records = self.load_metadata( + metadata_file, metadata_format, data_source_name, default_medium_type + ) for metadata in metadata_records: _, licensepool = self.work_from_metadata( collection, @@ -1390,7 +1428,7 @@ def run_with_arguments( replacement_policy, cover_directory, ebook_directory, - rights_uri + rights_uri, ) licensepool.self_hosted = True if self_hosted_collection else False @@ -1420,7 +1458,11 @@ def load_collection(self, collection_name, collection_type, data_source_name): :rtype: Tuple[Collection, List[MirrorUploader]] """ collection, is_new = Collection.by_name_and_protocol( - self._db, collection_name, ExternalIntegration.LCP if collection_type == CollectionType.LCP else ExternalIntegration.MANUAL + self._db, + collection_name, + ExternalIntegration.LCP + if collection_type == CollectionType.LCP + else ExternalIntegration.MANUAL, ) if is_new: @@ -1435,20 +1477,20 @@ def load_collection(self, collection_name, collection_type, data_source_name): ExternalIntegrationLink.COVERS, ExternalIntegrationLink.OPEN_ACCESS_BOOKS if collection_type == CollectionType.OPEN_ACCESS - else ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS + else ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS, ] for type in types: mirror_for_type = MirrorUploader.for_collection(collection, type) if not mirror_for_type: self.log.error( - "An existing %s mirror integration should be assigned to the collection before running the script." % type + "An existing %s mirror integration should be assigned to the collection before running the script." + % type ) return None, None mirrors[type] = mirror_for_type data_source = DataSource.lookup( - self._db, data_source_name, autocreate=True, - offers_licenses=True + self._db, data_source_name, autocreate=True, offers_licenses=True ) collection.external_integration.set_setting( Collection.DATA_SOURCE_NAME_SETTING, data_source.name @@ -1456,20 +1498,26 @@ def load_collection(self, collection_name, collection_type, data_source_name): return collection, mirrors - def load_metadata(self, metadata_file, metadata_format, data_source_name, default_medium_type): + def load_metadata( + self, metadata_file, metadata_format, data_source_name, default_medium_type + ): """Read a metadata file and convert the data into Metadata records.""" metadata_records = [] - if metadata_format == 'marc': + if metadata_format == "marc": extractor = MARCExtractor() - elif metadata_format == 'onix': + elif metadata_format == "onix": extractor = ONIXExtractor() with open(metadata_file) as f: - metadata_records.extend(extractor.parse(f, data_source_name, default_medium_type)) + metadata_records.extend( + extractor.parse(f, data_source_name, default_medium_type) + ) return metadata_records - def work_from_metadata(self, collection, collection_type, metadata, policy, *args, **kwargs): + def work_from_metadata( + self, collection, collection_type, metadata, policy, *args, **kwargs + ): """Creates a Work instance from metadata :param collection: Target collection @@ -1496,8 +1544,7 @@ def work_from_metadata(self, collection, collection_type, metadata, policy, *arg edition, new = metadata.edition(self._db) metadata.apply(edition, collection, replace=policy) - [pool] = [x for x in edition.license_pools - if x.collection == collection] + [pool] = [x for x in edition.license_pools if x.collection == collection] if new: self.log.info("Created new edition for %s", edition.title) else: @@ -1512,13 +1559,14 @@ def work_from_metadata(self, collection, collection_type, metadata, policy, *arg return work, pool def annotate_metadata( - self, - collection_type, - metadata, - policy, - cover_directory, - ebook_directory, - rights_uri): + self, + collection_type, + metadata, + policy, + cover_directory, + ebook_directory, + rights_uri, + ): """Add a CirculationData and possibly an extra LinkData to `metadata` :param collection_type: Collection's type: open access/protected access @@ -1550,7 +1598,7 @@ def annotate_metadata( ebook_directory, mirrors, metadata.title, - rights_uri + rights_uri, ) if not circulation_data: # There is no point in contining. @@ -1558,9 +1606,13 @@ def annotate_metadata( if metadata.circulation: circulation_data.licenses_owned = metadata.circulation.licenses_owned - circulation_data.licenses_available = metadata.circulation.licenses_available + circulation_data.licenses_available = ( + metadata.circulation.licenses_available + ) circulation_data.licenses_reserved = metadata.circulation.licenses_reserved - circulation_data.patrons_in_hold_queue = metadata.circulation.patrons_in_hold_queue + circulation_data.patrons_in_hold_queue = ( + metadata.circulation.patrons_in_hold_queue + ) circulation_data.licenses = metadata.circulation.licenses metadata.circulation = circulation_data @@ -1576,19 +1628,19 @@ def annotate_metadata( metadata.links.append(cover_link) else: logging.info( - "Proceeding with import even though %r has no cover.", - identifier + "Proceeding with import even though %r has no cover.", identifier ) def load_circulation_data( - self, - collection_type, - identifier, - data_source, - ebook_directory, - mirrors, - title, - rights_uri): + self, + collection_type, + identifier, + data_source, + ebook_directory, + mirrors, + title, + rights_uri, + ): """Loads an actual copy of a book from disk :param collection_type: Collection's type: open access/protected access @@ -1617,7 +1669,8 @@ def load_circulation_data( :rtype: CirculationData """ ignore, book_media_type, book_content = self._locate_file( - identifier.identifier, ebook_directory, + identifier.identifier, + ebook_directory, Representation.COMMON_EBOOK_EXTENSIONS, "ebook file", ) @@ -1626,39 +1679,50 @@ def load_circulation_data( # no point in proceeding. return - book_mirror = mirrors[ - ExternalIntegrationLink.OPEN_ACCESS_BOOKS - if collection_type == CollectionType.OPEN_ACCESS - else ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS - ] if mirrors else None + book_mirror = ( + mirrors[ + ExternalIntegrationLink.OPEN_ACCESS_BOOKS + if collection_type == CollectionType.OPEN_ACCESS + else ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS + ] + if mirrors + else None + ) # Use the S3 storage for books. if book_mirror: book_url = book_mirror.book_url( identifier, - '.' + Representation.FILE_EXTENSIONS[book_media_type], + "." + Representation.FILE_EXTENSIONS[book_media_type], open_access=collection_type == CollectionType.OPEN_ACCESS, data_source=data_source, - title=title + title=title, ) else: # This is a dry run and we won't be mirroring anything. - book_url = identifier.identifier + "." + Representation.FILE_EXTENSIONS[book_media_type] + book_url = ( + identifier.identifier + + "." + + Representation.FILE_EXTENSIONS[book_media_type] + ) - book_link_rel = \ - Hyperlink.OPEN_ACCESS_DOWNLOAD \ - if collection_type == CollectionType.OPEN_ACCESS \ + book_link_rel = ( + Hyperlink.OPEN_ACCESS_DOWNLOAD + if collection_type == CollectionType.OPEN_ACCESS else Hyperlink.GENERIC_OPDS_ACQUISITION + ) book_link = LinkData( rel=book_link_rel, href=book_url, media_type=book_media_type, - content=book_content + content=book_content, ) formats = [ FormatData( content_type=book_media_type, - drm_scheme=DeliveryMechanism.LCP_DRM if collection_type == CollectionType.LCP else DeliveryMechanism.NO_DRM, + drm_scheme=DeliveryMechanism.LCP_DRM + if collection_type == CollectionType.LCP + else DeliveryMechanism.NO_DRM, link=book_link, ) ] @@ -1673,20 +1737,23 @@ def load_circulation_data( def load_cover_link(self, identifier, data_source, cover_directory, mirrors): """Load an actual book cover from disk. - + :return: A LinkData containing a cover of the book, or None if no book cover can be found. """ cover_filename, cover_media_type, cover_content = self._locate_file( - identifier.identifier, cover_directory, - Representation.COMMON_IMAGE_EXTENSIONS, "cover image" + identifier.identifier, + cover_directory, + Representation.COMMON_IMAGE_EXTENSIONS, + "cover image", ) if not cover_content: return None cover_filename = ( identifier.identifier - + '.' + Representation.FILE_EXTENSIONS[cover_media_type] + + "." + + Representation.FILE_EXTENSIONS[cover_media_type] ) # Use an S3 storage mirror for specifically for covers. @@ -1707,8 +1774,14 @@ def load_cover_link(self, identifier, data_source, cover_directory, mirrors): return cover_link @classmethod - def _locate_file(cls, base_filename, directory, extensions, - file_type="file", mock_filesystem_operations=None): + def _locate_file( + cls, + base_filename, + directory, + extensions, + file_type="file", + mock_filesystem_operations=None, + ): """Find an acceptable file in the given directory. :param base_filename: A string to be used as the base of the filename. @@ -1740,8 +1813,8 @@ def _locate_file(cls, base_filename, directory, extensions, attempts = [] for extension in extensions: for ext in (extension, extension.upper()): - if not ext.startswith('.'): - ext = '.' + ext + if not ext.startswith("."): + ext = "." + ext filename = base_filename + ext path = os.path.join(directory, filename) attempts.append(path) @@ -1758,7 +1831,9 @@ def _locate_file(cls, base_filename, directory, extensions, # we have failed. logging.warn( "Could not find %s for %s. Looked in: %s", - file_type, base_filename, ", ".join(attempts) + file_type, + base_filename, + ", ".join(attempts), ) return None, None, None @@ -1771,9 +1846,9 @@ class LaneResetScript(LibraryInputScript): def arg_parser(cls, _db): parser = LibraryInputScript.arg_parser(_db) parser.add_argument( - '--reset', + "--reset", help="Actually reset the lanes as opposed to showing what would happen.", - action='store_true' + action="store_true", ) return parser @@ -1785,9 +1860,7 @@ def do_run(self, output=sys.stdout, **kwargs): self.log.info( "This is a dry run. Nothing will actually change in the database." ) - self.log.info( - "Run with --reset to change the database." - ) + self.log.info("Run with --reset to change the database.") if libraries and self.reset: self.log.warn( @@ -1796,7 +1869,7 @@ def do_run(self, output=sys.stdout, **kwargs): custom lists will be deleted (though the lists themselves will be preserved). Sleeping for five seconds to give you a chance to back out. You'll get another chance to back out before the database session is committed.""", - len(libraries) + len(libraries), ) time.sleep(5) self.process_libraries(libraries) @@ -1806,10 +1879,20 @@ def do_run(self, output=sys.stdout, **kwargs): new_lane_output += "\n\nLibrary '%s':\n" % library.name def print_lanes_for_parent(parent): - lanes = self._db.query(Lane).filter(Lane.library==library).filter(Lane.parent==parent).order_by(Lane.priority) + lanes = ( + self._db.query(Lane) + .filter(Lane.library == library) + .filter(Lane.parent == parent) + .order_by(Lane.priority) + ) lane_output = "" for lane in lanes: - lane_output += " " + (" " * len(list(lane.parentage))) + lane.display_name + "\n" + lane_output += ( + " " + + (" " * len(list(lane.parentage))) + + lane.display_name + + "\n" + ) lane_output += print_lanes_for_parent(lane) return lane_output @@ -1825,8 +1908,8 @@ def print_lanes_for_parent(parent): def process_library(self, library): create_default_lanes(self._db, library) -class NovelistSnapshotScript(TimestampScript, LibraryInputScript): +class NovelistSnapshotScript(TimestampScript, LibraryInputScript): def do_run(self, output=sys.stdout, *args, **kwargs): parsed = self.parse_command_line(self._db, *args, **kwargs) for library in parsed.libraries: @@ -1835,15 +1918,16 @@ def do_run(self, output=sys.stdout, *args, **kwargs): except CannotLoadConfiguration as e: self.log.info(str(e)) continue - if (api): + if api: response = api.put_items_novelist(library) - if (response): + if response: result = "NoveList API Response\n" result += str(response) output.write(result) + class ODLImportScript(OPDSImportScript): """Import information from the feed associated with an ODL collection.""" @@ -1852,11 +1936,13 @@ class ODLImportScript(OPDSImportScript): MONITOR_CLASS = ODLImportMonitor PROTOCOL = ODLImporter.NAME + class SharedODLImportScript(OPDSImportScript): IMPORTER_CLASS = SharedODLImporter MONITOR_CLASS = SharedODLImportMonitor PROTOCOL = SharedODLImporter.NAME + class LocalAnalyticsExportScript(Script): """Export circulation events for a date range to a CSV file.""" @@ -1864,12 +1950,12 @@ class LocalAnalyticsExportScript(Script): def arg_parser(cls, _db): parser = argparse.ArgumentParser() parser.add_argument( - '--start', + "--start", help="Include circulation events that happened at or after this time.", required=True, ) parser.add_argument( - '--end', + "--end", help="Include circulation events that happened before this time.", required=True, ) @@ -1884,6 +1970,7 @@ def do_run(self, output=sys.stdout, cmd_args=None, exporter=None): exporter = exporter or LocalAnalyticsExporter() output.write(exporter.export(self._db, start, end)) + class GenerateShortTokenScript(LibraryInputScript): """ Generate a short client token of the specified duration that can be used for testing that @@ -1892,29 +1979,28 @@ class GenerateShortTokenScript(LibraryInputScript): @classmethod def arg_parser(cls, _db): - parser = super(GenerateShortTokenScript, cls).arg_parser(_db, multiple_libraries=False) + parser = super(GenerateShortTokenScript, cls).arg_parser( + _db, multiple_libraries=False + ) parser.add_argument( - '--barcode', + "--barcode", help="The patron barcode.", required=True, ) - parser.add_argument( - '--pin', - help="The patron pin." - ) + parser.add_argument("--pin", help="The patron pin.") group = parser.add_mutually_exclusive_group(required=True) group.add_argument( - '--days', + "--days", help="Token expiry in days.", type=int, ) group.add_argument( - '--hours', + "--hours", help="Token expiry in hours.", type=int, ) group.add_argument( - '--minutes', + "--minutes", help="Token expiry in minutes.", type=int, ) @@ -1933,24 +2019,36 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout, authdata=None): patron = get_one(_db, Patron, authorization_identifier=args.barcode) if patron is None: # Fall back to a full patron lookup - auth = LibraryAuthenticator.from_config(_db, args.libraries[0]).basic_auth_provider + auth = LibraryAuthenticator.from_config( + _db, args.libraries[0] + ).basic_auth_provider if auth is None: output.write("No methods to authenticate patron found!\n") sys.exit(-1) - patron = auth.authenticate(_db, credentials={'username': args.barcode, 'password': args.pin}) + patron = auth.authenticate( + _db, credentials={"username": args.barcode, "password": args.pin} + ) if not isinstance(patron, Patron): output.write("Patron not found {}!\n".format(args.barcode)) sys.exit(-1) authdata = authdata or AuthdataUtility.from_config(library, _db) if authdata is None: - output.write("Library not registered with library registry! Please register and try again.") + output.write( + "Library not registered with library registry! Please register and try again." + ) sys.exit(-1) patron_identifier = authdata._adobe_patron_identifier(patron) - expires = {k: v for (k, v) in vars(args).items() if k in ['days', 'hours', 'minutes'] and v is not None} - vendor_id, token = authdata.encode_short_client_token(patron_identifier, expires=expires) - username, password = token.rsplit('|', 1) + expires = { + k: v + for (k, v) in vars(args).items() + if k in ["days", "hours", "minutes"] and v is not None + } + vendor_id, token = authdata.encode_short_client_token( + patron_identifier, expires=expires + ) + username, password = token.rsplit("|", 1) output.write("Vendor ID: {}\n".format(vendor_id)) output.write("Token: {}\n".format(token)) diff --git a/tests/admin/controller/test_admin_auth_services.py b/tests/admin/controller/test_admin_auth_services.py index f47e635f7f..3eaede41a9 100644 --- a/tests/admin/controller/test_admin_auth_services.py +++ b/tests/admin/controller/test_admin_auth_services.py @@ -1,17 +1,19 @@ -import pytest +import json import flask -import json +import pytest from werkzeug.datastructures import MultiDict + from api.admin.exceptions import * from api.app import initialize_database from core.model import ( AdminRole, ConfigurationSetting, - create, ExternalIntegration, + create, get_one, ) + from .test_controller import SettingsControllerTest @@ -24,24 +26,30 @@ def setup_class(cls): def test_admin_auth_services_get_with_no_services(self): with self.request_context_with_admin("/"): - response = self.manager.admin_auth_services_controller.process_admin_auth_services() + response = ( + self.manager.admin_auth_services_controller.process_admin_auth_services() + ) assert response.get("admin_auth_services") == [] # All the protocols in ExternalIntegration.ADMIN_AUTH_PROTOCOLS # are supported by the admin interface. - assert (sorted([p.get("name") for p in response.get("protocols")]) == - sorted(ExternalIntegration.ADMIN_AUTH_PROTOCOLS)) + assert sorted([p.get("name") for p in response.get("protocols")]) == sorted( + ExternalIntegration.ADMIN_AUTH_PROTOCOLS + ) self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self._db.flush() - pytest.raises(AdminNotAuthorized, - self.manager.admin_auth_services_controller.process_admin_auth_services) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_auth_services_controller.process_admin_auth_services, + ) def test_admin_auth_services_get_with_google_oauth_service(self): auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) auth_service.url = "http://oauth.test" auth_service.username = "user" @@ -52,7 +60,9 @@ def test_admin_auth_services_get_with_google_oauth_service(self): ).value = json.dumps(["nypl.org"]) with self.request_context_with_admin("/"): - response = self.manager.admin_auth_services_controller.process_admin_auth_services() + response = ( + self.manager.admin_auth_services_controller.process_admin_auth_services() + ) [service] = response.get("admin_auth_services") assert auth_service.id == service.get("id") @@ -67,70 +77,106 @@ def test_admin_auth_services_get_with_google_oauth_service(self): def test_admin_auth_services_post_errors(self): with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", "Unknown"), - ]) - response = self.manager.admin_auth_services_controller.process_admin_auth_services() + flask.request.form = MultiDict( + [ + ("protocol", "Unknown"), + ] + ) + response = ( + self.manager.admin_auth_services_controller.process_admin_auth_services() + ) assert response == UNKNOWN_PROTOCOL with self.request_context_with_admin("/", method="POST"): flask.request.form = MultiDict([]) - response = self.manager.admin_auth_services_controller.process_admin_auth_services() + response = ( + self.manager.admin_auth_services_controller.process_admin_auth_services() + ) assert response == NO_PROTOCOL_FOR_NEW_SERVICE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", "1234"), - ]) - response = self.manager.admin_auth_services_controller.process_admin_auth_services() + flask.request.form = MultiDict( + [ + ("id", "1234"), + ] + ) + response = ( + self.manager.admin_auth_services_controller.process_admin_auth_services() + ) assert response == MISSING_SERVICE auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", str(auth_service.id)), - ]) - response = self.manager.admin_auth_services_controller.process_admin_auth_services() + flask.request.form = MultiDict( + [ + ("id", str(auth_service.id)), + ] + ) + response = ( + self.manager.admin_auth_services_controller.process_admin_auth_services() + ) assert response == CANNOT_CHANGE_PROTOCOL with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", "Google OAuth"), - ]) - response = self.manager.admin_auth_services_controller.process_admin_auth_services() + flask.request.form = MultiDict( + [ + ("protocol", "Google OAuth"), + ] + ) + response = ( + self.manager.admin_auth_services_controller.process_admin_auth_services() + ) assert response.uri == INCOMPLETE_CONFIGURATION.uri self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self._db.flush() with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "oauth"), - ("protocol", "Google OAuth"), - ("url", "url"), - ("username", "username"), - ("password", "password"), - ("domains", "nypl.org"), - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_auth_services_controller.process_admin_auth_services) + flask.request.form = MultiDict( + [ + ("name", "oauth"), + ("protocol", "Google OAuth"), + ("url", "url"), + ("username", "username"), + ("password", "password"), + ("domains", "nypl.org"), + ] + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_auth_services_controller.process_admin_auth_services, + ) def test_admin_auth_services_post_create(self): with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "oauth"), - ("protocol", "Google OAuth"), - ("url", "http://url2"), - ("username", "username"), - ("password", "password"), - ("libraries", json.dumps([{ "short_name": self._default_library.short_name, - "domains": ["nypl.org", "gmail.com"] }])), - ]) - response = self.manager.admin_auth_services_controller.process_admin_auth_services() + flask.request.form = MultiDict( + [ + ("name", "oauth"), + ("protocol", "Google OAuth"), + ("url", "http://url2"), + ("username", "username"), + ("password", "password"), + ( + "libraries", + json.dumps( + [ + { + "short_name": self._default_library.short_name, + "domains": ["nypl.org", "gmail.com"], + } + ] + ), + ), + ] + ) + response = ( + self.manager.admin_auth_services_controller.process_admin_auth_services() + ) assert response.status_code == 201 # The auth service was created and configured properly. @@ -151,29 +197,44 @@ def test_admin_auth_services_post_create(self): def test_admin_auth_services_post_google_oauth_edit(self): # The auth service exists. auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) auth_service.url = "url" auth_service.username = "user" auth_service.password = "pass" auth_service.libraries += [self._default_library] setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, "domains", self._default_library, auth_service) + self._db, "domains", self._default_library, auth_service + ) setting.value = json.dumps(["library1.org"]) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "oauth"), - ("protocol", "Google OAuth"), - ("url", "http://url2"), - ("username", "user2"), - ("password", "pass2"), - ("libraries", json.dumps([{ "short_name": self._default_library.short_name, - "domains": ["library2.org"] }])), - ]) - response = self.manager.admin_auth_services_controller.process_admin_auth_services() + flask.request.form = MultiDict( + [ + ("name", "oauth"), + ("protocol", "Google OAuth"), + ("url", "http://url2"), + ("username", "user2"), + ("password", "pass2"), + ( + "libraries", + json.dumps( + [ + { + "short_name": self._default_library.short_name, + "domains": ["library2.org"], + } + ] + ), + ), + ] + ) + response = ( + self.manager.admin_auth_services_controller.process_admin_auth_services() + ) assert response.status_code == 200 assert auth_service.protocol == response.get_data(as_text=True) @@ -185,9 +246,10 @@ def test_admin_auth_services_post_google_oauth_edit(self): def test_admin_auth_service_delete(self): auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) auth_service.url = "url" auth_service.username = "user" @@ -196,12 +258,16 @@ def test_admin_auth_service_delete(self): with self.request_context_with_admin("/", method="DELETE"): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) - pytest.raises(AdminNotAuthorized, - self.manager.admin_auth_services_controller.process_delete, - auth_service.protocol) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_auth_services_controller.process_delete, + auth_service.protocol, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) - response = self.manager.admin_auth_services_controller.process_delete(auth_service.protocol) + response = self.manager.admin_auth_services_controller.process_delete( + auth_service.protocol + ) assert response.status_code == 200 service = get_one(self._db, ExternalIntegration, id=auth_service.id) diff --git a/tests/admin/controller/test_analytics_services.py b/tests/admin/controller/test_analytics_services.py index 11e48dfcfd..2f2919a359 100644 --- a/tests/admin/controller/test_analytics_services.py +++ b/tests/admin/controller/test_analytics_services.py @@ -1,29 +1,33 @@ -import pytest +import json import flask -import json +import pytest from werkzeug.datastructures import MultiDict + from api.admin.exceptions import * from api.google_analytics_provider import GoogleAnalyticsProvider from core.local_analytics_provider import LocalAnalyticsProvider from core.model import ( AdminRole, ConfigurationSetting, - create, ExternalIntegration, - get_one, Library, + create, + get_one, ) + from .test_controller import SettingsControllerTest -class TestAnalyticsServices(SettingsControllerTest): +class TestAnalyticsServices(SettingsControllerTest): def test_analytics_services_get_with_one_default_service(self): with self.request_context_with_admin("/"): - response = self.manager.admin_analytics_services_controller.process_analytics_services() + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert len(response.get("analytics_services")) == 1 local_analytics = response.get("analytics_services")[0] - assert local_analytics.get("name") == LocalAnalyticsProvider.NAME; + assert local_analytics.get("name") == LocalAnalyticsProvider.NAME assert local_analytics.get("protocol") == LocalAnalyticsProvider.__module__ protocols = response.get("protocols") @@ -33,33 +37,42 @@ def test_analytics_services_get_with_one_default_service(self): def test_analytics_services_get_with_one_service(self): # Delete the local analytics service that gets created by default. local_analytics_default = get_one( - self._db, ExternalIntegration, - protocol=LocalAnalyticsProvider.__module__ + self._db, ExternalIntegration, protocol=LocalAnalyticsProvider.__module__ ) self._db.delete(local_analytics_default) ga_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=GoogleAnalyticsProvider.__module__, goal=ExternalIntegration.ANALYTICS_GOAL, ) ga_service.url = self._str with self.request_context_with_admin("/"): - response = self.manager.admin_analytics_services_controller.process_analytics_services() + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) [service] = response.get("analytics_services") assert ga_service.id == service.get("id") assert ga_service.protocol == service.get("protocol") - assert ga_service.url == service.get("settings").get(ExternalIntegration.URL) + assert ga_service.url == service.get("settings").get( + ExternalIntegration.URL + ) ga_service.libraries += [self._default_library] ConfigurationSetting.for_library_and_externalintegration( - self._db, GoogleAnalyticsProvider.TRACKING_ID, self._default_library, ga_service + self._db, + GoogleAnalyticsProvider.TRACKING_ID, + self._default_library, + ga_service, ).value = "trackingid" with self.request_context_with_admin("/"): - response = self.manager.admin_analytics_services_controller.process_analytics_services() + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) [service] = response.get("analytics_services") [library] = service.get("libraries") @@ -69,14 +82,17 @@ def test_analytics_services_get_with_one_service(self): self._db.delete(ga_service) local_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=LocalAnalyticsProvider.__module__, goal=ExternalIntegration.ANALYTICS_GOAL, ) local_service.libraries += [self._default_library] with self.request_context_with_admin("/"): - response = self.manager.admin_analytics_services_controller.process_analytics_services() + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) [local_analytics] = response.get("analytics_services") assert local_service.id == local_analytics.get("id") @@ -88,167 +104,240 @@ def test_analytics_services_get_with_one_service(self): def test_analytics_services_post_errors(self): with self.request_context_with_admin("/", method="POST"): flask.request.form = MultiDict([]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response == MISSING_ANALYTICS_NAME with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("protocol", "Unknown"), - ("url", "http://test"), - ]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("protocol", "Unknown"), + ("url", "http://test"), + ] + ) + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response == UNKNOWN_PROTOCOL with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("url", "http://test"), - ]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("url", "http://test"), + ] + ) + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response == NO_PROTOCOL_FOR_NEW_SERVICE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", "123"), - ("url", "http://test"), - ]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", "123"), + ("url", "http://test"), + ] + ) + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response.uri == MISSING_SERVICE.uri service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=GoogleAnalyticsProvider.__module__, goal=ExternalIntegration.ANALYTICS_GOAL, name="name", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", service.name), - ("protocol", GoogleAnalyticsProvider.__module__), - ("url", "http://test"), - ]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + flask.request.form = MultiDict( + [ + ("name", service.name), + ("protocol", GoogleAnalyticsProvider.__module__), + ("url", "http://test"), + ] + ) + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response == INTEGRATION_NAME_ALREADY_IN_USE service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=GoogleAnalyticsProvider.__module__, goal=ExternalIntegration.ANALYTICS_GOAL, ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", service.id), - ("protocol", "core.local_analytics_provider"), - ("url", "http://test"), - ]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", service.id), + ("protocol", "core.local_analytics_provider"), + ("url", "http://test"), + ] + ) + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response == CANNOT_CHANGE_PROTOCOL with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", service.id), - ("name", "analytics name"), - ("protocol", GoogleAnalyticsProvider.__module__), - ("url", None), - ]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + flask.request.form = MultiDict( + [ + ("id", service.id), + ("name", "analytics name"), + ("protocol", GoogleAnalyticsProvider.__module__), + ("url", None), + ] + ) + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response.uri == INCOMPLETE_CONFIGURATION.uri with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", service.id), - ("protocol", GoogleAnalyticsProvider.__module__), - ("name", "some other analytics name"), - (ExternalIntegration.URL, "http://test"), - ("libraries", json.dumps([{"short_name": "not-a-library"}])), - ]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + flask.request.form = MultiDict( + [ + ("id", service.id), + ("protocol", GoogleAnalyticsProvider.__module__), + ("name", "some other analytics name"), + (ExternalIntegration.URL, "http://test"), + ("libraries", json.dumps([{"short_name": "not-a-library"}])), + ] + ) + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response.uri == NO_SUCH_LIBRARY.uri library, ignore = create( - self._db, Library, name="Library", short_name="L", + self._db, + Library, + name="Library", + short_name="L", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", service.id), - ("protocol", GoogleAnalyticsProvider.__module__), - ("name", "some other name"), - (ExternalIntegration.URL, ""), - ("libraries", json.dumps([{"short_name": library.short_name}])), - ]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + flask.request.form = MultiDict( + [ + ("id", service.id), + ("protocol", GoogleAnalyticsProvider.__module__), + ("name", "some other name"), + (ExternalIntegration.URL, ""), + ("libraries", json.dumps([{"short_name": library.short_name}])), + ] + ) + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response.uri == INCOMPLETE_CONFIGURATION.uri self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self.admin.remove_role(AdminRole.LIBRARY_MANAGER) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", LocalAnalyticsProvider.__module__), - (ExternalIntegration.URL, "url"), - ("libraries", json.dumps([])), - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_analytics_services_controller.process_analytics_services) + flask.request.form = MultiDict( + [ + ("protocol", LocalAnalyticsProvider.__module__), + (ExternalIntegration.URL, "url"), + ("libraries", json.dumps([])), + ] + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_analytics_services_controller.process_analytics_services, + ) def test_analytics_services_post_create(self): library, ignore = create( - self._db, Library, name="Library", short_name="L", + self._db, + Library, + name="Library", + short_name="L", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Google analytics name"), - ("protocol", GoogleAnalyticsProvider.__module__), - (ExternalIntegration.URL, "http://test"), - ("libraries", json.dumps([{"short_name": "L", "tracking_id": "trackingid"}])), - ]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + flask.request.form = MultiDict( + [ + ("name", "Google analytics name"), + ("protocol", GoogleAnalyticsProvider.__module__), + (ExternalIntegration.URL, "http://test"), + ( + "libraries", + json.dumps([{"short_name": "L", "tracking_id": "trackingid"}]), + ), + ] + ) + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response.status_code == 201 service = get_one( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, goal=ExternalIntegration.ANALYTICS_GOAL, - protocol=GoogleAnalyticsProvider.__module__ + protocol=GoogleAnalyticsProvider.__module__, ) assert service.id == int(response.get_data()) assert GoogleAnalyticsProvider.__module__ == service.protocol assert "http://test" == service.url assert [library] == service.libraries - assert "trackingid" == ConfigurationSetting.for_library_and_externalintegration( - self._db, GoogleAnalyticsProvider.TRACKING_ID, library, service).value + assert ( + "trackingid" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, GoogleAnalyticsProvider.TRACKING_ID, library, service + ).value + ) local_analytics_default = get_one( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, goal=ExternalIntegration.ANALYTICS_GOAL, - protocol=LocalAnalyticsProvider.__module__ + protocol=LocalAnalyticsProvider.__module__, ) self._db.delete(local_analytics_default) # Creating a local analytics service doesn't require a URL. with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "local analytics name"), - ("protocol", LocalAnalyticsProvider.__module__), - ("libraries", json.dumps([{"short_name": "L", "tracking_id": "trackingid"}])), - ]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + flask.request.form = MultiDict( + [ + ("name", "local analytics name"), + ("protocol", LocalAnalyticsProvider.__module__), + ( + "libraries", + json.dumps([{"short_name": "L", "tracking_id": "trackingid"}]), + ), + ] + ) + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response.status_code == 201 def test_analytics_services_post_edit(self): l1, ignore = create( - self._db, Library, name="Library 1", short_name="L1", + self._db, + Library, + name="Library 1", + short_name="L1", ) l2, ignore = create( - self._db, Library, name="Library 2", short_name="L2", + self._db, + Library, + name="Library 2", + short_name="L2", ) ga_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=GoogleAnalyticsProvider.__module__, goal=ExternalIntegration.ANALYTICS_GOAL, ) @@ -256,52 +345,69 @@ def test_analytics_services_post_edit(self): ga_service.libraries = [l1] with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", ga_service.id), - ("name", "some other analytics name"), - ("protocol", GoogleAnalyticsProvider.__module__), - (ExternalIntegration.URL, "http://test"), - ("libraries", json.dumps([{"short_name": "L2", "tracking_id": "l2id"}])), - ]) - response = self.manager.admin_analytics_services_controller.process_analytics_services() + flask.request.form = MultiDict( + [ + ("id", ga_service.id), + ("name", "some other analytics name"), + ("protocol", GoogleAnalyticsProvider.__module__), + (ExternalIntegration.URL, "http://test"), + ( + "libraries", + json.dumps([{"short_name": "L2", "tracking_id": "l2id"}]), + ), + ] + ) + response = ( + self.manager.admin_analytics_services_controller.process_analytics_services() + ) assert response.status_code == 200 assert ga_service.id == int(response.get_data()) assert GoogleAnalyticsProvider.__module__ == ga_service.protocol assert "http://test" == ga_service.url assert [l2] == ga_service.libraries - assert "l2id" == ConfigurationSetting.for_library_and_externalintegration( - self._db, GoogleAnalyticsProvider.TRACKING_ID, l2, ga_service).value + assert ( + "l2id" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, GoogleAnalyticsProvider.TRACKING_ID, l2, ga_service + ).value + ) def test_check_name_unique(self): - kwargs = dict(protocol=GoogleAnalyticsProvider.__module__, - goal=ExternalIntegration.ANALYTICS_GOAL) - existing_service, ignore = create(self._db, ExternalIntegration, name="existing service", **kwargs) - new_service, ignore = create(self._db, ExternalIntegration, name="new service", **kwargs) + kwargs = dict( + protocol=GoogleAnalyticsProvider.__module__, + goal=ExternalIntegration.ANALYTICS_GOAL, + ) + existing_service, ignore = create( + self._db, ExternalIntegration, name="existing service", **kwargs + ) + new_service, ignore = create( + self._db, ExternalIntegration, name="new service", **kwargs + ) - m = self.manager.admin_analytics_services_controller.check_name_unique + m = self.manager.admin_analytics_services_controller.check_name_unique - # Try to change new service so that it has the same name as existing service - # -- this is not allowed. - result = m(new_service, existing_service.name) - assert result == INTEGRATION_NAME_ALREADY_IN_USE + # Try to change new service so that it has the same name as existing service + # -- this is not allowed. + result = m(new_service, existing_service.name) + assert result == INTEGRATION_NAME_ALREADY_IN_USE - # Try to edit existing service without changing its name -- this is fine. - assert ( - None == - m(existing_service, existing_service.name)) + # Try to edit existing service without changing its name -- this is fine. + assert None == m(existing_service, existing_service.name) - # Changing the existing service's name is also fine. - assert ( - None == - m(existing_service, "new name")) + # Changing the existing service's name is also fine. + assert None == m(existing_service, "new name") def test_analytics_service_delete(self): l1, ignore = create( - self._db, Library, name="Library 1", short_name="L1", + self._db, + Library, + name="Library 1", + short_name="L1", ) ga_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=GoogleAnalyticsProvider.__module__, goal=ExternalIntegration.ANALYTICS_GOAL, ) @@ -310,12 +416,16 @@ def test_analytics_service_delete(self): with self.request_context_with_admin("/", method="DELETE"): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) - pytest.raises(AdminNotAuthorized, - self.manager.admin_analytics_services_controller.process_delete, - ga_service.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_analytics_services_controller.process_delete, + ga_service.id, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) - response = self.manager.admin_analytics_services_controller.process_delete(ga_service.id) + response = self.manager.admin_analytics_services_controller.process_delete( + ga_service.id + ) assert response.status_code == 200 service = get_one(self._db, ExternalIntegration, id=ga_service.id) diff --git a/tests/admin/controller/test_catalog_services.py b/tests/admin/controller/test_catalog_services.py index 7b0b6cce17..e1186175ea 100644 --- a/tests/admin/controller/test_catalog_services.py +++ b/tests/admin/controller/test_catalog_services.py @@ -15,14 +15,16 @@ ) from core.model.configuration import ExternalIntegrationLink from core.s3 import S3Uploader, S3UploaderConfiguration + from .test_controller import SettingsControllerTest class TestCatalogServicesController(SettingsControllerTest): - def test_catalog_services_get_with_no_services(self): with self.request_context_with_admin("/"): - response = self.manager.admin_catalog_services_controller.process_catalog_services() + response = ( + self.manager.admin_catalog_services_controller.process_catalog_services() + ) assert response.get("catalog_services") == [] protocols = response.get("protocols") assert 1 == len(protocols) @@ -32,29 +34,43 @@ def test_catalog_services_get_with_no_services(self): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self._db.flush() - pytest.raises(AdminNotAuthorized, - self.manager.admin_catalog_services_controller.process_catalog_services) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_catalog_services_controller.process_catalog_services, + ) def test_catalog_services_get_with_marc_exporter(self): integration, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.MARC_EXPORT, goal=ExternalIntegration.CATALOG_GOAL, name="name", ) integration.libraries += [self._default_library] ConfigurationSetting.for_library_and_externalintegration( - self._db, MARCExporter.MARC_ORGANIZATION_CODE, - self._default_library, integration).value = "US-MaBoDPL" + self._db, + MARCExporter.MARC_ORGANIZATION_CODE, + self._default_library, + integration, + ).value = "US-MaBoDPL" ConfigurationSetting.for_library_and_externalintegration( - self._db, MARCExporter.INCLUDE_SUMMARY, - self._default_library, integration).value = "false" + self._db, + MARCExporter.INCLUDE_SUMMARY, + self._default_library, + integration, + ).value = "false" ConfigurationSetting.for_library_and_externalintegration( - self._db, MARCExporter.INCLUDE_SIMPLIFIED_GENRES, - self._default_library, integration).value = "true" + self._db, + MARCExporter.INCLUDE_SIMPLIFIED_GENRES, + self._default_library, + integration, + ).value = "true" with self.request_context_with_admin("/"): - response = self.manager.admin_catalog_services_controller.process_catalog_services() + response = ( + self.manager.admin_catalog_services_controller.process_catalog_services() + ) [service] = response.get("catalog_services") assert integration.id == service.get("id") assert integration.name == service.get("name") @@ -65,48 +81,64 @@ def test_catalog_services_get_with_marc_exporter(self): assert "false" == library.get(MARCExporter.INCLUDE_SUMMARY) assert "true" == library.get(MARCExporter.INCLUDE_SIMPLIFIED_GENRES) - def test_catalog_services_post_errors(self): with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", "Unknown"), - ]) - response = self.manager.admin_catalog_services_controller.process_catalog_services() + flask.request.form = MultiDict( + [ + ("protocol", "Unknown"), + ] + ) + response = ( + self.manager.admin_catalog_services_controller.process_catalog_services() + ) assert response == UNKNOWN_PROTOCOL with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", "123"), - ]) - response = self.manager.admin_catalog_services_controller.process_catalog_services() + flask.request.form = MultiDict( + [ + ("id", "123"), + ] + ) + response = ( + self.manager.admin_catalog_services_controller.process_catalog_services() + ) assert response == MISSING_SERVICE service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol="fake protocol", goal=ExternalIntegration.CATALOG_GOAL, name="name", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", service.id), - ("protocol", ExternalIntegration.MARC_EXPORT), - ]) - response = self.manager.admin_catalog_services_controller.process_catalog_services() + flask.request.form = MultiDict( + [ + ("id", service.id), + ("protocol", ExternalIntegration.MARC_EXPORT), + ] + ) + response = ( + self.manager.admin_catalog_services_controller.process_catalog_services() + ) assert response == CANNOT_CHANGE_PROTOCOL with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", service.name), - ("protocol", ExternalIntegration.MARC_EXPORT), - ]) - response = self.manager.admin_catalog_services_controller.process_catalog_services() + flask.request.form = MultiDict( + [ + ("name", service.name), + ("protocol", ExternalIntegration.MARC_EXPORT), + ] + ) + response = ( + self.manager.admin_catalog_services_controller.process_catalog_services() + ) assert response == INTEGRATION_NAME_ALREADY_IN_USE - service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.MARC_EXPORT, goal=ExternalIntegration.CATALOG_GOAL, ) @@ -114,17 +146,22 @@ def test_catalog_services_post_errors(self): # Attempt to set an S3 mirror external integration but it does not exist! with self.request_context_with_admin("/", method="POST"): ME = MARCExporter - flask.request.form = MultiDict([ - ("name", "exporter name"), - ("id", service.id), - ("protocol", ME.NAME), - ("mirror_integration_id", "1234") - ]) - response = self.manager.admin_catalog_services_controller.process_catalog_services() + flask.request.form = MultiDict( + [ + ("name", "exporter name"), + ("id", service.id), + ("protocol", ME.NAME), + ("mirror_integration_id", "1234"), + ] + ) + response = ( + self.manager.admin_catalog_services_controller.process_catalog_services() + ) assert response.uri == MISSING_INTEGRATION.uri s3, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.S3, goal=ExternalIntegration.STORAGE_GOAL, ) @@ -132,25 +169,33 @@ def test_catalog_services_post_errors(self): # Now an S3 integration exists, but it has no MARC bucket configured. with self.request_context_with_admin("/", method="POST"): ME = MARCExporter - flask.request.form = MultiDict([ - ("name", "exporter name"), - ("id", service.id), - ("protocol", ME.NAME), - ("mirror_integration_id", s3.id) - ]) - response = self.manager.admin_catalog_services_controller.process_catalog_services() + flask.request.form = MultiDict( + [ + ("name", "exporter name"), + ("id", service.id), + ("protocol", ME.NAME), + ("mirror_integration_id", s3.id), + ] + ) + response = ( + self.manager.admin_catalog_services_controller.process_catalog_services() + ) assert response.uri == MISSING_INTEGRATION.uri self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self._db.flush() with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "new name"), - ("protocol", ME.NAME), - ("mirror_integration_id", s3.id), - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_catalog_services_controller.process_catalog_services) + flask.request.form = MultiDict( + [ + ("name", "new name"), + ("protocol", ME.NAME), + ("mirror_integration_id", s3.id), + ] + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_catalog_services_controller.process_catalog_services, + ) # This should be the last test to check since rolling back database # changes in the test can cause it to crash. @@ -160,49 +205,77 @@ def test_catalog_services_post_errors(self): with self.request_context_with_admin("/", method="POST"): ME = MARCExporter - flask.request.form = MultiDict([ - ("name", "new name"), - ("protocol", ME.NAME), - ("mirror_integration_id", s3.id), - ("libraries", json.dumps([{ - "short_name": self._default_library.short_name, - ME.INCLUDE_SUMMARY: "false", - ME.INCLUDE_SIMPLIFIED_GENRES: "true", - }])), - ]) - response = self.manager.admin_catalog_services_controller.process_catalog_services() + flask.request.form = MultiDict( + [ + ("name", "new name"), + ("protocol", ME.NAME), + ("mirror_integration_id", s3.id), + ( + "libraries", + json.dumps( + [ + { + "short_name": self._default_library.short_name, + ME.INCLUDE_SUMMARY: "false", + ME.INCLUDE_SIMPLIFIED_GENRES: "true", + } + ] + ), + ), + ] + ) + response = ( + self.manager.admin_catalog_services_controller.process_catalog_services() + ) assert response.uri == MULTIPLE_SERVICES_FOR_LIBRARY.uri def test_catalog_services_post_create(self): ME = MARCExporter s3, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.S3, goal=ExternalIntegration.STORAGE_GOAL, ) s3.setting(S3UploaderConfiguration.MARC_BUCKET_KEY).value = "marc-files" with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "exporter name"), - ("protocol", ME.NAME), - ("mirror_integration_id", s3.id), - ("libraries", json.dumps([{ - "short_name": self._default_library.short_name, - ME.INCLUDE_SUMMARY: "false", - ME.INCLUDE_SIMPLIFIED_GENRES: "true", - }])), - ]) - response = self.manager.admin_catalog_services_controller.process_catalog_services() + flask.request.form = MultiDict( + [ + ("name", "exporter name"), + ("protocol", ME.NAME), + ("mirror_integration_id", s3.id), + ( + "libraries", + json.dumps( + [ + { + "short_name": self._default_library.short_name, + ME.INCLUDE_SUMMARY: "false", + ME.INCLUDE_SIMPLIFIED_GENRES: "true", + } + ] + ), + ), + ] + ) + response = ( + self.manager.admin_catalog_services_controller.process_catalog_services() + ) assert response.status_code == 201 - service = get_one(self._db, ExternalIntegration, goal=ExternalIntegration.CATALOG_GOAL) + service = get_one( + self._db, ExternalIntegration, goal=ExternalIntegration.CATALOG_GOAL + ) # There was one S3 integration and it was selected. The service has an # External Integration Link to the storage integration that is created # in a POST with purpose of ExternalIntegrationLink.MARC. integration_link = get_one( - self._db, ExternalIntegrationLink, external_integration_id=service.id, purpose=ExternalIntegrationLink.MARC + self._db, + ExternalIntegrationLink, + external_integration_id=service.id, + purpose=ExternalIntegrationLink.MARC, ) assert service.id == int(response.get_data()) @@ -212,71 +285,110 @@ def test_catalog_services_post_create(self): # We expect the Catalog external integration to have a link to the # S3 storage external integration assert s3.id == integration_link.other_integration_id - assert "false" == ConfigurationSetting.for_library_and_externalintegration( - self._db, ME.INCLUDE_SUMMARY, self._default_library, service).value - assert "true" == ConfigurationSetting.for_library_and_externalintegration( - self._db, ME.INCLUDE_SIMPLIFIED_GENRES, self._default_library, service).value + assert ( + "false" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, ME.INCLUDE_SUMMARY, self._default_library, service + ).value + ) + assert ( + "true" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, ME.INCLUDE_SIMPLIFIED_GENRES, self._default_library, service + ).value + ) def test_catalog_services_post_edit(self): ME = MARCExporter s3, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.S3, goal=ExternalIntegration.STORAGE_GOAL, ) s3.setting(S3UploaderConfiguration.MARC_BUCKET_KEY).value = "marc-files" service, ignore = create( - self._db, ExternalIntegration, - protocol=ME.NAME, goal=ExternalIntegration.CATALOG_GOAL, - name="name" + self._db, + ExternalIntegration, + protocol=ME.NAME, + goal=ExternalIntegration.CATALOG_GOAL, + name="name", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "exporter name"), - ("id", service.id), - ("protocol", ME.NAME), - ("mirror_integration_id", s3.id), - ("libraries", json.dumps([{ - "short_name": self._default_library.short_name, - ME.INCLUDE_SUMMARY: "false", - ME.INCLUDE_SIMPLIFIED_GENRES: "true", - }])), - ]) - response = self.manager.admin_catalog_services_controller.process_catalog_services() + flask.request.form = MultiDict( + [ + ("name", "exporter name"), + ("id", service.id), + ("protocol", ME.NAME), + ("mirror_integration_id", s3.id), + ( + "libraries", + json.dumps( + [ + { + "short_name": self._default_library.short_name, + ME.INCLUDE_SUMMARY: "false", + ME.INCLUDE_SIMPLIFIED_GENRES: "true", + } + ] + ), + ), + ] + ) + response = ( + self.manager.admin_catalog_services_controller.process_catalog_services() + ) assert response.status_code == 200 integration_link = get_one( - self._db, ExternalIntegrationLink, external_integration_id=service.id, purpose=ExternalIntegrationLink.MARC + self._db, + ExternalIntegrationLink, + external_integration_id=service.id, + purpose=ExternalIntegrationLink.MARC, ) assert service.id == int(response.get_data()) assert ME.NAME == service.protocol assert "exporter name" == service.name assert s3.id == integration_link.other_integration_id assert [self._default_library] == service.libraries - assert "false" == ConfigurationSetting.for_library_and_externalintegration( - self._db, ME.INCLUDE_SUMMARY, self._default_library, service).value - assert "true" == ConfigurationSetting.for_library_and_externalintegration( - self._db, ME.INCLUDE_SIMPLIFIED_GENRES, self._default_library, service).value + assert ( + "false" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, ME.INCLUDE_SUMMARY, self._default_library, service + ).value + ) + assert ( + "true" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, ME.INCLUDE_SIMPLIFIED_GENRES, self._default_library, service + ).value + ) def test_catalog_services_delete(self): ME = MARCExporter service, ignore = create( - self._db, ExternalIntegration, - protocol=ME.NAME, goal=ExternalIntegration.CATALOG_GOAL, - name="name" + self._db, + ExternalIntegration, + protocol=ME.NAME, + goal=ExternalIntegration.CATALOG_GOAL, + name="name", ) with self.request_context_with_admin("/", method="DELETE"): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) - pytest.raises(AdminNotAuthorized, - self.manager.admin_catalog_services_controller.process_delete, - service.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_catalog_services_controller.process_delete, + service.id, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) - response = self.manager.admin_catalog_services_controller.process_delete(service.id) + response = self.manager.admin_catalog_services_controller.process_delete( + service.id + ) assert response.status_code == 200 service = get_one(self._db, ExternalIntegration, id=service.id) diff --git a/tests/admin/controller/test_cdn_services.py b/tests/admin/controller/test_cdn_services.py index 59db4f795f..0e70484fd8 100644 --- a/tests/admin/controller/test_cdn_services.py +++ b/tests/admin/controller/test_cdn_services.py @@ -1,17 +1,13 @@ -import pytest - import flask +import pytest from werkzeug.datastructures import MultiDict + from api.admin.exceptions import * -from core.model import ( - AdminRole, - Configuration, - create, - ExternalIntegration, - get_one, -) +from core.model import AdminRole, Configuration, ExternalIntegration, create, get_one + from .test_controller import SettingsControllerTest + class TestCDNServices(SettingsControllerTest): def test_cdn_services_get_with_no_services(self): with self.request_context_with_admin("/"): @@ -23,17 +19,22 @@ def test_cdn_services_get_with_no_services(self): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self._db.flush() - pytest.raises(AdminNotAuthorized, - self.manager.admin_cdn_services_controller.process_cdn_services) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_cdn_services_controller.process_cdn_services, + ) def test_cdn_services_get_with_one_service(self): cdn_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.CDN, goal=ExternalIntegration.CDN_GOAL, ) cdn_service.url = "cdn url" - cdn_service.setting(Configuration.CDN_MIRRORED_DOMAIN_KEY).value = "mirrored domain" + cdn_service.setting( + Configuration.CDN_MIRRORED_DOMAIN_KEY + ).value = "mirrored domain" with self.request_context_with_admin("/"): response = self.manager.admin_cdn_services_controller.process_cdn_services() @@ -42,7 +43,9 @@ def test_cdn_services_get_with_one_service(self): assert cdn_service.id == service.get("id") assert cdn_service.protocol == service.get("protocol") assert "cdn url" == service.get("settings").get(ExternalIntegration.URL) - assert "mirrored domain" == service.get("settings").get(Configuration.CDN_MIRRORED_DOMAIN_KEY) + assert "mirrored domain" == service.get("settings").get( + Configuration.CDN_MIRRORED_DOMAIN_KEY + ) def test_cdn_services_post_errors(self): with self.request_context_with_admin("/", method="POST"): @@ -51,146 +54,184 @@ def test_cdn_services_post_errors(self): assert response == INCOMPLETE_CONFIGURATION with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("protocol", "Unknown"), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("protocol", "Unknown"), + ] + ) response = self.manager.admin_cdn_services_controller.process_cdn_services() assert response == UNKNOWN_PROTOCOL with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ] + ) response = self.manager.admin_cdn_services_controller.process_cdn_services() assert response == NO_PROTOCOL_FOR_NEW_SERVICE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", "123"), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", "123"), + ] + ) response = self.manager.admin_cdn_services_controller.process_cdn_services() assert response == MISSING_SERVICE service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.CDN, goal=ExternalIntegration.CDN_GOAL, name="name", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", service.name), - ("protocol", ExternalIntegration.CDN), - ]) + flask.request.form = MultiDict( + [ + ("name", service.name), + ("protocol", ExternalIntegration.CDN), + ] + ) response = self.manager.admin_cdn_services_controller.process_cdn_services() assert response == INTEGRATION_NAME_ALREADY_IN_USE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", service.id), - ("protocol", ExternalIntegration.CDN), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", service.id), + ("protocol", ExternalIntegration.CDN), + ] + ) response = self.manager.admin_cdn_services_controller.process_cdn_services() assert response.uri == INCOMPLETE_CONFIGURATION.uri self.admin.remove_role(AdminRole.SYSTEM_ADMIN) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("protocol", ExternalIntegration.CDN), - (ExternalIntegration.URL, "cdn url"), - (Configuration.CDN_MIRRORED_DOMAIN_KEY, "mirrored domain"), - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_cdn_services_controller.process_cdn_services) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("protocol", ExternalIntegration.CDN), + (ExternalIntegration.URL, "cdn url"), + (Configuration.CDN_MIRRORED_DOMAIN_KEY, "mirrored domain"), + ] + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_cdn_services_controller.process_cdn_services, + ) def test_cdn_services_post_create(self): with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("protocol", ExternalIntegration.CDN), - (ExternalIntegration.URL, "http://cdn_url"), - (Configuration.CDN_MIRRORED_DOMAIN_KEY, "mirrored domain"), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("protocol", ExternalIntegration.CDN), + (ExternalIntegration.URL, "http://cdn_url"), + (Configuration.CDN_MIRRORED_DOMAIN_KEY, "mirrored domain"), + ] + ) response = self.manager.admin_cdn_services_controller.process_cdn_services() assert response.status_code == 201 - service = get_one(self._db, ExternalIntegration, goal=ExternalIntegration.CDN_GOAL) + service = get_one( + self._db, ExternalIntegration, goal=ExternalIntegration.CDN_GOAL + ) assert service.id == int(response.response[0]) assert ExternalIntegration.CDN == service.protocol assert "http://cdn_url" == service.url - assert "mirrored domain" == service.setting(Configuration.CDN_MIRRORED_DOMAIN_KEY).value + assert ( + "mirrored domain" + == service.setting(Configuration.CDN_MIRRORED_DOMAIN_KEY).value + ) def test_cdn_services_post_edit(self): cdn_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.CDN, goal=ExternalIntegration.CDN_GOAL, ) cdn_service.url = "cdn url" - cdn_service.setting(Configuration.CDN_MIRRORED_DOMAIN_KEY).value = "mirrored domain" + cdn_service.setting( + Configuration.CDN_MIRRORED_DOMAIN_KEY + ).value = "mirrored domain" with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", cdn_service.id), - ("protocol", ExternalIntegration.CDN), - (ExternalIntegration.URL, "http://new_cdn_url"), - (Configuration.CDN_MIRRORED_DOMAIN_KEY, "new mirrored domain") - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", cdn_service.id), + ("protocol", ExternalIntegration.CDN), + (ExternalIntegration.URL, "http://new_cdn_url"), + (Configuration.CDN_MIRRORED_DOMAIN_KEY, "new mirrored domain"), + ] + ) response = self.manager.admin_cdn_services_controller.process_cdn_services() assert response.status_code == 200 assert cdn_service.id == int(response.response[0]) assert ExternalIntegration.CDN == cdn_service.protocol assert "http://new_cdn_url" == cdn_service.url - assert "new mirrored domain" == cdn_service.setting(Configuration.CDN_MIRRORED_DOMAIN_KEY).value + assert ( + "new mirrored domain" + == cdn_service.setting(Configuration.CDN_MIRRORED_DOMAIN_KEY).value + ) def test_check_name_unique(self): - kwargs = dict(protocol=ExternalIntegration.CDN, - goal=ExternalIntegration.CDN_GOAL) + kwargs = dict( + protocol=ExternalIntegration.CDN, goal=ExternalIntegration.CDN_GOAL + ) - existing_service, ignore = create(self._db, ExternalIntegration, name="existing service", **kwargs) - new_service, ignore = create(self._db, ExternalIntegration, name="new service", **kwargs) + existing_service, ignore = create( + self._db, ExternalIntegration, name="existing service", **kwargs + ) + new_service, ignore = create( + self._db, ExternalIntegration, name="new service", **kwargs + ) - m = self.manager.admin_cdn_services_controller.check_name_unique + m = self.manager.admin_cdn_services_controller.check_name_unique - # Try to change new service so that it has the same name as existing service - # -- this is not allowed. - result = m(new_service, existing_service.name) - assert result == INTEGRATION_NAME_ALREADY_IN_USE + # Try to change new service so that it has the same name as existing service + # -- this is not allowed. + result = m(new_service, existing_service.name) + assert result == INTEGRATION_NAME_ALREADY_IN_USE - # Try to edit existing service without changing its name -- this is fine. - assert ( - None == - m(existing_service, existing_service.name)) + # Try to edit existing service without changing its name -- this is fine. + assert None == m(existing_service, existing_service.name) - # Changing the existing service's name is also fine. - assert ( - None == - m(existing_service, "new name")) + # Changing the existing service's name is also fine. + assert None == m(existing_service, "new name") def test_cdn_service_delete(self): cdn_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.CDN, goal=ExternalIntegration.CDN_GOAL, ) cdn_service.url = "cdn url" - cdn_service.setting(Configuration.CDN_MIRRORED_DOMAIN_KEY).value = "mirrored domain" + cdn_service.setting( + Configuration.CDN_MIRRORED_DOMAIN_KEY + ).value = "mirrored domain" with self.request_context_with_admin("/", method="DELETE"): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) - pytest.raises(AdminNotAuthorized, - self.manager.admin_cdn_services_controller.process_delete, - cdn_service.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_cdn_services_controller.process_delete, + cdn_service.id, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) - response = self.manager.admin_cdn_services_controller.process_delete(cdn_service.id) + response = self.manager.admin_cdn_services_controller.process_delete( + cdn_service.id + ) assert response.status_code == 200 service = get_one(self._db, ExternalIntegration, id=cdn_service.id) diff --git a/tests/admin/controller/test_collection_registrations.py b/tests/admin/controller/test_collection_registrations.py index 18f985ca7d..6f2711c173 100644 --- a/tests/admin/controller/test_collection_registrations.py +++ b/tests/admin/controller/test_collection_registrations.py @@ -1,22 +1,16 @@ -import pytest - import flask +import pytest from werkzeug.datastructures import MultiDict + from api.admin.exceptions import * from api.odl import SharedODLAPI -from api.registry import ( - Registration, - RemoteRegistry, -) -from core.model import ( - AdminRole, - ConfigurationSetting, - create, - Library, -) +from api.registry import Registration, RemoteRegistry +from core.model import AdminRole, ConfigurationSetting, Library, create from core.util.http import HTTP + from .test_controller import SettingsControllerTest + class TestCollectionRegistration(SettingsControllerTest): """Test the process of registering a specific collection with a RemoteRegistry. @@ -25,24 +19,41 @@ class TestCollectionRegistration(SettingsControllerTest): def test_collection_library_registrations_get(self): collection = self._default_collection succeeded, ignore = create( - self._db, Library, name="Library 1", short_name="L1", + self._db, + Library, + name="Library 1", + short_name="L1", ) ConfigurationSetting.for_library_and_externalintegration( - self._db, "library-registration-status", succeeded, collection.external_integration, - ).value = "success" + self._db, + "library-registration-status", + succeeded, + collection.external_integration, + ).value = "success" failed, ignore = create( - self._db, Library, name="Library 2", short_name="L2", + self._db, + Library, + name="Library 2", + short_name="L2", ) ConfigurationSetting.for_library_and_externalintegration( - self._db, "library-registration-status", failed, collection.external_integration, - ).value = "failure" + self._db, + "library-registration-status", + failed, + collection.external_integration, + ).value = "failure" unregistered, ignore = create( - self._db, Library, name="Library 3", short_name="L3", + self._db, + Library, + name="Library 3", + short_name="L3", ) collection.libraries = [succeeded, failed, unregistered] with self.request_context_with_admin("/", method="GET"): - response = self.manager.admin_collection_library_registrations_controller.process_collection_library_registrations() + response = ( + self.manager.admin_collection_library_registrations_controller.process_collection_library_registrations() + ) serviceInfo = response.get("library_registrations") assert 1 == len(serviceInfo) @@ -57,14 +68,18 @@ def test_collection_library_registrations_get(self): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self._db.flush() - pytest.raises(AdminNotAuthorized, - self.manager.admin_collection_library_registrations_controller.process_collection_library_registrations) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_collection_library_registrations_controller.process_collection_library_registrations, + ) def test_collection_library_registrations_post(self): """Test what might happen POSTing to collection_library_registrations.""" # First test the failure cases. - m = self.manager.admin_collection_library_registrations_controller.process_collection_library_registrations + m = ( + self.manager.admin_collection_library_registrations_controller.process_collection_library_registrations + ) # Here, the user doesn't have permission to start the # registration process. @@ -84,10 +99,12 @@ def test_collection_library_registrations_post(self): collection.external_account_id = "collection url" # Oops, the collection doesn't actually support registration. - form = MultiDict([ - ("collection_id", collection.id), - ("library_short_name", "not-a-library"), - ]) + form = MultiDict( + [ + ("collection_id", collection.id), + ("library_short_name", "not-a-library"), + ] + ) with self.request_context_with_admin("/", method="POST"): flask.request.form = form response = m() @@ -106,10 +123,12 @@ def test_collection_library_registrations_post(self): # The push() implementation might return a ProblemDetail for any # number of reasons. library = self._default_library - form = MultiDict([ - ("collection_id", collection.id), - ("library_short_name", library.short_name), - ]) + form = MultiDict( + [ + ("collection_id", collection.id), + ("library_short_name", library.short_name), + ] + ) class Mock(Registration): def push(self, *args, **kwargs): @@ -124,7 +143,9 @@ class Mock(Registration): """When asked to push a registration, do nothing and say it worked. """ + called_with = None + def push(self, *args, **kwargs): Mock.called_with = (args, kwargs) return True @@ -139,9 +160,9 @@ def push(self, *args, **kwargs): assert (Registration.PRODUCTION_STAGE, self.manager.url_for) == args # We would have made real HTTP requests. - assert HTTP.debuggable_post == kwargs.pop('do_post') - assert HTTP.debuggable_get == kwargs.pop('do_get') - # And passed the collection URL over to the shared collection. - assert collection.external_account_id == kwargs.pop('catalog_url') - # No other weird keyword arguments were passed in. + assert HTTP.debuggable_post == kwargs.pop("do_post") + assert HTTP.debuggable_get == kwargs.pop("do_get") + # And passed the collection URL over to the shared collection. + assert collection.external_account_id == kwargs.pop("catalog_url") + # No other weird keyword arguments were passed in. assert {} == kwargs diff --git a/tests/admin/controller/test_collection_self_tests.py b/tests/admin/controller/test_collection_self_tests.py index 1af3820f3b..ec3cebae15 100644 --- a/tests/admin/controller/test_collection_self_tests.py +++ b/tests/admin/controller/test_collection_self_tests.py @@ -1,22 +1,28 @@ - from flask_babel import lazy_gettext as _ + from api.admin.problem_details import * -from api.axis import (Axis360API, MockAxis360API) -from core.opds_import import (OPDSImporter, OPDSImportMonitor) +from api.axis import Axis360API, MockAxis360API +from core.opds_import import OPDSImporter, OPDSImportMonitor from core.selftest import HasSelfTests + from .test_controller import SettingsControllerTest + class TestCollectionSelfTests(SettingsControllerTest): def test_collection_self_tests_with_no_identifier(self): with self.request_context_with_admin("/"): - response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests(None) + response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests( + None + ) assert response.title == MISSING_IDENTIFIER.title assert response.detail == MISSING_IDENTIFIER.detail assert response.status_code == 400 def test_collection_self_tests_with_no_collection_found(self): with self.request_context_with_admin("/"): - response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests(-1) + response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests( + -1 + ) assert response == NO_SUCH_COLLECTION assert response.status_code == 404 @@ -28,7 +34,9 @@ def test_collection_self_tests_test_get(self): # Make sure that HasSelfTest.prior_test_results() was called and that # it is in the response's collection object. with self.request_context_with_admin("/"): - response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests(collection.id) + response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests( + collection.id + ) responseCollection = response.get("self_test_results") @@ -48,9 +56,14 @@ def test_collection_self_tests_failed_post(self): # Failed to run self tests with self.request_context_with_admin("/", method="POST"): - response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests(collection.id) - - (run_self_tests_args, run_self_tests_kwargs) = self.failed_run_self_tests_called_with + response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests( + collection.id + ) + + ( + run_self_tests_args, + run_self_tests_kwargs, + ) = self.failed_run_self_tests_called_with assert response.title == FAILED_TO_RUN_SELF_TESTS.title assert response.detail == "Failed to run self tests for this collection." assert response.status_code == 400 @@ -64,9 +77,14 @@ def test_collection_self_tests_post(self): collection = self._collection() # Successfully ran new self tests for the OPDSImportMonitor provider API with self.request_context_with_admin("/", method="POST"): - response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests(collection.id) - - (run_self_tests_args, run_self_tests_kwargs) = self.run_self_tests_called_with + response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests( + collection.id + ) + + ( + run_self_tests_args, + run_self_tests_kwargs, + ) = self.run_self_tests_called_with assert response.response == _("Successfully ran new self tests") assert response._status == "200 OK" @@ -75,13 +93,17 @@ def test_collection_self_tests_post(self): assert run_self_tests_args[1] == OPDSImportMonitor assert run_self_tests_args[3] == collection - collection = MockAxis360API.mock_collection(self._db) # Successfully ran new self tests with self.request_context_with_admin("/", method="POST"): - response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests(collection.id) - - (run_self_tests_args, run_self_tests_kwargs) = self.run_self_tests_called_with + response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests( + collection.id + ) + + ( + run_self_tests_args, + run_self_tests_kwargs, + ) = self.run_self_tests_called_with assert response.response == _("Successfully ran new self tests") assert response._status == "200 OK" @@ -97,9 +119,14 @@ def test_collection_self_tests_post(self): # No protocol found so run_self_tests was not called with self.request_context_with_admin("/", method="POST"): - response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests(collection.id) - - (run_self_tests_args, run_self_tests_kwargs) = self.run_self_tests_called_with + response = self.manager.admin_collection_self_tests_controller.process_collection_self_tests( + collection.id + ) + + ( + run_self_tests_args, + run_self_tests_kwargs, + ) = self.run_self_tests_called_with assert response.title == FAILED_TO_RUN_SELF_TESTS.title assert response.detail == "Failed to run self tests for this collection." assert response.status_code == 400 diff --git a/tests/admin/controller/test_collections.py b/tests/admin/controller/test_collections.py index 6cbcd320aa..a1a4535b0c 100644 --- a/tests/admin/controller/test_collections.py +++ b/tests/admin/controller/test_collections.py @@ -2,7 +2,6 @@ import flask import pytest - from werkzeug.datastructures import MultiDict from api.admin.exceptions import * @@ -11,14 +10,15 @@ AdminRole, Collection, ConfigurationSetting, - create, ExternalIntegration, - get_one, Library, + create, + get_one, ) from core.model.configuration import ExternalIntegrationLink from core.s3 import S3UploaderConfiguration from core.selftest import HasSelfTests + from .test_controller import SettingsControllerTest @@ -29,7 +29,9 @@ def test_collections_get_with_no_collections(self): self._db.delete(collection) with self.request_context_with_admin("/"): - response = self.manager.admin_collection_settings_controller.process_collections() + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.get("collections") == [] names = [p.get("name") for p in response.get("protocols")] @@ -47,11 +49,18 @@ def test_collections_get_collection_protocols(self): # the protocols will not offer a 'mirror_integration_id' # setting for covers or books. with self.request_context_with_admin("/"): - response = self.manager.admin_collection_settings_controller.process_collections() - protocols = response.get('protocols') + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) + protocols = response.get("protocols") for protocol in protocols: - assert all([not s.get('key').endswith('mirror_integration_id') - for s in protocol['settings'] if s]) + assert all( + [ + not s.get("key").endswith("mirror_integration_id") + for s in protocol["settings"] + if s + ] + ) # When storage integrations are configured, each protocol will # offer a 'mirror_integration_id' setting for covers and books. @@ -60,67 +69,84 @@ def test_collections_get_collection_protocols(self): protocol=ExternalIntegration.S3, goal=ExternalIntegration.STORAGE_GOAL, settings={ - S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: 'covers', - S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'open-access-books', - S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: 'protected-access-books' - } + S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "covers", + S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "open-access-books", + S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: "protected-access-books", + }, ) storage2 = self._external_integration( name="integration 2", protocol="Some other protocol", goal=ExternalIntegration.STORAGE_GOAL, settings={ - S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: 'covers', - S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'open-access-books', - S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: 'protected-access-books' - } + S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "covers", + S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "open-access-books", + S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: "protected-access-books", + }, ) with self.request_context_with_admin("/"): controller = self.manager.admin_collection_settings_controller response = controller.process_collections() - protocols = response.get('protocols') + protocols = response.get("protocols") for protocol in protocols: - mirror_settings = [x for x in protocol['settings'] - if x.get('key').endswith('mirror_integration_id')] + mirror_settings = [ + x + for x in protocol["settings"] + if x.get("key").endswith("mirror_integration_id") + ] covers_mirror = mirror_settings[0] open_access_books_mirror = mirror_settings[1] protected_access_books_mirror = mirror_settings[2] - assert "Covers Mirror" == covers_mirror['label'] - assert "Open Access Books Mirror" == open_access_books_mirror['label'] - assert "Protected Access Books Mirror" == protected_access_books_mirror['label'] - covers_mirror_option = covers_mirror['options'] - open_books_mirror_option = open_access_books_mirror['options'] - protected_books_mirror_option = protected_access_books_mirror['options'] + assert "Covers Mirror" == covers_mirror["label"] + assert "Open Access Books Mirror" == open_access_books_mirror["label"] + assert ( + "Protected Access Books Mirror" + == protected_access_books_mirror["label"] + ) + covers_mirror_option = covers_mirror["options"] + open_books_mirror_option = open_access_books_mirror["options"] + protected_books_mirror_option = protected_access_books_mirror["options"] # The first option is to disable mirroring on this # collection altogether. no_mirror_covers = covers_mirror_option[0] no_mirror_open_books = open_books_mirror_option[0] no_mirror_protected_books = protected_books_mirror_option[0] - assert controller.NO_MIRROR_INTEGRATION == no_mirror_covers['key'] - assert controller.NO_MIRROR_INTEGRATION == no_mirror_open_books['key'] - assert controller.NO_MIRROR_INTEGRATION == no_mirror_protected_books['key'] + assert controller.NO_MIRROR_INTEGRATION == no_mirror_covers["key"] + assert controller.NO_MIRROR_INTEGRATION == no_mirror_open_books["key"] + assert ( + controller.NO_MIRROR_INTEGRATION == no_mirror_protected_books["key"] + ) # The other options are to use one of the storage # integrations to do the mirroring. - use_covers_mirror = [(x['key'], x['label']) - for x in covers_mirror_option[1:]] - use_open_books_mirror = [(x['key'], x['label']) - for x in open_books_mirror_option[1:]] - use_protected_books_mirror = [(x['key'], x['label']) - for x in protected_books_mirror_option[1:]] + use_covers_mirror = [ + (x["key"], x["label"]) for x in covers_mirror_option[1:] + ] + use_open_books_mirror = [ + (x["key"], x["label"]) for x in open_books_mirror_option[1:] + ] + use_protected_books_mirror = [ + (x["key"], x["label"]) for x in protected_books_mirror_option[1:] + ] # Expect to have two separate mirrors - expect_covers = [(str(integration.id), integration.name) - for integration in (storage1, storage2)] + expect_covers = [ + (str(integration.id), integration.name) + for integration in (storage1, storage2) + ] assert expect_covers == use_covers_mirror - expect_open_books = [(str(integration.id), integration.name) - for integration in (storage1, storage2)] + expect_open_books = [ + (str(integration.id), integration.name) + for integration in (storage1, storage2) + ] assert expect_open_books == use_open_books_mirror - expect_protected_books = [(str(integration.id), integration.name) - for integration in (storage1, storage2)] + expect_protected_books = [ + (str(integration.id), integration.name) + for integration in (storage1, storage2) + ] assert expect_protected_books == use_protected_books_mirror HasSelfTests.prior_test_results = old_prior_test_results @@ -133,24 +159,26 @@ def test_collections_get_collections_with_multiple_collections(self): [c1] = self._default_library.collections c2 = self._collection( - name="Collection 2", protocol=ExternalIntegration.OVERDRIVE, + name="Collection 2", + protocol=ExternalIntegration.OVERDRIVE, ) c2_storage = self._external_integration( - protocol=ExternalIntegration.S3, - goal=ExternalIntegration.STORAGE_GOAL + protocol=ExternalIntegration.S3, goal=ExternalIntegration.STORAGE_GOAL ) c2_external_integration_link = self._external_integration_link( integration=c2.external_integration, - other_integration=c2_storage, purpose=ExternalIntegrationLink.COVERS + other_integration=c2_storage, + purpose=ExternalIntegrationLink.COVERS, ) c2.external_account_id = "1234" c2.external_integration.password = "b" c2.external_integration.username = "user" - c2.external_integration.setting('website_id').value = '100' + c2.external_integration.setting("website_id").value = "100" c3 = self._collection( - name="Collection 3", protocol=ExternalIntegration.OVERDRIVE, + name="Collection 3", + protocol=ExternalIntegration.OVERDRIVE, ) c3.external_account_id = "5678" c3.parent = c2 @@ -159,7 +187,8 @@ def test_collections_get_collections_with_multiple_collections(self): c3.libraries += [l1, self._default_library] c3.external_integration.libraries += [l1] ConfigurationSetting.for_library_and_externalintegration( - self._db, "ebook_loan_duration", l1, c3.external_integration).value = "14" + self._db, "ebook_loan_duration", l1, c3.external_integration + ).value = "14" l1_librarian, ignore = create(self._db, Admin, email="admin@l1.org") l1_librarian.add_role(AdminRole.LIBRARIAN, l1) @@ -169,7 +198,7 @@ def test_collections_get_collections_with_multiple_collections(self): response = controller.process_collections() # The system admin can see all collections. coll2, coll3, coll1 = sorted( - response.get("collections"), key = lambda c: c.get('name') + response.get("collections"), key=lambda c: c.get("name") ) assert c1.id == coll1.get("id") assert c2.id == coll2.get("id") @@ -191,17 +220,23 @@ def test_collections_get_collections_with_multiple_collections(self): settings2 = coll2.get("settings", {}) settings3 = coll3.get("settings", {}) - assert (controller.NO_MIRROR_INTEGRATION == - settings1.get("covers_mirror_integration_id")) - assert (controller.NO_MIRROR_INTEGRATION == - settings1.get("books_mirror_integration_id")) + assert controller.NO_MIRROR_INTEGRATION == settings1.get( + "covers_mirror_integration_id" + ) + assert controller.NO_MIRROR_INTEGRATION == settings1.get( + "books_mirror_integration_id" + ) # Only added an integration for S3 storage for covers. assert str(c2_storage.id) == settings2.get("covers_mirror_integration_id") - assert controller.NO_MIRROR_INTEGRATION == settings2.get("books_mirror_integration_id") - assert (controller.NO_MIRROR_INTEGRATION == - settings3.get("covers_mirror_integration_id")) - assert (controller.NO_MIRROR_INTEGRATION == - settings3.get("books_mirror_integration_id")) + assert controller.NO_MIRROR_INTEGRATION == settings2.get( + "books_mirror_integration_id" + ) + assert controller.NO_MIRROR_INTEGRATION == settings3.get( + "covers_mirror_integration_id" + ) + assert controller.NO_MIRROR_INTEGRATION == settings3.get( + "books_mirror_integration_id" + ) assert c1.external_account_id == settings1.get("external_account_id") assert c2.external_account_id == settings2.get("external_account_id") @@ -214,7 +249,9 @@ def test_collections_get_collections_with_multiple_collections(self): coll3_libraries = coll3.get("libraries") assert 2 == len(coll3_libraries) - coll3_l1, coll3_default = sorted(coll3_libraries, key=lambda x: x.get("short_name")) + coll3_l1, coll3_default = sorted( + coll3_libraries, key=lambda x: x.get("short_name") + ) assert "L1" == coll3_l1.get("short_name") assert "14" == coll3_l1.get("ebook_loan_duration") assert self._default_library.short_name == coll3_default.get("short_name") @@ -234,162 +271,235 @@ def test_collections_get_collections_with_multiple_collections(self): def test_collections_post_errors(self): with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", "Overdrive"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("protocol", "Overdrive"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response == MISSING_COLLECTION_NAME with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "collection"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "collection"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response == NO_PROTOCOL_FOR_NEW_SERVICE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "collection"), - ("protocol", "Unknown"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "collection"), + ("protocol", "Unknown"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response == UNKNOWN_PROTOCOL with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", "123456789"), - ("name", "collection"), - ("protocol", "Bibliotheca"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("id", "123456789"), + ("name", "collection"), + ("protocol", "Bibliotheca"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response == MISSING_COLLECTION collection = self._collection( - name="Collection 1", - protocol=ExternalIntegration.OVERDRIVE + name="Collection 1", protocol=ExternalIntegration.OVERDRIVE ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Collection 1"), - ("protocol", "Bibliotheca"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "Collection 1"), + ("protocol", "Bibliotheca"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response == COLLECTION_NAME_ALREADY_IN_USE self.admin.remove_role(AdminRole.SYSTEM_ADMIN) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", collection.id), - ("name", "Collection 1"), - ("protocol", "Overdrive"), - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_collection_settings_controller.process_collections) + flask.request.form = MultiDict( + [ + ("id", collection.id), + ("name", "Collection 1"), + ("protocol", "Overdrive"), + ] + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_collection_settings_controller.process_collections, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", collection.id), - ("name", "Collection 1"), - ("protocol", "Bibliotheca"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("id", collection.id), + ("name", "Collection 1"), + ("protocol", "Bibliotheca"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response == CANNOT_CHANGE_PROTOCOL with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Collection 2"), - ("protocol", "Bibliotheca"), - ("parent_id", "1234"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "Collection 2"), + ("protocol", "Bibliotheca"), + ("parent_id", "1234"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response == PROTOCOL_DOES_NOT_SUPPORT_PARENTS with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Collection 2"), - ("protocol", "Overdrive"), - ("parent_id", "1234"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "Collection 2"), + ("protocol", "Overdrive"), + ("parent_id", "1234"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response == MISSING_PARENT with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "collection"), - ("protocol", "OPDS Import"), - ("external_account_id", "http://url.test"), - ("data_source", "test"), - ("libraries", json.dumps([{"short_name": "nosuchlibrary"}])), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "collection"), + ("protocol", "OPDS Import"), + ("external_account_id", "http://url.test"), + ("data_source", "test"), + ("libraries", json.dumps([{"short_name": "nosuchlibrary"}])), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.uri == NO_SUCH_LIBRARY.uri with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "collection1"), - ("protocol", "OPDS Import"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "collection1"), + ("protocol", "OPDS Import"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.uri == INCOMPLETE_CONFIGURATION.uri with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "collection1"), - ("protocol", "Overdrive"), - ("external_account_id", "1234"), - ("username", "user"), - ("password", "password"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "collection1"), + ("protocol", "Overdrive"), + ("external_account_id", "1234"), + ("username", "user"), + ("password", "password"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.uri == INCOMPLETE_CONFIGURATION.uri with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "collection1"), - ("protocol", "Bibliotheca"), - ("external_account_id", "1234"), - ("password", "password"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "collection1"), + ("protocol", "Bibliotheca"), + ("external_account_id", "1234"), + ("password", "password"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.uri == INCOMPLETE_CONFIGURATION.uri with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "collection1"), - ("protocol", "Axis 360"), - ("username", "user"), - ("password", "password"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "collection1"), + ("protocol", "Axis 360"), + ("username", "user"), + ("password", "password"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.uri == INCOMPLETE_CONFIGURATION.uri def test_collections_post_create(self): l1, ignore = create( - self._db, Library, name="Library 1", short_name="L1", + self._db, + Library, + name="Library 1", + short_name="L1", ) l2, ignore = create( - self._db, Library, name="Library 2", short_name="L2", + self._db, + Library, + name="Library 2", + short_name="L2", ) l3, ignore = create( - self._db, Library, name="Library 3", short_name="L3", + self._db, + Library, + name="Library 3", + short_name="L3", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "New Collection"), - ("protocol", "Overdrive"), - ("libraries", json.dumps([ - {"short_name": "L1", "ils_name": "l1_ils"}, - {"short_name":"L2", "ils_name": "l2_ils"} - ])), - ("external_account_id", "acctid"), - ("username", "username"), - ("password", "password"), - ("website_id", "1234"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "New Collection"), + ("protocol", "Overdrive"), + ( + "libraries", + json.dumps( + [ + {"short_name": "L1", "ils_name": "l1_ils"}, + {"short_name": "L2", "ils_name": "l2_ils"}, + ] + ), + ), + ("external_account_id", "acctid"), + ("username", "username"), + ("password", "password"), + ("website_id", "1234"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.status_code == 201 # The collection was created and configured properly. @@ -410,21 +520,36 @@ def test_collections_post_create(self): assert "website_id" == setting.key assert "1234" == setting.value - assert "l1_ils" == ConfigurationSetting.for_library_and_externalintegration( - self._db, "ils_name", l1, collection.external_integration).value - assert "l2_ils" == ConfigurationSetting.for_library_and_externalintegration( - self._db, "ils_name", l2, collection.external_integration).value + assert ( + "l1_ils" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, "ils_name", l1, collection.external_integration + ).value + ) + assert ( + "l2_ils" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, "ils_name", l2, collection.external_integration + ).value + ) # This collection will be a child of the first collection. with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Child Collection"), - ("protocol", "Overdrive"), - ("parent_id", collection.id), - ("libraries", json.dumps([{"short_name": "L3", "ils_name": "l3_ils"}])), - ("external_account_id", "child-acctid"), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("name", "Child Collection"), + ("protocol", "Overdrive"), + ("parent_id", collection.id), + ( + "libraries", + json.dumps([{"short_name": "L3", "ils_name": "l3_ils"}]), + ), + ("external_account_id", "child-acctid"), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.status_code == 201 # The collection was created and configured properly. @@ -442,32 +567,45 @@ def test_collections_post_create(self): # One library has access to the collection. assert [child] == l3.collections - assert "l3_ils" == ConfigurationSetting.for_library_and_externalintegration( - self._db, "ils_name", l3, child.external_integration).value + assert ( + "l3_ils" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, "ils_name", l3, child.external_integration + ).value + ) def test_collections_post_edit(self): # The collection exists. collection = self._collection( - name="Collection 1", - protocol=ExternalIntegration.OVERDRIVE + name="Collection 1", protocol=ExternalIntegration.OVERDRIVE ) l1, ignore = create( - self._db, Library, name="Library 1", short_name="L1", + self._db, + Library, + name="Library 1", + short_name="L1", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", collection.id), - ("name", "Collection 1"), - ("protocol", ExternalIntegration.OVERDRIVE), - ("external_account_id", "1234"), - ("username", "user2"), - ("password", "password"), - ("website_id", "1234"), - ("libraries", json.dumps([{"short_name": "L1", "ils_name": "the_ils"}])), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("id", collection.id), + ("name", "Collection 1"), + ("protocol", ExternalIntegration.OVERDRIVE), + ("external_account_id", "1234"), + ("username", "user2"), + ("password", "password"), + ("website_id", "1234"), + ( + "libraries", + json.dumps([{"short_name": "L1", "ils_name": "the_ils"}]), + ), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.status_code == 200 assert collection.id == int(response.response[0]) @@ -483,21 +621,29 @@ def test_collections_post_edit(self): assert "website_id" == setting.key assert "1234" == setting.value - assert "the_ils" == ConfigurationSetting.for_library_and_externalintegration( - self._db, "ils_name", l1, collection.external_integration).value + assert ( + "the_ils" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, "ils_name", l1, collection.external_integration + ).value + ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", collection.id), - ("name", "Collection 1"), - ("protocol", ExternalIntegration.OVERDRIVE), - ("external_account_id", "1234"), - ("username", "user2"), - ("password", "password"), - ("website_id", "1234"), - ("libraries", json.dumps([])), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("id", collection.id), + ("name", "Collection 1"), + ("protocol", ExternalIntegration.OVERDRIVE), + ("external_account_id", "1234"), + ("username", "user2"), + ("password", "password"), + ("website_id", "1234"), + ("libraries", json.dumps([])), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.status_code == 200 assert collection.id == int(response.response[0]) @@ -511,28 +657,32 @@ def test_collections_post_edit(self): # All ConfigurationSettings for that library and collection # have been deleted. - qu = self._db.query(ConfigurationSetting).filter( - ConfigurationSetting.library==l1 - ).filter( - ConfigurationSetting.external_integration==collection.external_integration + qu = ( + self._db.query(ConfigurationSetting) + .filter(ConfigurationSetting.library == l1) + .filter( + ConfigurationSetting.external_integration + == collection.external_integration + ) ) assert 0 == qu.count() - parent = self._collection( - name="Parent", - protocol=ExternalIntegration.OVERDRIVE - ) + parent = self._collection(name="Parent", protocol=ExternalIntegration.OVERDRIVE) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", collection.id), - ("name", "Collection 1"), - ("protocol", ExternalIntegration.OVERDRIVE), - ("parent_id", parent.id), - ("external_account_id", "1234"), - ("libraries", json.dumps([])), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("id", collection.id), + ("name", "Collection 1"), + ("protocol", ExternalIntegration.OVERDRIVE), + ("parent_id", parent.id), + ("external_account_id", "1234"), + ("libraries", json.dumps([])), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.status_code == 200 assert collection.id == int(response.response[0]) @@ -555,14 +705,12 @@ def _base_collections_post_request(self, collection): def test_collections_post_edit_mirror_integration(self): # The collection exists. collection = self._collection( - name="Collection 1", - protocol=ExternalIntegration.AXIS_360 + name="Collection 1", protocol=ExternalIntegration.AXIS_360 ) # There is a storage integration not associated with the collection. storage = self._external_integration( - protocol=ExternalIntegration.S3, - goal=ExternalIntegration.STORAGE_GOAL + protocol=ExternalIntegration.S3, goal=ExternalIntegration.STORAGE_GOAL ) # It's possible to associate the storage integration with the @@ -573,14 +721,17 @@ def test_collections_post_edit_mirror_integration(self): base_request + [("books_mirror_integration_id", storage.id)] ) flask.request.form = request - response = self.manager.admin_collection_settings_controller.process_collections() + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.status_code == 200 # There is an external integration link to associate the collection's # external integration with the storage integration for a books mirror. external_integration_link = get_one( - self._db, ExternalIntegrationLink, - external_integration_id=collection.external_integration.id + self._db, + ExternalIntegrationLink, + external_integration_id=collection.external_integration.id, ) assert storage.id == external_integration_link.other_integration_id @@ -588,40 +739,44 @@ def test_collections_post_edit_mirror_integration(self): controller = self.manager.admin_collection_settings_controller with self.request_context_with_admin("/", method="POST"): request = MultiDict( - base_request + [("books_mirror_integration_id", - str(controller.NO_MIRROR_INTEGRATION))] + base_request + + [ + ( + "books_mirror_integration_id", + str(controller.NO_MIRROR_INTEGRATION), + ) + ] ) flask.request.form = request response = controller.process_collections() assert response.status_code == 200 external_integration_link = get_one( - self._db, ExternalIntegrationLink, - external_integration_id=collection.external_integration.id + self._db, + ExternalIntegrationLink, + external_integration_id=collection.external_integration.id, ) assert None == external_integration_link # Providing a nonexistent integration ID gives an error. with self.request_context_with_admin("/", method="POST"): - request = MultiDict( - base_request + [("books_mirror_integration_id", -200)] - ) + request = MultiDict(base_request + [("books_mirror_integration_id", -200)]) flask.request.form = request - response = self.manager.admin_collection_settings_controller.process_collections() + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response == MISSING_SERVICE def test_cannot_set_non_storage_integration_as_mirror_integration(self): # The collection exists. collection = self._collection( - name="Collection 1", - protocol=ExternalIntegration.AXIS_360 + name="Collection 1", protocol=ExternalIntegration.AXIS_360 ) # There is a storage integration not associated with the collection, # which makes it possible to associate storage integrations # with collections through the collections controller. storage = self._external_integration( - protocol=ExternalIntegration.S3, - goal=ExternalIntegration.STORAGE_GOAL + protocol=ExternalIntegration.S3, goal=ExternalIntegration.STORAGE_GOAL ) # Trying to set a non-storage integration (such as the @@ -630,70 +785,86 @@ def test_cannot_set_non_storage_integration_as_mirror_integration(self): base_request = self._base_collections_post_request(collection) with self.request_context_with_admin("/", method="POST"): request = MultiDict( - base_request + [ - ("books_mirror_integration_id", collection.external_integration.id) - ] + base_request + + [("books_mirror_integration_id", collection.external_integration.id)] ) flask.request.form = request - response = self.manager.admin_collection_settings_controller.process_collections() + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response == INTEGRATION_GOAL_CONFLICT def test_collections_post_edit_library_specific_configuration(self): # The collection exists. collection = self._collection( - name="Collection 1", - protocol=ExternalIntegration.AXIS_360 + name="Collection 1", protocol=ExternalIntegration.AXIS_360 ) l1, ignore = create( - self._db, Library, name="Library 1", short_name="L1", + self._db, + Library, + name="Library 1", + short_name="L1", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", collection.id), - ("name", "Collection 1"), - ("protocol", ExternalIntegration.AXIS_360), - ("external_account_id", "1234"), - ("username", "user2"), - ("password", "password"), - ("url", "http://axis/"), - ("libraries", json.dumps([ - { - "short_name": "L1", - "ebook_loan_duration": "14" - } - ]) - ), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("id", collection.id), + ("name", "Collection 1"), + ("protocol", ExternalIntegration.AXIS_360), + ("external_account_id", "1234"), + ("username", "user2"), + ("password", "password"), + ("url", "http://axis/"), + ( + "libraries", + json.dumps([{"short_name": "L1", "ebook_loan_duration": "14"}]), + ), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.status_code == 200 # Additional settings were set on the collection+library. - assert "14" == ConfigurationSetting.for_library_and_externalintegration( - self._db, "ebook_loan_duration", l1, collection.external_integration).value + assert ( + "14" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, "ebook_loan_duration", l1, collection.external_integration + ).value + ) # Remove the connection between collection and library. with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", collection.id), - ("name", "Collection 1"), - ("protocol", ExternalIntegration.AXIS_360), - ("external_account_id", "1234"), - ("username", "user2"), - ("password", "password"), - ("url", "http://axis/"), - ("libraries", json.dumps([])), - ]) - response = self.manager.admin_collection_settings_controller.process_collections() + flask.request.form = MultiDict( + [ + ("id", collection.id), + ("name", "Collection 1"), + ("protocol", ExternalIntegration.AXIS_360), + ("external_account_id", "1234"), + ("username", "user2"), + ("password", "password"), + ("url", "http://axis/"), + ("libraries", json.dumps([])), + ] + ) + response = ( + self.manager.admin_collection_settings_controller.process_collections() + ) assert response.status_code == 200 assert collection.id == int(response.response[0]) # The settings associated with the collection+library were removed # when the connection between collection and library was deleted. - assert None == ConfigurationSetting.for_library_and_externalintegration( - self._db, "ebook_loan_duration", l1, collection.external_integration).value + assert ( + None + == ConfigurationSetting.for_library_and_externalintegration( + self._db, "ebook_loan_duration", l1, collection.external_integration + ).value + ) assert [] == collection.libraries def test_collection_delete(self): @@ -702,12 +873,16 @@ def test_collection_delete(self): with self.request_context_with_admin("/", method="DELETE"): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) - pytest.raises(AdminNotAuthorized, - self.manager.admin_collection_settings_controller.process_delete, - collection.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_collection_settings_controller.process_delete, + collection.id, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) - response = self.manager.admin_collection_settings_controller.process_delete(collection.id) + response = self.manager.admin_collection_settings_controller.process_delete( + collection.id + ) assert response.status_code == 200 # The collection should still be available because it is not immediately deleted. @@ -723,5 +898,7 @@ def test_collection_delete_cant_delete_parent(self): child.parent = parent with self.request_context_with_admin("/", method="DELETE"): - response = self.manager.admin_collection_settings_controller.process_delete(parent.id) + response = self.manager.admin_collection_settings_controller.process_delete( + parent.id + ) assert CANNOT_DELETE_COLLECTION_WITH_CHILDREN == response diff --git a/tests/admin/controller/test_controller.py b/tests/admin/controller/test_controller.py index 3361786f0a..4a90db17f5 100644 --- a/tests/admin/controller/test_controller.py +++ b/tests/admin/controller/test_controller.py @@ -2,43 +2,37 @@ import datetime import json import re -from io import StringIO from contextlib import contextmanager from datetime import timedelta +from io import StringIO import feedparser import flask import pytest - from werkzeug.datastructures import MultiDict from werkzeug.http import dump_cookie from api.admin.controller import ( - setup_admin_controllers, AdminAnnotator, + PatronController, SettingsController, - PatronController + setup_admin_controllers, ) from api.admin.exceptions import * -from api.admin.google_oauth_admin_authentication_provider import GoogleOAuthAdminAuthenticationProvider -from api.admin.password_admin_authentication_provider import PasswordAdminAuthenticationProvider +from api.admin.google_oauth_admin_authentication_provider import ( + GoogleOAuthAdminAuthenticationProvider, +) +from api.admin.password_admin_authentication_provider import ( + PasswordAdminAuthenticationProvider, +) from api.admin.problem_details import * from api.admin.routes import setup_admin from api.admin.validator import Validator -from api.adobe_vendor_id import ( - AdobeVendorIDModel -) -from api.adobe_vendor_id import AuthdataUtility -from api.authenticator import ( - PatronData, -) -from api.axis import (Axis360API, MockAxis360API) -from api.config import ( - Configuration, -) -from core.classifier import ( - genres -) +from api.adobe_vendor_id import AdobeVendorIDModel, AuthdataUtility +from api.authenticator import PatronData +from api.axis import Axis360API, MockAxis360API +from api.config import Configuration +from core.classifier import genres from core.lane import Lane from core.model import ( Admin, @@ -48,26 +42,25 @@ ConfigurationSetting, CustomList, CustomListEntry, - create, DataSource, Edition, ExternalIntegration, Genre, - get_one, - get_one_or_create, Library, Timestamp, - WorkGenre + WorkGenre, + create, + get_one, + get_one_or_create, ) -from core.opds_import import (OPDSImporter, OPDSImportMonitor) +from core.opds_import import OPDSImporter, OPDSImportMonitor from core.s3 import S3UploaderConfiguration from core.selftest import HasSelfTests -from core.util.datetime_helpers import ( - utc_now, -) +from core.util.datetime_helpers import utc_now from core.util.http import HTTP from tests.test_controller import CirculationControllerTest + class AdminControllerTest(CirculationControllerTest): # Automatically creating books before the test wastes time -- we @@ -76,19 +69,23 @@ class AdminControllerTest(CirculationControllerTest): def setup_method(self): super(AdminControllerTest, self).setup_method() - ConfigurationSetting.sitewide(self._db, Configuration.SECRET_KEY).value = "a secret" + ConfigurationSetting.sitewide( + self._db, Configuration.SECRET_KEY + ).value = "a secret" setup_admin(self._db) setup_admin_controllers(self.manager) self.admin, ignore = create( - self._db, Admin, email='example@nypl.org', + self._db, + Admin, + email="example@nypl.org", ) self.admin.password = "password" @contextmanager def request_context_with_admin(self, route, *args, **kwargs): admin = self.admin - if 'admin' in kwargs: - admin = kwargs.pop('admin') + if "admin" in kwargs: + admin = kwargs.pop("admin") with self.app.test_request_context(route, *args, **kwargs) as c: flask.request.form = {} flask.request.files = {} @@ -100,8 +97,8 @@ def request_context_with_admin(self, route, *args, **kwargs): @contextmanager def request_context_with_library_and_admin(self, route, *args, **kwargs): admin = self.admin - if 'admin' in kwargs: - admin = kwargs.pop('admin') + if "admin" in kwargs: + admin = kwargs.pop("admin") with self.request_context_with_library(route, *args, **kwargs) as c: flask.request.form = {} flask.request.files = {} @@ -110,30 +107,30 @@ def request_context_with_library_and_admin(self, route, *args, **kwargs): yield c self._db.commit() -class TestViewController(AdminControllerTest): +class TestViewController(AdminControllerTest): def test_setting_up(self): # Test that the view is in setting-up mode if there's no auth service # and no admin with a password. self.admin.password_hashed = None - with self.app.test_request_context('/admin'): + with self.app.test_request_context("/admin"): response = self.manager.admin_view_controller(None, None) assert 200 == response.status_code html = response.get_data(as_text=True) - assert 'settingUp: true' in html + assert "settingUp: true" in html def test_not_setting_up(self): - with self.app.test_request_context('/admin'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = PasswordAdminAuthenticationProvider.NAME + with self.app.test_request_context("/admin"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME response = self.manager.admin_view_controller("collection", "book") assert 200 == response.status_code html = response.get_data(as_text=True) - assert 'settingUp: false' in html + assert "settingUp: false" in html def test_redirect_to_sign_in(self): - with self.app.test_request_context('/admin/web/collection/a/(b)/book/c/(d)'): + with self.app.test_request_context("/admin/web/collection/a/(b)/book/c/(d)"): response = self.manager.admin_view_controller("a/(b)", "c/(d)") assert 302 == response.status_code location = response.headers.get("Location") @@ -145,20 +142,23 @@ def test_redirect_to_sign_in(self): def test_redirect_to_library(self): # If the admin doesn't have access to any libraries, they get a message # instead of a redirect. - with self.app.test_request_context('/admin'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = PasswordAdminAuthenticationProvider.NAME + with self.app.test_request_context("/admin"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME response = self.manager.admin_view_controller(None, None) assert 200 == response.status_code - assert "Your admin account doesn't have access to any libraries" in response.get_data(as_text=True) + assert ( + "Your admin account doesn't have access to any libraries" + in response.get_data(as_text=True) + ) # Unless there aren't any libraries yet. In that case, an admin needs to # get in to create one. for library in self._db.query(Library): self._db.delete(library) - with self.app.test_request_context('/admin'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = PasswordAdminAuthenticationProvider.NAME + with self.app.test_request_context("/admin"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME response = self.manager.admin_view_controller(None, None) assert 200 == response.status_code assert "" in response.get_data(as_text=True) @@ -169,9 +169,9 @@ def test_redirect_to_library(self): self.admin.add_role(AdminRole.LIBRARIAN, l1) self.admin.add_role(AdminRole.LIBRARY_MANAGER, l3) # An admin with roles gets redirected to the oldest library they have access to. - with self.app.test_request_context('/admin'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = PasswordAdminAuthenticationProvider.NAME + with self.app.test_request_context("/admin"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME response = self.manager.admin_view_controller(None, None) assert 302 == response.status_code location = response.headers.get("Location") @@ -179,47 +179,48 @@ def test_redirect_to_library(self): # Only the root url redirects - a non-library specific page with another # path won't. - with self.app.test_request_context('/admin/web/config'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = PasswordAdminAuthenticationProvider.NAME + with self.app.test_request_context("/admin/web/config"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME response = self.manager.admin_view_controller(None, None, "config") assert 200 == response.status_code def test_csrf_token(self): self.admin.password_hashed = None - with self.app.test_request_context('/admin'): + with self.app.test_request_context("/admin"): response = self.manager.admin_view_controller(None, None) assert 200 == response.status_code html = response.get_data(as_text=True) # The CSRF token value is random, but the cookie and the html have the same value. - html_csrf_re = re.compile('csrfToken: \"([^\"]*)\"') + html_csrf_re = re.compile('csrfToken: "([^"]*)"') match = html_csrf_re.search(html) assert match != None csrf = match.groups(0)[0] - assert csrf in response.headers.get('Set-Cookie') - assert 'HttpOnly' in response.headers.get("Set-Cookie") + assert csrf in response.headers.get("Set-Cookie") + assert "HttpOnly" in response.headers.get("Set-Cookie") self.admin.password = "password" # If there's a CSRF token in the request cookie, the response # should keep that same token. token = self._str cookie = dump_cookie("csrf_token", token) - with self.app.test_request_context('/admin', environ_base={'HTTP_COOKIE': cookie}): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = PasswordAdminAuthenticationProvider.NAME + with self.app.test_request_context( + "/admin", environ_base={"HTTP_COOKIE": cookie} + ): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME response = self.manager.admin_view_controller("collection", "book") assert 200 == response.status_code html = response.get_data(as_text=True) assert 'csrfToken: "%s"' % token in html - assert token in response.headers.get('Set-Cookie') + assert token in response.headers.get("Set-Cookie") def test_tos_link(self): - def assert_tos(expect_href, expect_text): - with self.app.test_request_context('/admin'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = PasswordAdminAuthenticationProvider.NAME + with self.app.test_request_context("/admin"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME response = self.manager.admin_view_controller("collection", "book") assert 200 == response.status_code html = response.get_data(as_text=True) @@ -248,75 +249,95 @@ def assert_tos(expect_href, expect_text): def test_show_circ_events_download(self): # The local analytics provider will be configured by default if # there isn't one. - with self.app.test_request_context('/admin'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = PasswordAdminAuthenticationProvider.NAME + with self.app.test_request_context("/admin"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME response = self.manager.admin_view_controller("collection", "book") assert 200 == response.status_code html = response.get_data(as_text=True) - assert 'showCircEventsDownload: true' in html + assert "showCircEventsDownload: true" in html def test_roles(self): self.admin.add_role(AdminRole.SITEWIDE_LIBRARIAN) self.admin.add_role(AdminRole.LIBRARY_MANAGER, self._default_library) - with self.app.test_request_context('/admin'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = PasswordAdminAuthenticationProvider.NAME + with self.app.test_request_context("/admin"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME response = self.manager.admin_view_controller("collection", "book") assert 200 == response.status_code html = response.get_data(as_text=True) - assert "\"role\": \"librarian-all\"" in html - assert "\"role\": \"manager\", \"library\": \"%s\"" % self._default_library.short_name in html + assert '"role": "librarian-all"' in html + assert ( + '"role": "manager", "library": "%s"' % self._default_library.short_name + in html + ) + class TestAdminCirculationManagerController(AdminControllerTest): def test_require_system_admin(self): - with self.request_context_with_admin('/admin'): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.require_system_admin) + with self.request_context_with_admin("/admin"): + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.require_system_admin, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) self.manager.admin_work_controller.require_system_admin() def test_require_sitewide_library_manager(self): - with self.request_context_with_admin('/admin'): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.require_sitewide_library_manager) + with self.request_context_with_admin("/admin"): + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.require_sitewide_library_manager, + ) self.admin.add_role(AdminRole.SITEWIDE_LIBRARY_MANAGER) self.manager.admin_work_controller.require_sitewide_library_manager() def test_require_library_manager(self): - with self.request_context_with_admin('/admin'): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.require_library_manager, - self._default_library) + with self.request_context_with_admin("/admin"): + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.require_library_manager, + self._default_library, + ) self.admin.add_role(AdminRole.LIBRARY_MANAGER, self._default_library) - self.manager.admin_work_controller.require_library_manager(self._default_library) + self.manager.admin_work_controller.require_library_manager( + self._default_library + ) def test_require_librarian(self): - with self.request_context_with_admin('/admin'): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.require_librarian, - self._default_library) + with self.request_context_with_admin("/admin"): + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.require_librarian, + self._default_library, + ) self.admin.add_role(AdminRole.LIBRARIAN, self._default_library) self.manager.admin_work_controller.require_librarian(self._default_library) -class TestSignInController(AdminControllerTest): +class TestSignInController(AdminControllerTest): def setup_method(self): super(TestSignInController, self).setup_method() - self.admin.credential = json.dumps({ - 'access_token': 'abc123', - 'client_id': '', 'client_secret': '', - 'refresh_token': '', 'token_expiry': '', 'token_uri': '', - 'user_agent': '', 'invalid': '' - }) + self.admin.credential = json.dumps( + { + "access_token": "abc123", + "client_id": "", + "client_secret": "", + "refresh_token": "", + "token_expiry": "", + "token_uri": "", + "user_agent": "", + "invalid": "", + } + ) self.admin.password_hashed = None def test_admin_auth_providers(self): - with self.app.test_request_context('/admin'): + with self.app.test_request_context("/admin"): ctrl = self.manager.admin_sign_in_controller # An admin exists, but they have no password and there's @@ -325,40 +346,62 @@ def test_admin_auth_providers(self): # The auth service exists. create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) assert 1 == len(ctrl.admin_auth_providers) - assert GoogleOAuthAdminAuthenticationProvider.NAME == ctrl.admin_auth_providers[0].NAME + assert ( + GoogleOAuthAdminAuthenticationProvider.NAME + == ctrl.admin_auth_providers[0].NAME + ) # Here's another admin with a password. pw_admin, ignore = create(self._db, Admin, email="pw@nypl.org") pw_admin.password = "password" assert 2 == len(ctrl.admin_auth_providers) - assert (set([GoogleOAuthAdminAuthenticationProvider.NAME, PasswordAdminAuthenticationProvider.NAME]) == - set([provider.NAME for provider in ctrl.admin_auth_providers])) + assert ( + set( + [ + GoogleOAuthAdminAuthenticationProvider.NAME, + PasswordAdminAuthenticationProvider.NAME, + ] + ) + == set([provider.NAME for provider in ctrl.admin_auth_providers]) + ) # Only an admin with a password. self._db.delete(self.admin) assert 2 == len(ctrl.admin_auth_providers) - assert (set([GoogleOAuthAdminAuthenticationProvider.NAME, PasswordAdminAuthenticationProvider.NAME]) == - set([provider.NAME for provider in ctrl.admin_auth_providers])) + assert ( + set( + [ + GoogleOAuthAdminAuthenticationProvider.NAME, + PasswordAdminAuthenticationProvider.NAME, + ] + ) + == set([provider.NAME for provider in ctrl.admin_auth_providers]) + ) # No admins. Someone new could still log in with google if domains are # configured. self._db.delete(pw_admin) assert 1 == len(ctrl.admin_auth_providers) - assert GoogleOAuthAdminAuthenticationProvider.NAME == ctrl.admin_auth_providers[0].NAME + assert ( + GoogleOAuthAdminAuthenticationProvider.NAME + == ctrl.admin_auth_providers[0].NAME + ) def test_admin_auth_provider(self): - with self.app.test_request_context('/admin'): + with self.app.test_request_context("/admin"): ctrl = self.manager.admin_sign_in_controller create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) # We can find a google auth provider. @@ -381,49 +424,62 @@ def test_admin_auth_provider(self): def test_authenticated_admin_from_request(self): # Returns an error if there's no admin auth service. - with self.app.test_request_context('/admin'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = GoogleOAuthAdminAuthenticationProvider.NAME - response = self.manager.admin_sign_in_controller.authenticated_admin_from_request() + with self.app.test_request_context("/admin"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = GoogleOAuthAdminAuthenticationProvider.NAME + response = ( + self.manager.admin_sign_in_controller.authenticated_admin_from_request() + ) assert ADMIN_AUTH_NOT_CONFIGURED == response # Works once the admin auth service exists. create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) - with self.app.test_request_context('/admin'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = GoogleOAuthAdminAuthenticationProvider.NAME - response = self.manager.admin_sign_in_controller.authenticated_admin_from_request() + with self.app.test_request_context("/admin"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = GoogleOAuthAdminAuthenticationProvider.NAME + response = ( + self.manager.admin_sign_in_controller.authenticated_admin_from_request() + ) assert self.admin == response # Returns an error if you aren't authenticated. - with self.app.test_request_context('/admin'): + with self.app.test_request_context("/admin"): # You get back a problem detail when you're not authenticated. - response = self.manager.admin_sign_in_controller.authenticated_admin_from_request() + response = ( + self.manager.admin_sign_in_controller.authenticated_admin_from_request() + ) assert 401 == response.status_code assert INVALID_ADMIN_CREDENTIALS.detail == response.detail # Returns an error if the admin email or auth type is missing from the session. - with self.app.test_request_context('/admin'): - flask.session['auth_type'] = GoogleOAuthAdminAuthenticationProvider.NAME - response = self.manager.admin_sign_in_controller.authenticated_admin_from_request() + with self.app.test_request_context("/admin"): + flask.session["auth_type"] = GoogleOAuthAdminAuthenticationProvider.NAME + response = ( + self.manager.admin_sign_in_controller.authenticated_admin_from_request() + ) assert 401 == response.status_code assert INVALID_ADMIN_CREDENTIALS.detail == response.detail - with self.app.test_request_context('/admin'): - flask.session['admin_email'] = self.admin.email - response = self.manager.admin_sign_in_controller.authenticated_admin_from_request() + with self.app.test_request_context("/admin"): + flask.session["admin_email"] = self.admin.email + response = ( + self.manager.admin_sign_in_controller.authenticated_admin_from_request() + ) assert 401 == response.status_code assert INVALID_ADMIN_CREDENTIALS.detail == response.detail # Returns an error if the admin authentication type isn't configured. - with self.app.test_request_context('/admin'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = PasswordAdminAuthenticationProvider.NAME - response = self.manager.admin_sign_in_controller.authenticated_admin_from_request() + with self.app.test_request_context("/admin"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME + response = ( + self.manager.admin_sign_in_controller.authenticated_admin_from_request() + ) assert 400 == response.status_code assert ADMIN_AUTH_MECHANISM_NOT_CONFIGURED.detail == response.detail @@ -431,32 +487,39 @@ def test_authenticated_admin(self): # Unset the base URL -- it will be set automatically when we # successfully authenticate as an admin. - base_url = ConfigurationSetting.sitewide( - self._db, Configuration.BASE_URL_KEY - ) + base_url = ConfigurationSetting.sitewide(self._db, Configuration.BASE_URL_KEY) base_url.value = None assert None == base_url.value - # Creates a new admin with fresh details. new_admin_details = { - 'email' : 'admin@nypl.org', - 'credentials' : 'gnarly', - 'type': GoogleOAuthAdminAuthenticationProvider.NAME, - 'roles': [{ "role": AdminRole.LIBRARY_MANAGER, "library": self._default_library.short_name }], + "email": "admin@nypl.org", + "credentials": "gnarly", + "type": GoogleOAuthAdminAuthenticationProvider.NAME, + "roles": [ + { + "role": AdminRole.LIBRARY_MANAGER, + "library": self._default_library.short_name, + } + ], } - with self.app.test_request_context('/admin/sign_in?redirect=foo'): + with self.app.test_request_context("/admin/sign_in?redirect=foo"): flask.request.url = "http://chosen-hostname/admin/sign_in?redirect=foo" - admin = self.manager.admin_sign_in_controller.authenticated_admin(new_admin_details) - assert 'admin@nypl.org' == admin.email - assert 'gnarly' == admin.credential + admin = self.manager.admin_sign_in_controller.authenticated_admin( + new_admin_details + ) + assert "admin@nypl.org" == admin.email + assert "gnarly" == admin.credential [role] = admin.roles assert AdminRole.LIBRARY_MANAGER == role.role assert self._default_library == role.library # Also sets up the admin's flask session. assert "admin@nypl.org" == flask.session["admin_email"] - assert GoogleOAuthAdminAuthenticationProvider.NAME == flask.session["auth_type"] + assert ( + GoogleOAuthAdminAuthenticationProvider.NAME + == flask.session["auth_type"] + ) assert True == flask.session.permanent # The first successfully authenticated admin user automatically @@ -465,16 +528,23 @@ def test_authenticated_admin(self): # Or overwrites credentials for an existing admin. existing_admin_details = { - 'email' : 'example@nypl.org', - 'credentials' : 'b-a-n-a-n-a-s', - 'type': GoogleOAuthAdminAuthenticationProvider.NAME, - 'roles': [{ "role": AdminRole.LIBRARY_MANAGER, "library": self._default_library.short_name }], + "email": "example@nypl.org", + "credentials": "b-a-n-a-n-a-s", + "type": GoogleOAuthAdminAuthenticationProvider.NAME, + "roles": [ + { + "role": AdminRole.LIBRARY_MANAGER, + "library": self._default_library.short_name, + } + ], } - with self.app.test_request_context('/admin/sign_in?redirect=foo'): + with self.app.test_request_context("/admin/sign_in?redirect=foo"): flask.request.url = "http://a-different-hostname/" - admin = self.manager.admin_sign_in_controller.authenticated_admin(existing_admin_details) + admin = self.manager.admin_sign_in_controller.authenticated_admin( + existing_admin_details + ) assert self.admin.id == admin.id - assert 'b-a-n-a-n-a-s' == self.admin.credential + assert "b-a-n-a-n-a-s" == self.admin.credential # No roles were created since the admin already existed. assert [] == admin.roles @@ -485,19 +555,20 @@ def test_authenticated_admin(self): def test_admin_signin(self): # Returns an error if there's no admin auth service. - with self.app.test_request_context('/admin/sign_in?redirect=foo'): + with self.app.test_request_context("/admin/sign_in?redirect=foo"): response = self.manager.admin_sign_in_controller.sign_in() assert ADMIN_AUTH_NOT_CONFIGURED == response create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) # Shows the login page if there's an auth service # but no signed in admin. - with self.app.test_request_context('/admin/sign_in?redirect=foo'): + with self.app.test_request_context("/admin/sign_in?redirect=foo"): response = self.manager.admin_sign_in_controller.sign_in() assert 200 == response.status_code response_data = response.get_data(as_text=True) @@ -509,7 +580,7 @@ def test_admin_signin(self): # If there are multiple auth providers, the login page # shows them all. self.admin.password = "password" - with self.app.test_request_context('/admin/sign_in?redirect=foo'): + with self.app.test_request_context("/admin/sign_in?redirect=foo"): response = self.manager.admin_sign_in_controller.sign_in() assert 200 == response.status_code response_data = response.get_data(as_text=True) @@ -519,9 +590,9 @@ def test_admin_signin(self): assert "Password" in response_data # Redirects to the redirect parameter if an admin is signed in. - with self.app.test_request_context('/admin/sign_in?redirect=foo'): - flask.session['admin_email'] = self.admin.email - flask.session['auth_type'] = PasswordAdminAuthenticationProvider.NAME + with self.app.test_request_context("/admin/sign_in?redirect=foo"): + flask.session["admin_email"] = self.admin.email + flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME response = self.manager.admin_sign_in_controller.sign_in() assert 302 == response.status_code assert "foo" == response.headers["Location"] @@ -530,58 +601,75 @@ def test_redirect_after_google_sign_in(self): self._db.delete(self.admin) # Returns an error if there's no admin auth service. - with self.app.test_request_context('/admin/GoogleOAuth/callback'): - response = self.manager.admin_sign_in_controller.redirect_after_google_sign_in() + with self.app.test_request_context("/admin/GoogleOAuth/callback"): + response = ( + self.manager.admin_sign_in_controller.redirect_after_google_sign_in() + ) assert ADMIN_AUTH_NOT_CONFIGURED == response # Returns an error if the admin auth service isn't google. admin, ignore = create(self._db, Admin, email="admin@nypl.org") admin.password = "password" - with self.app.test_request_context('/admin/GoogleOAuth/callback'): - response = self.manager.admin_sign_in_controller.redirect_after_google_sign_in() + with self.app.test_request_context("/admin/GoogleOAuth/callback"): + response = ( + self.manager.admin_sign_in_controller.redirect_after_google_sign_in() + ) assert ADMIN_AUTH_MECHANISM_NOT_CONFIGURED == response self._db.delete(admin) auth_integration, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) auth_integration.libraries += [self._default_library] setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, "domains", self._default_library, auth_integration) + self._db, "domains", self._default_library, auth_integration + ) # Returns an error if google oauth fails.. - with self.app.test_request_context('/admin/GoogleOAuth/callback?error=foo'): - response = self.manager.admin_sign_in_controller.redirect_after_google_sign_in() + with self.app.test_request_context("/admin/GoogleOAuth/callback?error=foo"): + response = ( + self.manager.admin_sign_in_controller.redirect_after_google_sign_in() + ) assert 400 == response.status_code # Returns an error if the admin email isn't a staff email. setting.value = json.dumps(["alibrary.org"]) - with self.app.test_request_context('/admin/GoogleOAuth/callback?code=1234&state=foo'): - response = self.manager.admin_sign_in_controller.redirect_after_google_sign_in() + with self.app.test_request_context( + "/admin/GoogleOAuth/callback?code=1234&state=foo" + ): + response = ( + self.manager.admin_sign_in_controller.redirect_after_google_sign_in() + ) assert 401 == response.status_code # Redirects to the state parameter if the admin email is valid. setting.value = json.dumps(["nypl.org"]) - with self.app.test_request_context('/admin/GoogleOAuth/callback?code=1234&state=foo'): - response = self.manager.admin_sign_in_controller.redirect_after_google_sign_in() + with self.app.test_request_context( + "/admin/GoogleOAuth/callback?code=1234&state=foo" + ): + response = ( + self.manager.admin_sign_in_controller.redirect_after_google_sign_in() + ) assert 302 == response.status_code assert "foo" == response.headers["Location"] def test_password_sign_in(self): # Returns an error if there's no admin auth service and no admins. - with self.app.test_request_context('/admin/sign_in_with_password'): + with self.app.test_request_context("/admin/sign_in_with_password"): response = self.manager.admin_sign_in_controller.password_sign_in() assert ADMIN_AUTH_NOT_CONFIGURED == response # Returns an error if the admin auth service isn't password auth. auth_integration, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) - with self.app.test_request_context('/admin/sign_in_with_password'): + with self.app.test_request_context("/admin/sign_in_with_password"): response = self.manager.admin_sign_in_controller.password_sign_in() assert ADMIN_AUTH_MECHANISM_NOT_CONFIGURED == response @@ -590,33 +678,45 @@ def test_password_sign_in(self): admin.password = "password" # Returns an error if there's no admin with the provided email. - with self.app.test_request_context('/admin/sign_in_with_password', method='POST'): - flask.request.form = MultiDict([ - ("email", "notanadmin@nypl.org"), - ("password", "password"), - ("redirect", "foo") - ]) + with self.app.test_request_context( + "/admin/sign_in_with_password", method="POST" + ): + flask.request.form = MultiDict( + [ + ("email", "notanadmin@nypl.org"), + ("password", "password"), + ("redirect", "foo"), + ] + ) response = self.manager.admin_sign_in_controller.password_sign_in() assert 401 == response.status_code # Returns an error if the password doesn't match. self.admin.password = "password" - with self.app.test_request_context('/admin/sign_in_with_password', method='POST'): - flask.request.form = MultiDict([ - ("email", self.admin.email), - ("password", "notthepassword"), - ("redirect", "foo") - ]) + with self.app.test_request_context( + "/admin/sign_in_with_password", method="POST" + ): + flask.request.form = MultiDict( + [ + ("email", self.admin.email), + ("password", "notthepassword"), + ("redirect", "foo"), + ] + ) response = self.manager.admin_sign_in_controller.password_sign_in() assert 401 == response.status_code # Redirects if the admin email/password combination is valid. - with self.app.test_request_context('/admin/sign_in_with_password', method='POST'): - flask.request.form = MultiDict([ - ("email", self.admin.email), - ("password", "password"), - ("redirect", "foo") - ]) + with self.app.test_request_context( + "/admin/sign_in_with_password", method="POST" + ): + flask.request.form = MultiDict( + [ + ("email", self.admin.email), + ("password", "password"), + ("redirect", "foo"), + ] + ) response = self.manager.admin_sign_in_controller.password_sign_in() assert 302 == response.status_code assert "foo" == response.headers["Location"] @@ -624,10 +724,12 @@ def test_password_sign_in(self): def test_change_password(self): admin, ignore = create(self._db, Admin, email=self._str) admin.password = "old" - with self.request_context_with_admin('/admin/change_password', admin=admin): - flask.request.form = MultiDict([ - ("password", "new"), - ]) + with self.request_context_with_admin("/admin/change_password", admin=admin): + flask.request.form = MultiDict( + [ + ("password", "new"), + ] + ) response = self.manager.admin_sign_in_controller.change_password() assert 200 == response.status_code assert admin == Admin.authenticate(self._db, admin.email, "new") @@ -636,7 +738,7 @@ def test_change_password(self): def test_sign_out(self): admin, ignore = create(self._db, Admin, email=self._str) admin.password = "pass" - with self.app.test_request_context('/admin/sign_out'): + with self.app.test_request_context("/admin/sign_out"): flask.session["admin_email"] = admin.email flask.session["auth_type"] = PasswordAdminAuthenticationProvider.NAME response = self.manager.admin_sign_in_controller.sign_out() @@ -654,6 +756,7 @@ def setup_method(self): def test__load_patrondata(self): """Test the _load_patrondata helper method.""" + class MockAuthenticator(object): def __init__(self, providers): self.providers = providers @@ -690,8 +793,10 @@ def remote_patron_lookup(self, patrondata): assert 404 == response.status_code assert NO_SUCH_PATRON.uri == response.uri - assert ("This library has no authentication providers, so it has no patrons." == - response.detail) + assert ( + "This library has no authentication providers, so it has no patrons." + == response.detail + ) # Authenticator can't find patron with this identifier authenticator.providers.append(auth_provider) @@ -701,8 +806,10 @@ def remote_patron_lookup(self, patrondata): assert 404 == response.status_code assert NO_SUCH_PATRON.uri == response.uri - assert ("No patron with identifier %s was found at your library" % identifier == - response.detail) + assert ( + "No patron with identifier %s was found at your library" % identifier + == response.detail + ) def test_lookup_patron(self): @@ -732,8 +839,8 @@ def _load_patrondata(self, authenticator): # _load_patrondata() returned a PatronData object. We # converted it to a dictionary, which will be dumped to # JSON on the way out. - assert "An Identifier" == response['authorization_identifier'] - assert "A Patron" == response['personal_name'] + assert "An Identifier" == response["authorization_identifier"] + assert "A Patron" == response["personal_name"] def test_reset_adobe_id(self): # Here's a patron with two Adobe-relevant credentials. @@ -751,6 +858,7 @@ def test_reset_adobe_id(self): # PatronData object, no matter what is asked for. class MockPatronController(PatronController): mock_patrondata = None + def _load_patrondata(self, authenticator): self.called_with = authenticator return self.mock_patrondata @@ -787,8 +895,8 @@ def _load_patrondata(self, authenticator): assert NO_SUCH_PATRON.uri == response.uri assert "Could not create local patron object" in response.detail -class TestTimestampsController(AdminControllerTest): +class TestTimestampsController(AdminControllerTest): def setup_method(self): super(TestTimestampsController, self).setup_method() for timestamp in self._db.query(Timestamp): @@ -799,26 +907,29 @@ def setup_method(self): self.finish = utc_now() cp, ignore = create( - self._db, Timestamp, + self._db, + Timestamp, service_type="coverage_provider", service="test_cp", start=self.start, finish=self.finish, - collection=self.collection + collection=self.collection, ) monitor, ignore = create( - self._db, Timestamp, + self._db, + Timestamp, service_type="monitor", service="test_monitor", start=self.start, finish=self.finish, collection=self.collection, - exception="stack trace string" + exception="stack trace string", ) script, ignore = create( - self._db, Timestamp, + self._db, + Timestamp, achievements="ran a script", service_type="script", service="test_script", @@ -827,7 +938,8 @@ def setup_method(self): ) other, ignore = create( - self._db, Timestamp, + self._db, + Timestamp, service="test_other", start=self.start, finish=self.finish, @@ -835,7 +947,9 @@ def setup_method(self): def test_diagnostics_admin_not_authorized(self): with self.request_context_with_admin("/"): - pytest.raises(AdminNotAuthorized, self.manager.timestamps_controller.diagnostics) + pytest.raises( + AdminNotAuthorized, self.manager.timestamps_controller.diagnostics + ) def test_diagnostics(self): duration = (self.finish - self.start).total_seconds() @@ -844,7 +958,9 @@ def test_diagnostics(self): self.admin.add_role(AdminRole.SYSTEM_ADMIN) response = self.manager.timestamps_controller.diagnostics() - assert set(response.keys()) == set(["coverage_provider", "monitor", "script", "other"]) + assert set(response.keys()) == set( + ["coverage_provider", "monitor", "script", "other"] + ) cp_service = response["coverage_provider"] cp_name, cp_collection = list(cp_service.items())[0] @@ -859,7 +975,9 @@ def test_diagnostics(self): monitor_service = response["monitor"] monitor_name, monitor_collection = list(monitor_service.items())[0] assert monitor_name == "test_monitor" - monitor_collection_name, [monitor_timestamp] = list(monitor_collection.items())[0] + monitor_collection_name, [monitor_timestamp] = list(monitor_collection.items())[ + 0 + ] assert monitor_collection_name == self.collection.name assert monitor_timestamp.get("exception") == "stack trace string" assert monitor_timestamp.get("start") == self.start @@ -886,8 +1004,8 @@ def test_diagnostics(self): assert other_timestamp.get("start") == self.start assert other_timestamp.get("achievements") == None -class TestFeedController(AdminControllerTest): +class TestFeedController(AdminControllerTest): def setup_method(self): super(TestFeedController, self).setup_method() self.admin.add_role(AdminRole.LIBRARIAN, self._default_library) @@ -901,39 +1019,36 @@ def test_complaints(self): "fiction work with complaint 1", language="eng", fiction=True, - with_open_access_download=True) + with_open_access_download=True, + ) complaint1 = self._complaint( - work1.license_pools[0], - type1, - "complaint source 1", - "complaint detail 1") + work1.license_pools[0], type1, "complaint source 1", "complaint detail 1" + ) complaint2 = self._complaint( - work1.license_pools[0], - type2, - "complaint source 2", - "complaint detail 2") + work1.license_pools[0], type2, "complaint source 2", "complaint detail 2" + ) work2 = self._work( "nonfiction work with complaint", language="eng", fiction=False, - with_open_access_download=True) + with_open_access_download=True, + ) complaint3 = self._complaint( - work2.license_pools[0], - type1, - "complaint source 3", - "complaint detail 3") + work2.license_pools[0], type1, "complaint source 3", "complaint detail 3" + ) with self.request_context_with_library_and_admin("/"): response = self.manager.admin_feed_controller.complaints() feed = feedparser.parse(response.get_data(as_text=True)) - entries = feed['entries'] + entries = feed["entries"] assert len(entries) == 2 self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_feed_controller.complaints) + pytest.raises( + AdminNotAuthorized, self.manager.admin_feed_controller.complaints + ) def test_suppressed(self): suppressed_work = self._work(with_open_access_download=True) @@ -944,14 +1059,15 @@ def test_suppressed(self): with self.request_context_with_library_and_admin("/"): response = self.manager.admin_feed_controller.suppressed() feed = feedparser.parse(response.get_data(as_text=True)) - entries = feed['entries'] + entries = feed["entries"] assert 1 == len(entries) - assert suppressed_work.title == entries[0]['title'] + assert suppressed_work.title == entries[0]["title"] self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_feed_controller.suppressed) + pytest.raises( + AdminNotAuthorized, self.manager.admin_feed_controller.suppressed + ) def test_genres(self): with self.app.test_request_context("/"): @@ -959,11 +1075,16 @@ def test_genres(self): for name in genres: top = "Fiction" if genres[name].is_fiction else "Nonfiction" - assert response[top][name] == dict({ - "name": name, - "parents": [parent.name for parent in genres[name].parents], - "subgenres": [subgenre.name for subgenre in genres[name].subgenres] - }) + assert response[top][name] == dict( + { + "name": name, + "parents": [parent.name for parent in genres[name].parents], + "subgenres": [ + subgenre.name for subgenre in genres[name].subgenres + ], + } + ) + class TestCustomListsController(AdminControllerTest): def setup_method(self): @@ -974,13 +1095,17 @@ def test_custom_lists_get(self): # This list has no associated Library and should not be included. no_library, ignore = create(self._db, CustomList, name=self._str) - one_entry, ignore = create(self._db, CustomList, name=self._str, library=self._default_library) + one_entry, ignore = create( + self._db, CustomList, name=self._str, library=self._default_library + ) edition = self._edition() one_entry.add_entry(edition) collection = self._collection() collection.customlists = [one_entry] - no_entries, ignore = create(self._db, CustomList, name=self._str, library=self._default_library) + no_entries, ignore = create( + self._db, CustomList, name=self._str, library=self._default_library + ) with self.request_context_with_library_and_admin("/"): response = self.manager.admin_custom_lists_controller.custom_lists() @@ -1004,53 +1129,85 @@ def test_custom_lists_get(self): self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_custom_lists_controller.custom_lists) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_custom_lists_controller.custom_lists, + ) def test_custom_lists_post_errors(self): - with self.request_context_with_library_and_admin("/", method='POST'): - flask.request.form = MultiDict([ - ("id", "4"), - ("name", "name"), - ]) + with self.request_context_with_library_and_admin("/", method="POST"): + flask.request.form = MultiDict( + [ + ("id", "4"), + ("name", "name"), + ] + ) response = self.manager.admin_custom_lists_controller.custom_lists() assert MISSING_CUSTOM_LIST == response library = self._library() data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) - list, ignore = create(self._db, CustomList, name=self._str, data_source=data_source) + list, ignore = create( + self._db, CustomList, name=self._str, data_source=data_source + ) list.library = library - with self.request_context_with_library_and_admin("/", method='POST'): - flask.request.form = MultiDict([ - ("id", list.id), - ("name", list.name), - ]) + with self.request_context_with_library_and_admin("/", method="POST"): + flask.request.form = MultiDict( + [ + ("id", list.id), + ("name", list.name), + ] + ) response = self.manager.admin_custom_lists_controller.custom_lists() assert CANNOT_CHANGE_LIBRARY_FOR_CUSTOM_LIST == response - list, ignore = create(self._db, CustomList, name=self._str, data_source=data_source, library=self._default_library) - with self.request_context_with_library_and_admin("/", method='POST'): - flask.request.form = MultiDict([ - ("name", list.name), - ]) + list, ignore = create( + self._db, + CustomList, + name=self._str, + data_source=data_source, + library=self._default_library, + ) + with self.request_context_with_library_and_admin("/", method="POST"): + flask.request.form = MultiDict( + [ + ("name", list.name), + ] + ) response = self.manager.admin_custom_lists_controller.custom_lists() assert CUSTOM_LIST_NAME_ALREADY_IN_USE == response - l1, ignore = create(self._db, CustomList, name=self._str, data_source=data_source, library=self._default_library) - l2, ignore = create(self._db, CustomList, name=self._str, data_source=data_source, library=self._default_library) - with self.request_context_with_library_and_admin("/", method='POST'): - flask.request.form = MultiDict([ - ("id", l2.id), - ("name", l1.name), - ]) + l1, ignore = create( + self._db, + CustomList, + name=self._str, + data_source=data_source, + library=self._default_library, + ) + l2, ignore = create( + self._db, + CustomList, + name=self._str, + data_source=data_source, + library=self._default_library, + ) + with self.request_context_with_library_and_admin("/", method="POST"): + flask.request.form = MultiDict( + [ + ("id", l2.id), + ("name", l1.name), + ] + ) response = self.manager.admin_custom_lists_controller.custom_lists() assert CUSTOM_LIST_NAME_ALREADY_IN_USE == response - with self.request_context_with_library_and_admin("/", method='POST'): - flask.request.form = MultiDict([ - ("name", "name"), - ("collections", json.dumps([12345])), - ]) + with self.request_context_with_library_and_admin("/", method="POST"): + flask.request.form = MultiDict( + [ + ("name", "name"), + ("collections", json.dumps([12345])), + ] + ) response = self.manager.admin_custom_lists_controller.custom_lists() assert MISSING_COLLECTION == response @@ -1058,21 +1215,27 @@ def test_custom_lists_post_errors(self): library = self._library() with self.request_context_with_admin("/", method="POST", admin=admin): flask.request.library = library - flask.request.form = MultiDict([ - ("name", "name"), - ("collections", json.dumps([])), - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_custom_lists_controller.custom_lists) + flask.request.form = MultiDict( + [ + ("name", "name"), + ("collections", json.dumps([])), + ] + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_custom_lists_controller.custom_lists, + ) def test_custom_lists_post_collection_with_wrong_library(self): # This collection is not associated with any libraries. collection = self._collection() - with self.request_context_with_library_and_admin("/", method='POST'): - flask.request.form = MultiDict([ - ("name", "name"), - ("collections", json.dumps([collection.id])), - ]) + with self.request_context_with_library_and_admin("/", method="POST"): + flask.request.form = MultiDict( + [ + ("name", "name"), + ("collections", json.dumps([collection.id])), + ] + ) response = self.manager.admin_custom_lists_controller.custom_lists() assert COLLECTION_NOT_ASSOCIATED_WITH_LIBRARY == response @@ -1082,11 +1245,18 @@ def test_custom_lists_create(self): collection.libraries = [self._default_library] with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "List"), - ("entries", json.dumps([dict(id=work.presentation_edition.primary_identifier.urn)])), - ("collections", json.dumps([collection.id])), - ]) + flask.request.form = MultiDict( + [ + ("name", "List"), + ( + "entries", + json.dumps( + [dict(id=work.presentation_edition.primary_identifier.urn)] + ), + ), + ("collections", json.dumps([collection.id])), + ] + ) response = self.manager.admin_custom_lists_controller.custom_lists() assert 201 == response.status_code @@ -1103,7 +1273,13 @@ def test_custom_lists_create(self): def test_custom_list_get(self): data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) - list, ignore = create(self._db, CustomList, name=self._str, library=self._default_library, data_source=data_source) + list, ignore = create( + self._db, + CustomList, + name=self._str, + library=self._default_library, + data_source=data_source, + ) work1 = self._work(with_license_pool=True) work2 = self._work(with_license_pool=True) @@ -1117,8 +1293,9 @@ def test_custom_list_get(self): assert list.name == feed.feed.title assert 2 == len(feed.entries) - [self_custom_list_link] = [x['href'] for x in feed.feed['links'] - if x['rel'] == "self"] + [self_custom_list_link] = [ + x["href"] for x in feed.feed["links"] if x["rel"] == "self" + ] assert self_custom_list_link == feed.feed.id [entry1, entry2] = feed.entries @@ -1134,17 +1311,27 @@ def test_custom_list_get_errors(self): assert MISSING_CUSTOM_LIST == response data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) - list, ignore = create(self._db, CustomList, name=self._str, library=self._default_library, data_source=data_source) + list, ignore = create( + self._db, + CustomList, + name=self._str, + library=self._default_library, + data_source=data_source, + ) self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_custom_lists_controller.custom_list, - list.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_custom_lists_controller.custom_list, + list.id, + ) def test_custom_list_edit(self): data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) - list, ignore = create(self._db, CustomList, name=self._str, data_source=data_source) + list, ignore = create( + self._db, CustomList, name=self._str, data_source=data_source + ) list.library = self._default_library # Create a Lane that depends on this CustomList for its membership. @@ -1158,7 +1345,9 @@ def test_custom_list_edit(self): w2 = self._work(title="Bravo", with_license_pool=True, language="fre") w3 = self._work(title="Charlie", with_license_pool=True) w2.presentation_edition.medium = Edition.AUDIO_MEDIUM - w3.presentation_edition.permanent_work_id = w2.presentation_edition.permanent_work_id + w3.presentation_edition.permanent_work_id = ( + w2.presentation_edition.permanent_work_id + ) w3.presentation_edition.medium = Edition.BOOK_MEDIUM list.add_entry(w1) @@ -1167,17 +1356,31 @@ def test_custom_list_edit(self): # All three works should be indexed, but only w1 and w2 should be related to the list assert len(self.controller.search_engine.docs) == 3 - currently_indexed_on_list = [v['title'] for (k, v) - in self.controller.search_engine.docs.items() - if v['customlists'] is not None] - assert sorted(currently_indexed_on_list) == ['Alpha', 'Bravo'] - - new_entries = [dict(id=work.presentation_edition.primary_identifier.urn, - medium=Edition.medium_to_additional_type[work.presentation_edition.medium]) - for work in [w2, w3]] - deletedEntries = [dict(id=work.presentation_edition.primary_identifier.urn, - medium=Edition.medium_to_additional_type[work.presentation_edition.medium]) - for work in [w1]] + currently_indexed_on_list = [ + v["title"] + for (k, v) in self.controller.search_engine.docs.items() + if v["customlists"] is not None + ] + assert sorted(currently_indexed_on_list) == ["Alpha", "Bravo"] + + new_entries = [ + dict( + id=work.presentation_edition.primary_identifier.urn, + medium=Edition.medium_to_additional_type[ + work.presentation_edition.medium + ], + ) + for work in [w2, w3] + ] + deletedEntries = [ + dict( + id=work.presentation_edition.primary_identifier.urn, + medium=Edition.medium_to_additional_type[ + work.presentation_edition.medium + ], + ) + for work in [w1] + ] c1 = self._collection() c1.libraries = [self._default_library] @@ -1192,30 +1395,33 @@ def test_custom_list_edit(self): assert lane.size == 350 with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", str(list.id)), - ("name", "new name"), - ("entries", json.dumps(new_entries)), - ("deletedEntries", json.dumps(deletedEntries)), - ("collections", json.dumps([c.id for c in new_collections])), - ]) + flask.request.form = MultiDict( + [ + ("id", str(list.id)), + ("name", "new name"), + ("entries", json.dumps(new_entries)), + ("deletedEntries", json.dumps(deletedEntries)), + ("collections", json.dumps([c.id for c in new_collections])), + ] + ) response = self.manager.admin_custom_lists_controller.custom_list(list.id) # The works associated with the list in ES should have changed, though the total # number of documents in the index should be the same. assert len(self.controller.search_engine.docs) == 3 - currently_indexed_on_list = [v['title'] for (k, v) - in self.controller.search_engine.docs.items() - if v['customlists'] is not None] - assert sorted(currently_indexed_on_list) == ['Bravo', 'Charlie'] + currently_indexed_on_list = [ + v["title"] + for (k, v) in self.controller.search_engine.docs.items() + if v["customlists"] is not None + ] + assert sorted(currently_indexed_on_list) == ["Bravo", "Charlie"] assert 200 == response.status_code assert list.id == int(response.get_data(as_text=True)) assert "new name" == list.name - assert (set([w2, w3]) == - set([entry.work for entry in list.entries])) + assert set([w2, w3]) == set([entry.work for entry in list.entries]) assert new_collections == list.collections # If we were using a real search engine instance, the lane's size would be set @@ -1228,15 +1434,19 @@ def test_custom_list_edit(self): self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", str(list.id)), - ("name", "another new name"), - ("entries", json.dumps(new_entries)), - ("collections", json.dumps([c.id for c in new_collections])), - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_custom_lists_controller.custom_list, - list.id) + flask.request.form = MultiDict( + [ + ("id", str(list.id)), + ("name", "another new name"), + ("entries", json.dumps(new_entries)), + ("collections", json.dumps([c.id for c in new_collections])), + ] + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_custom_lists_controller.custom_list, + list.id, + ) def test_custom_list_delete_success(self): self.admin.add_role(AdminRole.LIBRARY_MANAGER, self._default_library) @@ -1260,9 +1470,7 @@ def test_custom_list_delete_success(self): # Create a second CustomList, from another data source, # containing a single work. nyt = DataSource.lookup(self._db, DataSource.NYT) - list2, ignore = create( - self._db, CustomList, name=self._str, data_source=nyt - ) + list2, ignore = create(self._db, CustomList, name=self._str, data_source=nyt) list2.library = self._default_library list2.add_entry(w2) @@ -1316,11 +1524,15 @@ def test_custom_list_delete_success(self): def test_custom_list_delete_errors(self): data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) - list, ignore = create(self._db, CustomList, name=self._str, data_source=data_source) + list, ignore = create( + self._db, CustomList, name=self._str, data_source=data_source + ) with self.request_context_with_library_and_admin("/", method="DELETE"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_custom_lists_controller.custom_list, - list.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_custom_lists_controller.custom_list, + list.id, + ) self.admin.add_role(AdminRole.LIBRARY_MANAGER, self._default_library) with self.request_context_with_library_and_admin("/", method="DELETE"): @@ -1341,10 +1553,14 @@ def test_lanes_get(self): english = self._lane("English", library=library, languages=["eng"]) english.priority = 0 english.size = 44 - english_fiction = self._lane("Fiction", library=library, parent=english, fiction=True) + english_fiction = self._lane( + "Fiction", library=library, parent=english, fiction=True + ) english_fiction.visible = False english_fiction.size = 33 - english_sf = self._lane("Science Fiction", library=library, parent=english_fiction) + english_sf = self._lane( + "Science Fiction", library=library, parent=english_fiction + ) english_sf.add_genre("Science Fiction") english_sf.inherit_parent_restrictions = True english_sf.size = 22 @@ -1352,10 +1568,19 @@ def test_lanes_get(self): spanish.priority = 1 spanish.size = 11 - w1 = self._work(with_license_pool=True, language="eng", genre="Science Fiction", collection=collection) - w2 = self._work(with_license_pool=True, language="eng", fiction=False, collection=collection) + w1 = self._work( + with_license_pool=True, + language="eng", + genre="Science Fiction", + collection=collection, + ) + w2 = self._work( + with_license_pool=True, language="eng", fiction=False, collection=collection + ) - list, ignore = self._customlist(data_source_name=DataSource.LIBRARY_STAFF, num_entries=0) + list, ignore = self._customlist( + data_source_name=DataSource.LIBRARY_STAFF, num_entries=0 + ) list.library = library lane_for_list = self._lane("List Lane", library=library) lane_for_list.customlists += [list] @@ -1410,92 +1635,110 @@ def test_lanes_get(self): assert True == list_info.get("inherit_parent_restrictions") def test_lanes_post_errors(self): - with self.request_context_with_library_and_admin("/", method='POST'): - flask.request.form = MultiDict([ - ]) + with self.request_context_with_library_and_admin("/", method="POST"): + flask.request.form = MultiDict([]) response = self.manager.admin_lanes_controller.lanes() assert NO_DISPLAY_NAME_FOR_LANE == response - with self.request_context_with_library_and_admin("/", method='POST'): - flask.request.form = MultiDict([ - ("display_name", "lane"), - ]) + with self.request_context_with_library_and_admin("/", method="POST"): + flask.request.form = MultiDict( + [ + ("display_name", "lane"), + ] + ) response = self.manager.admin_lanes_controller.lanes() assert NO_CUSTOM_LISTS_FOR_LANE == response - list, ignore = self._customlist(data_source_name=DataSource.LIBRARY_STAFF, num_entries=0) + list, ignore = self._customlist( + data_source_name=DataSource.LIBRARY_STAFF, num_entries=0 + ) list.library = self._default_library with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", "12345"), - ("display_name", "lane"), - ("custom_list_ids", json.dumps([list.id])), - ]) + flask.request.form = MultiDict( + [ + ("id", "12345"), + ("display_name", "lane"), + ("custom_list_ids", json.dumps([list.id])), + ] + ) response = self.manager.admin_lanes_controller.lanes() assert MISSING_LANE == response library = self._library() - with self.request_context_with_library_and_admin("/", method='POST'): + with self.request_context_with_library_and_admin("/", method="POST"): flask.request.library = library - flask.request.form = MultiDict([ - ("display_name", "lane"), - ("custom_list_ids", json.dumps([list.id])), - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_lanes_controller.lanes) + flask.request.form = MultiDict( + [ + ("display_name", "lane"), + ("custom_list_ids", json.dumps([list.id])), + ] + ) + pytest.raises(AdminNotAuthorized, self.manager.admin_lanes_controller.lanes) lane1 = self._lane("lane1") lane2 = self._lane("lane2") with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", lane1.id), - ("display_name", "lane1"), - ("custom_list_ids", json.dumps([list.id])), - ]) + flask.request.form = MultiDict( + [ + ("id", lane1.id), + ("display_name", "lane1"), + ("custom_list_ids", json.dumps([list.id])), + ] + ) response = self.manager.admin_lanes_controller.lanes() assert CANNOT_EDIT_DEFAULT_LANE == response lane1.customlists += [list] with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", lane1.id), - ("display_name", "lane2"), - ("custom_list_ids", json.dumps([list.id])), - ]) + flask.request.form = MultiDict( + [ + ("id", lane1.id), + ("display_name", "lane2"), + ("custom_list_ids", json.dumps([list.id])), + ] + ) response = self.manager.admin_lanes_controller.lanes() assert LANE_WITH_PARENT_AND_DISPLAY_NAME_ALREADY_EXISTS == response with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("display_name", "lane2"), - ("custom_list_ids", json.dumps([list.id])), - ]) + flask.request.form = MultiDict( + [ + ("display_name", "lane2"), + ("custom_list_ids", json.dumps([list.id])), + ] + ) response = self.manager.admin_lanes_controller.lanes() assert LANE_WITH_PARENT_AND_DISPLAY_NAME_ALREADY_EXISTS == response with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("parent_id", "12345"), - ("display_name", "lane"), - ("custom_list_ids", json.dumps([list.id])), - ]) + flask.request.form = MultiDict( + [ + ("parent_id", "12345"), + ("display_name", "lane"), + ("custom_list_ids", json.dumps([list.id])), + ] + ) response = self.manager.admin_lanes_controller.lanes() assert MISSING_LANE.uri == response.uri with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("parent_id", lane1.id), - ("display_name", "lane"), - ("custom_list_ids", json.dumps(["12345"])), - ]) + flask.request.form = MultiDict( + [ + ("parent_id", lane1.id), + ("display_name", "lane"), + ("custom_list_ids", json.dumps(["12345"])), + ] + ) response = self.manager.admin_lanes_controller.lanes() assert MISSING_CUSTOM_LIST.uri == response.uri def test_lanes_create(self): - list, ignore = self._customlist(data_source_name=DataSource.LIBRARY_STAFF, num_entries=0) + list, ignore = self._customlist( + data_source_name=DataSource.LIBRARY_STAFF, num_entries=0 + ) list.library = self._default_library # The new lane's parent has a sublane already. @@ -1503,16 +1746,18 @@ def test_lanes_create(self): sibling = self._lane("sibling", parent=parent) with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("parent_id", parent.id), - ("display_name", "lane"), - ("custom_list_ids", json.dumps([list.id])), - ("inherit_parent_restrictions", "false"), - ]) + flask.request.form = MultiDict( + [ + ("parent_id", parent.id), + ("display_name", "lane"), + ("custom_list_ids", json.dumps([list.id])), + ("inherit_parent_restrictions", "false"), + ] + ) response = self.manager.admin_lanes_controller.lanes() assert 201 == response.status_code - [lane] = self._db.query(Lane).filter(Lane.display_name=="lane") + [lane] = self._db.query(Lane).filter(Lane.display_name == "lane") assert lane.id == int(response.get_data(as_text=True)) assert self._default_library == lane.library assert "lane" == lane.display_name @@ -1530,9 +1775,13 @@ def test_lanes_edit(self): work = self._work(with_license_pool=True) - list1, ignore = self._customlist(data_source_name=DataSource.LIBRARY_STAFF, num_entries=0) + list1, ignore = self._customlist( + data_source_name=DataSource.LIBRARY_STAFF, num_entries=0 + ) list1.library = self._default_library - list2, ignore = self._customlist(data_source_name=DataSource.LIBRARY_STAFF, num_entries=0) + list2, ignore = self._customlist( + data_source_name=DataSource.LIBRARY_STAFF, num_entries=0 + ) list2.library = self._default_library list2.add_entry(work) @@ -1546,12 +1795,14 @@ def test_lanes_edit(self): self.controller.search_engine.docs = dict(id1="value1", id2="value2") with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", str(lane.id)), - ("display_name", "new name"), - ("custom_list_ids", json.dumps([list2.id])), - ("inherit_parent_restrictions", "true"), - ]) + flask.request.form = MultiDict( + [ + ("id", str(lane.id)), + ("display_name", "new name"), + ("custom_list_ids", json.dumps([list2.id])), + ("inherit_parent_restrictions", "true"), + ] + ) response = self.manager.admin_lanes_controller.lanes() assert 200 == response.status_code @@ -1567,10 +1818,12 @@ def test_lane_delete_success(self): library = self._library() self.admin.add_role(AdminRole.LIBRARY_MANAGER, library) lane = self._lane("lane", library=library) - list, ignore = self._customlist(data_source_name=DataSource.LIBRARY_STAFF, num_entries=0) + list, ignore = self._customlist( + data_source_name=DataSource.LIBRARY_STAFF, num_entries=0 + ) list.library = library lane.customlists += [list] - assert 1 == self._db.query(Lane).filter(Lane.library==library).count() + assert 1 == self._db.query(Lane).filter(Lane.library == library).count() with self.request_context_with_library_and_admin("/", method="DELETE"): flask.request.library = library @@ -1578,10 +1831,15 @@ def test_lane_delete_success(self): assert 200 == response.status_code # The lane has been deleted. - assert 0 == self._db.query(Lane).filter(Lane.library==library).count() + assert 0 == self._db.query(Lane).filter(Lane.library == library).count() # The custom list still exists though. - assert 1 == self._db.query(CustomList).filter(CustomList.library==library).count() + assert ( + 1 + == self._db.query(CustomList) + .filter(CustomList.library == library) + .count() + ) lane = self._lane("lane", library=library) lane.customlists += [list] @@ -1589,7 +1847,7 @@ def test_lane_delete_success(self): child.customlists += [list] grandchild = self._lane("grandchild", parent=child, library=library) grandchild.customlists += [list] - assert 3 == self._db.query(Lane).filter(Lane.library==library).count() + assert 3 == self._db.query(Lane).filter(Lane.library == library).count() with self.request_context_with_library_and_admin("/", method="DELETE"): flask.request.library = library @@ -1597,10 +1855,15 @@ def test_lane_delete_success(self): assert 200 == response.status_code # The lanes have all been deleted. - assert 0 == self._db.query(Lane).filter(Lane.library==library).count() + assert 0 == self._db.query(Lane).filter(Lane.library == library).count() # The custom list still exists though. - assert 1 == self._db.query(CustomList).filter(CustomList.library==library).count() + assert ( + 1 + == self._db.query(CustomList) + .filter(CustomList.library == library) + .count() + ) def test_lane_delete_errors(self): with self.request_context_with_library_and_admin("/", method="DELETE"): @@ -1611,9 +1874,9 @@ def test_lane_delete_errors(self): library = self._library() with self.request_context_with_library_and_admin("/", method="DELETE"): flask.request.library = library - pytest.raises(AdminNotAuthorized, - self.manager.admin_lanes_controller.lane, - lane.id) + pytest.raises( + AdminNotAuthorized, self.manager.admin_lanes_controller.lane, lane.id + ) with self.request_context_with_library_and_admin("/", method="DELETE"): response = self.manager.admin_lanes_controller.lane(lane.id) @@ -1643,9 +1906,11 @@ def test_show_lane_errors(self): self.admin.remove_role(AdminRole.LIBRARY_MANAGER, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_lanes_controller.show_lane, - parent.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_lanes_controller.show_lane, + parent.id, + ) def test_hide_lane_success(self): lane = self._lane("lane") @@ -1663,9 +1928,11 @@ def test_hide_lane_errors(self): lane = self._lane() self.admin.remove_role(AdminRole.LIBRARY_MANAGER, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_lanes_controller.show_lane, - lane.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_lanes_controller.show_lane, + lane.id, + ) def test_reset(self): library = self._library() @@ -1680,10 +1947,16 @@ def test_reset(self): assert 200 == response.status_code # The old lane is gone. - assert 0 == self._db.query(Lane).filter(Lane.library==library).filter(Lane.id==old_lane.id).count() + assert ( + 0 + == self._db.query(Lane) + .filter(Lane.library == library) + .filter(Lane.id == old_lane.id) + .count() + ) # tests/test_lanes.py tests the default lane creation, but make sure some # lanes were created. - assert 0 < self._db.query(Lane).filter(Lane.library==library).count() + assert 0 < self._db.query(Lane).filter(Lane.library == library).count() def test_change_order(self): library = self._library() @@ -1696,14 +1969,18 @@ def test_change_order(self): child1.priority = 0 child2.priority = 1 - new_order = [{ "id": parent2.id, "sublanes": [{ "id": child2.id }, { "id": child1.id }] }, - { "id": parent1.id }] + new_order = [ + {"id": parent2.id, "sublanes": [{"id": child2.id}, {"id": child1.id}]}, + {"id": parent1.id}, + ] with self.request_context_with_library_and_admin("/"): flask.request.library = library flask.request.data = json.dumps(new_order) - pytest.raises(AdminNotAuthorized, self.manager.admin_lanes_controller.change_order) + pytest.raises( + AdminNotAuthorized, self.manager.admin_lanes_controller.change_order + ) self.admin.add_role(AdminRole.LIBRARY_MANAGER, library) response = self.manager.admin_lanes_controller.change_order() @@ -1714,6 +1991,7 @@ def test_change_order(self): assert 0 == child2.priority assert 1 == child1.priority + class TestDashboardController(AdminControllerTest): # Unlike most of these controllers, we do want to have a book @@ -1727,56 +2005,75 @@ def test_circulation_events(self): CirculationEvent.DISTRIBUTOR_CHECKOUT, CirculationEvent.DISTRIBUTOR_HOLD_PLACE, CirculationEvent.DISTRIBUTOR_HOLD_RELEASE, - CirculationEvent.DISTRIBUTOR_TITLE_ADD + CirculationEvent.DISTRIBUTOR_TITLE_ADD, ] time = utc_now() - timedelta(minutes=len(types)) for type in types: get_one_or_create( - self._db, CirculationEvent, - license_pool=lp, type=type, start=time, end=time, + self._db, + CirculationEvent, + license_pool=lp, + type=type, + start=time, + end=time, ) time += timedelta(minutes=1) with self.request_context_with_library_and_admin("/"): response = self.manager.admin_dashboard_controller.circulation_events() - url = AdminAnnotator(self.manager.d_circulation, self._default_library).permalink_for(self.english_1, lp, lp.identifier) - - events = response['circulation_events'] - assert types[::-1] == [event['type'] for event in events] - assert [self.english_1.title]*len(types) == [event['book']['title'] for event in events] - assert [url]*len(types) == [event['book']['url'] for event in events] + url = AdminAnnotator( + self.manager.d_circulation, self._default_library + ).permalink_for(self.english_1, lp, lp.identifier) + + events = response["circulation_events"] + assert types[::-1] == [event["type"] for event in events] + assert [self.english_1.title] * len(types) == [ + event["book"]["title"] for event in events + ] + assert [url] * len(types) == [event["book"]["url"] for event in events] # request fewer events with self.request_context_with_library_and_admin("/?num=2"): response = self.manager.admin_dashboard_controller.circulation_events() - url = AdminAnnotator(self.manager.d_circulation, self._default_library).permalink_for(self.english_1, lp, lp.identifier) + url = AdminAnnotator( + self.manager.d_circulation, self._default_library + ).permalink_for(self.english_1, lp, lp.identifier) - assert 2 == len(response['circulation_events']) + assert 2 == len(response["circulation_events"]) def test_bulk_circulation_events(self): [lp] = self.english_1.license_pools edition = self.english_1.presentation_edition identifier = self.english_1.presentation_edition.primary_identifier genres = self._db.query(Genre).all() - get_one_or_create(self._db, WorkGenre, work=self.english_1, genre=genres[0], affinity=0.2) + get_one_or_create( + self._db, WorkGenre, work=self.english_1, genre=genres[0], affinity=0.2 + ) time = utc_now() - timedelta(minutes=1) event, ignore = get_one_or_create( - self._db, CirculationEvent, - license_pool=lp, type=CirculationEvent.DISTRIBUTOR_CHECKOUT, - start=time, end=time + self._db, + CirculationEvent, + license_pool=lp, + type=CirculationEvent.DISTRIBUTOR_CHECKOUT, + start=time, + end=time, ) time += timedelta(minutes=1) # Try an end-to-end test, getting all circulation events for # the current day. with self.app.test_request_context("/"): - response, requested_date, date_end, library_short_name = self.manager.admin_dashboard_controller.bulk_circulation_events() + ( + response, + requested_date, + date_end, + library_short_name, + ) = self.manager.admin_dashboard_controller.bulk_circulation_events() reader = csv.reader( - [row for row in response.split("\r\n") if row], - dialect=csv.excel + [row for row in response.split("\r\n") if row], dialect=csv.excel ) - rows = [row for row in reader][1::] # skip header row + rows = [row for row in reader][1::] # skip header row assert 1 == len(rows) [row] = rows assert CirculationEvent.DISTRIBUTOR_CHECKOUT == row[1] @@ -1789,14 +2086,21 @@ def test_bulk_circulation_events(self): # parameters into a LocalAnalyticsExporter object. class MockLocalAnalyticsExporter(object): def export(self, _db, date_start, date_end, locations, library): - self.called_with = ( - _db, date_start, date_end, locations, library - ) + self.called_with = (_db, date_start, date_end, locations, library) return "A CSV file" exporter = MockLocalAnalyticsExporter() - with self.request_context_with_library("/?date=2018-01-01&dateEnd=2018-01-04&locations=loc1,loc2"): - response, requested_date, date_end, library_short_name = self.manager.admin_dashboard_controller.bulk_circulation_events(analytics_exporter=exporter) + with self.request_context_with_library( + "/?date=2018-01-01&dateEnd=2018-01-04&locations=loc1,loc2" + ): + ( + response, + requested_date, + date_end, + library_short_name, + ) = self.manager.admin_dashboard_controller.bulk_circulation_events( + analytics_exporter=exporter + ) # export() was called with the arguments we expect. # @@ -1824,7 +2128,6 @@ def export(self, _db, date_start, date_end, locations, library): assert "2018-01-04" == date_end assert self._default_library.short_name == library_short_name - def test_stats_patrons(self): with self.request_context_with_admin("/"): self.admin.add_role(AdminRole.SYSTEM_ADMIN) @@ -1834,14 +2137,16 @@ def test_stats_patrons(self): library_data = response.get(self._default_library.short_name) total_data = response.get("total") for data in [library_data, total_data]: - patron_data = data.get('patrons') - assert 1 == patron_data.get('total') - assert 0 == patron_data.get('with_active_loans') - assert 0 == patron_data.get('with_active_loans_or_holds') - assert 0 == patron_data.get('loans') - assert 0 == patron_data.get('holds') - - edition, pool = self._edition(with_license_pool=True, with_open_access_download=False) + patron_data = data.get("patrons") + assert 1 == patron_data.get("total") + assert 0 == patron_data.get("with_active_loans") + assert 0 == patron_data.get("with_active_loans_or_holds") + assert 0 == patron_data.get("loans") + assert 0 == patron_data.get("holds") + + edition, pool = self._edition( + with_license_pool=True, with_open_access_download=False + ) edition2, open_access_pool = self._edition(with_open_access_download=True) # patron1 has a loan. @@ -1861,12 +2166,12 @@ def test_stats_patrons(self): library_data = response.get(self._default_library.short_name) total_data = response.get("total") for data in [library_data, total_data]: - patron_data = data.get('patrons') - assert 4 == patron_data.get('total') - assert 1 == patron_data.get('with_active_loans') - assert 2 == patron_data.get('with_active_loans_or_holds') - assert 1 == patron_data.get('loans') - assert 1 == patron_data.get('holds') + patron_data = data.get("patrons") + assert 4 == patron_data.get("total") + assert 1 == patron_data.get("with_active_loans") + assert 2 == patron_data.get("with_active_loans_or_holds") + assert 1 == patron_data.get("loans") + assert 1 == patron_data.get("holds") # These patrons are in a different library.. l2 = self._library() @@ -1878,16 +2183,16 @@ def test_stats_patrons(self): response = self.manager.admin_dashboard_controller.stats() library_data = response.get(self._default_library.short_name) total_data = response.get("total") - assert 4 == library_data.get('patrons').get('total') - assert 1 == library_data.get('patrons').get('with_active_loans') - assert 2 == library_data.get('patrons').get('with_active_loans_or_holds') - assert 1 == library_data.get('patrons').get('loans') - assert 1 == library_data.get('patrons').get('holds') - assert 6 == total_data.get('patrons').get('total') - assert 2 == total_data.get('patrons').get('with_active_loans') - assert 4 == total_data.get('patrons').get('with_active_loans_or_holds') - assert 2 == total_data.get('patrons').get('loans') - assert 2 == total_data.get('patrons').get('holds') + assert 4 == library_data.get("patrons").get("total") + assert 1 == library_data.get("patrons").get("with_active_loans") + assert 2 == library_data.get("patrons").get("with_active_loans_or_holds") + assert 1 == library_data.get("patrons").get("loans") + assert 1 == library_data.get("patrons").get("holds") + assert 6 == total_data.get("patrons").get("total") + assert 2 == total_data.get("patrons").get("with_active_loans") + assert 4 == total_data.get("patrons").get("with_active_loans_or_holds") + assert 2 == total_data.get("patrons").get("loans") + assert 2 == total_data.get("patrons").get("holds") # If the admin only has access to some libraries, only those will be counted # in the total stats. @@ -1897,16 +2202,16 @@ def test_stats_patrons(self): response = self.manager.admin_dashboard_controller.stats() library_data = response.get(self._default_library.short_name) total_data = response.get("total") - assert 4 == library_data.get('patrons').get('total') - assert 1 == library_data.get('patrons').get('with_active_loans') - assert 2 == library_data.get('patrons').get('with_active_loans_or_holds') - assert 1 == library_data.get('patrons').get('loans') - assert 1 == library_data.get('patrons').get('holds') - assert 4 == total_data.get('patrons').get('total') - assert 1 == total_data.get('patrons').get('with_active_loans') - assert 2 == total_data.get('patrons').get('with_active_loans_or_holds') - assert 1 == total_data.get('patrons').get('loans') - assert 1 == total_data.get('patrons').get('holds') + assert 4 == library_data.get("patrons").get("total") + assert 1 == library_data.get("patrons").get("with_active_loans") + assert 2 == library_data.get("patrons").get("with_active_loans_or_holds") + assert 1 == library_data.get("patrons").get("loans") + assert 1 == library_data.get("patrons").get("holds") + assert 4 == total_data.get("patrons").get("total") + assert 1 == total_data.get("patrons").get("with_active_loans") + assert 2 == total_data.get("patrons").get("with_active_loans_or_holds") + assert 1 == total_data.get("patrons").get("loans") + assert 1 == total_data.get("patrons").get("holds") def test_stats_inventory(self): with self.request_context_with_admin("/"): @@ -1918,23 +2223,29 @@ def test_stats_inventory(self): library_data = response.get(self._default_library.short_name) total_data = response.get("total") for data in [library_data, total_data]: - inventory_data = data.get('inventory') - assert 1 == inventory_data.get('titles') - assert 0 == inventory_data.get('licenses') - assert 0 == inventory_data.get('available_licenses') + inventory_data = data.get("inventory") + assert 1 == inventory_data.get("titles") + assert 0 == inventory_data.get("licenses") + assert 0 == inventory_data.get("available_licenses") # This edition has no licenses owned and isn't counted in the inventory. - edition1, pool1 = self._edition(with_license_pool=True, with_open_access_download=False) + edition1, pool1 = self._edition( + with_license_pool=True, with_open_access_download=False + ) pool1.open_access = False pool1.licenses_owned = 0 pool1.licenses_available = 0 - edition2, pool2 = self._edition(with_license_pool=True, with_open_access_download=False) + edition2, pool2 = self._edition( + with_license_pool=True, with_open_access_download=False + ) pool2.open_access = False pool2.licenses_owned = 10 pool2.licenses_available = 0 - edition3, pool3 = self._edition(with_license_pool=True, with_open_access_download=False) + edition3, pool3 = self._edition( + with_license_pool=True, with_open_access_download=False + ) pool3.open_access = False pool3.licenses_owned = 5 pool3.licenses_available = 4 @@ -1943,26 +2254,28 @@ def test_stats_inventory(self): library_data = response.get(self._default_library.short_name) total_data = response.get("total") for data in [library_data, total_data]: - inventory_data = data.get('inventory') - assert 3 == inventory_data.get('titles') - assert 15 == inventory_data.get('licenses') - assert 4 == inventory_data.get('available_licenses') + inventory_data = data.get("inventory") + assert 3 == inventory_data.get("titles") + assert 15 == inventory_data.get("licenses") + assert 4 == inventory_data.get("available_licenses") # This edition is in a different collection. c2 = self._collection() - edition4, pool4 = self._edition(with_license_pool=True, with_open_access_download=False, collection=c2) + edition4, pool4 = self._edition( + with_license_pool=True, with_open_access_download=False, collection=c2 + ) pool4.licenses_owned = 2 pool4.licenses_available = 2 response = self.manager.admin_dashboard_controller.stats() library_data = response.get(self._default_library.short_name) total_data = response.get("total") - assert 3 == library_data.get('inventory').get('titles') - assert 4 == total_data.get('inventory').get('titles') - assert 15 == library_data.get('inventory').get('licenses') - assert 17 == total_data.get('inventory').get('licenses') - assert 4 == library_data.get('inventory').get('available_licenses') - assert 6 == total_data.get('inventory').get('available_licenses') + assert 3 == library_data.get("inventory").get("titles") + assert 4 == total_data.get("inventory").get("titles") + assert 15 == library_data.get("inventory").get("licenses") + assert 17 == total_data.get("inventory").get("licenses") + assert 4 == library_data.get("inventory").get("available_licenses") + assert 6 == total_data.get("inventory").get("available_licenses") self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self.admin.add_role(AdminRole.LIBRARIAN, self._default_library) @@ -1973,10 +2286,10 @@ def test_stats_inventory(self): library_data = response.get(self._default_library.short_name) total_data = response.get("total") for data in [library_data, total_data]: - inventory_data = data.get('inventory') - assert 3 == inventory_data.get('titles') - assert 15 == inventory_data.get('licenses') - assert 4 == inventory_data.get('available_licenses') + inventory_data = data.get("inventory") + assert 3 == inventory_data.get("titles") + assert 15 == inventory_data.get("licenses") + assert 4 == inventory_data.get("available_licenses") def test_stats_collections(self): with self.request_context_with_admin("/"): @@ -1988,45 +2301,53 @@ def test_stats_collections(self): library_data = response.get(self._default_library.short_name) total_data = response.get("total") for data in [library_data, total_data]: - collections_data = data.get('collections') + collections_data = data.get("collections") assert 1 == len(collections_data) collection_data = collections_data.get(self._default_collection.name) - assert 0 == collection_data.get('licensed_titles') - assert 1 == collection_data.get('open_access_titles') - assert 0 == collection_data.get('licenses') - assert 0 == collection_data.get('available_licenses') + assert 0 == collection_data.get("licensed_titles") + assert 1 == collection_data.get("open_access_titles") + assert 0 == collection_data.get("licenses") + assert 0 == collection_data.get("available_licenses") c2 = self._collection() c3 = self._collection() c3.libraries += [self._default_library] - edition1, pool1 = self._edition(with_license_pool=True, - with_open_access_download=False, - data_source_name=DataSource.OVERDRIVE, - collection=c2) + edition1, pool1 = self._edition( + with_license_pool=True, + with_open_access_download=False, + data_source_name=DataSource.OVERDRIVE, + collection=c2, + ) pool1.open_access = False pool1.licenses_owned = 10 pool1.licenses_available = 5 - edition2, pool2 = self._edition(with_license_pool=True, - with_open_access_download=False, - data_source_name=DataSource.OVERDRIVE, - collection=c3) + edition2, pool2 = self._edition( + with_license_pool=True, + with_open_access_download=False, + data_source_name=DataSource.OVERDRIVE, + collection=c3, + ) pool2.open_access = False pool2.licenses_owned = 0 pool2.licenses_available = 0 - edition3, pool3 = self._edition(with_license_pool=True, - with_open_access_download=False, - data_source_name=DataSource.BIBLIOTHECA) + edition3, pool3 = self._edition( + with_license_pool=True, + with_open_access_download=False, + data_source_name=DataSource.BIBLIOTHECA, + ) pool3.open_access = False pool3.licenses_owned = 3 pool3.licenses_available = 0 - edition4, pool4 = self._edition(with_license_pool=True, - with_open_access_download=False, - data_source_name=DataSource.AXIS_360, - collection=c2) + edition4, pool4 = self._edition( + with_license_pool=True, + with_open_access_download=False, + data_source_name=DataSource.AXIS_360, + collection=c2, + ) pool4.open_access = False pool4.licenses_owned = 5 pool4.licenses_available = 5 @@ -2034,29 +2355,29 @@ def test_stats_collections(self): response = self.manager.admin_dashboard_controller.stats() library_data = response.get(self._default_library.short_name) total_data = response.get("total") - library_collections_data = library_data.get('collections') - total_collections_data = total_data.get('collections') + library_collections_data = library_data.get("collections") + total_collections_data = total_data.get("collections") assert 2 == len(library_collections_data) assert 3 == len(total_collections_data) for data in [library_collections_data, total_collections_data]: c1_data = data.get(self._default_collection.name) - assert 1 == c1_data.get('licensed_titles') - assert 1 == c1_data.get('open_access_titles') - assert 3 == c1_data.get('licenses') - assert 0 == c1_data.get('available_licenses') + assert 1 == c1_data.get("licensed_titles") + assert 1 == c1_data.get("open_access_titles") + assert 3 == c1_data.get("licenses") + assert 0 == c1_data.get("available_licenses") c3_data = data.get(c3.name) - assert 0 == c3_data.get('licensed_titles') - assert 0 == c3_data.get('open_access_titles') - assert 0 == c3_data.get('licenses') - assert 0 == c3_data.get('available_licenses') + assert 0 == c3_data.get("licensed_titles") + assert 0 == c3_data.get("open_access_titles") + assert 0 == c3_data.get("licenses") + assert 0 == c3_data.get("available_licenses") assert None == library_collections_data.get(c2.name) c2_data = total_collections_data.get(c2.name) - assert 2 == c2_data.get('licensed_titles') - assert 0 == c2_data.get('open_access_titles') - assert 15 == c2_data.get('licenses') - assert 10 == c2_data.get('available_licenses') + assert 2 == c2_data.get("licensed_titles") + assert 0 == c2_data.get("open_access_titles") + assert 15 == c2_data.get("licenses") + assert 10 == c2_data.get("available_licenses") self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self.admin.add_role(AdminRole.LIBRARY_MANAGER, self._default_library) @@ -2072,16 +2393,16 @@ def test_stats_collections(self): assert None == collections_data.get(c2.name) c1_data = collections_data.get(self._default_collection.name) - assert 1 == c1_data.get('licensed_titles') - assert 1 == c1_data.get('open_access_titles') - assert 3 == c1_data.get('licenses') - assert 0 == c1_data.get('available_licenses') + assert 1 == c1_data.get("licensed_titles") + assert 1 == c1_data.get("open_access_titles") + assert 3 == c1_data.get("licenses") + assert 0 == c1_data.get("available_licenses") c3_data = collections_data.get(c3.name) - assert 0 == c3_data.get('licensed_titles') - assert 0 == c3_data.get('open_access_titles') - assert 0 == c3_data.get('licenses') - assert 0 == c3_data.get('available_licenses') + assert 0 == c3_data.get("licensed_titles") + assert 0 == c3_data.get("open_access_titles") + assert 0 == c3_data.get("licenses") + assert 0 == c3_data.get("available_licenses") class SettingsControllerTest(AdminControllerTest): @@ -2091,14 +2412,16 @@ def setup_method(self): super(SettingsControllerTest, self).setup_method() # Delete any existing patron auth services created by controller test setup. for auth_service in self._db.query(ExternalIntegration).filter( - ExternalIntegration.goal==ExternalIntegration.PATRON_AUTH_GOAL - ): + ExternalIntegration.goal == ExternalIntegration.PATRON_AUTH_GOAL + ): self._db.delete(auth_service) # Delete any existing sitewide ConfigurationSettings. - for setting in self._db.query(ConfigurationSetting).filter( - ConfigurationSetting.library_id==None).filter( - ConfigurationSetting.external_integration_id==None): + for setting in ( + self._db.query(ConfigurationSetting) + .filter(ConfigurationSetting.library_id == None) + .filter(ConfigurationSetting.external_integration_id == None) + ): self._db.delete(setting) self.responses = [] @@ -2119,7 +2442,7 @@ def mock_prior_test_results(self, *args, **kwargs): duration=0.9, start="2018-08-08T16:04:05Z", end="2018-08-08T16:05:05Z", - results=[] + results=[], ) self.self_test_results = self_test_results @@ -2163,7 +2486,9 @@ def test_get_prior_test_results(self): OPDSCollection = self._collection() # If a collection's protocol is OPDSImporter, make sure that # OPDSImportMonitor.prior_test_results is called - self_test_results = controller._get_prior_test_results(OPDSCollection, OPDSImporter) + self_test_results = controller._get_prior_test_results( + OPDSCollection, OPDSImporter + ) args = self.prior_test_results_called_with[0] assert args[1] == OPDSImportMonitor assert args[3] == OPDSCollection @@ -2173,39 +2498,44 @@ def test_get_prior_test_results(self): @classmethod def oops(cls, *args, **kwargs): raise Exception("Test result disaster!") + HasSelfTests.prior_test_results = oops self_test_results = controller._get_prior_test_results( OPDSCollection, OPDSImporter ) assert ( - "Exception getting self-test results for collection %s: Test result disaster!" % ( - OPDSCollection.name - ) == - self_test_results["exception"]) + "Exception getting self-test results for collection %s: Test result disaster!" + % (OPDSCollection.name) + == self_test_results["exception"] + ) HasSelfTests.prior_test_results = old_prior_test_results class TestSettingsController(SettingsControllerTest): - def test_get_integration_protocols(self): """Test the _get_integration_protocols helper method.""" + class Protocol(object): - __module__ = 'my name' - NAME = 'my label' - DESCRIPTION = 'my description' + __module__ = "my name" + NAME = "my label" + DESCRIPTION = "my description" SITEWIDE = True - SETTINGS = [1,2,3] - CHILD_SETTINGS = [4,5] + SETTINGS = [1, 2, 3] + CHILD_SETTINGS = [4, 5] LIBRARY_SETTINGS = [6] CARDINALITY = 1 [result] = SettingsController._get_integration_protocols([Protocol]) expect = dict( - sitewide=True, description='my description', - settings=[1, 2, 3], library_settings=[6], - child_settings=[4, 5], label='my label', - cardinality=1, name='my name' + sitewide=True, + description="my description", + settings=[1, 2, 3], + library_settings=[6], + child_settings=[4, 5], + label="my label", + cardinality=1, + name="my name", ) assert expect == result @@ -2214,11 +2544,11 @@ class Protocol(object): # And look in a different place for the name. [result] = SettingsController._get_integration_protocols( - [Protocol], protocol_name_attr='NAME' + [Protocol], protocol_name_attr="NAME" ) - assert 'my label' == result['name'] - assert 'cardinality' not in result + assert "my label" == result["name"] + assert "cardinality" not in result def test_get_integration_info(self): """Test the _get_integration_info helper method.""" @@ -2228,9 +2558,7 @@ def test_get_integration_info(self): # with the given goal, but none of them match the # configuration. goal = self._str - integration = self._external_integration( - protocol="a protocol", goal=goal - ) + integration = self._external_integration(protocol="a protocol", goal=goal) assert [] == m(goal, [dict(name="some other protocol")]) def test_create_integration(self): @@ -2245,15 +2573,15 @@ def test_create_integration(self): goal = "some goal" # You get an error if you don't pass in a protocol. - assert ( - (NO_PROTOCOL_FOR_NEW_SERVICE, False) == - m(protocol_definitions, None, goal)) + assert (NO_PROTOCOL_FOR_NEW_SERVICE, False) == m( + protocol_definitions, None, goal + ) # You get an error if you do provide a protocol but no definition # for it can be found. - assert ( - (UNKNOWN_PROTOCOL, False) == - m(protocol_definitions, "no definition", goal)) + assert (UNKNOWN_PROTOCOL, False) == m( + protocol_definitions, "no definition", goal + ) # If the protocol has multiple cardinality you can create as many # integrations using that protocol as you want. @@ -2282,104 +2610,126 @@ class MockValidator(Validator): def __init__(self): self.was_called = False self.args = [] + def validate(self, settings, content): self.was_called = True self.args.append(settings) self.args.append(content) + def validate_error(self, settings, content): return INVALID_EMAIL validator = MockValidator() with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "The New York Public Library"), - ("short_name", "nypl"), - (Configuration.WEBSITE_URL, "https://library.library/"), - (Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, "email@example.com"), - (Configuration.HELP_EMAIL, "help@example.com") - ]) - flask.request.files = MultiDict([ - (Configuration.LOGO, StringIO()) - ]) - response = self.manager.admin_settings_controller.validate_formats(Configuration.LIBRARY_SETTINGS, validator) + flask.request.form = MultiDict( + [ + ("name", "The New York Public Library"), + ("short_name", "nypl"), + (Configuration.WEBSITE_URL, "https://library.library/"), + ( + Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, + "email@example.com", + ), + (Configuration.HELP_EMAIL, "help@example.com"), + ] + ) + flask.request.files = MultiDict([(Configuration.LOGO, StringIO())]) + response = self.manager.admin_settings_controller.validate_formats( + Configuration.LIBRARY_SETTINGS, validator + ) assert response == None assert validator.was_called == True assert validator.args[0] == Configuration.LIBRARY_SETTINGS - assert validator.args[1] == {"files": flask.request.files, "form": flask.request.form} + assert validator.args[1] == { + "files": flask.request.files, + "form": flask.request.form, + } validator.validate = validator.validate_error # If the validator returns an problem detail, validate_formats returns it. - response = self.manager.admin_settings_controller.validate_formats(Configuration.LIBRARY_SETTINGS, validator) + response = self.manager.admin_settings_controller.validate_formats( + Configuration.LIBRARY_SETTINGS, validator + ) assert response == INVALID_EMAIL def test__mirror_integration_settings(self): # If no storage integrations are available, return none - mirror_integration_settings = self.manager.admin_settings_controller._mirror_integration_settings + mirror_integration_settings = ( + self.manager.admin_settings_controller._mirror_integration_settings + ) assert None == mirror_integration_settings() # Storages created will appear for settings of any purpose storage1 = self._external_integration( - "protocol1", ExternalIntegration.STORAGE_GOAL, name="storage1", + "protocol1", + ExternalIntegration.STORAGE_GOAL, + name="storage1", settings={ - S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: 'covers', - S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'open-access-books', - S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: 'protected-access-books' - } + S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "covers", + S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "open-access-books", + S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: "protected-access-books", + }, ) settings = mirror_integration_settings() assert settings[0]["key"] == "covers_mirror_integration_id" assert settings[0]["label"] == "Covers Mirror" - assert (settings[0]["options"][0]['key'] == - self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION) - assert (settings[0]["options"][1]['key'] == - str(storage1.id)) + assert ( + settings[0]["options"][0]["key"] + == self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION + ) + assert settings[0]["options"][1]["key"] == str(storage1.id) assert settings[1]["key"] == "books_mirror_integration_id" assert settings[1]["label"] == "Open Access Books Mirror" - assert (settings[1]["options"][0]['key'] == - self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION) - assert (settings[1]["options"][1]['key'] == - str(storage1.id)) + assert ( + settings[1]["options"][0]["key"] + == self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION + ) + assert settings[1]["options"][1]["key"] == str(storage1.id) assert settings[2]["label"] == "Protected Access Books Mirror" - assert (settings[2]["options"][0]['key'] == - self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION) - assert (settings[2]["options"][1]['key'] == - str(storage1.id)) + assert ( + settings[2]["options"][0]["key"] + == self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION + ) + assert settings[2]["options"][1]["key"] == str(storage1.id) storage2 = self._external_integration( - "protocol2", ExternalIntegration.STORAGE_GOAL, name="storage2", + "protocol2", + ExternalIntegration.STORAGE_GOAL, + name="storage2", settings={ - S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: 'covers', - S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'open-access-books', - S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: 'protected-access-books' - } + S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "covers", + S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "open-access-books", + S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: "protected-access-books", + }, ) settings = mirror_integration_settings() assert settings[0]["key"] == "covers_mirror_integration_id" assert settings[0]["label"] == "Covers Mirror" - assert (settings[0]["options"][0]['key'] == - self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION) - assert (settings[0]["options"][1]['key'] == - str(storage1.id)) - assert (settings[0]["options"][2]['key'] == - str(storage2.id)) + assert ( + settings[0]["options"][0]["key"] + == self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION + ) + assert settings[0]["options"][1]["key"] == str(storage1.id) + assert settings[0]["options"][2]["key"] == str(storage2.id) assert settings[1]["key"] == "books_mirror_integration_id" assert settings[1]["label"] == "Open Access Books Mirror" - assert (settings[1]["options"][0]['key'] == - self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION) - assert (settings[1]["options"][1]['key'] == - str(storage1.id)) - assert (settings[1]["options"][2]['key'] == - str(storage2.id)) + assert ( + settings[1]["options"][0]["key"] + == self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION + ) + assert settings[1]["options"][1]["key"] == str(storage1.id) + assert settings[1]["options"][2]["key"] == str(storage2.id) assert settings[2]["label"] == "Protected Access Books Mirror" - assert (settings[2]["options"][0]['key'] == - self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION) - assert (settings[2]["options"][1]['key'] == - str(storage1.id)) + assert ( + settings[2]["options"][0]["key"] + == self.manager.admin_settings_controller.NO_MIRROR_INTEGRATION + ) + assert settings[2]["options"][1]["key"] == str(storage1.id) def test_check_url_unique(self): # Verify our ability to catch duplicate integrations for a @@ -2397,9 +2747,7 @@ def test_check_url_unique(self): # Here's another ExternalIntegration that might or might not # be about to become a duplicate of the original. - new = self._external_integration( - protocol=protocol, goal="new goal" - ) + new = self._external_integration(protocol=protocol, goal="new goal") new.goal = original.goal assert new != original @@ -2417,9 +2765,7 @@ def is_dupe(url, protocol, goal): ) # The original ExternalIntegration is not a duplicate of itself. - assert ( - None == - m(original, original.url, protocol, goal)) + assert None == m(original, original.url, protocol, goal) # However, any other ExternalIntegration with the same URL, # protocol, and goal is considered a duplicate. @@ -2454,16 +2800,16 @@ def m(url): assert [] == m("not a url") # Variants of an HTTP URL with a trailing slash. - assert ( - ['http://url/', 'http://url', 'https://url/', 'https://url'] == - m("http://url/")) + assert ["http://url/", "http://url", "https://url/", "https://url"] == m( + "http://url/" + ) # Variants of an HTTPS URL with a trailing slash. - assert ( - ['https://url/', 'https://url', 'http://url/', 'http://url'] == - m("https://url/")) + assert ["https://url/", "https://url", "http://url/", "http://url"] == m( + "https://url/" + ) # Variants of a URL with no trailing slash. - assert ( - ['https://url', 'https://url/', 'http://url', 'http://url/'] == - m("https://url")) + assert ["https://url", "https://url/", "http://url", "http://url/"] == m( + "https://url" + ) diff --git a/tests/admin/controller/test_discovery_services.py b/tests/admin/controller/test_discovery_services.py index f543996473..38ae98df28 100644 --- a/tests/admin/controller/test_discovery_services.py +++ b/tests/admin/controller/test_discovery_services.py @@ -25,24 +25,33 @@ class TestDiscoveryServices(SettingsControllerTest): def test_discovery_services_get_with_no_services_creates_default(self): with self.request_context_with_admin("/"): - response = self.manager.admin_discovery_services_controller.process_discovery_services() + response = ( + self.manager.admin_discovery_services_controller.process_discovery_services() + ) [service] = response.get("discovery_services") protocols = response.get("protocols") - assert ExternalIntegration.OPDS_REGISTRATION in [p.get("name") for p in protocols] + assert ExternalIntegration.OPDS_REGISTRATION in [ + p.get("name") for p in protocols + ] assert "settings" in protocols[0] assert ExternalIntegration.OPDS_REGISTRATION == service.get("protocol") - assert RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_URL == service.get("settings").get(ExternalIntegration.URL) + assert RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_URL == service.get( + "settings" + ).get(ExternalIntegration.URL) assert RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_NAME == service.get("name") # Only system admins can see the discovery services. self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self._db.flush() - pytest.raises(AdminNotAuthorized, - self.manager.admin_discovery_services_controller.process_discovery_services) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_discovery_services_controller.process_discovery_services, + ) def test_discovery_services_get_with_one_service(self): discovery_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.OPDS_REGISTRATION, goal=ExternalIntegration.DISCOVERY_GOAL, ) @@ -56,111 +65,138 @@ def test_discovery_services_get_with_one_service(self): assert discovery_service.id == service.get("id") assert discovery_service.protocol == service.get("protocol") - assert discovery_service.url == service.get("settings").get(ExternalIntegration.URL) + assert discovery_service.url == service.get("settings").get( + ExternalIntegration.URL + ) def test_discovery_services_post_errors(self): controller = self.manager.admin_discovery_services_controller with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("protocol", "Unknown"), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("protocol", "Unknown"), + ] + ) response = controller.process_discovery_services() assert response == UNKNOWN_PROTOCOL with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ] + ) response = controller.process_discovery_services() assert response == NO_PROTOCOL_FOR_NEW_SERVICE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", "123"), - ("protocol", ExternalIntegration.OPDS_REGISTRATION), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", "123"), + ("protocol", ExternalIntegration.OPDS_REGISTRATION), + ] + ) response = controller.process_discovery_services() assert response == MISSING_SERVICE service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.OPDS_REGISTRATION, goal=ExternalIntegration.DISCOVERY_GOAL, name="name", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", service.name), - ("protocol", ExternalIntegration.OPDS_REGISTRATION), - ]) + flask.request.form = MultiDict( + [ + ("name", service.name), + ("protocol", ExternalIntegration.OPDS_REGISTRATION), + ] + ) response = controller.process_discovery_services() assert response == INTEGRATION_NAME_ALREADY_IN_USE existing_integration = self._external_integration( ExternalIntegration.OPDS_REGISTRATION, ExternalIntegration.DISCOVERY_GOAL, - url=self._url + url=self._url, ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "new name"), - ("protocol", existing_integration.protocol), - ("url", existing_integration.url) - ]) + flask.request.form = MultiDict( + [ + ("name", "new name"), + ("protocol", existing_integration.protocol), + ("url", existing_integration.url), + ] + ) response = controller.process_discovery_services() assert response == INTEGRATION_URL_ALREADY_IN_USE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", service.id), - ("protocol", ExternalIntegration.OPDS_REGISTRATION), - ]) + flask.request.form = MultiDict( + [ + ("id", service.id), + ("protocol", ExternalIntegration.OPDS_REGISTRATION), + ] + ) response = controller.process_discovery_services() assert response.uri == INCOMPLETE_CONFIGURATION.uri self.admin.remove_role(AdminRole.SYSTEM_ADMIN) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", ExternalIntegration.OPDS_REGISTRATION), - (ExternalIntegration.URL, "registry url"), - ]) - pytest.raises(AdminNotAuthorized, - controller.process_discovery_services) + flask.request.form = MultiDict( + [ + ("protocol", ExternalIntegration.OPDS_REGISTRATION), + (ExternalIntegration.URL, "registry url"), + ] + ) + pytest.raises(AdminNotAuthorized, controller.process_discovery_services) def test_discovery_services_post_create(self): with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("protocol", ExternalIntegration.OPDS_REGISTRATION), - (ExternalIntegration.URL, "http://registry_url"), - ]) - response = self.manager.admin_discovery_services_controller.process_discovery_services() + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("protocol", ExternalIntegration.OPDS_REGISTRATION), + (ExternalIntegration.URL, "http://registry_url"), + ] + ) + response = ( + self.manager.admin_discovery_services_controller.process_discovery_services() + ) assert response.status_code == 201 - service = get_one(self._db, ExternalIntegration, goal=ExternalIntegration.DISCOVERY_GOAL) + service = get_one( + self._db, ExternalIntegration, goal=ExternalIntegration.DISCOVERY_GOAL + ) assert service.id == int(response.response[0]) assert ExternalIntegration.OPDS_REGISTRATION == service.protocol assert "http://registry_url" == service.url def test_discovery_services_post_edit(self): discovery_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.OPDS_REGISTRATION, goal=ExternalIntegration.DISCOVERY_GOAL, ) discovery_service.url = "registry url" with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", discovery_service.id), - ("protocol", ExternalIntegration.OPDS_REGISTRATION), - (ExternalIntegration.URL, "http://new_registry_url"), - ]) - response = self.manager.admin_discovery_services_controller.process_discovery_services() + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", discovery_service.id), + ("protocol", ExternalIntegration.OPDS_REGISTRATION), + (ExternalIntegration.URL, "http://new_registry_url"), + ] + ) + response = ( + self.manager.admin_discovery_services_controller.process_discovery_services() + ) assert response.status_code == 200 assert discovery_service.id == int(response.response[0]) @@ -168,32 +204,35 @@ def test_discovery_services_post_edit(self): assert "http://new_registry_url" == discovery_service.url def test_check_name_unique(self): - kwargs = dict(protocol=ExternalIntegration.OPDS_REGISTRATION, - goal=ExternalIntegration.DISCOVERY_GOAL,) + kwargs = dict( + protocol=ExternalIntegration.OPDS_REGISTRATION, + goal=ExternalIntegration.DISCOVERY_GOAL, + ) - existing_service, ignore = create(self._db, ExternalIntegration, name="existing service", **kwargs) - new_service, ignore = create(self._db, ExternalIntegration, name="new service", **kwargs) + existing_service, ignore = create( + self._db, ExternalIntegration, name="existing service", **kwargs + ) + new_service, ignore = create( + self._db, ExternalIntegration, name="new service", **kwargs + ) - m = self.manager.admin_discovery_services_controller.check_name_unique + m = self.manager.admin_discovery_services_controller.check_name_unique - # Try to change new service so that it has the same name as existing service - # -- this is not allowed. - result = m(new_service, existing_service.name) - assert result == INTEGRATION_NAME_ALREADY_IN_USE + # Try to change new service so that it has the same name as existing service + # -- this is not allowed. + result = m(new_service, existing_service.name) + assert result == INTEGRATION_NAME_ALREADY_IN_USE - # Try to edit existing service without changing its name -- this is fine. - assert ( - None == - m(existing_service, existing_service.name)) + # Try to edit existing service without changing its name -- this is fine. + assert None == m(existing_service, existing_service.name) - # Changing the existing service's name is also fine. - assert ( - None == - m(existing_service, "new name")) + # Changing the existing service's name is also fine. + assert None == m(existing_service, "new name") def test_discovery_service_delete(self): discovery_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.OPDS_REGISTRATION, goal=ExternalIntegration.DISCOVERY_GOAL, ) @@ -201,12 +240,16 @@ def test_discovery_service_delete(self): with self.request_context_with_admin("/", method="DELETE"): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) - pytest.raises(AdminNotAuthorized, - self.manager.admin_discovery_services_controller.process_delete, - discovery_service.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_discovery_services_controller.process_delete, + discovery_service.id, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) - response = self.manager.admin_discovery_services_controller.process_delete(discovery_service.id) + response = self.manager.admin_discovery_services_controller.process_delete( + discovery_service.id + ) assert response.status_code == 200 service = get_one(self._db, ExternalIntegration, id=discovery_service.id) diff --git a/tests/admin/controller/test_individual_admins.py b/tests/admin/controller/test_individual_admins.py index a8e3fecaa5..a71e233a7c 100644 --- a/tests/admin/controller/test_individual_admins.py +++ b/tests/admin/controller/test_individual_admins.py @@ -1,22 +1,18 @@ -import pytest +import json import flask +import pytest from flask_babel import lazy_gettext as _ -import json from werkzeug.datastructures import MultiDict + from api.admin.exceptions import * from api.admin.problem_details import * -from core.model import ( - Admin, - AdminRole, - create, - get_one, -) +from core.model import Admin, AdminRole, create, get_one from .test_controller import SettingsControllerTest -class TestIndividualAdmins(SettingsControllerTest): +class TestIndividualAdmins(SettingsControllerTest): def test_individual_admins_get(self): for admin in self._db.query(Admin): self._db.delete(admin) @@ -41,79 +37,280 @@ def test_individual_admins_get(self): with self.request_context_with_admin("/", admin=admin1): # A system admin can see all other admins' roles. - response = self.manager.admin_individual_admin_settings_controller.process_get() + response = ( + self.manager.admin_individual_admin_settings_controller.process_get() + ) admins = response.get("individualAdmins") - assert (sorted([{"email": "admin1@nypl.org", "roles": [{ "role": AdminRole.SYSTEM_ADMIN }]}, - {"email": "admin2@nypl.org", "roles": [{ "role": AdminRole.LIBRARY_MANAGER, "library": self._default_library.short_name }, { "role": AdminRole.SITEWIDE_LIBRARIAN }]}, - {"email": "admin3@nypl.org", "roles": [{ "role": AdminRole.LIBRARIAN, "library": self._default_library.short_name }]}, - {"email": "admin4@l2.org", "roles": [{ "role": AdminRole.LIBRARY_MANAGER, "library": library2.short_name }]}, - {"email": "admin5@l2.org", "roles": [{ "role": AdminRole.LIBRARIAN, "library": library2.short_name }]}], key=lambda x:x["email"]) == - sorted(admins, key=lambda x:x["email"])) + assert ( + sorted( + [ + { + "email": "admin1@nypl.org", + "roles": [{"role": AdminRole.SYSTEM_ADMIN}], + }, + { + "email": "admin2@nypl.org", + "roles": [ + { + "role": AdminRole.LIBRARY_MANAGER, + "library": self._default_library.short_name, + }, + {"role": AdminRole.SITEWIDE_LIBRARIAN}, + ], + }, + { + "email": "admin3@nypl.org", + "roles": [ + { + "role": AdminRole.LIBRARIAN, + "library": self._default_library.short_name, + } + ], + }, + { + "email": "admin4@l2.org", + "roles": [ + { + "role": AdminRole.LIBRARY_MANAGER, + "library": library2.short_name, + } + ], + }, + { + "email": "admin5@l2.org", + "roles": [ + { + "role": AdminRole.LIBRARIAN, + "library": library2.short_name, + } + ], + }, + ], + key=lambda x: x["email"], + ) + == sorted(admins, key=lambda x: x["email"]) + ) with self.request_context_with_admin("/", admin=admin2): # A sitewide librarian or library manager can also see all admins' roles. - response = self.manager.admin_individual_admin_settings_controller.process_get() + response = ( + self.manager.admin_individual_admin_settings_controller.process_get() + ) admins = response.get("individualAdmins") - assert (sorted([{"email": "admin1@nypl.org", "roles": [{ "role": AdminRole.SYSTEM_ADMIN }]}, - {"email": "admin2@nypl.org", "roles": [{ "role": AdminRole.LIBRARY_MANAGER, "library": self._default_library.short_name }, { "role": AdminRole.SITEWIDE_LIBRARIAN }]}, - {"email": "admin3@nypl.org", "roles": [{ "role": AdminRole.LIBRARIAN, "library": self._default_library.short_name }]}, - {"email": "admin4@l2.org", "roles": [{ "role": AdminRole.LIBRARY_MANAGER, "library": library2.short_name }]}, - {"email": "admin5@l2.org", "roles": [{ "role": AdminRole.LIBRARIAN, "library": library2.short_name }]}], key=lambda x:x["email"]) == - sorted(admins, key=lambda x:x["email"])) + assert ( + sorted( + [ + { + "email": "admin1@nypl.org", + "roles": [{"role": AdminRole.SYSTEM_ADMIN}], + }, + { + "email": "admin2@nypl.org", + "roles": [ + { + "role": AdminRole.LIBRARY_MANAGER, + "library": self._default_library.short_name, + }, + {"role": AdminRole.SITEWIDE_LIBRARIAN}, + ], + }, + { + "email": "admin3@nypl.org", + "roles": [ + { + "role": AdminRole.LIBRARIAN, + "library": self._default_library.short_name, + } + ], + }, + { + "email": "admin4@l2.org", + "roles": [ + { + "role": AdminRole.LIBRARY_MANAGER, + "library": library2.short_name, + } + ], + }, + { + "email": "admin5@l2.org", + "roles": [ + { + "role": AdminRole.LIBRARIAN, + "library": library2.short_name, + } + ], + }, + ], + key=lambda x: x["email"], + ) + == sorted(admins, key=lambda x: x["email"]) + ) with self.request_context_with_admin("/", admin=admin3): # A librarian or library manager of a specific library can see all admins, but only # roles that affect their libraries. - response = self.manager.admin_individual_admin_settings_controller.process_get() + response = ( + self.manager.admin_individual_admin_settings_controller.process_get() + ) admins = response.get("individualAdmins") - assert (sorted([{"email": "admin1@nypl.org", "roles": [{ "role": AdminRole.SYSTEM_ADMIN }]}, - {"email": "admin2@nypl.org", "roles": [{ "role": AdminRole.LIBRARY_MANAGER, "library": self._default_library.short_name }, { "role": AdminRole.SITEWIDE_LIBRARIAN }]}, - {"email": "admin3@nypl.org", "roles": [{ "role": AdminRole.LIBRARIAN, "library": self._default_library.short_name }]}, + assert ( + sorted( + [ + { + "email": "admin1@nypl.org", + "roles": [{"role": AdminRole.SYSTEM_ADMIN}], + }, + { + "email": "admin2@nypl.org", + "roles": [ + { + "role": AdminRole.LIBRARY_MANAGER, + "library": self._default_library.short_name, + }, + {"role": AdminRole.SITEWIDE_LIBRARIAN}, + ], + }, + { + "email": "admin3@nypl.org", + "roles": [ + { + "role": AdminRole.LIBRARIAN, + "library": self._default_library.short_name, + } + ], + }, {"email": "admin4@l2.org", "roles": []}, - {"email": "admin5@l2.org", "roles": []}], key=lambda x:x["email"]) == - sorted(admins, key=lambda x:x["email"])) + {"email": "admin5@l2.org", "roles": []}, + ], + key=lambda x: x["email"], + ) + == sorted(admins, key=lambda x: x["email"]) + ) with self.request_context_with_admin("/", admin=admin4): - response = self.manager.admin_individual_admin_settings_controller.process_get() + response = ( + self.manager.admin_individual_admin_settings_controller.process_get() + ) admins = response.get("individualAdmins") - assert (sorted([{"email": "admin1@nypl.org", "roles": [{ "role": AdminRole.SYSTEM_ADMIN }]}, - {"email": "admin2@nypl.org", "roles": [{ "role": AdminRole.SITEWIDE_LIBRARIAN }]}, + assert ( + sorted( + [ + { + "email": "admin1@nypl.org", + "roles": [{"role": AdminRole.SYSTEM_ADMIN}], + }, + { + "email": "admin2@nypl.org", + "roles": [{"role": AdminRole.SITEWIDE_LIBRARIAN}], + }, {"email": "admin3@nypl.org", "roles": []}, - {"email": "admin4@l2.org", "roles": [{ "role": AdminRole.LIBRARY_MANAGER, "library": library2.short_name }]}, - {"email": "admin5@l2.org", "roles": [{ "role": AdminRole.LIBRARIAN, "library": library2.short_name }]}], key=lambda x:x["email"]) == - sorted(admins, key=lambda x:x["email"])) + { + "email": "admin4@l2.org", + "roles": [ + { + "role": AdminRole.LIBRARY_MANAGER, + "library": library2.short_name, + } + ], + }, + { + "email": "admin5@l2.org", + "roles": [ + { + "role": AdminRole.LIBRARIAN, + "library": library2.short_name, + } + ], + }, + ], + key=lambda x: x["email"], + ) + == sorted(admins, key=lambda x: x["email"]) + ) with self.request_context_with_admin("/", admin=admin5): - response = self.manager.admin_individual_admin_settings_controller.process_get() + response = ( + self.manager.admin_individual_admin_settings_controller.process_get() + ) admins = response.get("individualAdmins") - assert (sorted([{"email": "admin1@nypl.org", "roles": [{ "role": AdminRole.SYSTEM_ADMIN }]}, - {"email": "admin2@nypl.org", "roles": [{ "role": AdminRole.SITEWIDE_LIBRARIAN }]}, + assert ( + sorted( + [ + { + "email": "admin1@nypl.org", + "roles": [{"role": AdminRole.SYSTEM_ADMIN}], + }, + { + "email": "admin2@nypl.org", + "roles": [{"role": AdminRole.SITEWIDE_LIBRARIAN}], + }, {"email": "admin3@nypl.org", "roles": []}, - {"email": "admin4@l2.org", "roles": [{ "role": AdminRole.LIBRARY_MANAGER, "library": library2.short_name }]}, - {"email": "admin5@l2.org", "roles": [{ "role": AdminRole.LIBRARIAN, "library": library2.short_name }]}], key=lambda x:x["email"]) == - sorted(admins, key=lambda x:x["email"])) + { + "email": "admin4@l2.org", + "roles": [ + { + "role": AdminRole.LIBRARY_MANAGER, + "library": library2.short_name, + } + ], + }, + { + "email": "admin5@l2.org", + "roles": [ + { + "role": AdminRole.LIBRARIAN, + "library": library2.short_name, + } + ], + }, + ], + key=lambda x: x["email"], + ) + == sorted(admins, key=lambda x: x["email"]) + ) def test_individual_admins_post_errors(self): with self.request_context_with_admin("/", method="POST"): flask.request.form = MultiDict([]) - response = self.manager.admin_individual_admin_settings_controller.process_post() + response = ( + self.manager.admin_individual_admin_settings_controller.process_post() + ) assert response.uri == INCOMPLETE_CONFIGURATION.uri with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("email", "test@library.org"), - ("roles", json.dumps([{ "role": AdminRole.LIBRARIAN, "library": "notalibrary" }])), - ]) - response = self.manager.admin_individual_admin_settings_controller.process_post() + flask.request.form = MultiDict( + [ + ("email", "test@library.org"), + ( + "roles", + json.dumps( + [{"role": AdminRole.LIBRARIAN, "library": "notalibrary"}] + ), + ), + ] + ) + response = ( + self.manager.admin_individual_admin_settings_controller.process_post() + ) assert response.uri == LIBRARY_NOT_FOUND.uri library = self._library() with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("email", "test@library.org"), - ("roles", json.dumps([{ "role": "notarole", "library": library.short_name }])), - ]) - response = self.manager.admin_individual_admin_settings_controller.process_post() + flask.request.form = MultiDict( + [ + ("email", "test@library.org"), + ( + "roles", + json.dumps( + [{"role": "notarole", "library": library.short_name}] + ), + ), + ] + ) + response = ( + self.manager.admin_individual_admin_settings_controller.process_post() + ) assert response.uri == UNKNOWN_ROLE.uri def test_individual_admins_post_permissions(self): @@ -121,32 +318,48 @@ def test_individual_admins_post_permissions(self): l2 = self._library() system, ignore = create(self._db, Admin, email="system@example.com") system.add_role(AdminRole.SYSTEM_ADMIN) - sitewide_manager, ignore = create(self._db, Admin, email="sitewide_manager@example.com") + sitewide_manager, ignore = create( + self._db, Admin, email="sitewide_manager@example.com" + ) sitewide_manager.add_role(AdminRole.SITEWIDE_LIBRARY_MANAGER) - sitewide_librarian, ignore = create(self._db, Admin, email="sitewide_librarian@example.com") + sitewide_librarian, ignore = create( + self._db, Admin, email="sitewide_librarian@example.com" + ) sitewide_librarian.add_role(AdminRole.SITEWIDE_LIBRARIAN) - manager1, ignore = create(self._db, Admin, email="library_manager_l1@example.com") + manager1, ignore = create( + self._db, Admin, email="library_manager_l1@example.com" + ) manager1.add_role(AdminRole.LIBRARY_MANAGER, l1) librarian1, ignore = create(self._db, Admin, email="librarian_l1@example.com") librarian1.add_role(AdminRole.LIBRARIAN, l1) l2 = self._library() - manager2, ignore = create(self._db, Admin, email="library_manager_l2@example.com") + manager2, ignore = create( + self._db, Admin, email="library_manager_l2@example.com" + ) manager2.add_role(AdminRole.LIBRARY_MANAGER, l2) librarian2, ignore = create(self._db, Admin, email="librarian_l2@example.com") librarian2.add_role(AdminRole.LIBRARIAN, l2) - def test_changing_roles(admin_making_request, target_admin, roles=None, allowed=False): - with self.request_context_with_admin("/", method="POST", admin=admin_making_request): - flask.request.form = MultiDict([ - ("email", target_admin.email), - ("roles", json.dumps(roles or [])), - ]) + def test_changing_roles( + admin_making_request, target_admin, roles=None, allowed=False + ): + with self.request_context_with_admin( + "/", method="POST", admin=admin_making_request + ): + flask.request.form = MultiDict( + [ + ("email", target_admin.email), + ("roles", json.dumps(roles or [])), + ] + ) if allowed: self.manager.admin_individual_admin_settings_controller.process_post() self._db.rollback() else: - pytest.raises(AdminNotAuthorized, - self.manager.admin_individual_admin_settings_controller.process_post) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_individual_admin_settings_controller.process_post, + ) # Various types of user trying to change a system admin's roles test_changing_roles(system, system, allowed=True) @@ -176,36 +389,60 @@ def test_changing_roles(admin_making_request, target_admin, roles=None, allowed= test_changing_roles(librarian2, sitewide_librarian) test_changing_roles(manager1, manager1, allowed=True) - test_changing_roles(manager1, sitewide_librarian, - roles=[{ "role": AdminRole.SITEWIDE_LIBRARIAN }, - { "role": AdminRole.LIBRARY_MANAGER, "library": l1.short_name }], - allowed=True) + test_changing_roles( + manager1, + sitewide_librarian, + roles=[ + {"role": AdminRole.SITEWIDE_LIBRARIAN}, + {"role": AdminRole.LIBRARY_MANAGER, "library": l1.short_name}, + ], + allowed=True, + ) test_changing_roles(manager1, librarian1, allowed=True) - test_changing_roles(manager2, librarian2, - roles=[{ "role": AdminRole.LIBRARIAN, "library": l1.short_name }]) - test_changing_roles(manager2, librarian1, - roles=[{ "role": AdminRole.LIBRARY_MANAGER, "library": l1.short_name }]) + test_changing_roles( + manager2, + librarian2, + roles=[{"role": AdminRole.LIBRARIAN, "library": l1.short_name}], + ) + test_changing_roles( + manager2, + librarian1, + roles=[{"role": AdminRole.LIBRARY_MANAGER, "library": l1.short_name}], + ) test_changing_roles(sitewide_librarian, librarian1) - test_changing_roles(sitewide_manager, sitewide_manager, - roles=[{ "role": AdminRole.SYSTEM_ADMIN }]) - test_changing_roles(sitewide_librarian, manager1, - roles=[{ "role": AdminRole.SITEWIDE_LIBRARY_MANAGER }]) + test_changing_roles( + sitewide_manager, sitewide_manager, roles=[{"role": AdminRole.SYSTEM_ADMIN}] + ) + test_changing_roles( + sitewide_librarian, + manager1, + roles=[{"role": AdminRole.SITEWIDE_LIBRARY_MANAGER}], + ) def test_changing_password(admin_making_request, target_admin, allowed=False): - with self.request_context_with_admin("/", method="POST", admin=admin_making_request): - flask.request.form = MultiDict([ - ("email", target_admin.email), - ("password", "new password"), - ("roles", json.dumps([role.to_dict() for role in target_admin.roles])), - ]) + with self.request_context_with_admin( + "/", method="POST", admin=admin_making_request + ): + flask.request.form = MultiDict( + [ + ("email", target_admin.email), + ("password", "new password"), + ( + "roles", + json.dumps([role.to_dict() for role in target_admin.roles]), + ), + ] + ) if allowed: self.manager.admin_individual_admin_settings_controller.process_post() self._db.rollback() else: - pytest.raises(AdminNotAuthorized, - self.manager.admin_individual_admin_settings_controller.process_post) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_individual_admin_settings_controller.process_post, + ) # Various types of user trying to change a system admin's password test_changing_password(system, system, allowed=True) @@ -268,12 +505,26 @@ def test_changing_password(admin_making_request, target_admin, allowed=False): def test_individual_admins_post_create(self): with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("email", "admin@nypl.org"), - ("password", "pass"), - ("roles", json.dumps([{ "role": AdminRole.LIBRARY_MANAGER, "library": self._default_library.short_name }])), - ]) - response = self.manager.admin_individual_admin_settings_controller.process_post() + flask.request.form = MultiDict( + [ + ("email", "admin@nypl.org"), + ("password", "pass"), + ( + "roles", + json.dumps( + [ + { + "role": AdminRole.LIBRARY_MANAGER, + "library": self._default_library.short_name, + } + ] + ), + ), + ] + ) + response = ( + self.manager.admin_individual_admin_settings_controller.process_post() + ) assert response.status_code == 201 # The admin was created. @@ -288,12 +539,26 @@ def test_individual_admins_post_create(self): # The new admin is a library manager, so they can create librarians. with self.request_context_with_admin("/", method="POST", admin=admin_match): - flask.request.form = MultiDict([ - ("email", "admin2@nypl.org"), - ("password", "pass"), - ("roles", json.dumps([{ "role": AdminRole.LIBRARIAN, "library": self._default_library.short_name }])), - ]) - response = self.manager.admin_individual_admin_settings_controller.process_post() + flask.request.form = MultiDict( + [ + ("email", "admin2@nypl.org"), + ("password", "pass"), + ( + "roles", + json.dumps( + [ + { + "role": AdminRole.LIBRARIAN, + "library": self._default_library.short_name, + } + ] + ), + ), + ] + ) + response = ( + self.manager.admin_individual_admin_settings_controller.process_post() + ) assert response.status_code == 201 admin_match = Admin.authenticate(self._db, "admin2@nypl.org", "pass") @@ -308,19 +573,35 @@ def test_individual_admins_post_create(self): def test_individual_admins_post_edit(self): # An admin exists. admin, ignore = create( - self._db, Admin, email="admin@nypl.org", + self._db, + Admin, + email="admin@nypl.org", ) admin.password = "password" admin.add_role(AdminRole.SYSTEM_ADMIN) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("email", "admin@nypl.org"), - ("password", "new password"), - ("roles", json.dumps([{"role": AdminRole.SITEWIDE_LIBRARIAN}, - {"role": AdminRole.LIBRARY_MANAGER, "library": self._default_library.short_name}])), - ]) - response = self.manager.admin_individual_admin_settings_controller.process_post() + flask.request.form = MultiDict( + [ + ("email", "admin@nypl.org"), + ("password", "new password"), + ( + "roles", + json.dumps( + [ + {"role": AdminRole.SITEWIDE_LIBRARIAN}, + { + "role": AdminRole.LIBRARY_MANAGER, + "library": self._default_library.short_name, + }, + ] + ), + ), + ] + ) + response = ( + self.manager.admin_individual_admin_settings_controller.process_post() + ) assert response.status_code == 200 assert admin.email == response.get_data(as_text=True) @@ -329,7 +610,9 @@ def test_individual_admins_post_edit(self): old_password_match = Admin.authenticate(self._db, "admin@nypl.org", "password") assert None == old_password_match - new_password_match = Admin.authenticate(self._db, "admin@nypl.org", "new password") + new_password_match = Admin.authenticate( + self._db, "admin@nypl.org", "new password" + ) assert admin == new_password_match # The roles were changed. @@ -341,34 +624,45 @@ def test_individual_admins_post_edit(self): assert self._default_library == manager.library def test_individual_admin_delete(self): - librarian, ignore = create( - self._db, Admin, email=self._str) + librarian, ignore = create(self._db, Admin, email=self._str) librarian.password = "password" librarian.add_role(AdminRole.LIBRARIAN, self._default_library) - sitewide_manager, ignore = create( - self._db, Admin, email=self._str) + sitewide_manager, ignore = create(self._db, Admin, email=self._str) sitewide_manager.add_role(AdminRole.SITEWIDE_LIBRARY_MANAGER) - system_admin, ignore = create( - self._db, Admin, email=self._str) + system_admin, ignore = create(self._db, Admin, email=self._str) system_admin.add_role(AdminRole.SYSTEM_ADMIN) with self.request_context_with_admin("/", method="DELETE", admin=librarian): - pytest.raises(AdminNotAuthorized, - self.manager.admin_individual_admin_settings_controller.process_delete, - librarian.email) - - with self.request_context_with_admin("/", method="DELETE", admin=sitewide_manager): - response = self.manager.admin_individual_admin_settings_controller.process_delete(librarian.email) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_individual_admin_settings_controller.process_delete, + librarian.email, + ) + + with self.request_context_with_admin( + "/", method="DELETE", admin=sitewide_manager + ): + response = ( + self.manager.admin_individual_admin_settings_controller.process_delete( + librarian.email + ) + ) assert response.status_code == 200 - pytest.raises(AdminNotAuthorized, - self.manager.admin_individual_admin_settings_controller.process_delete, - system_admin.email) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_individual_admin_settings_controller.process_delete, + system_admin.email, + ) with self.request_context_with_admin("/", method="DELETE", admin=system_admin): - response = self.manager.admin_individual_admin_settings_controller.process_delete(system_admin.email) + response = ( + self.manager.admin_individual_admin_settings_controller.process_delete( + system_admin.email + ) + ) assert response.status_code == 200 admin = get_one(self._db, Admin, id=librarian.id) @@ -383,35 +677,58 @@ def test_individual_admins_post_create_on_setup(self): # Creating an admin that's not a system admin will fail. with self.app.test_request_context("/", method="POST"): - flask.request.form = MultiDict([ - ("email", "first_admin@nypl.org"), - ("password", "pass"), - ("roles", json.dumps([{ "role": AdminRole.LIBRARY_MANAGER, "library": self._default_library.short_name }])), - ]) + flask.request.form = MultiDict( + [ + ("email", "first_admin@nypl.org"), + ("password", "pass"), + ( + "roles", + json.dumps( + [ + { + "role": AdminRole.LIBRARY_MANAGER, + "library": self._default_library.short_name, + } + ] + ), + ), + ] + ) flask.request.files = {} - pytest.raises(AdminNotAuthorized, self.manager.admin_individual_admin_settings_controller.process_post) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_individual_admin_settings_controller.process_post, + ) self._db.rollback() # The password is required. with self.app.test_request_context("/", method="POST"): - flask.request.form = MultiDict([ - ("email", "first_admin@nypl.org"), - ("roles", json.dumps([{ "role": AdminRole.SYSTEM_ADMIN }])), - ]) + flask.request.form = MultiDict( + [ + ("email", "first_admin@nypl.org"), + ("roles", json.dumps([{"role": AdminRole.SYSTEM_ADMIN}])), + ] + ) flask.request.files = {} - response = self.manager.admin_individual_admin_settings_controller.process_post() + response = ( + self.manager.admin_individual_admin_settings_controller.process_post() + ) assert 400 == response.status_code assert response.uri == INCOMPLETE_CONFIGURATION.uri # Creating a system admin with a password works. with self.app.test_request_context("/", method="POST"): - flask.request.form = MultiDict([ - ("email", "first_admin@nypl.org"), - ("password", "pass"), - ("roles", json.dumps([{ "role": AdminRole.SYSTEM_ADMIN }])), - ]) + flask.request.form = MultiDict( + [ + ("email", "first_admin@nypl.org"), + ("password", "pass"), + ("roles", json.dumps([{"role": AdminRole.SYSTEM_ADMIN}])), + ] + ) flask.request.files = {} - response = self.manager.admin_individual_admin_settings_controller.process_post() + response = ( + self.manager.admin_individual_admin_settings_controller.process_post() + ) assert 201 == response.status_code # The admin was created. diff --git a/tests/admin/controller/test_library.py b/tests/admin/controller/test_library.py index 739027ddf1..967a047513 100644 --- a/tests/admin/controller/test_library.py +++ b/tests/admin/controller/test_library.py @@ -1,40 +1,37 @@ import base64 import datetime -from io import BytesIO import json +from io import BytesIO import flask -from PIL import Image import pytest +from PIL import Image from werkzeug.datastructures import MultiDict -from .test_controller import SettingsControllerTest from api.admin.announcement_list_validator import AnnouncementListValidator from api.admin.controller.library_settings import LibrarySettingsController from api.admin.exceptions import * from api.admin.geographic_validator import GeographicValidator -from api.announcements import ( - Announcements, - Announcement, -) +from api.announcements import Announcement, Announcements from api.config import Configuration from api.testing import AnnouncementTest from core.facets import FacetConstants from core.model import ( AdminRole, ConfigurationSetting, + Library, get_one, get_one_or_create, - Library, ) from core.util.problem_detail import ProblemDetail +from .test_controller import SettingsControllerTest -class TestLibrarySettings(SettingsControllerTest, AnnouncementTest): +class TestLibrarySettings(SettingsControllerTest, AnnouncementTest): @pytest.fixture() def logo_properties(self): - image_data_raw = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x01\x03\x00\x00\x00%\xdbV\xca\x00\x00\x00\x06PLTE\xffM\x00\x01\x01\x01\x8e\x1e\xe5\x1b\x00\x00\x00\x01tRNS\xcc\xd24V\xfd\x00\x00\x00\nIDATx\x9cc`\x00\x00\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82' + image_data_raw = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x01\x03\x00\x00\x00%\xdbV\xca\x00\x00\x00\x06PLTE\xffM\x00\x01\x01\x01\x8e\x1e\xe5\x1b\x00\x00\x00\x01tRNS\xcc\xd24V\xfd\x00\x00\x00\nIDATx\x9cc`\x00\x00\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82" image_data_b64_bytes = base64.b64encode(image_data_raw) image_data_b64_unicode = image_data_b64_bytes.decode("utf-8") data_url = "data:image/png;base64," + image_data_b64_unicode @@ -55,7 +52,7 @@ def library_form(self, library, fields={}): "short_name": library.short_name, Configuration.WEBSITE_URL: "https://library.library/", Configuration.HELP_EMAIL: "help@example.com", - Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS: "email@example.com" + Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS: "email@example.com", } defaults.update(fields) form = MultiDict(list(defaults.items())) @@ -88,8 +85,14 @@ def test_libraries_get_with_geographic_info(self): with self.request_context_with_admin("/"): response = self.manager.admin_library_settings_controller.process_get() library_settings = response.get("libraries")[0].get("settings") - assert library_settings.get("focus_area") == {'CA': [{'N3L': 'Paris, Ontario'}], 'US': [{'11235': 'Brooklyn, NY'}]} - assert library_settings.get("service_area") == {'CA': [{'J2S': 'Saint-Hyacinthe Southwest, Quebec'}], 'US': [{'31415': 'Savannah, GA'}]} + assert library_settings.get("focus_area") == { + "CA": [{"N3L": "Paris, Ontario"}], + "US": [{"11235": "Brooklyn, NY"}], + } + assert library_settings.get("service_area") == { + "CA": [{"J2S": "Saint-Hyacinthe Southwest, Quebec"}], + "US": [{"31415": "Savannah, GA"}], + } def test_libraries_get_with_announcements(self): # Delete any existing library created by the controller test setup. @@ -110,16 +113,22 @@ def test_libraries_get_with_announcements(self): # We find out about the library's announcements. announcements = library_settings.get(Announcements.SETTING_NAME) - assert ( - [self.active['id'], self.expired['id'], self.forthcoming['id']] == - [x.get('id') for x in json.loads(announcements)]) + assert [self.active["id"], self.expired["id"], self.forthcoming["id"]] == [ + x.get("id") for x in json.loads(announcements) + ] # The objects found in `library_settings` aren't exactly # the same as what is stored in the database: string dates # can be parsed into datetime.date objects. for i in json.loads(announcements): - assert isinstance(datetime.datetime.strptime(i.get('start'), "%Y-%m-%d"), datetime.date) - assert isinstance(datetime.datetime.strptime(i.get('finish'), "%Y-%m-%d"), datetime.date) + assert isinstance( + datetime.datetime.strptime(i.get("start"), "%Y-%m-%d"), + datetime.date, + ) + assert isinstance( + datetime.datetime.strptime(i.get("finish"), "%Y-%m-%d"), + datetime.date, + ) def test_libraries_get_with_multiple_libraries(self): # Delete any existing library created by the controller test setup. @@ -133,10 +142,14 @@ def test_libraries_get_with_multiple_libraries(self): # L2 has some additional library-wide settings. ConfigurationSetting.for_library(Configuration.FEATURED_LANE_SIZE, l2).value = 5 ConfigurationSetting.for_library( - Configuration.DEFAULT_FACET_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME, l2 + Configuration.DEFAULT_FACET_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME, + l2, ).value = FacetConstants.ORDER_TITLE ConfigurationSetting.for_library( - Configuration.ENABLED_FACETS_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME, l2 + Configuration.ENABLED_FACETS_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME, + l2, ).value = json.dumps([FacetConstants.ORDER_TITLE, FacetConstants.ORDER_AUTHOR]) ConfigurationSetting.for_library( Configuration.LARGE_COLLECTION_LANGUAGES, l2 @@ -164,17 +177,26 @@ def test_libraries_get_with_multiple_libraries(self): assert 4 == len(libraries[1].get("settings").keys()) settings = libraries[1].get("settings") assert "5" == settings.get(Configuration.FEATURED_LANE_SIZE) - assert (FacetConstants.ORDER_TITLE == - settings.get(Configuration.DEFAULT_FACET_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME)) - assert ([FacetConstants.ORDER_TITLE, FacetConstants.ORDER_AUTHOR] == - settings.get(Configuration.ENABLED_FACETS_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME)) + assert FacetConstants.ORDER_TITLE == settings.get( + Configuration.DEFAULT_FACET_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME + ) + assert [ + FacetConstants.ORDER_TITLE, + FacetConstants.ORDER_AUTHOR, + ] == settings.get( + Configuration.ENABLED_FACETS_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME + ) assert ["French"] == settings.get(Configuration.LARGE_COLLECTION_LANGUAGES) def test_libraries_post_errors(self): with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Brooklyn Public Library"), - ]) + flask.request.form = MultiDict( + [ + ("name", "Brooklyn Public Library"), + ] + ) response = self.manager.admin_library_settings_controller.process_post() assert response == MISSING_LIBRARY_SHORT_NAME @@ -185,31 +207,35 @@ def test_libraries_post_errors(self): assert response.uri == LIBRARY_NOT_FOUND.uri with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Brooklyn Public Library"), - ("short_name", library.short_name), - ]) + flask.request.form = MultiDict( + [ + ("name", "Brooklyn Public Library"), + ("short_name", library.short_name), + ] + ) response = self.manager.admin_library_settings_controller.process_post() assert response == LIBRARY_SHORT_NAME_ALREADY_IN_USE - bpl, ignore = get_one_or_create( - self._db, Library, short_name="bpl" - ) + bpl, ignore = get_one_or_create(self._db, Library, short_name="bpl") with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("uuid", bpl.uuid), - ("name", "Brooklyn Public Library"), - ("short_name", library.short_name), - ]) + flask.request.form = MultiDict( + [ + ("uuid", bpl.uuid), + ("name", "Brooklyn Public Library"), + ("short_name", library.short_name), + ] + ) response = self.manager.admin_library_settings_controller.process_post() assert response == LIBRARY_SHORT_NAME_ALREADY_IN_USE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("uuid", library.uuid), - ("name", "The New York Public Library"), - ("short_name", library.short_name), - ]) + flask.request.form = MultiDict( + [ + ("uuid", library.uuid), + ("name", "The New York Public Library"), + ("short_name", library.short_name), + ] + ) response = self.manager.admin_library_settings_controller.process_post() assert response.uri == INCOMPLETE_CONFIGURATION.uri @@ -217,8 +243,11 @@ def test_libraries_post_errors(self): # well on white. Here primary will, secondary should not. with self.request_context_with_admin("/", method="POST"): flask.request.form = self.library_form( - library, {Configuration.WEB_PRIMARY_COLOR: "#000000", - Configuration.WEB_SECONDARY_COLOR: "#e0e0e0"} + library, + { + Configuration.WEB_PRIMARY_COLOR: "#000000", + Configuration.WEB_SECONDARY_COLOR: "#e0e0e0", + }, ) response = self.manager.admin_library_settings_controller.process_post() assert response.uri == INVALID_CONFIGURATION_OPTION.uri @@ -229,109 +258,174 @@ def test_libraries_post_errors(self): # aren't the same length. library = self._library() with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("uuid", library.uuid), - ("name", "The New York Public Library"), - ("short_name", library.short_name), - (Configuration.WEBSITE_URL, "https://library.library/"), - (Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, "email@example.com"), - (Configuration.HELP_EMAIL, "help@example.com"), - (Configuration.WEB_HEADER_LINKS, "http://library.com/1"), - (Configuration.WEB_HEADER_LINKS, "http://library.com/2"), - (Configuration.WEB_HEADER_LABELS, "One"), - ]) + flask.request.form = MultiDict( + [ + ("uuid", library.uuid), + ("name", "The New York Public Library"), + ("short_name", library.short_name), + (Configuration.WEBSITE_URL, "https://library.library/"), + ( + Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, + "email@example.com", + ), + (Configuration.HELP_EMAIL, "help@example.com"), + (Configuration.WEB_HEADER_LINKS, "http://library.com/1"), + (Configuration.WEB_HEADER_LINKS, "http://library.com/2"), + (Configuration.WEB_HEADER_LABELS, "One"), + ] + ) response = self.manager.admin_library_settings_controller.process_post() assert response.uri == INVALID_CONFIGURATION_OPTION.uri def test__data_url_for_image(self, logo_properties): """""" - image, expected_data_url = [logo_properties[key] for key in ( - "image", "data_url" - )] + image, expected_data_url = [ + logo_properties[key] for key in ("image", "data_url") + ] data_url = LibrarySettingsController._data_url_for_image(image) assert expected_data_url == data_url def test_libraries_post_create(self, logo_properties): - class TestFileUpload(BytesIO): - headers = { "Content-Type": "image/png" } + headers = {"Content-Type": "image/png"} # Pull needed properties from logo fixture - image_data, expected_logo_data_url, image = [logo_properties[key] for key in ( - "raw_bytes", "data_url", "image" - )] + image_data, expected_logo_data_url, image = [ + logo_properties[key] for key in ("raw_bytes", "data_url", "image") + ] # LibrarySettingsController scales down images that are too large, # so we fail here if our test fixture image is large enough to cause # a mismatch between the expected data URL and the one configured. assert max(*image.size) <= Configuration.LOGO_MAX_DIMENSION original_geographic_validate = GeographicValidator().validate_geographic_areas + class MockGeographicValidator(GeographicValidator): def __init__(self): self.was_called = False + def validate_geographic_areas(self, values, db): self.was_called = True return original_geographic_validate(values, db) - original_announcement_validate = AnnouncementListValidator().validate_announcements + original_announcement_validate = ( + AnnouncementListValidator().validate_announcements + ) + class MockAnnouncementListValidator(AnnouncementListValidator): def __init__(self): self.was_called = False + def validate_announcements(self, values): self.was_called = True return original_announcement_validate(values) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "The New York Public Library"), - ("short_name", "nypl"), - ("library_description", "Short description of library"), - (Configuration.WEBSITE_URL, "https://library.library/"), - (Configuration.TINY_COLLECTION_LANGUAGES, ['ger']), - (Configuration.LIBRARY_SERVICE_AREA, ['06759', 'everywhere', 'MD', 'Boston, MA']), - (Configuration.LIBRARY_FOCUS_AREA, ['Manitoba', 'Broward County, FL', 'QC']), - (Announcements.SETTING_NAME, json.dumps([self.active, self.forthcoming])), - (Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, "email@example.com"), - (Configuration.HELP_EMAIL, "help@example.com"), - (Configuration.FEATURED_LANE_SIZE, "5"), - (Configuration.DEFAULT_FACET_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME, - FacetConstants.ORDER_RANDOM), - (Configuration.ENABLED_FACETS_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME + "_" + FacetConstants.ORDER_TITLE, - ''), - (Configuration.ENABLED_FACETS_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME + "_" + FacetConstants.ORDER_RANDOM, - ''), - ]) - flask.request.files = MultiDict([ - (Configuration.LOGO, TestFileUpload(image_data)), - ]) + flask.request.form = MultiDict( + [ + ("name", "The New York Public Library"), + ("short_name", "nypl"), + ("library_description", "Short description of library"), + (Configuration.WEBSITE_URL, "https://library.library/"), + (Configuration.TINY_COLLECTION_LANGUAGES, ["ger"]), + ( + Configuration.LIBRARY_SERVICE_AREA, + ["06759", "everywhere", "MD", "Boston, MA"], + ), + ( + Configuration.LIBRARY_FOCUS_AREA, + ["Manitoba", "Broward County, FL", "QC"], + ), + ( + Announcements.SETTING_NAME, + json.dumps([self.active, self.forthcoming]), + ), + ( + Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, + "email@example.com", + ), + (Configuration.HELP_EMAIL, "help@example.com"), + (Configuration.FEATURED_LANE_SIZE, "5"), + ( + Configuration.DEFAULT_FACET_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME, + FacetConstants.ORDER_RANDOM, + ), + ( + Configuration.ENABLED_FACETS_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME + + "_" + + FacetConstants.ORDER_TITLE, + "", + ), + ( + Configuration.ENABLED_FACETS_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME + + "_" + + FacetConstants.ORDER_RANDOM, + "", + ), + ] + ) + flask.request.files = MultiDict( + [ + (Configuration.LOGO, TestFileUpload(image_data)), + ] + ) geographic_validator = MockGeographicValidator() announcement_validator = MockAnnouncementListValidator() validators = dict( geographic=geographic_validator, announcements=announcement_validator, ) - response = self.manager.admin_library_settings_controller.process_post(validators) + response = self.manager.admin_library_settings_controller.process_post( + validators + ) assert response.status_code == 201 library = get_one(self._db, Library, short_name="nypl") assert library.uuid == response.get_data(as_text=True) assert library.name == "The New York Public Library" assert library.short_name == "nypl" - assert "5" == ConfigurationSetting.for_library(Configuration.FEATURED_LANE_SIZE, library).value - assert (FacetConstants.ORDER_RANDOM == - ConfigurationSetting.for_library( - Configuration.DEFAULT_FACET_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME, - library).value) - assert (json.dumps([FacetConstants.ORDER_TITLE]) == - ConfigurationSetting.for_library( - Configuration.ENABLED_FACETS_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME, - library).value) - assert (expected_logo_data_url == ConfigurationSetting.for_library(Configuration.LOGO, library).value) + assert ( + "5" + == ConfigurationSetting.for_library( + Configuration.FEATURED_LANE_SIZE, library + ).value + ) + assert ( + FacetConstants.ORDER_RANDOM + == ConfigurationSetting.for_library( + Configuration.DEFAULT_FACET_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME, + library, + ).value + ) + assert ( + json.dumps([FacetConstants.ORDER_TITLE]) + == ConfigurationSetting.for_library( + Configuration.ENABLED_FACETS_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME, + library, + ).value + ) + assert ( + expected_logo_data_url + == ConfigurationSetting.for_library(Configuration.LOGO, library).value + ) assert geographic_validator.was_called == True - assert ('{"US": ["06759", "everywhere", "MD", "Boston, MA"], "CA": []}' == - ConfigurationSetting.for_library(Configuration.LIBRARY_SERVICE_AREA, library).value) - assert ('{"US": ["Broward County, FL"], "CA": ["Manitoba", "Quebec"]}' == - ConfigurationSetting.for_library(Configuration.LIBRARY_FOCUS_AREA, library).value) + assert ( + '{"US": ["06759", "everywhere", "MD", "Boston, MA"], "CA": []}' + == ConfigurationSetting.for_library( + Configuration.LIBRARY_SERVICE_AREA, library + ).value + ) + assert ( + '{"US": ["Broward County, FL"], "CA": ["Manitoba", "Quebec"]}' + == ConfigurationSetting.for_library( + Configuration.LIBRARY_FOCUS_AREA, library + ).value + ) # Announcements were validated. assert announcement_validator.was_called == True @@ -339,7 +433,9 @@ def validate_announcements(self, values): # The validated result was written to the database, such that we can # parse it as a list of Announcement objects. announcements = Announcements.for_library(library).announcements - assert [self.active['id'], self.forthcoming['id']] == [x.id for x in announcements] + assert [self.active["id"], self.forthcoming["id"]] == [ + x.id for x in announcements + ] assert all(isinstance(x, Announcement) for x in announcements) # When the library was created, default lanes were also created @@ -347,46 +443,68 @@ def validate_announcements(self, values): # collection (not a good choice for a real library), so only # two lanes were created: "Other Languages" and then "German" # underneath it. - [german, other_languages] = sorted( - library.lanes, key=lambda x: x.display_name - ) + [german, other_languages] = sorted(library.lanes, key=lambda x: x.display_name) assert None == other_languages.parent - assert ['ger'] == other_languages.languages + assert ["ger"] == other_languages.languages assert other_languages == german.parent - assert ['ger'] == german.languages + assert ["ger"] == german.languages def test_libraries_post_edit(self): # A library already exists. library = self._library("New York Public Library", "nypl") - ConfigurationSetting.for_library(Configuration.FEATURED_LANE_SIZE, library).value = 5 ConfigurationSetting.for_library( - Configuration.DEFAULT_FACET_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME, library + Configuration.FEATURED_LANE_SIZE, library + ).value = 5 + ConfigurationSetting.for_library( + Configuration.DEFAULT_FACET_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME, + library, ).value = FacetConstants.ORDER_RANDOM ConfigurationSetting.for_library( - Configuration.ENABLED_FACETS_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME, library + Configuration.ENABLED_FACETS_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME, + library, ).value = json.dumps([FacetConstants.ORDER_TITLE, FacetConstants.ORDER_RANDOM]) ConfigurationSetting.for_library( Configuration.LOGO, library ).value = "A tiny image" with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("uuid", library.uuid), - ("name", "The New York Public Library"), - ("short_name", "nypl"), - (Configuration.FEATURED_LANE_SIZE, "20"), - (Configuration.MINIMUM_FEATURED_QUALITY, "0.9"), - (Configuration.WEBSITE_URL, "https://library.library/"), - (Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, "email@example.com"), - (Configuration.HELP_EMAIL, "help@example.com"), - (Configuration.DEFAULT_FACET_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME, - FacetConstants.ORDER_AUTHOR), - (Configuration.ENABLED_FACETS_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME + "_" + FacetConstants.ORDER_AUTHOR, - ''), - (Configuration.ENABLED_FACETS_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME + "_" + FacetConstants.ORDER_RANDOM, - ''), - ]) + flask.request.form = MultiDict( + [ + ("uuid", library.uuid), + ("name", "The New York Public Library"), + ("short_name", "nypl"), + (Configuration.FEATURED_LANE_SIZE, "20"), + (Configuration.MINIMUM_FEATURED_QUALITY, "0.9"), + (Configuration.WEBSITE_URL, "https://library.library/"), + ( + Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, + "email@example.com", + ), + (Configuration.HELP_EMAIL, "help@example.com"), + ( + Configuration.DEFAULT_FACET_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME, + FacetConstants.ORDER_AUTHOR, + ), + ( + Configuration.ENABLED_FACETS_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME + + "_" + + FacetConstants.ORDER_AUTHOR, + "", + ), + ( + Configuration.ENABLED_FACETS_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME + + "_" + + FacetConstants.ORDER_RANDOM, + "", + ), + ] + ) flask.request.files = MultiDict([]) response = self.manager.admin_library_settings_controller.process_post() assert response.status_code == 200 @@ -400,31 +518,44 @@ def test_libraries_post_edit(self): # The library-wide settings were updated. def val(x): return ConfigurationSetting.for_library(x, library).value + assert "https://library.library/" == val(Configuration.WEBSITE_URL) - assert "email@example.com" == val(Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS) + assert "email@example.com" == val( + Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS + ) assert "help@example.com" == val(Configuration.HELP_EMAIL) assert "20" == val(Configuration.FEATURED_LANE_SIZE) assert "0.9" == val(Configuration.MINIMUM_FEATURED_QUALITY) - assert (FacetConstants.ORDER_AUTHOR == - val(Configuration.DEFAULT_FACET_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME)) - assert (json.dumps([FacetConstants.ORDER_AUTHOR]) == - val(Configuration.ENABLED_FACETS_KEY_PREFIX + FacetConstants.ORDER_FACET_GROUP_NAME)) + assert FacetConstants.ORDER_AUTHOR == val( + Configuration.DEFAULT_FACET_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME + ) + assert json.dumps([FacetConstants.ORDER_AUTHOR]) == val( + Configuration.ENABLED_FACETS_KEY_PREFIX + + FacetConstants.ORDER_FACET_GROUP_NAME + ) # The library-wide logo was not updated and has been left alone. - assert ("A tiny image" == - ConfigurationSetting.for_library(Configuration.LOGO, library).value) + assert ( + "A tiny image" + == ConfigurationSetting.for_library(Configuration.LOGO, library).value + ) def test_library_delete(self): library = self._library() with self.request_context_with_admin("/", method="DELETE"): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) - pytest.raises(AdminNotAuthorized, - self.manager.admin_library_settings_controller.process_delete, - library.uuid) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_library_settings_controller.process_delete, + library.uuid, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) - response = self.manager.admin_library_settings_controller.process_delete(library.uuid) + response = self.manager.admin_library_settings_controller.process_delete( + library.uuid + ) assert response.status_code == 200 library = get_one(self._db, Library, uuid=library.uuid) @@ -443,6 +574,7 @@ class MockValidator(object): def format_as_string(self, value): self.format_as_string_called_with = value return value + ", formatted for storage" + validator1 = MockValidator() validators = dict(format1=validator1) @@ -453,14 +585,16 @@ class MockController(LibrarySettingsController): def _validate_setting(self, library, setting, validator): self._validate_setting_calls.append((library, setting, validator)) if self.succeed: - return "validated %s" % setting['key'] + return "validated %s" % setting["key"] else: return INVALID_INPUT.detailed("invalid!") # Run library_configuration_settings in a situation where all validations succeed. controller = MockController(self.manager) library = self._default_library - result = controller.library_configuration_settings(library, validators, settings) + result = controller.library_configuration_settings( + library, validators, settings + ) # No problem detail was returned -- the 'request' can continue. assert None == result @@ -476,11 +610,14 @@ def _validate_setting(self, library, setting, validator): # The 'validated' value from the MockValidator was then formatted # for storage using the format() method. - assert "validated %s" % settings[0]['key'] == validator1.format_as_string_called_with + assert ( + "validated %s" % settings[0]["key"] + == validator1.format_as_string_called_with + ) # Each (validated and formatted) value was written to the # database. - setting1, setting2 = [library.setting(x['key']) for x in settings] + setting1, setting2 = [library.setting(x["key"]) for x in settings] assert "validated %s, formatted for storage" % setting1.key == setting1.value assert "validated %s" % setting2.key == setting2.value @@ -494,8 +631,9 @@ def _validate_setting(self, library, setting, validator): ) # _validate_setting was only called once. - assert ([(library, settings[0], validator1)] == - controller._validate_setting_calls) + assert [ + (library, settings[0], validator1) + ] == controller._validate_setting_calls # When it returned a ProblemDetail, that ProblemDetail # was propagated outwards. @@ -504,35 +642,36 @@ def _validate_setting(self, library, setting, validator): # No new values were written to the database. for x in settings: - assert None == library.setting(x['key']).value + assert None == library.setting(x["key"]).value def test__validate_setting(self): # Verify the rules for validating different kinds of settings, # one simulated setting at a time. library = self._default_library + class MockController(LibrarySettingsController): # Mock the functions that pull various values out of the # 'current request' or the 'database' so we don't need an # actual current request or actual database settings. def scalar_setting(self, setting): - return self.scalar_form_values.get(setting['key']) + return self.scalar_form_values.get(setting["key"]) def list_setting(self, setting, json_objects=False): - value = self.list_form_values.get(setting['key']) + value = self.list_form_values.get(setting["key"]) if json_objects: value = [json.loads(x) for x in value] return json.dumps(value) def image_setting(self, setting): - return self.image_form_values.get(setting['key']) + return self.image_form_values.get(setting["key"]) def current_value(self, setting, _library): # While we're here, make sure the right Library # object was passed in. assert _library == library - return self.current_values.get(setting['key']) + return self.current_values.get(setting["key"]) # Now insert mock data into the 'form submission' and # the 'database'. @@ -540,31 +679,28 @@ def current_value(self, setting, _library): # Simulate list values in a form submission. The geographic values # go in as normal strings; the announcements go in as strings that are # JSON-encoded data structures. - announcement_list = [{"content" : "announcement1"}, {"content": "announcement2"}] + announcement_list = [ + {"content": "announcement1"}, + {"content": "announcement2"}, + ] list_form_values = dict( geographic_setting=["geographic values"], - announcement_list=[ - json.dumps(x) for x in announcement_list - ], + announcement_list=[json.dumps(x) for x in announcement_list], language_codes=["English", "fr"], list_value=["a list"], ) # Simulate scalar values in a form submission. - scalar_form_values = dict( - string_value="a scalar value" - ) + scalar_form_values = dict(string_value="a scalar value") # Simulate uploaded images in a form submission. - image_form_values = dict( - image_setting="some image data" - ) + image_form_values = dict(image_setting="some image data") # Simulate values present in the database but not present # in the form submission. current_values = dict( - value_not_present_in_request = "a database value", - previously_uploaded_image = "an old image", + value_not_present_in_request="a database value", + previously_uploaded_image="an old image", ) # First test some simple cases: scalar values. @@ -576,12 +712,15 @@ def current_value(self, setting, _library): # But not for this setting: we end up going to the database # instead. - assert "a database value" == m(library, dict(key="value_not_present_in_request")) + assert "a database value" == m( + library, dict(key="value_not_present_in_request") + ) # And not for this setting either: there is no database value, # so we have to use the default associated with the setting configuration. - assert "a default value" == m(library, dict(key="some_other_value", - default="a default value")) + assert "a default value" == m( + library, dict(key="some_other_value", default="a default value") + ) # An uploaded image is (from the perspective of this method) also simple. @@ -589,7 +728,9 @@ def current_value(self, setting, _library): assert "some image data" == m(library, dict(key="image_setting", type="image")) # Here, no image was uploaded so we use the currently stored database value. - assert "an old image" == m(library, dict(key="previously_uploaded_image", type="image")) + assert "an old image" == m( + library, dict(key="previously_uploaded_image", type="image") + ) # There are some lists which are more complex, but a normal list is # simple: the return value is the JSON-encoded list. @@ -598,43 +739,47 @@ def current_value(self, setting, _library): # Now let's look at the more complex lists. # A list of language codes. - assert ( - json.dumps(["eng", "fre"]) == - m(library, dict(key="language_codes", format="language-code", type="list"))) + assert json.dumps(["eng", "fre"]) == m( + library, dict(key="language_codes", format="language-code", type="list") + ) # A list of geographic places class MockGeographicValidator(object): value = "validated value" + def validate_geographic_areas(self, value, _db): self.called_with = (value, _db) return self.value + validator = MockGeographicValidator() # The validator was consulted and its response was used as the # value. - assert ( - 'validated value' == - m(library, dict(key="geographic_setting", format="geographic"), validator)) + assert "validated value" == m( + library, dict(key="geographic_setting", format="geographic"), validator + ) assert (json.dumps(["geographic values"]), self._db) == validator.called_with # Just to be explicit, let's also test the case where the 'response' sent from the # validator is a ProblemDetail. validator.value = INVALID_INPUT - assert ( - INVALID_INPUT == - m(library, dict(key="geographic_setting", format="geographic"), validator)) + assert INVALID_INPUT == m( + library, dict(key="geographic_setting", format="geographic"), validator + ) # A list of announcements. class MockAnnouncementValidator(object): value = "validated value" + def validate_announcements(self, value): self.called_with = value return self.value + validator = MockAnnouncementValidator() - assert ( - 'validated value' == - m(library, dict(key="announcement_list", type="announcements"), validator)) + assert "validated value" == m( + library, dict(key="announcement_list", type="announcements"), validator + ) assert json.dumps(controller.announcement_list) == validator.called_with def test__format_validated_value(self): diff --git a/tests/admin/controller/test_library_registrations.py b/tests/admin/controller/test_library_registrations.py index 9709bcb5f3..be5c02b9ab 100644 --- a/tests/admin/controller/test_library_registrations.py +++ b/tests/admin/controller/test_library_registrations.py @@ -1,72 +1,81 @@ -import pytest +import json import flask -import json +import pytest from werkzeug.datastructures import MultiDict + from api.admin.exceptions import * -from api.registry import ( - RemoteRegistry, - Registration, -) +from api.registry import Registration, RemoteRegistry from core.model import ( AdminRole, ConfigurationSetting, - create, ExternalIntegration, Library, + create, ) -from core.testing import ( - DummyHTTPClient, - MockRequestsResponse, -) +from core.testing import DummyHTTPClient, MockRequestsResponse from core.util.http import HTTP + from .test_controller import SettingsControllerTest + class TestLibraryRegistration(SettingsControllerTest): """Test the process of registering a library with a RemoteRegistry.""" def test_discovery_service_library_registrations_get(self): # Here's a discovery service. discovery_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.OPDS_REGISTRATION, goal=ExternalIntegration.DISCOVERY_GOAL, ) # We'll be making a mock request to this URL later. - discovery_service.setting(ExternalIntegration.URL).value = ( - "http://service-url/" - ) + discovery_service.setting(ExternalIntegration.URL).value = "http://service-url/" # We successfully registered this library with the service. succeeded, ignore = create( - self._db, Library, name="Library 1", short_name="L1", + self._db, + Library, + name="Library 1", + short_name="L1", ) config = ConfigurationSetting.for_library_and_externalintegration config( - self._db, "library-registration-status", succeeded, - discovery_service + self._db, "library-registration-status", succeeded, discovery_service ).value = "success" # We tried to register this library with the service but were # unsuccessful. config( - self._db, "library-registration-stage", succeeded, - discovery_service + self._db, "library-registration-stage", succeeded, discovery_service ).value = "production" failed, ignore = create( - self._db, Library, name="Library 2", short_name="L2", + self._db, + Library, + name="Library 2", + short_name="L2", ) config( - self._db, "library-registration-status", failed, discovery_service, + self._db, + "library-registration-status", + failed, + discovery_service, ).value = "failure" config( - self._db, "library-registration-stage", failed, discovery_service, + self._db, + "library-registration-stage", + failed, + discovery_service, ).value = "testing" # We've never tried to register this library with the service. unregistered, ignore = create( - self._db, Library, name="Library 3", short_name="L3", + self._db, + Library, + name="Library 3", + short_name="L3", ) discovery_service.libraries = [succeeded, failed] @@ -83,34 +92,30 @@ def test_discovery_service_library_registrations_get(self): # In this case we'll make two requests. The first request will # ask for the root catalog, where we'll look for a # registration link. - root_catalog = dict( - links=[dict(href="http://register-here/", rel="register")] - ) + root_catalog = dict(links=[dict(href="http://register-here/", rel="register")]) client.queue_requests_response( - 200, RemoteRegistry.OPDS_2_TYPE, - content=json.dumps(root_catalog) + 200, RemoteRegistry.OPDS_2_TYPE, content=json.dumps(root_catalog) ) # The second request will fetch that registration link -- then # we'll look for TOS data inside. registration_document = dict( links=[ + dict(rel="terms-of-service", type="text/html", href="http://tos/"), dict( - rel="terms-of-service", type="text/html", - href="http://tos/" + rel="terms-of-service", + type="text/html", + href="data:text/html;charset=utf-8;base64,PHA+SG93IGFib3V0IHRoYXQgVE9TPC9wPg==", ), - dict( - rel="terms-of-service", type="text/html", - href="data:text/html;charset=utf-8;base64,PHA+SG93IGFib3V0IHRoYXQgVE9TPC9wPg==" - ) ] ) client.queue_requests_response( - 200, RemoteRegistry.OPDS_2_TYPE, - content=json.dumps(registration_document) + 200, RemoteRegistry.OPDS_2_TYPE, content=json.dumps(registration_document) ) - controller = self.manager.admin_discovery_service_library_registrations_controller + controller = ( + self.manager.admin_discovery_service_library_registrations_controller + ) m = controller.process_discovery_service_library_registrations with self.request_context_with_admin("/", method="GET"): response = m(do_get=client.do_get) @@ -124,15 +129,14 @@ def test_discovery_service_library_registrations_get(self): # happened. The target of the first request is the URL to # the discovery service's main catalog. The second request # is to the "register" link found in that catalog. - assert (["http://service-url/", "http://register-here/"] == - client.requests) + assert ["http://service-url/", "http://register-here/"] == client.requests # The TOS link and TOS HTML snippet were recovered from # the registration document served in response to the # second HTTP request, and included in the dictionary. - assert "http://tos/" == service['terms_of_service_link'] - assert "

    How about that TOS

    " == service['terms_of_service_html'] - assert None == service['access_problem'] + assert "http://tos/" == service["terms_of_service_link"] + assert "

    How about that TOS

    " == service["terms_of_service_html"] + assert None == service["access_problem"] # The dictionary includes a 'libraries' object, a list of # dictionaries with information about the relationships @@ -141,17 +145,15 @@ def test_discovery_service_library_registrations_get(self): info1, info2 = service["libraries"] # Here's the library that successfully registered. - assert ( - info1 == - dict(short_name=succeeded.short_name, status="success", - stage="production")) + assert info1 == dict( + short_name=succeeded.short_name, status="success", stage="production" + ) # And here's the library that tried to register but # failed. - assert ( - info2 == - dict(short_name=failed.short_name, status="failure", - stage="testing")) + assert info2 == dict( + short_name=failed.short_name, status="failure", stage="testing" + ) # Note that `unregistered`, the library that never tried # to register with this discover service, is not included. @@ -162,7 +164,8 @@ def test_discovery_service_library_registrations_get(self): # there will be no second request. client.requests = [] client.queue_requests_response( - 502, content=REMOTE_INTEGRATION_FAILED, + 502, + content=REMOTE_INTEGRATION_FAILED, ) response = m(do_get=client.do_get) @@ -170,16 +173,15 @@ def test_discovery_service_library_registrations_get(self): # available. [service] = response["library_registrations"] assert discovery_service.id == service["id"] - assert 2 == len(service['libraries']) - assert None == service['terms_of_service_link'] - assert None == service['terms_of_service_html'] + assert 2 == len(service["libraries"]) + assert None == service["terms_of_service_link"] + assert None == service["terms_of_service_html"] # The problem detail document that prevented the TOS data # from showing up has been converted to a dictionary and # included in the dictionary of information for this # discovery service. - assert (REMOTE_INTEGRATION_FAILED.uri == - service['access_problem']['type']) + assert REMOTE_INTEGRATION_FAILED.uri == service["access_problem"]["type"] # When the user lacks the SYSTEM_ADMIN role, the # controller won't even start processing their GET @@ -193,30 +195,36 @@ def test_discovery_service_library_registrations_post(self): discovery_service_library_registrations. """ - controller = self.manager.admin_discovery_service_library_registrations_controller + controller = ( + self.manager.admin_discovery_service_library_registrations_controller + ) m = controller.process_discovery_service_library_registrations # Here, the user doesn't have permission to start the # registration process. self.admin.remove_role(AdminRole.SYSTEM_ADMIN) with self.request_context_with_admin("/", method="POST"): - pytest.raises(AdminNotAuthorized, m, - do_get=self.do_request, do_post=self.do_request) + pytest.raises( + AdminNotAuthorized, m, do_get=self.do_request, do_post=self.do_request + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) # The integration ID might not correspond to a valid # ExternalIntegration. with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("integration_id", "1234"), - ]) + flask.request.form = MultiDict( + [ + ("integration_id", "1234"), + ] + ) response = m() assert MISSING_SERVICE == response # Create an ExternalIntegration to avoid that problem in future # tests. discovery_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.OPDS_REGISTRATION, goal=ExternalIntegration.DISCOVERY_GOAL, ) @@ -224,20 +232,24 @@ def test_discovery_service_library_registrations_post(self): # The library name might not correspond to a real library. with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("integration_id", discovery_service.id), - ("library_short_name", "not-a-library"), - ]) + flask.request.form = MultiDict( + [ + ("integration_id", discovery_service.id), + ("library_short_name", "not-a-library"), + ] + ) response = m() assert NO_SUCH_LIBRARY == response # Take care of that problem. library = self._default_library - form = MultiDict([ - ("integration_id", discovery_service.id), - ("library_short_name", library.short_name), - ("registration_stage", Registration.TESTING_STAGE), - ]) + form = MultiDict( + [ + ("integration_id", discovery_service.id), + ("library_short_name", library.short_name), + ("registration_stage", Registration.TESTING_STAGE), + ] + ) # Registration.push might return a ProblemDetail for whatever # reason. @@ -245,8 +257,7 @@ class Mock(Registration): # We reproduce the signature, even though it's not # necessary for what we're testing, so that if the push() # signature changes this test will fail. - def push(self, stage, url_for, catalog_url=None, do_get=None, - do_post=None): + def push(self, stage, url_for, catalog_url=None, do_get=None, do_post=None): return REMOTE_INTEGRATION_FAILED with self.request_context_with_admin("/", method="POST"): @@ -259,7 +270,9 @@ class Mock(Registration): """When asked to push a registration, do nothing and say it worked. """ + called_with = None + def push(self, *args, **kwargs): Mock.called_with = (args, kwargs) return True @@ -276,8 +289,8 @@ def push(self, *args, **kwargs): assert (Registration.TESTING_STAGE, self.manager.url_for) == args # We would have made real HTTP requests. - assert HTTP.debuggable_post == kwargs.pop('do_post') - assert HTTP.debuggable_get == kwargs.pop('do_get') + assert HTTP.debuggable_post == kwargs.pop("do_post") + assert HTTP.debuggable_get == kwargs.pop("do_get") # No other keyword arguments were passed in. assert {} == kwargs diff --git a/tests/admin/controller/test_metadata_service_self_tests.py b/tests/admin/controller/test_metadata_service_self_tests.py index aad87fad28..c66b56aea7 100644 --- a/tests/admin/controller/test_metadata_service_self_tests.py +++ b/tests/admin/controller/test_metadata_service_self_tests.py @@ -1,31 +1,29 @@ - from flask_babel import lazy_gettext as _ -from core.selftest import ( - HasSelfTests, - SelfTestResult, -) -from .test_controller import SettingsControllerTest -from core.model import ( - create, - ExternalIntegration, -) -from core.opds_import import MetadataWranglerOPDSLookup from api.admin.problem_details import * from api.nyt import NYTBestSellerAPI +from core.model import ExternalIntegration, create +from core.opds_import import MetadataWranglerOPDSLookup +from core.selftest import HasSelfTests, SelfTestResult -class TestMetadataServiceSelfTests(SettingsControllerTest): +from .test_controller import SettingsControllerTest + +class TestMetadataServiceSelfTests(SettingsControllerTest): def test_metadata_service_self_tests_with_no_identifier(self): with self.request_context_with_admin("/"): - response = self.manager.admin_metadata_service_self_tests_controller.process_metadata_service_self_tests(None) + response = self.manager.admin_metadata_service_self_tests_controller.process_metadata_service_self_tests( + None + ) assert response.title == MISSING_IDENTIFIER.title assert response.detail == MISSING_IDENTIFIER.detail assert response.status_code == 400 def test_metadata_service_self_tests_with_no_metadata_service_found(self): with self.request_context_with_admin("/"): - response = self.manager.admin_metadata_service_self_tests_controller.process_metadata_service_self_tests(-1) + response = self.manager.admin_metadata_service_self_tests_controller.process_metadata_service_self_tests( + -1 + ) assert response == MISSING_SERVICE assert response.status_code == 404 @@ -33,23 +31,30 @@ def test_metadata_service_self_tests_test_get(self): old_prior_test_results = HasSelfTests.prior_test_results HasSelfTests.prior_test_results = self.mock_prior_test_results metadata_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.NYT, goal=ExternalIntegration.METADATA_GOAL, ) # Make sure that HasSelfTest.prior_test_results() was called and that # it is in the response's self tests object. with self.request_context_with_admin("/"): - response = self.manager.admin_metadata_service_self_tests_controller.process_metadata_service_self_tests(metadata_service.id) + response = self.manager.admin_metadata_service_self_tests_controller.process_metadata_service_self_tests( + metadata_service.id + ) response_metadata_service = response.get("self_test_results") assert response_metadata_service.get("id") == metadata_service.id assert response_metadata_service.get("name") == metadata_service.name - assert response_metadata_service.get("protocol").get("label") == NYTBestSellerAPI.NAME + assert ( + response_metadata_service.get("protocol").get("label") + == NYTBestSellerAPI.NAME + ) assert response_metadata_service.get("goal") == metadata_service.goal assert ( - response_metadata_service.get("self_test_results") == - HasSelfTests.prior_test_results()) + response_metadata_service.get("self_test_results") + == HasSelfTests.prior_test_results() + ) HasSelfTests.prior_test_results = old_prior_test_results def test_metadata_service_self_tests_post(self): @@ -57,11 +62,14 @@ def test_metadata_service_self_tests_post(self): HasSelfTests.run_self_tests = self.mock_run_self_tests metadata_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.NYT, - goal=ExternalIntegration.METADATA_GOAL + goal=ExternalIntegration.METADATA_GOAL, + ) + m = ( + self.manager.admin_metadata_service_self_tests_controller.self_tests_process_post ) - m = self.manager.admin_metadata_service_self_tests_controller.self_tests_process_post with self.request_context_with_admin("/", method="POST"): response = m(metadata_service.id) assert response._status == "200 OK" @@ -74,13 +82,7 @@ def test_metadata_service_self_tests_post(self): # (NYTBestSellerAPI.from_config) # * The database connection again (to be passed into # NYTBestSellerAPI.from_config). - assert ( - ( - self._db, - NYTBestSellerAPI.from_config, - self._db - ) == - positional) + assert (self._db, NYTBestSellerAPI.from_config, self._db) == positional # run_self_tests was not called with any keyword arguments. assert {} == keyword diff --git a/tests/admin/controller/test_metadata_services.py b/tests/admin/controller/test_metadata_services.py index 1812a38560..cc1411e23a 100644 --- a/tests/admin/controller/test_metadata_services.py +++ b/tests/admin/controller/test_metadata_services.py @@ -1,30 +1,28 @@ -import pytest +import json import flask -import json +import pytest from werkzeug.datastructures import MultiDict -from core.util.http import HTTP -from api.nyt import NYTBestSellerAPI + +from api.admin.controller.metadata_services import MetadataServicesController from api.admin.exceptions import * from api.admin.problem_details import INVALID_URL from api.novelist import NoveListAPI -from api.admin.controller.metadata_services import MetadataServicesController +from api.nyt import NYTBestSellerAPI +from core.model import AdminRole, ExternalIntegration, Library, create, get_one from core.opds_import import MetadataWranglerOPDSLookup -from core.model import ( - AdminRole, - create, - ExternalIntegration, - get_one, - Library, -) +from core.util.http import HTTP + from .test_controller import SettingsControllerTest + class TestMetadataServices(SettingsControllerTest): def create_service(self, name): return create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.__dict__.get(name) or "fake", - goal=ExternalIntegration.METADATA_GOAL + goal=ExternalIntegration.METADATA_GOAL, )[0] def test_process_metadata_services_dispatches_by_request_method(self): @@ -47,10 +45,7 @@ def process_post(self): self._db.flush() with self.request_context_with_admin("/"): - pytest.raises( - AdminNotAuthorized, - controller.process_metadata_services - ) + pytest.raises(AdminNotAuthorized, controller.process_metadata_services) def test_process_get_with_no_services(self): with self.request_context_with_admin("/"): @@ -100,11 +95,15 @@ def test_process_get_with_self_tests(self): # But we just need to make sure that the response has a self_test_results attribute--for this test, # it doesn't matter what it is--so that's fine. assert ( - service.get("self_test_results").get("exception") == - "Exception getting self-test results for metadata service Test: Metadata Wrangler improperly configured.") + service.get("self_test_results").get("exception") + == "Exception getting self-test results for metadata service Test: Metadata Wrangler improperly configured." + ) def test_find_protocol_class(self): - [wrangler, nyt, novelist, fake] = [self.create_service(x) for x in ["METADATA_WRANGLER", "NYT", "NOVELIST", "FAKE"]] + [wrangler, nyt, novelist, fake] = [ + self.create_service(x) + for x in ["METADATA_WRANGLER", "NYT", "NOVELIST", "FAKE"] + ] m = self.manager.admin_metadata_services_controller.find_protocol_class assert m(wrangler)[0] == MetadataWranglerOPDSLookup @@ -115,10 +114,12 @@ def test_find_protocol_class(self): def test_metadata_services_post_errors(self): controller = self.manager.admin_metadata_services_controller with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("protocol", "Unknown"), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("protocol", "Unknown"), + ] + ) response = controller.process_post() assert response == UNKNOWN_PROTOCOL @@ -128,18 +129,22 @@ def test_metadata_services_post_errors(self): assert response == INCOMPLETE_CONFIGURATION with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ] + ) response = controller.process_post() assert response == NO_PROTOCOL_FOR_NEW_SERVICE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", "123"), - ("protocol", ExternalIntegration.NYT), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", "123"), + ("protocol", ExternalIntegration.NYT), + ] + ) response = controller.process_post() assert response == MISSING_SERVICE @@ -147,63 +152,75 @@ def test_metadata_services_post_errors(self): service.name = "name" with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", service.name), - ("protocol", ExternalIntegration.NYT), - ]) + flask.request.form = MultiDict( + [ + ("name", service.name), + ("protocol", ExternalIntegration.NYT), + ] + ) response = controller.process_post() assert response == INTEGRATION_NAME_ALREADY_IN_USE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", service.id), - ("protocol", ExternalIntegration.NYT), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", service.id), + ("protocol", ExternalIntegration.NYT), + ] + ) response = controller.process_post() assert response == CANNOT_CHANGE_PROTOCOL with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", service.id), - ("protocol", ExternalIntegration.NOVELIST), - ]) + flask.request.form = MultiDict( + [ + ("id", service.id), + ("protocol", ExternalIntegration.NOVELIST), + ] + ) response = controller.process_post() assert response.uri == INCOMPLETE_CONFIGURATION.uri with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", service.id), - ("protocol", ExternalIntegration.NOVELIST), - (ExternalIntegration.USERNAME, "user"), - (ExternalIntegration.PASSWORD, "pass"), - ("libraries", json.dumps([{"short_name": "not-a-library"}])), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", service.id), + ("protocol", ExternalIntegration.NOVELIST), + (ExternalIntegration.USERNAME, "user"), + (ExternalIntegration.PASSWORD, "pass"), + ("libraries", json.dumps([{"short_name": "not-a-library"}])), + ] + ) response = controller.process_post() assert response.uri == NO_SUCH_LIBRARY.uri def test_metadata_services_post_create(self): controller = self.manager.admin_metadata_services_controller library, ignore = create( - self._db, Library, name="Library", short_name="L", + self._db, + Library, + name="Library", + short_name="L", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("protocol", ExternalIntegration.NOVELIST), - (ExternalIntegration.USERNAME, "user"), - (ExternalIntegration.PASSWORD, "pass"), - ("libraries", json.dumps([{"short_name": "L"}])), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("protocol", ExternalIntegration.NOVELIST), + (ExternalIntegration.USERNAME, "user"), + (ExternalIntegration.PASSWORD, "pass"), + ("libraries", json.dumps([{"short_name": "L"}])), + ] + ) response = controller.process_post() assert response.status_code == 201 # A new ExternalIntegration has been created based on the submitted # information. service = get_one( - self._db, ExternalIntegration, - goal=ExternalIntegration.METADATA_GOAL + self._db, ExternalIntegration, goal=ExternalIntegration.METADATA_GOAL ) assert service.id == int(response.response[0]) assert ExternalIntegration.NOVELIST == service.protocol @@ -213,10 +230,16 @@ def test_metadata_services_post_create(self): def test_metadata_services_post_edit(self): l1, ignore = create( - self._db, Library, name="Library 1", short_name="L1", + self._db, + Library, + name="Library 1", + short_name="L1", ) l2, ignore = create( - self._db, Library, name="Library 2", short_name="L2", + self._db, + Library, + name="Library 2", + short_name="L2", ) novelist_service = self.create_service("NOVELIST") novelist_service.username = "olduser" @@ -225,14 +248,16 @@ def test_metadata_services_post_edit(self): controller = self.manager.admin_metadata_services_controller with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", novelist_service.id), - ("protocol", ExternalIntegration.NOVELIST), - (ExternalIntegration.USERNAME, "user"), - (ExternalIntegration.PASSWORD, "pass"), - ("libraries", json.dumps([{"short_name": "L2"}])), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", novelist_service.id), + ("protocol", ExternalIntegration.NOVELIST), + (ExternalIntegration.USERNAME, "user"), + (ExternalIntegration.PASSWORD, "pass"), + ("libraries", json.dumps([{"short_name": "L2"}])), + ] + ) response = controller.process_post() assert response.status_code == 200 @@ -240,18 +265,21 @@ def test_metadata_services_post_calls_register_with_metadata_wrangler(self): """Verify that process_post() calls register_with_metadata_wrangler if the rest of the request is handled successfully. """ + class Mock(MetadataServicesController): RETURN_VALUE = INVALID_URL called_with = None - def register_with_metadata_wrangler( - self, do_get, do_post, is_new, service - ): + + def register_with_metadata_wrangler(self, do_get, do_post, is_new, service): self.called_with = (do_get, do_post, is_new, service) return self.RETURN_VALUE controller = Mock(self.manager) library, ignore = create( - self._db, Library, name="Library", short_name="L", + self._db, + Library, + name="Library", + short_name="L", ) do_get = object() do_post = object() @@ -263,12 +291,14 @@ def register_with_metadata_wrangler( # register_with_metadata_wrangler was not called. assert None == controller.called_with - form = MultiDict([ - ("name", "Name"), - ("protocol", ExternalIntegration.NOVELIST), - (ExternalIntegration.USERNAME, "user"), - (ExternalIntegration.PASSWORD, "pass"), - ]) + form = MultiDict( + [ + ("name", "Name"), + ("protocol", ExternalIntegration.NOVELIST), + (ExternalIntegration.USERNAME, "user"), + (ExternalIntegration.PASSWORD, "pass"), + ] + ) with self.request_context_with_admin("/", method="POST"): flask.request.form = form @@ -281,8 +311,9 @@ def register_with_metadata_wrangler( assert INVALID_URL == response # We ended up not creating an ExternalIntegration. - assert None == get_one(self._db, ExternalIntegration, - goal=ExternalIntegration.METADATA_GOAL) + assert None == get_one( + self._db, ExternalIntegration, goal=ExternalIntegration.METADATA_GOAL + ) # But the ExternalIntegration we _would_ have created was # passed in to register_with_metadata_wrangler. @@ -303,8 +334,7 @@ def register_with_metadata_wrangler( # This time we successfully created an ExternalIntegration. integration = get_one( - self._db, ExternalIntegration, - goal=ExternalIntegration.METADATA_GOAL + self._db, ExternalIntegration, goal=ExternalIntegration.METADATA_GOAL ) assert integration != None @@ -318,11 +348,11 @@ def test_register_with_metadata_wrangler(self): """Verify that register_with_metadata wrangler calls process_sitewide_registration appropriately. """ + class Mock(MetadataServicesController): called_with = None - def process_sitewide_registration( - self, integration, do_get, do_post - ): + + def process_sitewide_registration(self, integration, do_get, do_post): self.called_with = (integration, do_get, do_post) controller = Mock(self.manager) @@ -332,9 +362,7 @@ def process_sitewide_registration( # If register_with_metadata_wrangler is called on an ExternalIntegration # with some other service, nothing happens. - integration = self._external_integration( - protocol=ExternalIntegration.NOVELIST - ) + integration = self._external_integration(protocol=ExternalIntegration.NOVELIST) m(do_get, do_post, True, integration) assert None == controller.called_with @@ -343,7 +371,7 @@ def process_sitewide_registration( integration = self._external_integration( protocol=ExternalIntegration.METADATA_WRANGLER ) - integration.password = 'already done' + integration.password = "already done" m(do_get, do_post, False, integration) assert None == controller.called_with @@ -359,32 +387,36 @@ def process_sitewide_registration( result = m(do_get, do_post, False, integration) def test_check_name_unique(self): - kwargs = dict(protocol=ExternalIntegration.NYT, - goal=ExternalIntegration.METADATA_GOAL) + kwargs = dict( + protocol=ExternalIntegration.NYT, goal=ExternalIntegration.METADATA_GOAL + ) - existing_service, ignore = create(self._db, ExternalIntegration, name="existing service", **kwargs) - new_service, ignore = create(self._db, ExternalIntegration, name="new service", **kwargs) + existing_service, ignore = create( + self._db, ExternalIntegration, name="existing service", **kwargs + ) + new_service, ignore = create( + self._db, ExternalIntegration, name="new service", **kwargs + ) - m = self.manager.admin_metadata_services_controller.check_name_unique + m = self.manager.admin_metadata_services_controller.check_name_unique - # Try to change new service so that it has the same name as existing service - # -- this is not allowed. - result = m(new_service, existing_service.name) - assert result == INTEGRATION_NAME_ALREADY_IN_USE + # Try to change new service so that it has the same name as existing service + # -- this is not allowed. + result = m(new_service, existing_service.name) + assert result == INTEGRATION_NAME_ALREADY_IN_USE - # Try to edit existing service without changing its name -- this is fine. - assert ( - None == - m(existing_service, existing_service.name)) + # Try to edit existing service without changing its name -- this is fine. + assert None == m(existing_service, existing_service.name) - # Changing the existing service's name is also fine. - assert ( - None == - m(existing_service, "new name")) + # Changing the existing service's name is also fine. + assert None == m(existing_service, "new name") def test_metadata_service_delete(self): l1, ignore = create( - self._db, Library, name="Library 1", short_name="L1", + self._db, + Library, + name="Library 1", + short_name="L1", ) novelist_service = self.create_service("NOVELIST") novelist_service.username = "olduser" @@ -393,12 +425,16 @@ def test_metadata_service_delete(self): with self.request_context_with_admin("/", method="DELETE"): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) - pytest.raises(AdminNotAuthorized, - self.manager.admin_metadata_services_controller.process_delete, - novelist_service.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_metadata_services_controller.process_delete, + novelist_service.id, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) - response = self.manager.admin_metadata_services_controller.process_delete(novelist_service.id) + response = self.manager.admin_metadata_services_controller.process_delete( + novelist_service.id + ) assert response.status_code == 200 service = get_one(self._db, ExternalIntegration, id=novelist_service.id) diff --git a/tests/admin/controller/test_patron_auth.py b/tests/admin/controller/test_patron_auth.py index 093c4b3510..b3d818e83e 100644 --- a/tests/admin/controller/test_patron_auth.py +++ b/tests/admin/controller/test_patron_auth.py @@ -1,15 +1,13 @@ -import pytest +import json import flask +import pytest from flask_babel import lazy_gettext as _ -import json from werkzeug.datastructures import MultiDict + from api.admin.controller.patron_auth_services import PatronAuthServicesController from api.admin.exceptions import * -from api.authenticator import ( - AuthenticationProvider, - BasicAuthenticationProvider, -) +from api.authenticator import AuthenticationProvider, BasicAuthenticationProvider from api.clever import CleverAuthenticationAPI from api.firstbook import FirstBookAuthenticationAPI from api.millenium_patron import MilleniumPatronAPI @@ -19,19 +17,22 @@ from api.sip import SIP2AuthenticationProvider from core.model import ( AdminRole, - create, ConfigurationSetting, ExternalIntegration, - get_one, Library, + create, + get_one, ) + from .test_controller import SettingsControllerTest -class TestPatronAuth(SettingsControllerTest): +class TestPatronAuth(SettingsControllerTest): def test_patron_auth_services_get_with_no_services(self): with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response.get("patron_auth_services") == [] protocols = response.get("protocols") assert 8 == len(protocols) @@ -43,12 +44,13 @@ def test_patron_auth_services_get_with_no_services(self): self._db.flush() pytest.raises( AdminNotAuthorized, - self.manager.admin_patron_auth_services_controller.process_patron_auth_services + self.manager.admin_patron_auth_services_controller.process_patron_auth_services, ) def test_patron_auth_services_get_with_simple_auth_service(self): auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=SimpleAuthenticationProvider.__module__, goal=ExternalIntegration.PATRON_AUTH_GOAL, name="name", @@ -57,74 +59,111 @@ def test_patron_auth_services_get_with_simple_auth_service(self): auth_service.setting(BasicAuthenticationProvider.TEST_PASSWORD).value = "pass" with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) [service] = response.get("patron_auth_services") assert auth_service.id == service.get("id") assert auth_service.name == service.get("name") assert SimpleAuthenticationProvider.__module__ == service.get("protocol") - assert "user" == service.get("settings").get(BasicAuthenticationProvider.TEST_IDENTIFIER) - assert "pass" == service.get("settings").get(BasicAuthenticationProvider.TEST_PASSWORD) + assert "user" == service.get("settings").get( + BasicAuthenticationProvider.TEST_IDENTIFIER + ) + assert "pass" == service.get("settings").get( + BasicAuthenticationProvider.TEST_PASSWORD + ) assert [] == service.get("libraries") auth_service.libraries += [self._default_library] with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) [service] = response.get("patron_auth_services") - assert "user" == service.get("settings").get(BasicAuthenticationProvider.TEST_IDENTIFIER) + assert "user" == service.get("settings").get( + BasicAuthenticationProvider.TEST_IDENTIFIER + ) [library] = service.get("libraries") assert self._default_library.short_name == library.get("short_name") - assert None == library.get(AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION) + assert None == library.get( + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION + ) ConfigurationSetting.for_library_and_externalintegration( - self._db, AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, - self._default_library, auth_service, + self._db, + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, + self._default_library, + auth_service, ).value = "^(u)" with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) [service] = response.get("patron_auth_services") [library] = service.get("libraries") assert self._default_library.short_name == library.get("short_name") - assert "^(u)" == library.get(AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION) + assert "^(u)" == library.get( + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION + ) def test_patron_auth_services_get_with_millenium_auth_service(self): auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=MilleniumPatronAPI.__module__, - goal=ExternalIntegration.PATRON_AUTH_GOAL + goal=ExternalIntegration.PATRON_AUTH_GOAL, ) auth_service.setting(BasicAuthenticationProvider.TEST_IDENTIFIER).value = "user" auth_service.setting(BasicAuthenticationProvider.TEST_PASSWORD).value = "pass" - auth_service.setting(BasicAuthenticationProvider.IDENTIFIER_REGULAR_EXPRESSION).value = "u*" - auth_service.setting(BasicAuthenticationProvider.PASSWORD_REGULAR_EXPRESSION).value = "p*" + auth_service.setting( + BasicAuthenticationProvider.IDENTIFIER_REGULAR_EXPRESSION + ).value = "u*" + auth_service.setting( + BasicAuthenticationProvider.PASSWORD_REGULAR_EXPRESSION + ).value = "p*" auth_service.libraries += [self._default_library] ConfigurationSetting.for_library_and_externalintegration( - self._db, AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, - self._default_library, auth_service, + self._db, + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, + self._default_library, + auth_service, ).value = "^(u)" with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) [service] = response.get("patron_auth_services") assert auth_service.id == service.get("id") assert MilleniumPatronAPI.__module__ == service.get("protocol") - assert "user" == service.get("settings").get(BasicAuthenticationProvider.TEST_IDENTIFIER) - assert "pass" == service.get("settings").get(BasicAuthenticationProvider.TEST_PASSWORD) - assert "u*" == service.get("settings").get(BasicAuthenticationProvider.IDENTIFIER_REGULAR_EXPRESSION) - assert "p*" == service.get("settings").get(BasicAuthenticationProvider.PASSWORD_REGULAR_EXPRESSION) + assert "user" == service.get("settings").get( + BasicAuthenticationProvider.TEST_IDENTIFIER + ) + assert "pass" == service.get("settings").get( + BasicAuthenticationProvider.TEST_PASSWORD + ) + assert "u*" == service.get("settings").get( + BasicAuthenticationProvider.IDENTIFIER_REGULAR_EXPRESSION + ) + assert "p*" == service.get("settings").get( + BasicAuthenticationProvider.PASSWORD_REGULAR_EXPRESSION + ) [library] = service.get("libraries") assert self._default_library.short_name == library.get("short_name") - assert "^(u)" == library.get(AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION) - + assert "^(u)" == library.get( + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION + ) def test_patron_auth_services_get_with_sip2_auth_service(self): auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=SIP2AuthenticationProvider.__module__, - goal=ExternalIntegration.PATRON_AUTH_GOAL + goal=ExternalIntegration.PATRON_AUTH_GOAL, ) auth_service.url = "url" auth_service.setting(SIP2AuthenticationProvider.PORT).value = "1234" @@ -135,42 +174,59 @@ def test_patron_auth_services_get_with_sip2_auth_service(self): auth_service.libraries += [self._default_library] ConfigurationSetting.for_library_and_externalintegration( - self._db, AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, - self._default_library, auth_service, + self._db, + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, + self._default_library, + auth_service, ).value = "^(u)" with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) [service] = response.get("patron_auth_services") assert auth_service.id == service.get("id") assert SIP2AuthenticationProvider.__module__ == service.get("protocol") assert "url" == service.get("settings").get(ExternalIntegration.URL) - assert "1234" == service.get("settings").get(SIP2AuthenticationProvider.PORT) + assert "1234" == service.get("settings").get( + SIP2AuthenticationProvider.PORT + ) assert "user" == service.get("settings").get(ExternalIntegration.USERNAME) assert "pass" == service.get("settings").get(ExternalIntegration.PASSWORD) - assert "5" == service.get("settings").get(SIP2AuthenticationProvider.LOCATION_CODE) - assert "," == service.get("settings").get(SIP2AuthenticationProvider.FIELD_SEPARATOR) + assert "5" == service.get("settings").get( + SIP2AuthenticationProvider.LOCATION_CODE + ) + assert "," == service.get("settings").get( + SIP2AuthenticationProvider.FIELD_SEPARATOR + ) [library] = service.get("libraries") assert self._default_library.short_name == library.get("short_name") - assert "^(u)" == library.get(AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION) + assert "^(u)" == library.get( + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION + ) def test_patron_auth_services_get_with_firstbook_auth_service(self): auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=FirstBookAuthenticationAPI.__module__, - goal=ExternalIntegration.PATRON_AUTH_GOAL + goal=ExternalIntegration.PATRON_AUTH_GOAL, ) auth_service.url = "url" auth_service.password = "pass" auth_service.libraries += [self._default_library] ConfigurationSetting.for_library_and_externalintegration( - self._db, AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, - self._default_library, auth_service, + self._db, + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, + self._default_library, + auth_service, ).value = "^(u)" with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) [service] = response.get("patron_auth_services") assert auth_service.id == service.get("id") @@ -179,20 +235,25 @@ def test_patron_auth_services_get_with_firstbook_auth_service(self): assert "pass" == service.get("settings").get(ExternalIntegration.PASSWORD) [library] = service.get("libraries") assert self._default_library.short_name == library.get("short_name") - assert "^(u)" == library.get(AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION) + assert "^(u)" == library.get( + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION + ) def test_patron_auth_services_get_with_clever_auth_service(self): auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=CleverAuthenticationAPI.__module__, - goal=ExternalIntegration.PATRON_AUTH_GOAL + goal=ExternalIntegration.PATRON_AUTH_GOAL, ) auth_service.username = "user" auth_service.password = "pass" auth_service.libraries += [self._default_library] with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) [service] = response.get("patron_auth_services") assert auth_service.id == service.get("id") @@ -204,18 +265,23 @@ def test_patron_auth_services_get_with_clever_auth_service(self): def test_patron_auth_services_get_with_saml_auth_service(self): auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=SAMLWebSSOAuthenticationProvider.__module__, - goal=ExternalIntegration.PATRON_AUTH_GOAL + goal=ExternalIntegration.PATRON_AUTH_GOAL, ) auth_service.libraries += [self._default_library] with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) [service] = response.get("patron_auth_services") assert auth_service.id == service.get("id") - assert SAMLWebSSOAuthenticationProvider.__module__ == service.get("protocol") + assert SAMLWebSSOAuthenticationProvider.__module__ == service.get( + "protocol" + ) [library] = service.get("libraries") assert self._default_library.short_name == library.get("short_name") @@ -234,141 +300,220 @@ def _common_basic_auth_arguments(self): def test_patron_auth_services_post_errors(self): with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", "Unknown"), - ]) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + flask.request.form = MultiDict( + [ + ("protocol", "Unknown"), + ] + ) + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response == UNKNOWN_PROTOCOL with self.request_context_with_admin("/", method="POST"): flask.request.form = MultiDict([]) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response == NO_PROTOCOL_FOR_NEW_SERVICE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", "123"), - ]) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + flask.request.form = MultiDict( + [ + ("id", "123"), + ] + ) + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response == MISSING_SERVICE auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=SimpleAuthenticationProvider.__module__, goal=ExternalIntegration.PATRON_AUTH_GOAL, name="name", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", auth_service.id), - ("protocol", SIP2AuthenticationProvider.__module__), - ]) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + flask.request.form = MultiDict( + [ + ("id", auth_service.id), + ("protocol", SIP2AuthenticationProvider.__module__), + ] + ) + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response == CANNOT_CHANGE_PROTOCOL with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", auth_service.name), - ("protocol", SIP2AuthenticationProvider.__module__), - ]) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + flask.request.form = MultiDict( + [ + ("name", auth_service.name), + ("protocol", SIP2AuthenticationProvider.__module__), + ] + ) + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response == INTEGRATION_NAME_ALREADY_IN_USE auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=MilleniumPatronAPI.__module__, - goal=ExternalIntegration.PATRON_AUTH_GOAL + goal=ExternalIntegration.PATRON_AUTH_GOAL, ) common_args = self._common_basic_auth_arguments() with self.request_context_with_admin("/", method="POST"): M = MilleniumPatronAPI - flask.request.form = MultiDict([ - ("name", "some auth name"), - ("id", auth_service.id), - ("protocol", MilleniumPatronAPI.__module__), - (ExternalIntegration.URL, "http://url"), - (M.AUTHENTICATION_MODE, "Invalid mode"), - (M.VERIFY_CERTIFICATE, "true"), - ] + common_args) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + flask.request.form = MultiDict( + [ + ("name", "some auth name"), + ("id", auth_service.id), + ("protocol", MilleniumPatronAPI.__module__), + (ExternalIntegration.URL, "http://url"), + (M.AUTHENTICATION_MODE, "Invalid mode"), + (M.VERIFY_CERTIFICATE, "true"), + ] + + common_args + ) + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response.uri == INVALID_CONFIGURATION_OPTION.uri auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=SimpleAuthenticationProvider.__module__, - goal=ExternalIntegration.PATRON_AUTH_GOAL + goal=ExternalIntegration.PATRON_AUTH_GOAL, ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", auth_service.id), - ("protocol", SimpleAuthenticationProvider.__module__), - ]) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + flask.request.form = MultiDict( + [ + ("id", auth_service.id), + ("protocol", SimpleAuthenticationProvider.__module__), + ] + ) + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response.uri == INCOMPLETE_CONFIGURATION.uri with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", SimpleAuthenticationProvider.__module__), - ("libraries", json.dumps([{ "short_name": "not-a-library" }])), - ] + common_args) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + flask.request.form = MultiDict( + [ + ("protocol", SimpleAuthenticationProvider.__module__), + ("libraries", json.dumps([{"short_name": "not-a-library"}])), + ] + + common_args + ) + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response.uri == NO_SUCH_LIBRARY.uri library, ignore = create( - self._db, Library, name="Library", short_name="L", + self._db, + Library, + name="Library", + short_name="L", ) auth_service.libraries += [library] with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", SimpleAuthenticationProvider.__module__), - ("libraries", json.dumps([{ - "short_name": library.short_name, - AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE, - AuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, - }])), - ] + common_args) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + flask.request.form = MultiDict( + [ + ("protocol", SimpleAuthenticationProvider.__module__), + ( + "libraries", + json.dumps( + [ + { + "short_name": library.short_name, + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE, + AuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, + } + ] + ), + ), + ] + + common_args + ) + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response.uri == MULTIPLE_BASIC_AUTH_SERVICES.uri self._db.delete(auth_service) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", SimpleAuthenticationProvider.__module__), - ("libraries", json.dumps([{ - "short_name": library.short_name, - AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE, - AuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, - AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION: "(invalid re", - }])), - ] + common_args) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + flask.request.form = MultiDict( + [ + ("protocol", SimpleAuthenticationProvider.__module__), + ( + "libraries", + json.dumps( + [ + { + "short_name": library.short_name, + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE, + AuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION: "(invalid re", + } + ] + ), + ), + ] + + common_args + ) + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response == INVALID_EXTERNAL_TYPE_REGULAR_EXPRESSION with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", SimpleAuthenticationProvider.__module__), - ("libraries", json.dumps([{ - "short_name": library.short_name, - AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX, - AuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, - AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION: "(invalid re", - }])), - ] + common_args) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + flask.request.form = MultiDict( + [ + ("protocol", SimpleAuthenticationProvider.__module__), + ( + "libraries", + json.dumps( + [ + { + "short_name": library.short_name, + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX, + AuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION: "(invalid re", + } + ] + ), + ), + ] + + common_args + ) + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response == INVALID_LIBRARY_IDENTIFIER_RESTRICTION_REGULAR_EXPRESSION self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self._db.flush() with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", SimpleAuthenticationProvider.__module__), - ] + self._common_basic_auth_arguments()) - pytest.raises(AdminNotAuthorized, - self.manager.admin_patron_auth_services_controller.process_patron_auth_services) + flask.request.form = MultiDict( + [ + ("protocol", SimpleAuthenticationProvider.__module__), + ] + + self._common_basic_auth_arguments() + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_patron_auth_services_controller.process_patron_auth_services, + ) def _get_mock(self): manager = self.manager @@ -389,58 +534,103 @@ def test_patron_auth_services_post_create(self): mock_controller = self._get_mock() library, ignore = create( - self._db, Library, name="Library", short_name="L", + self._db, + Library, + name="Library", + short_name="L", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", SimpleAuthenticationProvider.__module__), - ("libraries", json.dumps([{ - "short_name": library.short_name, - AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION: "^(.)", - AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX, - AuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, - AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION: "^1234", - }])), - ] + self._common_basic_auth_arguments()) + flask.request.form = MultiDict( + [ + ("protocol", SimpleAuthenticationProvider.__module__), + ( + "libraries", + json.dumps( + [ + { + "short_name": library.short_name, + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION: "^(.)", + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX, + AuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION: "^1234", + } + ] + ), + ), + ] + + self._common_basic_auth_arguments() + ) response = mock_controller.process_patron_auth_services() assert response.status_code == 201 assert mock_controller.validate_formats_call_count == 1 - auth_service = get_one(self._db, ExternalIntegration, goal=ExternalIntegration.PATRON_AUTH_GOAL) + auth_service = get_one( + self._db, ExternalIntegration, goal=ExternalIntegration.PATRON_AUTH_GOAL + ) assert auth_service.id == int(response.response[0]) assert SimpleAuthenticationProvider.__module__ == auth_service.protocol - assert "user" == auth_service.setting(BasicAuthenticationProvider.TEST_IDENTIFIER).value - assert "pass" == auth_service.setting(BasicAuthenticationProvider.TEST_PASSWORD).value + assert ( + "user" + == auth_service.setting(BasicAuthenticationProvider.TEST_IDENTIFIER).value + ) + assert ( + "pass" + == auth_service.setting(BasicAuthenticationProvider.TEST_PASSWORD).value + ) assert [library] == auth_service.libraries - assert "^(.)" == ConfigurationSetting.for_library_and_externalintegration( - self._db, AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, - library, auth_service).value + assert ( + "^(.)" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, + library, + auth_service, + ).value + ) common_args = self._common_basic_auth_arguments() with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", MilleniumPatronAPI.__module__), - (ExternalIntegration.URL, "url"), - (MilleniumPatronAPI.VERIFY_CERTIFICATE, "true"), - (MilleniumPatronAPI.AUTHENTICATION_MODE, MilleniumPatronAPI.PIN_AUTHENTICATION_MODE), - ] + common_args) + flask.request.form = MultiDict( + [ + ("protocol", MilleniumPatronAPI.__module__), + (ExternalIntegration.URL, "url"), + (MilleniumPatronAPI.VERIFY_CERTIFICATE, "true"), + ( + MilleniumPatronAPI.AUTHENTICATION_MODE, + MilleniumPatronAPI.PIN_AUTHENTICATION_MODE, + ), + ] + + common_args + ) response = mock_controller.process_patron_auth_services() assert response.status_code == 201 assert mock_controller.validate_formats_call_count == 2 - auth_service2 = get_one(self._db, ExternalIntegration, - goal=ExternalIntegration.PATRON_AUTH_GOAL, - protocol=MilleniumPatronAPI.__module__) + auth_service2 = get_one( + self._db, + ExternalIntegration, + goal=ExternalIntegration.PATRON_AUTH_GOAL, + protocol=MilleniumPatronAPI.__module__, + ) assert auth_service2 != auth_service assert auth_service2.id == int(response.response[0]) assert "url" == auth_service2.url - assert "user" == auth_service2.setting(BasicAuthenticationProvider.TEST_IDENTIFIER).value - assert "pass" == auth_service2.setting(BasicAuthenticationProvider.TEST_PASSWORD).value - assert ("true" == - auth_service2.setting(MilleniumPatronAPI.VERIFY_CERTIFICATE).value) - assert (MilleniumPatronAPI.PIN_AUTHENTICATION_MODE == - auth_service2.setting(MilleniumPatronAPI.AUTHENTICATION_MODE).value) + assert ( + "user" + == auth_service2.setting(BasicAuthenticationProvider.TEST_IDENTIFIER).value + ) + assert ( + "pass" + == auth_service2.setting(BasicAuthenticationProvider.TEST_PASSWORD).value + ) + assert ( + "true" == auth_service2.setting(MilleniumPatronAPI.VERIFY_CERTIFICATE).value + ) + assert ( + MilleniumPatronAPI.PIN_AUTHENTICATION_MODE + == auth_service2.setting(MilleniumPatronAPI.AUTHENTICATION_MODE).value + ) assert None == auth_service2.setting(MilleniumPatronAPI.BLOCK_TYPES).value assert [] == auth_service2.libraries @@ -448,66 +638,115 @@ def test_patron_auth_services_post_edit(self): mock_controller = self._get_mock() l1, ignore = create( - self._db, Library, name="Library 1", short_name="L1", + self._db, + Library, + name="Library 1", + short_name="L1", ) l2, ignore = create( - self._db, Library, name="Library 2", short_name="L2", + self._db, + Library, + name="Library 2", + short_name="L2", ) auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=SimpleAuthenticationProvider.__module__, - goal=ExternalIntegration.PATRON_AUTH_GOAL + goal=ExternalIntegration.PATRON_AUTH_GOAL, ) - auth_service.setting(BasicAuthenticationProvider.TEST_IDENTIFIER).value = "old_user" - auth_service.setting(BasicAuthenticationProvider.TEST_PASSWORD).value = "old_password" + auth_service.setting( + BasicAuthenticationProvider.TEST_IDENTIFIER + ).value = "old_user" + auth_service.setting( + BasicAuthenticationProvider.TEST_PASSWORD + ).value = "old_password" auth_service.libraries = [l1] with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", auth_service.id), - ("protocol", SimpleAuthenticationProvider.__module__), - ("libraries", json.dumps([{ - "short_name": l2.short_name, - AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION: "^(.)", - AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE, - AuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, - }])), - ] + self._common_basic_auth_arguments()) - response = self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + flask.request.form = MultiDict( + [ + ("id", auth_service.id), + ("protocol", SimpleAuthenticationProvider.__module__), + ( + "libraries", + json.dumps( + [ + { + "short_name": l2.short_name, + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION: "^(.)", + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE, + AuthenticationProvider.LIBRARY_IDENTIFIER_FIELD: AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE, + } + ] + ), + ), + ] + + self._common_basic_auth_arguments() + ) + response = ( + self.manager.admin_patron_auth_services_controller.process_patron_auth_services() + ) assert response.status_code == 200 assert mock_controller.validate_formats_call_count == 1 assert auth_service.id == int(response.response[0]) assert SimpleAuthenticationProvider.__module__ == auth_service.protocol - assert "user" == auth_service.setting(BasicAuthenticationProvider.TEST_IDENTIFIER).value - assert "pass" == auth_service.setting(BasicAuthenticationProvider.TEST_PASSWORD).value + assert ( + "user" + == auth_service.setting(BasicAuthenticationProvider.TEST_IDENTIFIER).value + ) + assert ( + "pass" + == auth_service.setting(BasicAuthenticationProvider.TEST_PASSWORD).value + ) assert [l2] == auth_service.libraries - assert "^(.)" == ConfigurationSetting.for_library_and_externalintegration( - self._db, AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, - l2, auth_service).value + assert ( + "^(.)" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, + AuthenticationProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, + l2, + auth_service, + ).value + ) def test_patron_auth_service_delete(self): l1, ignore = create( - self._db, Library, name="Library 1", short_name="L1", + self._db, + Library, + name="Library 1", + short_name="L1", ) auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=SimpleAuthenticationProvider.__module__, - goal=ExternalIntegration.PATRON_AUTH_GOAL + goal=ExternalIntegration.PATRON_AUTH_GOAL, ) - auth_service.setting(BasicAuthenticationProvider.TEST_IDENTIFIER).value = "old_user" - auth_service.setting(BasicAuthenticationProvider.TEST_PASSWORD).value = "old_password" + auth_service.setting( + BasicAuthenticationProvider.TEST_IDENTIFIER + ).value = "old_user" + auth_service.setting( + BasicAuthenticationProvider.TEST_PASSWORD + ).value = "old_password" auth_service.libraries = [l1] with self.request_context_with_admin("/", method="DELETE"): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) - pytest.raises(AdminNotAuthorized, - self.manager.admin_patron_auth_services_controller.process_delete, - auth_service.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_patron_auth_services_controller.process_delete, + auth_service.id, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) - response = self.manager.admin_patron_auth_services_controller.process_delete(auth_service.id) + response = ( + self.manager.admin_patron_auth_services_controller.process_delete( + auth_service.id + ) + ) assert response.status_code == 200 service = get_one(self._db, ExternalIntegration, id=auth_service.id) diff --git a/tests/admin/controller/test_patron_auth_self_tests.py b/tests/admin/controller/test_patron_auth_self_tests.py index 9b5fe3128d..d2309fe185 100644 --- a/tests/admin/controller/test_patron_auth_self_tests.py +++ b/tests/admin/controller/test_patron_auth_self_tests.py @@ -1,49 +1,54 @@ - from flask_babel import lazy_gettext as _ + from api.admin.problem_details import * -from core.selftest import ( - HasSelfTests, - SelfTestResult, -) from api.simple_authentication import SimpleAuthenticationProvider +from core.model import ExternalIntegration, create +from core.selftest import HasSelfTests, SelfTestResult + from .test_controller import SettingsControllerTest -from core.model import ( - create, - ExternalIntegration, -) -class TestPatronAuthSelfTests(SettingsControllerTest): +class TestPatronAuthSelfTests(SettingsControllerTest): def _auth_service(self, libraries=[]): auth_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=SimpleAuthenticationProvider.__module__, goal=ExternalIntegration.PATRON_AUTH_GOAL, name="name", - libraries=libraries + libraries=libraries, ) return auth_service def test_patron_auth_self_tests_with_no_identifier(self): with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests(None) + response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests( + None + ) assert response.title == MISSING_IDENTIFIER.title assert response.detail == MISSING_IDENTIFIER.detail assert response.status_code == 400 def test_patron_auth_self_tests_with_no_auth_service_found(self): with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests(-1) + response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests( + -1 + ) assert response == MISSING_SERVICE assert response.status_code == 404 def test_patron_auth_self_tests_get_with_no_libraries(self): auth_service = self._auth_service() with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests(auth_service.id) + response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests( + auth_service.id + ) results = response.get("self_test_results").get("self_test_results") assert results.get("disabled") == True - assert results.get("exception") == "You must associate this service with at least one library before you can run self tests for it." + assert ( + results.get("exception") + == "You must associate this service with at least one library before you can run self tests for it." + ) def test_patron_auth_self_tests_test_get(self): old_prior_test_results = HasSelfTests.prior_test_results @@ -53,23 +58,32 @@ def test_patron_auth_self_tests_test_get(self): # Make sure that HasSelfTest.prior_test_results() was called and that # it is in the response's self tests object. with self.request_context_with_admin("/"): - response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests(auth_service.id) + response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests( + auth_service.id + ) response_auth_service = response.get("self_test_results") assert response_auth_service.get("name") == auth_service.name assert response_auth_service.get("protocol") == auth_service.protocol assert response_auth_service.get("id") == auth_service.id assert response_auth_service.get("goal") == auth_service.goal - assert response_auth_service.get("self_test_results") == self.self_test_results + assert ( + response_auth_service.get("self_test_results") == self.self_test_results + ) HasSelfTests.prior_test_results = old_prior_test_results def test_patron_auth_self_tests_post_with_no_libraries(self): auth_service = self._auth_service() with self.request_context_with_admin("/", method="POST"): - response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests(auth_service.id) + response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests( + auth_service.id + ) assert response.title == FAILED_TO_RUN_SELF_TESTS.title - assert response.detail == "Failed to run self tests for this patron authentication service." + assert ( + response.detail + == "Failed to run self tests for this patron authentication service." + ) assert response.status_code == 400 def test_patron_auth_self_tests_test_post(self): @@ -78,13 +92,18 @@ def test_patron_auth_self_tests_test_post(self): auth_service = self._auth_service([self._library()]) with self.request_context_with_admin("/", method="POST"): - response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests(auth_service.id) + response = self.manager.admin_patron_auth_service_self_tests_controller.process_patron_auth_service_self_tests( + auth_service.id + ) assert response._status == "200 OK" assert "Successfully ran new self tests" == response.get_data(as_text=True) # run_self_tests was called with the database twice (the # second time to be used in the ExternalSearchIntegration # constructor). There were no keyword arguments. - assert ((self._db, None, auth_service.libraries[0], auth_service), {}) == self.run_self_tests_called_with + assert ( + (self._db, None, auth_service.libraries[0], auth_service), + {}, + ) == self.run_self_tests_called_with HasSelfTests.run_self_tests = old_run_self_tests diff --git a/tests/admin/controller/test_search_service_self_tests.py b/tests/admin/controller/test_search_service_self_tests.py index 30d91fe217..b1f35718e8 100644 --- a/tests/admin/controller/test_search_service_self_tests.py +++ b/tests/admin/controller/test_search_service_self_tests.py @@ -1,31 +1,34 @@ - from flask_babel import lazy_gettext as _ + from api.admin.problem_details import * -from api.axis import (Axis360API, MockAxis360API) -from core.opds_import import (OPDSImporter, OPDSImportMonitor) -from core.selftest import ( - HasSelfTests, - SelfTestResult, +from api.axis import Axis360API, MockAxis360API +from core.external_search import ( + ExternalSearchIndex, + MockExternalSearchIndex, + MockSearchResult, ) +from core.model import ExternalIntegration, create +from core.opds_import import OPDSImporter, OPDSImportMonitor +from core.selftest import HasSelfTests, SelfTestResult + from .test_controller import SettingsControllerTest -from core.model import ( - create, - ExternalIntegration, -) -from core.external_search import ExternalSearchIndex, MockExternalSearchIndex, MockSearchResult + class TestSearchServiceSelfTests(SettingsControllerTest): def test_search_service_self_tests_with_no_identifier(self): with self.request_context_with_admin("/"): - response = self.manager.admin_search_service_self_tests_controller.process_search_service_self_tests(None) + response = self.manager.admin_search_service_self_tests_controller.process_search_service_self_tests( + None + ) assert response.title == MISSING_IDENTIFIER.title assert response.detail == MISSING_IDENTIFIER.detail assert response.status_code == 400 - def test_search_service_self_tests_with_no_search_service_found(self): with self.request_context_with_admin("/"): - response = self.manager.admin_search_service_self_tests_controller.process_search_service_self_tests(-1) + response = self.manager.admin_search_service_self_tests_controller.process_search_service_self_tests( + -1 + ) assert response == MISSING_SERVICE assert response.status_code == 404 @@ -33,23 +36,30 @@ def test_search_service_self_tests_test_get(self): old_prior_test_results = HasSelfTests.prior_test_results HasSelfTests.prior_test_results = self.mock_prior_test_results search_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.ELASTICSEARCH, goal=ExternalIntegration.SEARCH_GOAL, ) # Make sure that HasSelfTest.prior_test_results() was called and that # it is in the response's self tests object. with self.request_context_with_admin("/"): - response = self.manager.admin_search_service_self_tests_controller.process_search_service_self_tests(search_service.id) + response = self.manager.admin_search_service_self_tests_controller.process_search_service_self_tests( + search_service.id + ) response_search_service = response.get("self_test_results") assert response_search_service.get("id") == search_service.id assert response_search_service.get("name") == search_service.name - assert response_search_service.get("protocol").get("label") == search_service.protocol + assert ( + response_search_service.get("protocol").get("label") + == search_service.protocol + ) assert response_search_service.get("goal") == search_service.goal assert ( - response_search_service.get("self_test_results") == - HasSelfTests.prior_test_results()) + response_search_service.get("self_test_results") + == HasSelfTests.prior_test_results() + ) HasSelfTests.prior_test_results = old_prior_test_results @@ -58,11 +68,14 @@ def test_search_service_self_tests_post(self): HasSelfTests.run_self_tests = self.mock_run_self_tests search_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.ELASTICSEARCH, - goal=ExternalIntegration.SEARCH_GOAL + goal=ExternalIntegration.SEARCH_GOAL, + ) + m = ( + self.manager.admin_search_service_self_tests_controller.self_tests_process_post ) - m = self.manager.admin_search_service_self_tests_controller.self_tests_process_post with self.request_context_with_admin("/", method="POST"): response = m(search_service.id) assert response._status == "200 OK" diff --git a/tests/admin/controller/test_search_services.py b/tests/admin/controller/test_search_services.py index a4f1c96d3d..84174ae83f 100644 --- a/tests/admin/controller/test_search_services.py +++ b/tests/admin/controller/test_search_services.py @@ -1,40 +1,46 @@ -import pytest - import flask +import pytest from werkzeug.datastructures import MultiDict + from api.admin.exceptions import * from core.external_search import ExternalSearchIndex -from core.model import ( - AdminRole, - create, - get_one, - ExternalIntegration, -) +from core.model import AdminRole, ExternalIntegration, create, get_one + from .test_controller import SettingsControllerTest + class TestSearchServices(SettingsControllerTest): def test_search_services_get_with_no_services(self): with self.request_context_with_admin("/"): response = self.manager.admin_search_services_controller.process_services() assert response.get("search_services") == [] protocols = response.get("protocols") - assert ExternalIntegration.ELASTICSEARCH in [p.get("name") for p in protocols] + assert ExternalIntegration.ELASTICSEARCH in [ + p.get("name") for p in protocols + ] assert "settings" in protocols[0] self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self._db.flush() - pytest.raises(AdminNotAuthorized, - self.manager.admin_search_services_controller.process_services) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_search_services_controller.process_services, + ) def test_search_services_get_with_one_service(self): search_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.ELASTICSEARCH, goal=ExternalIntegration.SEARCH_GOAL, ) search_service.url = "search url" - search_service.setting(ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY).value = "works-index-prefix" - search_service.setting(ExternalSearchIndex.TEST_SEARCH_TERM_KEY).value = "search-term-for-self-tests" + search_service.setting( + ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY + ).value = "works-index-prefix" + search_service.setting( + ExternalSearchIndex.TEST_SEARCH_TERM_KEY + ).value = "search-term-for-self-tests" with self.request_context_with_admin("/"): response = self.manager.admin_search_services_controller.process_services() @@ -43,17 +49,23 @@ def test_search_services_get_with_one_service(self): assert search_service.id == service.get("id") assert search_service.protocol == service.get("protocol") assert "search url" == service.get("settings").get(ExternalIntegration.URL) - assert "works-index-prefix" == service.get("settings").get(ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY) - assert "search-term-for-self-tests" == service.get("settings").get(ExternalSearchIndex.TEST_SEARCH_TERM_KEY) + assert "works-index-prefix" == service.get("settings").get( + ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY + ) + assert "search-term-for-self-tests" == service.get("settings").get( + ExternalSearchIndex.TEST_SEARCH_TERM_KEY + ) def test_search_services_post_errors(self): controller = self.manager.admin_search_services_controller with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("protocol", "Unknown"), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("protocol", "Unknown"), + ] + ) response = controller.process_services() assert response == UNKNOWN_PROTOCOL @@ -63,132 +75,180 @@ def test_search_services_post_errors(self): assert response == NO_PROTOCOL_FOR_NEW_SERVICE with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", "123"), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", "123"), + ] + ) response = controller.process_services() assert response == MISSING_SERVICE service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.ELASTICSEARCH, goal=ExternalIntegration.SEARCH_GOAL, ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("protocol", ExternalIntegration.ELASTICSEARCH), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("protocol", ExternalIntegration.ELASTICSEARCH), + ] + ) response = controller.process_services() assert response.uri == MULTIPLE_SITEWIDE_SERVICES.uri self._db.delete(service) service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.CDN, goal=ExternalIntegration.CDN_GOAL, name="name", ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", service.name), - ("protocol", ExternalIntegration.ELASTICSEARCH), - ]) + flask.request.form = MultiDict( + [ + ("name", service.name), + ("protocol", ExternalIntegration.ELASTICSEARCH), + ] + ) response = controller.process_services() assert response == INTEGRATION_NAME_ALREADY_IN_USE service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.ELASTICSEARCH, goal=ExternalIntegration.SEARCH_GOAL, ) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", service.id), - ("protocol", ExternalIntegration.ELASTICSEARCH), - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", service.id), + ("protocol", ExternalIntegration.ELASTICSEARCH), + ] + ) response = controller.process_services() assert response.uri == INCOMPLETE_CONFIGURATION.uri self.admin.remove_role(AdminRole.SYSTEM_ADMIN) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("protocol", ExternalIntegration.ELASTICSEARCH), - (ExternalIntegration.URL, "search url"), - (ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY, "works-index-prefix"), - ]) - pytest.raises(AdminNotAuthorized, - controller.process_services) + flask.request.form = MultiDict( + [ + ("protocol", ExternalIntegration.ELASTICSEARCH), + (ExternalIntegration.URL, "search url"), + (ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY, "works-index-prefix"), + ] + ) + pytest.raises(AdminNotAuthorized, controller.process_services) def test_search_services_post_create(self): with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("protocol", ExternalIntegration.ELASTICSEARCH), - (ExternalIntegration.URL, "http://search_url"), - (ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY, "works-index-prefix"), - (ExternalSearchIndex.TEST_SEARCH_TERM_KEY, "sample-search-term") - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("protocol", ExternalIntegration.ELASTICSEARCH), + (ExternalIntegration.URL, "http://search_url"), + (ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY, "works-index-prefix"), + (ExternalSearchIndex.TEST_SEARCH_TERM_KEY, "sample-search-term"), + ] + ) response = self.manager.admin_search_services_controller.process_services() assert response.status_code == 201 - service = get_one(self._db, ExternalIntegration, goal=ExternalIntegration.SEARCH_GOAL) + service = get_one( + self._db, ExternalIntegration, goal=ExternalIntegration.SEARCH_GOAL + ) assert service.id == int(response.response[0]) assert ExternalIntegration.ELASTICSEARCH == service.protocol assert "http://search_url" == service.url - assert "works-index-prefix" == service.setting(ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY).value - assert "sample-search-term" == service.setting(ExternalSearchIndex.TEST_SEARCH_TERM_KEY).value + assert ( + "works-index-prefix" + == service.setting(ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY).value + ) + assert ( + "sample-search-term" + == service.setting(ExternalSearchIndex.TEST_SEARCH_TERM_KEY).value + ) def test_search_services_post_edit(self): search_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.ELASTICSEARCH, goal=ExternalIntegration.SEARCH_GOAL, ) search_service.url = "search url" - search_service.setting(ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY).value = "works-index-prefix" - search_service.setting(ExternalSearchIndex.TEST_SEARCH_TERM_KEY).value = "sample-search-term" + search_service.setting( + ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY + ).value = "works-index-prefix" + search_service.setting( + ExternalSearchIndex.TEST_SEARCH_TERM_KEY + ).value = "sample-search-term" with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("name", "Name"), - ("id", search_service.id), - ("protocol", ExternalIntegration.ELASTICSEARCH), - (ExternalIntegration.URL, "http://new_search_url"), - (ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY, "new-works-index-prefix"), - (ExternalSearchIndex.TEST_SEARCH_TERM_KEY, "new-sample-search-term") - ]) + flask.request.form = MultiDict( + [ + ("name", "Name"), + ("id", search_service.id), + ("protocol", ExternalIntegration.ELASTICSEARCH), + (ExternalIntegration.URL, "http://new_search_url"), + ( + ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY, + "new-works-index-prefix", + ), + ( + ExternalSearchIndex.TEST_SEARCH_TERM_KEY, + "new-sample-search-term", + ), + ] + ) response = self.manager.admin_search_services_controller.process_services() assert response.status_code == 200 assert search_service.id == int(response.response[0]) assert ExternalIntegration.ELASTICSEARCH == search_service.protocol assert "http://new_search_url" == search_service.url - assert "new-works-index-prefix" == search_service.setting(ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY).value - assert "new-sample-search-term" == search_service.setting(ExternalSearchIndex.TEST_SEARCH_TERM_KEY).value + assert ( + "new-works-index-prefix" + == search_service.setting(ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY).value + ) + assert ( + "new-sample-search-term" + == search_service.setting(ExternalSearchIndex.TEST_SEARCH_TERM_KEY).value + ) def test_search_service_delete(self): search_service, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.ELASTICSEARCH, goal=ExternalIntegration.SEARCH_GOAL, ) search_service.url = "search url" - search_service.setting(ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY).value = "works-index-prefix" + search_service.setting( + ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY + ).value = "works-index-prefix" with self.request_context_with_admin("/", method="DELETE"): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) - pytest.raises(AdminNotAuthorized, - self.manager.admin_search_services_controller.process_delete, - search_service.id) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_search_services_controller.process_delete, + search_service.id, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) - response = self.manager.admin_search_services_controller.process_delete(search_service.id) + response = self.manager.admin_search_services_controller.process_delete( + search_service.id + ) assert response.status_code == 200 service = get_one(self._db, ExternalIntegration, id=search_service.id) diff --git a/tests/admin/controller/test_sitewide_registration.py b/tests/admin/controller/test_sitewide_registration.py index 0be7d93507..79488c9d33 100644 --- a/tests/admin/controller/test_sitewide_registration.py +++ b/tests/admin/controller/test_sitewide_registration.py @@ -1,21 +1,22 @@ -import binascii import base64 -import flask +import binascii import json -import jwt import os + +import flask +import jwt from werkzeug.datastructures import MultiDict + from api.admin.problem_details import * from api.config import Configuration -from core.model import ( - ExternalIntegration, -) +from core.model import ExternalIntegration from core.testing import MockRequestsResponse from core.util.problem_detail import ProblemDetail + from .test_controller import SettingsControllerTest -class TestSitewideRegistration(SettingsControllerTest): +class TestSitewideRegistration(SettingsControllerTest): def test_sitewide_registration_post_errors(self): def assert_remote_integration_error(response, message=None): assert REMOTE_INTEGRATION_FAILED.uri == response.uri @@ -25,7 +26,8 @@ def assert_remote_integration_error(response, message=None): metadata_wrangler_service = self._external_integration( ExternalIntegration.METADATA_WRANGLER, - goal=ExternalIntegration.METADATA_GOAL, url=self._url + goal=ExternalIntegration.METADATA_GOAL, + url=self._url, ) default_form = None controller = self.manager.admin_metadata_services_controller @@ -39,7 +41,7 @@ def assert_remote_integration_error(response, message=None): # If an error is raised during registration, a ProblemDetail is returned. def error_get(*args, **kwargs): - raise RuntimeError('Mock error during request') + raise RuntimeError("Mock error during request") with self.request_context_with_admin("/"): response = controller.process_sitewide_registration( @@ -49,7 +51,7 @@ def error_get(*args, **kwargs): # # If the response has the wrong media type, a ProblemDetail is returned. self.responses.append( - MockRequestsResponse(200, headers={'Content-Type' : 'text/plain'}) + MockRequestsResponse(200, headers={"Content-Type": "text/plain"}) ) with self.request_context_with_admin("/"): @@ -57,15 +59,13 @@ def error_get(*args, **kwargs): metadata_wrangler_service, do_get=self.do_request ) assert_remote_integration_error( - response, 'The service did not provide a valid catalog.' + response, "The service did not provide a valid catalog." ) # If the response returns a ProblemDetail, its contents are wrapped # in another ProblemDetail. status_code, content, headers = MULTIPLE_BASIC_AUTH_SERVICES.response - self.responses.append( - MockRequestsResponse(content, headers, status_code) - ) + self.responses.append(MockRequestsResponse(content, headers, status_code)) with self.request_context_with_admin("/"): response = controller.process_sitewide_registration( metadata_wrangler_service, do_get=self.do_request @@ -78,7 +78,7 @@ def error_get(*args, **kwargs): # If no registration link is available, a ProblemDetail is returned catalog = dict(id=self._url, links=[]) - headers = { 'Content-Type' : 'application/opds+json' } + headers = {"Content-Type": "application/opds+json"} self.responses.append( MockRequestsResponse(200, content=json.dumps(catalog), headers=headers) ) @@ -88,51 +88,64 @@ def error_get(*args, **kwargs): metadata_wrangler_service, do_get=self.do_request ) assert_remote_integration_error( - response, 'The service did not provide a register link.' + response, "The service did not provide a register link." ) # If no registration details are given, a ProblemDetail is returned link_type = self.manager.admin_settings_controller.METADATA_SERVICE_URI_TYPE - catalog['links'] = [dict(rel='register', href=self._url, type=link_type)] + catalog["links"] = [dict(rel="register", href=self._url, type=link_type)] registration = dict(id=self._url, metadata={}) - self.responses.extend([ - MockRequestsResponse(200, content=json.dumps(registration), headers=headers), - MockRequestsResponse(200, content=json.dumps(catalog), headers=headers) - ]) + self.responses.extend( + [ + MockRequestsResponse( + 200, content=json.dumps(registration), headers=headers + ), + MockRequestsResponse(200, content=json.dumps(catalog), headers=headers), + ] + ) - with self.request_context_with_admin('/', method='POST'): + with self.request_context_with_admin("/", method="POST"): response = controller.process_sitewide_registration( - metadata_wrangler_service, do_get=self.do_request, do_post=self.do_request + metadata_wrangler_service, + do_get=self.do_request, + do_post=self.do_request, ) assert_remote_integration_error( - response, 'The service did not provide registration information.' + response, "The service did not provide registration information." ) # If we get all the way to the registration POST, but that # request results in a ProblemDetail, that ProblemDetail is # passed along. - self.responses.extend([ - MockRequestsResponse(200, content=json.dumps(registration), headers=headers), - MockRequestsResponse(200, content=json.dumps(catalog), headers=headers) - ]) + self.responses.extend( + [ + MockRequestsResponse( + 200, content=json.dumps(registration), headers=headers + ), + MockRequestsResponse(200, content=json.dumps(catalog), headers=headers), + ] + ) def bad_do_post(self, *args, **kwargs): return MULTIPLE_BASIC_AUTH_SERVICES - with self.request_context_with_admin('/', method='POST'): - flask.request.form = MultiDict([ - ('integration_id', metadata_wrangler_service.id), - ]) + + with self.request_context_with_admin("/", method="POST"): + flask.request.form = MultiDict( + [ + ("integration_id", metadata_wrangler_service.id), + ] + ) response = controller.process_sitewide_registration( metadata_wrangler_service, do_get=self.do_request, do_post=bad_do_post ) assert MULTIPLE_BASIC_AUTH_SERVICES == response - def test_sitewide_registration_post_success(self): # A service to register with metadata_wrangler_service = self._external_integration( ExternalIntegration.METADATA_WRANGLER, - goal=ExternalIntegration.METADATA_GOAL, url=self._url + goal=ExternalIntegration.METADATA_GOAL, + url=self._url, ) # The service knows this site's public key, and is going @@ -141,17 +154,19 @@ def test_sitewide_registration_post_success(self): encryptor = Configuration.cipher(public_key) # A catalog with registration url - register_link_type = self.manager.admin_settings_controller.METADATA_SERVICE_URI_TYPE + register_link_type = ( + self.manager.admin_settings_controller.METADATA_SERVICE_URI_TYPE + ) registration_url = self._url catalog = dict( - id = metadata_wrangler_service.url, - links = [ - dict(rel='collection-add', href=self._url, type='collection'), - dict(rel='register', href=registration_url, type=register_link_type), - dict(rel='collection-remove', href=self._url, type='collection'), - ] + id=metadata_wrangler_service.url, + links=[ + dict(rel="collection-add", href=self._url, type="collection"), + dict(rel="register", href=registration_url, type=register_link_type), + dict(rel="collection-remove", href=self._url, type="collection"), + ], ) - headers = { 'Content-Type' : 'application/opds+json' } + headers = {"Content-Type": "application/opds+json"} self.responses.append( MockRequestsResponse(200, content=json.dumps(catalog), headers=headers) ) @@ -160,18 +175,23 @@ def test_sitewide_registration_post_success(self): shared_secret = binascii.hexlify(os.urandom(24)) encrypted_secret = base64.b64encode(encryptor.encrypt(shared_secret)) registration = dict( - id = metadata_wrangler_service.url, - metadata = dict(shared_secret=encrypted_secret.decode("utf-8")) + id=metadata_wrangler_service.url, + metadata=dict(shared_secret=encrypted_secret.decode("utf-8")), + ) + self.responses.insert( + 0, MockRequestsResponse(200, content=json.dumps(registration)) ) - self.responses.insert(0, MockRequestsResponse(200, content=json.dumps(registration))) - with self.request_context_with_admin('/', method='POST'): - flask.request.form = MultiDict([ - ('integration_id', metadata_wrangler_service.id), - ]) + with self.request_context_with_admin("/", method="POST"): + flask.request.form = MultiDict( + [ + ("integration_id", metadata_wrangler_service.id), + ] + ) response = self.manager.admin_metadata_services_controller.process_sitewide_registration( - metadata_wrangler_service, do_get=self.do_request, - do_post=self.do_request + metadata_wrangler_service, + do_get=self.do_request, + do_post=self.do_request, ) assert None == response @@ -185,7 +205,7 @@ def test_sitewide_registration_post_success(self): url, [document], ignore = registration_request assert url == registration_url - for k in 'url', 'jwt': + for k in "url", "jwt": assert k in document # The end result is that our ExternalIntegration for the metadata @@ -195,19 +215,17 @@ def test_sitewide_registration_post_success(self): def test_sitewide_registration_document(self): """Test the document sent along to sitewide registration.""" controller = self.manager.admin_metadata_services_controller - with self.request_context_with_admin('/'): + with self.request_context_with_admin("/"): doc = controller.sitewide_registration_document() # The registrar knows where to go to get our public key. - assert doc['url'] == controller.url_for('public_key_document') + assert doc["url"] == controller.url_for("public_key_document") # The JWT proves that we control the public/private key pair. public_key, private_key = self.manager.sitewide_key_pair - parsed = jwt.decode( - doc['jwt'], public_key, algorithm='RS256' - ) + parsed = jwt.decode(doc["jwt"], public_key, algorithm="RS256") # The JWT must be valid or jwt.decode() would have raised # an exception. This simply verifies that the JWT includes # an expiration date and doesn't last forever. - assert 'exp' in parsed + assert "exp" in parsed diff --git a/tests/admin/controller/test_sitewide_services.py b/tests/admin/controller/test_sitewide_services.py index 75af63e67e..bb55b56a85 100644 --- a/tests/admin/controller/test_sitewide_services.py +++ b/tests/admin/controller/test_sitewide_services.py @@ -1,14 +1,14 @@ - import flask + from api.admin.controller import SettingsController from api.admin.controller.sitewide_services import * from api.admin.controller.storage_services import StorageServicesController -from core.model import ( - ExternalIntegration, -) -from core.s3 import S3Uploader, MockS3Uploader +from core.model import ExternalIntegration +from core.s3 import MockS3Uploader, S3Uploader + from .test_controller import SettingsControllerTest + class TestSitewideServices(SettingsControllerTest): def test_sitewide_service_management(self): # The configuration of search and logging collections is delegated to @@ -17,12 +17,14 @@ def test_sitewide_service_management(self): # Search collections are more comprehensively tested in test_search_services. EI = ExternalIntegration + class MockSearch(SearchServicesController): - def _manage_sitewide_service(self,*args): + def _manage_sitewide_service(self, *args): self.manage_called_with = args def _delete_integration(self, *args): self.delete_called_with = args + controller = MockSearch(self.manager) with self.request_context_with_admin("/"): @@ -30,11 +32,10 @@ def _delete_integration(self, *args): goal, apis, key_name, problem = controller.manage_called_with assert EI.SEARCH_GOAL == goal assert ExternalSearchIndex in apis - assert 'search_services' == key_name - assert 'new search service' in problem + assert "search_services" == key_name + assert "new search service" in problem with self.request_context_with_admin("/"): id = object() controller.process_delete(id) - assert ((id, EI.SEARCH_GOAL) == - controller.delete_called_with) + assert (id, EI.SEARCH_GOAL) == controller.delete_called_with diff --git a/tests/admin/controller/test_sitewide_settings.py b/tests/admin/controller/test_sitewide_settings.py index 51c45885f6..9f6ceefe23 100644 --- a/tests/admin/controller/test_sitewide_settings.py +++ b/tests/admin/controller/test_sitewide_settings.py @@ -1,21 +1,21 @@ +import flask import pytest +from werkzeug.datastructures import ImmutableMultiDict, MultiDict from api.admin.exceptions import * from api.config import Configuration +from core.model import AdminRole, ConfigurationSetting from core.opds import AcquisitionFeed -from core.model import ( - AdminRole, - ConfigurationSetting -) + from .test_controller import SettingsControllerTest -from werkzeug.datastructures import ImmutableMultiDict, MultiDict -import flask -class TestSitewideSettings(SettingsControllerTest): +class TestSitewideSettings(SettingsControllerTest): def test_sitewide_settings_get(self): with self.request_context_with_admin("/"): - response = self.manager.admin_sitewide_configuration_settings_controller.process_get() + response = ( + self.manager.admin_sitewide_configuration_settings_controller.process_get() + ) settings = response.get("settings") all_settings = response.get("all_settings") @@ -25,17 +25,23 @@ def test_sitewide_settings_get(self): assert Configuration.DATABASE_LOG_LEVEL in keys assert Configuration.SECRET_KEY in keys - ConfigurationSetting.sitewide(self._db, Configuration.DATABASE_LOG_LEVEL).value = 'INFO' - ConfigurationSetting.sitewide(self._db, Configuration.SECRET_KEY).value = "secret" + ConfigurationSetting.sitewide( + self._db, Configuration.DATABASE_LOG_LEVEL + ).value = "INFO" + ConfigurationSetting.sitewide( + self._db, Configuration.SECRET_KEY + ).value = "secret" self._db.flush() with self.request_context_with_admin("/"): - response = self.manager.admin_sitewide_configuration_settings_controller.process_get() + response = ( + self.manager.admin_sitewide_configuration_settings_controller.process_get() + ) settings = response.get("settings") all_settings = response.get("all_settings") assert 2 == len(settings) - settings_by_key = { s.get("key") : s.get("value") for s in settings } + settings_by_key = {s.get("key"): s.get("value") for s in settings} assert "INFO" == settings_by_key.get(Configuration.DATABASE_LOG_LEVEL) assert "secret" == settings_by_key.get(Configuration.SECRET_KEY) keys = [s.get("key") for s in all_settings] @@ -45,56 +51,77 @@ def test_sitewide_settings_get(self): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) self._db.flush() - pytest.raises(AdminNotAuthorized, - self.manager.admin_sitewide_configuration_settings_controller.process_get) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_sitewide_configuration_settings_controller.process_get, + ) def test_sitewide_settings_post_errors(self): with self.request_context_with_admin("/", method="POST"): flask.request.form = MultiDict([("key", None)]) - response = self.manager.admin_sitewide_configuration_settings_controller.process_post() + response = ( + self.manager.admin_sitewide_configuration_settings_controller.process_post() + ) assert response == MISSING_SITEWIDE_SETTING_KEY with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("key", Configuration.SECRET_KEY), - ("value", None) - ]) - response = self.manager.admin_sitewide_configuration_settings_controller.process_post() + flask.request.form = MultiDict( + [("key", Configuration.SECRET_KEY), ("value", None)] + ) + response = ( + self.manager.admin_sitewide_configuration_settings_controller.process_post() + ) assert response == MISSING_SITEWIDE_SETTING_VALUE self.admin.remove_role(AdminRole.SYSTEM_ADMIN) with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("key", Configuration.SECRET_KEY), - ("value", "secret"), - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_sitewide_configuration_settings_controller.process_post) + flask.request.form = MultiDict( + [ + ("key", Configuration.SECRET_KEY), + ("value", "secret"), + ] + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_sitewide_configuration_settings_controller.process_post, + ) def test_sitewide_settings_post_create(self): with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("key", Configuration.DATABASE_LOG_LEVEL), - ("value", "10"), - ]) - response = self.manager.admin_sitewide_configuration_settings_controller.process_post() + flask.request.form = MultiDict( + [ + ("key", Configuration.DATABASE_LOG_LEVEL), + ("value", "10"), + ] + ) + response = ( + self.manager.admin_sitewide_configuration_settings_controller.process_post() + ) assert response.status_code == 200 # The setting was created. - setting = ConfigurationSetting.sitewide(self._db, Configuration.DATABASE_LOG_LEVEL) + setting = ConfigurationSetting.sitewide( + self._db, Configuration.DATABASE_LOG_LEVEL + ) assert setting.key == response.get_data(as_text=True) assert "10" == setting.value def test_sitewide_settings_post_edit(self): - setting = ConfigurationSetting.sitewide(self._db, Configuration.DATABASE_LOG_LEVEL) + setting = ConfigurationSetting.sitewide( + self._db, Configuration.DATABASE_LOG_LEVEL + ) setting.value = "WARN" with self.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("key", Configuration.DATABASE_LOG_LEVEL), - ("value", "ERROR"), - ]) - response = self.manager.admin_sitewide_configuration_settings_controller.process_post() + flask.request.form = MultiDict( + [ + ("key", Configuration.DATABASE_LOG_LEVEL), + ("value", "ERROR"), + ] + ) + response = ( + self.manager.admin_sitewide_configuration_settings_controller.process_post() + ) assert response.status_code == 200 # The setting was changed. @@ -102,17 +129,23 @@ def test_sitewide_settings_post_edit(self): assert "ERROR" == setting.value def test_sitewide_setting_delete(self): - setting = ConfigurationSetting.sitewide(self._db, Configuration.DATABASE_LOG_LEVEL) + setting = ConfigurationSetting.sitewide( + self._db, Configuration.DATABASE_LOG_LEVEL + ) setting.value = "WARN" with self.request_context_with_admin("/", method="DELETE"): self.admin.remove_role(AdminRole.SYSTEM_ADMIN) - pytest.raises(AdminNotAuthorized, - self.manager.admin_sitewide_configuration_settings_controller.process_delete, - setting.key) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_sitewide_configuration_settings_controller.process_delete, + setting.key, + ) self.admin.add_role(AdminRole.SYSTEM_ADMIN) - response = self.manager.admin_sitewide_configuration_settings_controller.process_delete(setting.key) + response = self.manager.admin_sitewide_configuration_settings_controller.process_delete( + setting.key + ) assert response.status_code == 200 assert None == setting.value diff --git a/tests/admin/controller/test_storage_services.py b/tests/admin/controller/test_storage_services.py index 558d318b89..8a958daf15 100644 --- a/tests/admin/controller/test_storage_services.py +++ b/tests/admin/controller/test_storage_services.py @@ -1,22 +1,22 @@ - import flask + from api.admin.controller import SettingsController from api.admin.controller.storage_services import StorageServicesController -from core.model import ( - ExternalIntegration, -) -from core.s3 import S3Uploader, MockS3Uploader +from core.model import ExternalIntegration +from core.s3 import MockS3Uploader, S3Uploader + from .test_controller import SettingsControllerTest + class TestStorageServices(SettingsControllerTest): def test_storage_service_management(self): - class MockStorage(StorageServicesController): def _get_integration_protocols(self, apis, protocol_name_attr): self.manage_called_with = (apis, protocol_name_attr) def _delete_integration(self, *args): self.delete_called_with = args + controller = MockStorage(self.manager) EI = ExternalIntegration with self.request_context_with_admin("/"): @@ -24,7 +24,7 @@ def _delete_integration(self, *args): (apis, procotol_name) = controller.manage_called_with assert S3Uploader in apis - assert procotol_name == 'NAME' + assert procotol_name == "NAME" with self.request_context_with_admin("/"): id = object() diff --git a/tests/admin/controller/test_work_editor.py b/tests/admin/controller/test_work_editor.py index c3067a8e3d..455a2376b0 100644 --- a/tests/admin/controller/test_work_editor.py +++ b/tests/admin/controller/test_work_editor.py @@ -1,30 +1,26 @@ -import pytest - -from api.admin.exceptions import * -from api.admin.problem_details import * -import feedparser -from werkzeug.datastructures import ImmutableMultiDict, MultiDict import base64 -import flask import json import math import operator import os -from PIL import Image +from functools import reduce from io import BytesIO -from tests.admin.controller.test_controller import AdminControllerTest -from tests.test_controller import CirculationControllerTest -from core.classifier import ( - genres, - SimplifiedGenreClassifier -) + +import feedparser +import flask +import pytest +from PIL import Image +from werkzeug.datastructures import ImmutableMultiDict, MultiDict + +from api.admin.exceptions import * +from api.admin.problem_details import * +from core.classifier import SimplifiedGenreClassifier, genres from core.model import ( AdminRole, Classification, - Contributor, Complaint, + Contributor, CoverageRecord, - create, CustomList, DataSource, Edition, @@ -36,16 +32,19 @@ RightsStatus, SessionManager, Subject, + create, ) from core.model.configuration import ExternalIntegrationLink from core.s3 import MockS3Uploader from core.testing import ( AlwaysSuccessfulCoverageProvider, - NeverSuccessfulCoverageProvider, MockRequestsResponse, + NeverSuccessfulCoverageProvider, ) from core.util.datetime_helpers import datetime_utc -from functools import reduce +from tests.admin.controller.test_controller import AdminControllerTest +from tests.test_controller import CirculationControllerTest + class TestWorkController(AdminControllerTest): @@ -67,11 +66,17 @@ def test_details(self): ) assert 200 == response.status_code feed = feedparser.parse(response.get_data()) - [entry] = feed['entries'] - suppress_links = [x['href'] for x in entry['links'] - if x['rel'] == "http://librarysimplified.org/terms/rel/hide"] - unsuppress_links = [x['href'] for x in entry['links'] - if x['rel'] == "http://librarysimplified.org/terms/rel/restore"] + [entry] = feed["entries"] + suppress_links = [ + x["href"] + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/hide" + ] + unsuppress_links = [ + x["href"] + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/restore" + ] assert 0 == len(unsuppress_links) assert 1 == len(suppress_links) assert lp.identifier.identifier in suppress_links[0] @@ -83,48 +88,65 @@ def test_details(self): ) assert 200 == response.status_code feed = feedparser.parse(response.get_data()) - [entry] = feed['entries'] - suppress_links = [x['href'] for x in entry['links'] - if x['rel'] == "http://librarysimplified.org/terms/rel/hide"] - unsuppress_links = [x['href'] for x in entry['links'] - if x['rel'] == "http://librarysimplified.org/terms/rel/restore"] + [entry] = feed["entries"] + suppress_links = [ + x["href"] + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/hide" + ] + unsuppress_links = [ + x["href"] + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/restore" + ] assert 0 == len(suppress_links) assert 1 == len(unsuppress_links) assert lp.identifier.identifier in unsuppress_links[0] self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.details, - lp.identifier.type, lp.identifier.identifier) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.details, + lp.identifier.type, + lp.identifier.identifier, + ) def test_roles(self): roles = self.manager.admin_work_controller.roles() assert Contributor.ILLUSTRATOR_ROLE in list(roles.values()) assert Contributor.NARRATOR_ROLE in list(roles.values()) - assert (Contributor.ILLUSTRATOR_ROLE == - roles[Contributor.MARC_ROLE_CODES[Contributor.ILLUSTRATOR_ROLE]]) - assert (Contributor.NARRATOR_ROLE == - roles[Contributor.MARC_ROLE_CODES[Contributor.NARRATOR_ROLE]]) + assert ( + Contributor.ILLUSTRATOR_ROLE + == roles[Contributor.MARC_ROLE_CODES[Contributor.ILLUSTRATOR_ROLE]] + ) + assert ( + Contributor.NARRATOR_ROLE + == roles[Contributor.MARC_ROLE_CODES[Contributor.NARRATOR_ROLE]] + ) def test_languages(self): languages = self.manager.admin_work_controller.languages() - assert 'en' in list(languages.keys()) - assert 'fre' in list(languages.keys()) + assert "en" in list(languages.keys()) + assert "fre" in list(languages.keys()) names = [name for sublist in list(languages.values()) for name in sublist] - assert 'English' in names - assert 'French' in names + assert "English" in names + assert "French" in names def test_media(self): media = self.manager.admin_work_controller.media() assert Edition.BOOK_MEDIUM in list(media.values()) - assert Edition.medium_to_additional_type[Edition.BOOK_MEDIUM] in list(media.keys()) + assert Edition.medium_to_additional_type[Edition.BOOK_MEDIUM] in list( + media.keys() + ) def test_rights_status(self): rights_status = self.manager.admin_work_controller.rights_status() public_domain = rights_status.get(RightsStatus.PUBLIC_DOMAIN_USA) - assert RightsStatus.NAMES.get(RightsStatus.PUBLIC_DOMAIN_USA) == public_domain.get("name") + assert RightsStatus.NAMES.get( + RightsStatus.PUBLIC_DOMAIN_USA + ) == public_domain.get("name") assert True == public_domain.get("open_access") assert True == public_domain.get("allows_derivatives") @@ -139,7 +161,9 @@ def test_rights_status(self): assert False == cc_by_nd.get("allows_derivatives") copyright = rights_status.get(RightsStatus.IN_COPYRIGHT) - assert RightsStatus.NAMES.get(RightsStatus.IN_COPYRIGHT) == copyright.get("name") + assert RightsStatus.NAMES.get(RightsStatus.IN_COPYRIGHT) == copyright.get( + "name" + ) assert False == copyright.get("open_access") assert False == copyright.get("allows_derivatives") @@ -153,51 +177,45 @@ def _make_test_edit_request(self, data): def test_edit_unknown_role(self): response = self._make_test_edit_request( - [('contributor-role', self._str), - ('contributor-name', self._str)]) + [("contributor-role", self._str), ("contributor-name", self._str)] + ) assert 400 == response.status_code assert UNKNOWN_ROLE.uri == response.uri def test_edit_invalid_series_position(self): response = self._make_test_edit_request( - [('series', self._str), - ('series_position', 'five')]) + [("series", self._str), ("series_position", "five")] + ) assert 400 == response.status_code assert INVALID_SERIES_POSITION.uri == response.uri def test_edit_unknown_medium(self): - response = self._make_test_edit_request( - [('medium', self._str)]) + response = self._make_test_edit_request([("medium", self._str)]) assert 400 == response.status_code assert UNKNOWN_MEDIUM.uri == response.uri def test_edit_unknown_language(self): - response = self._make_test_edit_request( - [('language', self._str)]) + response = self._make_test_edit_request([("language", self._str)]) assert 400 == response.status_code assert UNKNOWN_LANGUAGE.uri == response.uri def test_edit_invalid_date_format(self): - response = self._make_test_edit_request( - [('issued', self._str)]) + response = self._make_test_edit_request([("issued", self._str)]) assert 400 == response.status_code assert INVALID_DATE_FORMAT.uri == response.uri def test_edit_invalid_rating_not_number(self): - response = self._make_test_edit_request( - [('rating', 'abc')]) + response = self._make_test_edit_request([("rating", "abc")]) assert 400 == response.status_code assert INVALID_RATING.uri == response.uri def test_edit_invalid_rating_above_scale(self): - response = self._make_test_edit_request( - [('rating', 9999)]) + response = self._make_test_edit_request([("rating", 9999)]) assert 400 == response.status_code assert INVALID_RATING.uri == response.uri def test_edit_invalid_rating_below_scale(self): - response = self._make_test_edit_request( - [('rating', -3)]) + response = self._make_test_edit_request([("rating", -3)]) assert 400 == response.status_code assert INVALID_RATING.uri == response.uri @@ -205,32 +223,38 @@ def test_edit(self): [lp] = self.english_1.license_pools staff_data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) + def staff_edition_count(): - return self._db.query(Edition) \ + return ( + self._db.query(Edition) .filter( Edition.data_source == staff_data_source, - Edition.primary_identifier_id == self.english_1.presentation_edition.primary_identifier.id - ) \ + Edition.primary_identifier_id + == self.english_1.presentation_edition.primary_identifier.id, + ) .count() + ) with self.request_context_with_library_and_admin("/"): - flask.request.form = ImmutableMultiDict([ - ("title", "New title"), - ("subtitle", "New subtitle"), - ("contributor-role", "Author"), - ("contributor-name", "New Author"), - ("contributor-role", "Narrator"), - ("contributor-name", "New Narrator"), - ("series", "New series"), - ("series_position", "144"), - ("medium", "Audio"), - ("language", "French"), - ("publisher", "New Publisher"), - ("imprint", "New Imprint"), - ("issued", "2017-11-05"), - ("rating", "2"), - ("summary", "

    New summary

    ") - ]) + flask.request.form = ImmutableMultiDict( + [ + ("title", "New title"), + ("subtitle", "New subtitle"), + ("contributor-role", "Author"), + ("contributor-name", "New Author"), + ("contributor-role", "Narrator"), + ("contributor-name", "New Narrator"), + ("series", "New series"), + ("series_position", "144"), + ("medium", "Audio"), + ("language", "French"), + ("publisher", "New Publisher"), + ("imprint", "New Imprint"), + ("issued", "2017-11-05"), + ("rating", "2"), + ("summary", "

    New summary

    "), + ] + ) response = self.manager.admin_work_controller.edit( lp.identifier.type, lp.identifier.identifier ) @@ -243,7 +267,8 @@ def staff_edition_count(): assert "New Author" in self.english_1.simple_opds_entry [author, narrator] = sorted( self.english_1.presentation_edition.contributions, - key=lambda x: x.contributor.display_name) + key=lambda x: x.contributor.display_name, + ) assert "New Author" == author.contributor.display_name assert "Author, New" == author.contributor.sort_name assert "Primary Author" == author.role @@ -258,7 +283,9 @@ def staff_edition_count(): assert "fre" == self.english_1.presentation_edition.language assert "New Publisher" == self.english_1.publisher assert "New Imprint" == self.english_1.presentation_edition.imprint - assert datetime_utc(2017, 11, 5) == self.english_1.presentation_edition.issued + assert ( + datetime_utc(2017, 11, 5) == self.english_1.presentation_edition.issued + ) assert 0.25 == self.english_1.quality assert "

    New summary

    " == self.english_1.summary_text assert "<p>New summary</p>" in self.english_1.simple_opds_entry @@ -266,34 +293,37 @@ def staff_edition_count(): with self.request_context_with_library_and_admin("/"): # Change the summary again and add an author. - flask.request.form = ImmutableMultiDict([ - ("title", "New title"), - ("subtitle", "New subtitle"), - ("contributor-role", "Author"), - ("contributor-name", "New Author"), - ("contributor-role", "Narrator"), - ("contributor-name", "New Narrator"), - ("contributor-role", "Author"), - ("contributor-name", "Second Author"), - ("series", "New series"), - ("series_position", "144"), - ("medium", "Audio"), - ("language", "French"), - ("publisher", "New Publisher"), - ("imprint", "New Imprint"), - ("issued", "2017-11-05"), - ("rating", "2"), - ("summary", "abcd") - ]) + flask.request.form = ImmutableMultiDict( + [ + ("title", "New title"), + ("subtitle", "New subtitle"), + ("contributor-role", "Author"), + ("contributor-name", "New Author"), + ("contributor-role", "Narrator"), + ("contributor-name", "New Narrator"), + ("contributor-role", "Author"), + ("contributor-name", "Second Author"), + ("series", "New series"), + ("series_position", "144"), + ("medium", "Audio"), + ("language", "French"), + ("publisher", "New Publisher"), + ("imprint", "New Imprint"), + ("issued", "2017-11-05"), + ("rating", "2"), + ("summary", "abcd"), + ] + ) response = self.manager.admin_work_controller.edit( lp.identifier.type, lp.identifier.identifier ) assert 200 == response.status_code assert "abcd" == self.english_1.summary_text - assert 'New summary' not in self.english_1.simple_opds_entry + assert "New summary" not in self.english_1.simple_opds_entry [author, narrator, author2] = sorted( self.english_1.presentation_edition.contributions, - key=lambda x: x.contributor.display_name) + key=lambda x: x.contributor.display_name, + ) assert "New Author" == author.contributor.display_name assert "Author, New" == author.contributor.sort_name assert "Primary Author" == author.role @@ -306,21 +336,23 @@ def staff_edition_count(): with self.request_context_with_library_and_admin("/"): # Now delete the subtitle, narrator, series, and summary entirely - flask.request.form = ImmutableMultiDict([ - ("title", "New title"), - ("contributor-role", "Author"), - ("contributor-name", "New Author"), - ("subtitle", ""), - ("series", ""), - ("series_position", ""), - ("medium", "Audio"), - ("language", "French"), - ("publisher", "New Publisher"), - ("imprint", "New Imprint"), - ("issued", "2017-11-05"), - ("rating", "2"), - ("summary", "") - ]) + flask.request.form = ImmutableMultiDict( + [ + ("title", "New title"), + ("contributor-role", "Author"), + ("contributor-name", "New Author"), + ("subtitle", ""), + ("series", ""), + ("series_position", ""), + ("medium", "Audio"), + ("language", "French"), + ("publisher", "New Publisher"), + ("imprint", "New Imprint"), + ("issued", "2017-11-05"), + ("rating", "2"), + ("summary", ""), + ] + ) response = self.manager.admin_work_controller.edit( lp.identifier.type, lp.identifier.identifier ) @@ -331,22 +363,24 @@ def staff_edition_count(): assert None == self.english_1.series assert None == self.english_1.series_position assert "" == self.english_1.summary_text - assert 'New subtitle' not in self.english_1.simple_opds_entry + assert "New subtitle" not in self.english_1.simple_opds_entry assert "Narrator" not in self.english_1.simple_opds_entry - assert 'New series' not in self.english_1.simple_opds_entry - assert '144' not in self.english_1.simple_opds_entry - assert 'abcd' not in self.english_1.simple_opds_entry + assert "New series" not in self.english_1.simple_opds_entry + assert "144" not in self.english_1.simple_opds_entry + assert "abcd" not in self.english_1.simple_opds_entry assert 1 == staff_edition_count() with self.request_context_with_library_and_admin("/"): # Set the fields one more time - flask.request.form = ImmutableMultiDict([ - ("title", "New title"), - ("subtitle", "Final subtitle"), - ("series", "Final series"), - ("series_position", "169"), - ("summary", "

    Final summary

    ") - ]) + flask.request.form = ImmutableMultiDict( + [ + ("title", "New title"), + ("subtitle", "Final subtitle"), + ("series", "Final series"), + ("series_position", "169"), + ("summary", "

    Final summary

    "), + ] + ) response = self.manager.admin_work_controller.edit( lp.identifier.type, lp.identifier.identifier ) @@ -355,21 +389,28 @@ def staff_edition_count(): assert "Final series" == self.english_1.series assert 169 == self.english_1.series_position assert "

    Final summary

    " == self.english_1.summary_text - assert 'Final subtitle' in self.english_1.simple_opds_entry - assert 'Final series' in self.english_1.simple_opds_entry - assert '169' in self.english_1.simple_opds_entry - assert "<p>Final summary</p>" in self.english_1.simple_opds_entry + assert "Final subtitle" in self.english_1.simple_opds_entry + assert "Final series" in self.english_1.simple_opds_entry + assert "169" in self.english_1.simple_opds_entry + assert ( + "<p>Final summary</p>" in self.english_1.simple_opds_entry + ) assert 1 == staff_edition_count() # Make sure a non-librarian of this library can't edit. self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - flask.request.form = ImmutableMultiDict([ - ("title", "Another new title"), - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.edit, - lp.identifier.type, lp.identifier.identifier) + flask.request.form = ImmutableMultiDict( + [ + ("title", "Another new title"), + ] + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.edit, + lp.identifier.type, + lp.identifier.identifier, + ) def test_edit_classifications(self): # start with a couple genres based on BISAC classifications from Axis 360 @@ -383,13 +424,13 @@ def test_edit_classifications(self): data_source=axis_360, subject_type=Subject.BISAC, subject_identifier="FICTION / Horror", - weight=1 + weight=1, ) classification2 = primary_identifier.classify( data_source=axis_360, subject_type=Subject.BISAC, subject_identifier="FICTION / Science Fiction / Time Travel", - weight=1 + weight=1, ) genre1, ignore = Genre.lookup(self._db, "Horror") genre2, ignore = Genre.lookup(self._db, "Science Fiction") @@ -397,12 +438,14 @@ def test_edit_classifications(self): # make no changes with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("audience", "Adult"), - ("fiction", "fiction"), - ("genres", "Horror"), - ("genres", "Science Fiction") - ]) + flask.request.form = MultiDict( + [ + ("audience", "Adult"), + ("fiction", "fiction"), + ("genres", "Horror"), + ("genres", "Science Fiction"), + ] + ) requested_genres = flask.request.form.getlist("genres") response = self.manager.admin_work_controller.edit_classifications( lp.identifier.type, lp.identifier.identifier @@ -410,18 +453,17 @@ def test_edit_classifications(self): assert response.status_code == 200 staff_data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) - genre_classifications = self._db \ - .query(Classification) \ - .join(Subject) \ + genre_classifications = ( + self._db.query(Classification) + .join(Subject) .filter( Classification.identifier == primary_identifier, Classification.data_source == staff_data_source, - Subject.genre_id != None + Subject.genre_id != None, ) + ) staff_genres = [ - c.subject.genre.name - for c in genre_classifications - if c.subject.genre + c.subject.genre.name for c in genre_classifications if c.subject.genre ] assert staff_genres == [] assert "Adult" == work.audience @@ -431,10 +473,9 @@ def test_edit_classifications(self): # remove all genres with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("audience", "Adult"), - ("fiction", "fiction") - ]) + flask.request.form = MultiDict( + [("audience", "Adult"), ("fiction", "fiction")] + ) response = self.manager.admin_work_controller.edit_classifications( lp.identifier.type, lp.identifier.identifier ) @@ -442,15 +483,16 @@ def test_edit_classifications(self): primary_identifier = work.presentation_edition.primary_identifier staff_data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) - none_classification_count = self._db \ - .query(Classification) \ - .join(Subject) \ + none_classification_count = ( + self._db.query(Classification) + .join(Subject) .filter( Classification.identifier == primary_identifier, Classification.data_source == staff_data_source, - Subject.identifier == SimplifiedGenreClassifier.NONE - ) \ + Subject.identifier == SimplifiedGenreClassifier.NONE, + ) .all() + ) assert 1 == len(none_classification_count) assert "Adult" == work.audience assert 18 == work.target_age.lower @@ -459,13 +501,15 @@ def test_edit_classifications(self): # completely change genres with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("audience", "Adult"), - ("fiction", "fiction"), - ("genres", "Drama"), - ("genres", "Urban Fantasy"), - ("genres", "Women's Fiction") - ]) + flask.request.form = MultiDict( + [ + ("audience", "Adult"), + ("fiction", "fiction"), + ("genres", "Drama"), + ("genres", "Urban Fantasy"), + ("genres", "Women's Fiction"), + ] + ) requested_genres = flask.request.form.getlist("genres") response = self.manager.admin_work_controller.edit_classifications( lp.identifier.type, lp.identifier.identifier @@ -482,13 +526,15 @@ def test_edit_classifications(self): # remove some genres and change audience and target age with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("audience", "Young Adult"), - ("target_age_min", 16), - ("target_age_max", 18), - ("fiction", "fiction"), - ("genres", "Urban Fantasy") - ]) + flask.request.form = MultiDict( + [ + ("audience", "Young Adult"), + ("target_age_min", 16), + ("target_age_max", 18), + ("fiction", "fiction"), + ("genres", "Urban Fantasy"), + ] + ) requested_genres = flask.request.form.getlist("genres") response = self.manager.admin_work_controller.edit_classifications( lp.identifier.type, lp.identifier.identifier @@ -507,14 +553,16 @@ def test_edit_classifications(self): # try to add a nonfiction genre with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("audience", "Young Adult"), - ("target_age_min", 16), - ("target_age_max", 18), - ("fiction", "fiction"), - ("genres", "Cooking"), - ("genres", "Urban Fantasy") - ]) + flask.request.form = MultiDict( + [ + ("audience", "Young Adult"), + ("target_age_min", 16), + ("target_age_max", 18), + ("fiction", "fiction"), + ("genres", "Cooking"), + ("genres", "Urban Fantasy"), + ] + ) response = self.manager.admin_work_controller.edit_classifications( lp.identifier.type, lp.identifier.identifier ) @@ -529,14 +577,16 @@ def test_edit_classifications(self): # try to add Erotica with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("audience", "Young Adult"), - ("target_age_min", 16), - ("target_age_max", 18), - ("fiction", "fiction"), - ("genres", "Erotica"), - ("genres", "Urban Fantasy") - ]) + flask.request.form = MultiDict( + [ + ("audience", "Young Adult"), + ("target_age_min", 16), + ("target_age_max", 18), + ("fiction", "fiction"), + ("genres", "Erotica"), + ("genres", "Urban Fantasy"), + ] + ) response = self.manager.admin_work_controller.edit_classifications( lp.identifier.type, lp.identifier.identifier ) @@ -552,13 +602,15 @@ def test_edit_classifications(self): # try to set min target age greater than max target age # othe edits should not go through with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("audience", "Young Adult"), - ("target_age_min", 16), - ("target_age_max", 14), - ("fiction", "nonfiction"), - ("genres", "Cooking") - ]) + flask.request.form = MultiDict( + [ + ("audience", "Young Adult"), + ("target_age_min", 16), + ("target_age_max", 14), + ("fiction", "nonfiction"), + ("genres", "Cooking"), + ] + ) response = self.manager.admin_work_controller.edit_classifications( lp.identifier.type, lp.identifier.identifier ) @@ -571,13 +623,15 @@ def test_edit_classifications(self): # change to nonfiction with nonfiction genres and new target age with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("audience", "Young Adult"), - ("target_age_min", 15), - ("target_age_max", 17), - ("fiction", "nonfiction"), - ("genres", "Cooking") - ]) + flask.request.form = MultiDict( + [ + ("audience", "Young Adult"), + ("target_age_min", 15), + ("target_age_max", 17), + ("fiction", "nonfiction"), + ("genres", "Cooking"), + ] + ) requested_genres = flask.request.form.getlist("genres") response = self.manager.admin_work_controller.edit_classifications( lp.identifier.type, lp.identifier.identifier @@ -592,11 +646,13 @@ def test_edit_classifications(self): # set to Adult and make sure that target ages is set automatically with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("audience", "Adult"), - ("fiction", "nonfiction"), - ("genres", "Cooking") - ]) + flask.request.form = MultiDict( + [ + ("audience", "Adult"), + ("fiction", "nonfiction"), + ("genres", "Cooking"), + ] + ) requested_genres = flask.request.form.getlist("genres") response = self.manager.admin_work_controller.edit_classifications( lp.identifier.type, lp.identifier.identifier @@ -609,14 +665,19 @@ def test_edit_classifications(self): # Make sure a non-librarian of this library can't edit. self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("audience", "Children"), - ("fiction", "nonfiction"), - ("genres", "Biography") - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.edit_classifications, - lp.identifier.type, lp.identifier.identifier) + flask.request.form = MultiDict( + [ + ("audience", "Children"), + ("fiction", "nonfiction"), + ("genres", "Biography"), + ] + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.edit_classifications, + lp.identifier.type, + lp.identifier.identifier, + ) def test_suppress(self): [lp] = self.english_1.license_pools @@ -631,17 +692,19 @@ def test_suppress(self): lp.suppressed = False self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.suppress, - lp.identifier.type, lp.identifier.identifier) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.suppress, + lp.identifier.type, + lp.identifier.identifier, + ) def test_unsuppress(self): [lp] = self.english_1.license_pools lp.suppressed = True broken_lp = self._licensepool( - self.english_1.presentation_edition, - data_source_name=DataSource.OVERDRIVE + self.english_1.presentation_edition, data_source_name=DataSource.OVERDRIVE ) broken_lp.work = self.english_1 broken_lp.suppressed = True @@ -650,7 +713,8 @@ def test_unsuppress(self): Complaint.register( broken_lp, "http://librarysimplified.org/terms/problem/cannot-render", - "blah", "blah" + "blah", + "blah", ) with self.request_context_with_library_and_admin("/"): @@ -667,22 +731,27 @@ def test_unsuppress(self): lp.suppressed = True self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.unsuppress, - lp.identifier.type, lp.identifier.identifier) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.unsuppress, + lp.identifier.type, + lp.identifier.identifier, + ) def test_refresh_metadata(self): wrangler = DataSource.lookup(self._db, DataSource.METADATA_WRANGLER) class AlwaysSuccessfulMetadataProvider(AlwaysSuccessfulCoverageProvider): DATA_SOURCE_NAME = wrangler.name + success_provider = AlwaysSuccessfulMetadataProvider(self._db) class NeverSuccessfulMetadataProvider(NeverSuccessfulCoverageProvider): DATA_SOURCE_NAME = wrangler.name + failure_provider = NeverSuccessfulMetadataProvider(self._db) - with self.request_context_with_library_and_admin('/'): + with self.request_context_with_library_and_admin("/"): [lp] = self.english_1.license_pools response = self.manager.admin_work_controller.refresh_metadata( lp.identifier.type, lp.identifier.identifier, provider=success_provider @@ -707,9 +776,13 @@ class NeverSuccessfulMetadataProvider(NeverSuccessfulCoverageProvider): self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.refresh_metadata, - lp.identifier.type, lp.identifier.identifier, provider=success_provider) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.refresh_metadata, + lp.identifier.type, + lp.identifier.identifier, + provider=success_provider, + ) def test_complaints(self): type = iter(Complaint.VALID_TYPES) @@ -720,22 +793,17 @@ def test_complaints(self): "fiction work with complaint", language="eng", fiction=True, - with_open_access_download=True) + with_open_access_download=True, + ) complaint1 = self._complaint( - work.license_pools[0], - type1, - "complaint1 source", - "complaint1 detail") + work.license_pools[0], type1, "complaint1 source", "complaint1 detail" + ) complaint2 = self._complaint( - work.license_pools[0], - type1, - "complaint2 source", - "complaint2 detail") + work.license_pools[0], type1, "complaint2 source", "complaint2 detail" + ) complaint3 = self._complaint( - work.license_pools[0], - type2, - "complaint3 source", - "complaint3 detail") + work.license_pools[0], type2, "complaint3 source", "complaint3 detail" + ) [lp] = work.license_pools @@ -743,16 +811,19 @@ def test_complaints(self): response = self.manager.admin_work_controller.complaints( lp.identifier.type, lp.identifier.identifier ) - assert response['book']['identifier_type'] == lp.identifier.type - assert response['book']['identifier'] == lp.identifier.identifier - assert response['complaints'][type1] == 2 - assert response['complaints'][type2] == 1 + assert response["book"]["identifier_type"] == lp.identifier.type + assert response["book"]["identifier"] == lp.identifier.identifier + assert response["complaints"][type1] == 2 + assert response["complaints"][type2] == 1 self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.complaints, - lp.identifier.type, lp.identifier.identifier) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.complaints, + lp.identifier.type, + lp.identifier.identifier, + ) def test_resolve_complaints(self): type = iter(Complaint.VALID_TYPES) @@ -763,17 +834,14 @@ def test_resolve_complaints(self): "fiction work with complaint", language="eng", fiction=True, - with_open_access_download=True) + with_open_access_download=True, + ) complaint1 = self._complaint( - work.license_pools[0], - type1, - "complaint1 source", - "complaint1 detail") + work.license_pools[0], type1, "complaint1 source", "complaint1 detail" + ) complaint2 = self._complaint( - work.license_pools[0], - type1, - "complaint2 source", - "complaint2 detail") + work.license_pools[0], type1, "complaint2 source", "complaint2 detail" + ) [lp] = work.license_pools @@ -783,7 +851,9 @@ def test_resolve_complaints(self): response = self.manager.admin_work_controller.resolve_complaints( lp.identifier.type, lp.identifier.identifier ) - unresolved_complaints = [complaint for complaint in lp.complaints if complaint.resolved == None] + unresolved_complaints = [ + complaint for complaint in lp.complaints if complaint.resolved == None + ] assert response.status_code == 404 assert len(unresolved_complaints) == 2 @@ -793,8 +863,9 @@ def test_resolve_complaints(self): response = self.manager.admin_work_controller.resolve_complaints( lp.identifier.type, lp.identifier.identifier ) - unresolved_complaints = [complaint for complaint in lp.complaints - if complaint.resolved == None] + unresolved_complaints = [ + complaint for complaint in lp.complaints if complaint.resolved == None + ] assert response.status_code == 200 assert len(unresolved_complaints) == 0 @@ -809,9 +880,12 @@ def test_resolve_complaints(self): self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): flask.request.form = ImmutableMultiDict([("type", type1)]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.resolve_complaints, - lp.identifier.type, lp.identifier.identifier) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.resolve_complaints, + lp.identifier.type, + lp.identifier.identifier, + ) def test_classifications(self): e, pool = self._edition(with_license_pool=True) @@ -826,38 +900,42 @@ def test_classifications(self): subject3.genre = None source = DataSource.lookup(self._db, DataSource.AXIS_360) classification1 = self._classification( - identifier=identifier, subject=subject1, - data_source=source, weight=1) + identifier=identifier, subject=subject1, data_source=source, weight=1 + ) classification2 = self._classification( - identifier=identifier, subject=subject2, - data_source=source, weight=3) + identifier=identifier, subject=subject2, data_source=source, weight=3 + ) classification3 = self._classification( - identifier=identifier, subject=subject3, - data_source=source, weight=2) + identifier=identifier, subject=subject3, data_source=source, weight=2 + ) [lp] = work.license_pools with self.request_context_with_library_and_admin("/"): response = self.manager.admin_work_controller.classifications( - lp.identifier.type, lp.identifier.identifier) - assert response['book']['identifier_type'] == lp.identifier.type - assert response['book']['identifier'] == lp.identifier.identifier + lp.identifier.type, lp.identifier.identifier + ) + assert response["book"]["identifier_type"] == lp.identifier.type + assert response["book"]["identifier"] == lp.identifier.identifier expected_results = [classification2, classification3, classification1] - assert len(response['classifications']) == len(expected_results) + assert len(response["classifications"]) == len(expected_results) for i, classification in enumerate(expected_results): subject = classification.subject source = classification.data_source - assert response['classifications'][i]['name'] == subject.identifier - assert response['classifications'][i]['type'] == subject.type - assert response['classifications'][i]['source'] == source.name - assert response['classifications'][i]['weight'] == classification.weight + assert response["classifications"][i]["name"] == subject.identifier + assert response["classifications"][i]["type"] == subject.type + assert response["classifications"][i]["source"] == source.name + assert response["classifications"][i]["weight"] == classification.weight self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.classifications, - lp.identifier.type, lp.identifier.identifier) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.classifications, + lp.identifier.type, + lp.identifier.identifier, + ) def test_validate_cover_image(self): base_path = os.path.split(__file__)[0] @@ -869,7 +947,10 @@ def test_validate_cover_image(self): result = self.manager.admin_work_controller._validate_cover_image(too_small) assert INVALID_IMAGE.uri == result.uri - assert "Cover image must be at least 600px in width and 900px in height." == result.detail + assert ( + "Cover image must be at least 600px in width and 900px in height." + == result.detail + ) path = os.path.join(resource_path, "blue.jpg") valid = Image.open(path) @@ -887,19 +968,30 @@ def test_process_cover_image(self): processed = Image.open(path) # Without a title position, the image won't be changed. - processed = self.manager.admin_work_controller._process_cover_image(work, processed, "none") + processed = self.manager.admin_work_controller._process_cover_image( + work, processed, "none" + ) image_histogram = original.histogram() expected_histogram = processed.histogram() - root_mean_square = math.sqrt(reduce(operator.add, - list(map(lambda a,b: (a-b)**2, image_histogram, expected_histogram)))/len(image_histogram)) + root_mean_square = math.sqrt( + reduce( + operator.add, + list( + map(lambda a, b: (a - b) ** 2, image_histogram, expected_histogram) + ), + ) + / len(image_histogram) + ) assert root_mean_square < 10 # Here the title and author are added in the center. Compare the result # with a pre-generated version. processed = Image.open(path) - processed = self.manager.admin_work_controller._process_cover_image(work, processed, "center") + processed = self.manager.admin_work_controller._process_cover_image( + work, processed, "center" + ) path = os.path.join(resource_path, "blue_with_title_author.png") expected_image = Image.open(path) @@ -907,8 +999,15 @@ def test_process_cover_image(self): image_histogram = processed.histogram() expected_histogram = expected_image.histogram() - root_mean_square = math.sqrt(reduce(operator.add, - list(map(lambda a,b: (a-b)**2, image_histogram, expected_histogram)))/len(image_histogram)) + root_mean_square = math.sqrt( + reduce( + operator.add, + list( + map(lambda a, b: (a - b) ** 2, image_histogram, expected_histogram) + ), + ) + / len(image_histogram) + ) assert root_mean_square < 10 def test_preview_book_cover(self): @@ -916,20 +1015,27 @@ def test_preview_book_cover(self): identifier = work.license_pools[0].identifier with self.request_context_with_library_and_admin("/"): - response = self.manager.admin_work_controller.preview_book_cover(identifier.type, identifier.identifier) + response = self.manager.admin_work_controller.preview_book_cover( + identifier.type, identifier.identifier + ) assert INVALID_IMAGE.uri == response.uri assert "Image file or image URL is required." == response.detail with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("cover_url", "bad_url"), - ]) - response = self.manager.admin_work_controller.preview_book_cover(identifier.type, identifier.identifier) + flask.request.form = MultiDict( + [ + ("cover_url", "bad_url"), + ] + ) + response = self.manager.admin_work_controller.preview_book_cover( + identifier.type, identifier.identifier + ) assert INVALID_URL.uri == response.uri assert '"bad_url" is not a valid URL.' == response.detail class TestFileUpload(BytesIO): - headers = { "Content-Type": "image/png" } + headers = {"Content-Type": "image/png"} + base_path = os.path.split(__file__)[0] folder = os.path.dirname(base_path) resource_path = os.path.join(folder, "..", "files", "images") @@ -940,82 +1046,107 @@ class TestFileUpload(BytesIO): image_data = buffer.getvalue() with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("title_position", "none") - ]) - flask.request.files = MultiDict([ - ("cover_file", TestFileUpload(image_data)), - ]) - response = self.manager.admin_work_controller.preview_book_cover(identifier.type, identifier.identifier) + flask.request.form = MultiDict([("title_position", "none")]) + flask.request.files = MultiDict( + [ + ("cover_file", TestFileUpload(image_data)), + ] + ) + response = self.manager.admin_work_controller.preview_book_cover( + identifier.type, identifier.identifier + ) assert 200 == response.status_code - assert "data:image/png;base64,%s" % base64.b64encode(image_data) == response.get_data(as_text=True) + assert "data:image/png;base64,%s" % base64.b64encode( + image_data + ) == response.get_data(as_text=True) self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.preview_book_cover, - identifier.type, identifier.identifier) - + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.preview_book_cover, + identifier.type, + identifier.identifier, + ) def test_change_book_cover(self): # Mock image processing which has been tested in other methods. process_called_with = [] + def mock_process(work, image, position): # Modify the image to ensure it gets a different generic URI. image.thumbnail((500, 500)) process_called_with.append((work, image, position)) return image + old_process = self.manager.admin_work_controller._process_cover_image self.manager.admin_work_controller._process_cover_image = mock_process work = self._work(with_license_pool=True) identifier = work.license_pools[0].identifier mirror_type = ExternalIntegrationLink.COVERS - mirrors = dict(covers_mirror=MockS3Uploader(),books_mirror=None) + mirrors = dict(covers_mirror=MockS3Uploader(), books_mirror=None) with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("rights_status", RightsStatus.CC_BY), - ("rights_explanation", "explanation"), - ]) - response = self.manager.admin_work_controller.change_book_cover(identifier.type, identifier.identifier, mirrors) + flask.request.form = MultiDict( + [ + ("rights_status", RightsStatus.CC_BY), + ("rights_explanation", "explanation"), + ] + ) + response = self.manager.admin_work_controller.change_book_cover( + identifier.type, identifier.identifier, mirrors + ) assert INVALID_IMAGE.uri == response.uri assert "Image file or image URL is required." == response.detail with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("cover_url", "http://example.com"), - ("title_position", "none"), - ]) + flask.request.form = MultiDict( + [ + ("cover_url", "http://example.com"), + ("title_position", "none"), + ] + ) flask.request.files = MultiDict([]) - response = self.manager.admin_work_controller.change_book_cover(identifier.type, identifier.identifier) + response = self.manager.admin_work_controller.change_book_cover( + identifier.type, identifier.identifier + ) assert INVALID_IMAGE.uri == response.uri assert "You must specify the image's license." == response.detail with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("cover_url", "bad_url"), - ("title_position", "none"), - ("rights_status", RightsStatus.CC_BY), - ]) - response = self.manager.admin_work_controller.change_book_cover(identifier.type, identifier.identifier, mirrors) + flask.request.form = MultiDict( + [ + ("cover_url", "bad_url"), + ("title_position", "none"), + ("rights_status", RightsStatus.CC_BY), + ] + ) + response = self.manager.admin_work_controller.change_book_cover( + identifier.type, identifier.identifier, mirrors + ) assert INVALID_URL.uri == response.uri assert '"bad_url" is not a valid URL.' == response.detail with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("cover_url", "http://example.com"), - ("title_position", "none"), - ("rights_status", RightsStatus.CC_BY), - ("rights_explanation", "explanation"), - ]) + flask.request.form = MultiDict( + [ + ("cover_url", "http://example.com"), + ("title_position", "none"), + ("rights_status", RightsStatus.CC_BY), + ("rights_explanation", "explanation"), + ] + ) flask.request.files = MultiDict([]) - response = self.manager.admin_work_controller.change_book_cover(identifier.type, identifier.identifier) + response = self.manager.admin_work_controller.change_book_cover( + identifier.type, identifier.identifier + ) assert INVALID_CONFIGURATION_OPTION.uri == response.uri assert "Could not find a storage integration" in response.detail class TestFileUpload(BytesIO): - headers = { "Content-Type": "image/png" } + headers = {"Content-Type": "image/png"} + base_path = os.path.split(__file__)[0] folder = os.path.dirname(base_path) resource_path = os.path.join(folder, "..", "files", "images") @@ -1029,15 +1160,21 @@ class TestFileUpload(BytesIO): # Upload a new cover image but don't modify it. with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("title_position", "none"), - ("rights_status", RightsStatus.CC_BY), - ("rights_explanation", "explanation"), - ]) - flask.request.files = MultiDict([ - ("cover_file", TestFileUpload(image_data)), - ]) - response = self.manager.admin_work_controller.change_book_cover(identifier.type, identifier.identifier, mirrors) + flask.request.form = MultiDict( + [ + ("title_position", "none"), + ("rights_status", RightsStatus.CC_BY), + ("rights_explanation", "explanation"), + ] + ) + flask.request.files = MultiDict( + [ + ("cover_file", TestFileUpload(image_data)), + ] + ) + response = self.manager.admin_work_controller.change_book_cover( + identifier.type, identifier.identifier, mirrors + ) assert 200 == response.status_code [link] = identifier.links @@ -1062,7 +1199,9 @@ class TestFileUpload(BytesIO): assert [] == process_called_with assert [representation, thumbnail] == mirrors[mirror_type].uploaded - assert [representation.mirror_url, thumbnail.mirror_url] == mirrors[mirror_type].destinations + assert [representation.mirror_url, thumbnail.mirror_url] == mirrors[ + mirror_type + ].destinations work = self._work(with_license_pool=True) identifier = work.license_pools[0].identifier @@ -1070,15 +1209,21 @@ class TestFileUpload(BytesIO): # Upload a new cover image and add the title and author to it. # Both the original image and the generated image will become resources. with self.request_context_with_library_and_admin("/"): - flask.request.form = MultiDict([ - ("title_position", "center"), - ("rights_status", RightsStatus.CC_BY), - ("rights_explanation", "explanation"), - ]) - flask.request.files = MultiDict([ - ("cover_file", TestFileUpload(image_data)), - ]) - response = self.manager.admin_work_controller.change_book_cover(identifier.type, identifier.identifier, mirrors) + flask.request.form = MultiDict( + [ + ("title_position", "center"), + ("rights_status", RightsStatus.CC_BY), + ("rights_explanation", "explanation"), + ] + ) + flask.request.files = MultiDict( + [ + ("cover_file", TestFileUpload(image_data)), + ] + ) + response = self.manager.admin_work_controller.change_book_cover( + identifier.type, identifier.identifier, mirrors + ) assert 200 == response.status_code [link] = identifier.links @@ -1089,9 +1234,16 @@ class TestFileUpload(BytesIO): assert identifier.urn in resource.url assert staff_data_source == resource.data_source assert RightsStatus.CC_BY == resource.rights_status.uri - assert "The original image license allows derivatives." == resource.rights_explanation + assert ( + "The original image license allows derivatives." + == resource.rights_explanation + ) - transformation = self._db.query(ResourceTransformation).filter(ResourceTransformation.derivative_id==resource.id).one() + transformation = ( + self._db.query(ResourceTransformation) + .filter(ResourceTransformation.derivative_id == resource.id) + .one() + ) original_resource = transformation.original assert resource != original_resource assert identifier.urn in original_resource.url @@ -1101,7 +1253,10 @@ class TestFileUpload(BytesIO): assert image_data == original_resource.representation.content assert None == original_resource.representation.mirror_url assert "center" == transformation.settings.get("title_position") - assert resource.representation.content != original_resource.representation.content + assert ( + resource.representation.content + != original_resource.representation.content + ) assert image_data != resource.representation.content assert work == process_called_with[0][0] @@ -1115,52 +1270,79 @@ class TestFileUpload(BytesIO): assert identifier.identifier in resource.representation.mirror_url assert identifier.identifier in thumbnail.mirror_url - assert [resource.representation, thumbnail] == mirrors[mirror_type].uploaded[2:] - assert [resource.representation.mirror_url, thumbnail.mirror_url] == mirrors[mirror_type].destinations[2:] + assert [resource.representation, thumbnail] == mirrors[ + mirror_type + ].uploaded[2:] + assert [ + resource.representation.mirror_url, + thumbnail.mirror_url, + ] == mirrors[mirror_type].destinations[2:] self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.preview_book_cover, - identifier.type, identifier.identifier) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.preview_book_cover, + identifier.type, + identifier.identifier, + ) self.manager.admin_work_controller._process_cover_image = old_process def test_custom_lists_get(self): staff_data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) - list, ignore = create(self._db, CustomList, name=self._str, library=self._default_library, data_source=staff_data_source) + list, ignore = create( + self._db, + CustomList, + name=self._str, + library=self._default_library, + data_source=staff_data_source, + ) work = self._work(with_license_pool=True) list.add_entry(work) identifier = work.presentation_edition.primary_identifier with self.request_context_with_library_and_admin("/"): - response = self.manager.admin_work_controller.custom_lists(identifier.type, identifier.identifier) - lists = response.get('custom_lists') + response = self.manager.admin_work_controller.custom_lists( + identifier.type, identifier.identifier + ) + lists = response.get("custom_lists") assert 1 == len(lists) assert list.id == lists[0].get("id") assert list.name == lists[0].get("name") self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/"): - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.custom_lists, - identifier.type, identifier.identifier) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.custom_lists, + identifier.type, + identifier.identifier, + ) def test_custom_lists_edit_with_missing_list(self): work = self._work(with_license_pool=True) identifier = work.presentation_edition.primary_identifier with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("id", "4"), - ("name", "name"), - ]) + flask.request.form = MultiDict( + [ + ("id", "4"), + ("name", "name"), + ] + ) response = self.manager.admin_custom_lists_controller.custom_lists() assert MISSING_CUSTOM_LIST == response def test_custom_lists_edit_success(self): staff_data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) - list, ignore = create(self._db, CustomList, name=self._str, library=self._default_library, data_source=staff_data_source) + list, ignore = create( + self._db, + CustomList, + name=self._str, + library=self._default_library, + data_source=staff_data_source, + ) work = self._work(with_license_pool=True) identifier = work.presentation_edition.primary_identifier @@ -1175,10 +1357,12 @@ def test_custom_lists_edit_success(self): # Add the list to the work. with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("lists", json.dumps([{ "id": str(list.id), "name": list.name }])) - ]) - response = self.manager.admin_work_controller.custom_lists(identifier.type, identifier.identifier) + flask.request.form = MultiDict( + [("lists", json.dumps([{"id": str(list.id), "name": list.name}]))] + ) + response = self.manager.admin_work_controller.custom_lists( + identifier.type, identifier.identifier + ) assert 200 == response.status_code assert 1 == len(work.custom_list_entries) assert 1 == len(list.entries) @@ -1193,10 +1377,14 @@ def test_custom_lists_edit_success(self): # Now remove the work from the list. self.controller.search_engine.docs = dict(id1="doc1") with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("lists", json.dumps([])), - ]) - response = self.manager.admin_work_controller.custom_lists(identifier.type, identifier.identifier) + flask.request.form = MultiDict( + [ + ("lists", json.dumps([])), + ] + ) + response = self.manager.admin_work_controller.custom_lists( + identifier.type, identifier.identifier + ) assert 200 == response.status_code assert 0 == len(work.custom_list_entries) assert 0 == len(list.entries) @@ -1206,21 +1394,28 @@ def test_custom_lists_edit_success(self): # Add a list that didn't exist before. with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("lists", json.dumps([{ "name": "new list" }])) - ]) - response = self.manager.admin_work_controller.custom_lists(identifier.type, identifier.identifier) + flask.request.form = MultiDict( + [("lists", json.dumps([{"name": "new list"}]))] + ) + response = self.manager.admin_work_controller.custom_lists( + identifier.type, identifier.identifier + ) assert 200 == response.status_code assert 1 == len(work.custom_list_entries) - new_list = CustomList.find(self._db, "new list", staff_data_source, self._default_library) + new_list = CustomList.find( + self._db, "new list", staff_data_source, self._default_library + ) assert new_list == work.custom_list_entries[0].customlist assert True == work.custom_list_entries[0].featured self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library) with self.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = MultiDict([ - ("lists", json.dumps([{ "name": "another new list" }])) - ]) - pytest.raises(AdminNotAuthorized, - self.manager.admin_work_controller.custom_lists, - identifier.type, identifier.identifier) + flask.request.form = MultiDict( + [("lists", json.dumps([{"name": "another new list"}]))] + ) + pytest.raises( + AdminNotAuthorized, + self.manager.admin_work_controller.custom_lists, + identifier.type, + identifier.identifier, + ) diff --git a/tests/admin/test_announcement_list_validator.py b/tests/admin/test_announcement_list_validator.py index 441ebf1fa0..9c9784b945 100644 --- a/tests/admin/test_announcement_list_validator.py +++ b/tests/admin/test_announcement_list_validator.py @@ -1,18 +1,14 @@ -from datetime import ( - date, - datetime, - timedelta, -) import json +from datetime import date, datetime, timedelta -from core.problem_details import INVALID_INPUT -from core.util.problem_detail import ProblemDetail -from api.announcements import Announcement from api.admin.announcement_list_validator import AnnouncementListValidator +from api.announcements import Announcement from api.testing import AnnouncementTest +from core.problem_details import INVALID_INPUT +from core.util.problem_detail import ProblemDetail -class TestAnnouncementListValidator(AnnouncementTest): +class TestAnnouncementListValidator(AnnouncementTest): def assert_invalid(self, x, detail): assert isinstance(x, ProblemDetail) assert INVALID_INPUT.uri == x.uri @@ -31,7 +27,7 @@ def test_validate_announcements(self): class AlwaysAcceptValidator(AnnouncementListValidator): def validate_announcement(self, announcement): - announcement['validated'] = True + announcement["validated"] = True return announcement validator = AlwaysAcceptValidator(maximum_announcements=2) @@ -61,9 +57,8 @@ def validate_announcement(self, announcement): for invalid in dict(), json.dumps(dict()), "non-json string": self.assert_invalid( m(invalid), - "Invalid announcement list format: %(announcements)r" % dict( - announcements=invalid - ) + "Invalid announcement list format: %(announcements)r" + % dict(announcements=invalid), ) # validate_announcements runs some checks on the list of announcements. @@ -93,8 +88,7 @@ def validate_announcement(self, announcement): validator = AlwaysRejectValidator() self.assert_invalid( - validator.validate_announcements(["an announcement"]), - "Rejected!" + validator.validate_announcements(["an announcement"]), "Rejected!" ) def test_validate_announcement_success(self): @@ -106,25 +100,25 @@ def test_validate_announcement_success(self): today = date.today() in_a_week = today + timedelta(days=7) valid = dict( - start=today.strftime('%Y-%m-%d'), - finish=in_a_week.strftime('%Y-%m-%d'), - content="This is a test of announcement validation." + start=today.strftime("%Y-%m-%d"), + finish=in_a_week.strftime("%Y-%m-%d"), + content="This is a test of announcement validation.", ) validated = m(valid) # A UUID has been added in the 'id' field. - id = validated.pop('id') + id = validated.pop("id") assert 36 == len(id) for position in 8, 13, 18, 23: - assert '-' == id[position] + assert "-" == id[position] # Date strings have been converted to date objects. - assert today == validated['start'] - assert in_a_week == validated['finish'] - + assert today == validated["start"] + assert in_a_week == validated["finish"] + # Now simulate an edit, where an ID is provided. - validated['id'] = 'an existing id' + validated["id"] = "an existing id" # Now the incoming data is validated but not changed at all. assert validated == m(validated) @@ -132,14 +126,12 @@ def test_validate_announcement_success(self): # If no start date is specified, today's date is used. If no # finish date is specified, a default associated with the # validator is used. - no_finish_date = dict( - content="This is a test of announcment validation" - ) + no_finish_date = dict(content="This is a test of announcment validation") validated = m(no_finish_date) - assert today == validated['start'] + assert today == validated["start"] assert ( - today + timedelta(days=validator.default_duration_days) == - validated['finish'] + today + timedelta(days=validator.default_duration_days) + == validated["finish"] ) def test_validate_announcement_failure(self): @@ -151,8 +143,9 @@ def test_validate_announcement_failure(self): # Totally bogus format for invalid in '{"a": "string"}', ["a list"]: self.assert_invalid( - m(invalid), - "Invalid announcement format: %(announcement)r" % dict(announcement=invalid) + m(invalid), + "Invalid announcement format: %(announcement)r" + % dict(announcement=invalid), ) # Some baseline valid value to use in tests where _some_ of the data is valid. @@ -163,19 +156,17 @@ def test_validate_announcement_failure(self): # Missing a required field no_content = dict(start=today) self.assert_invalid(m(no_content), "Missing required field: content") - + # Bad content -- tested at greater length in another test. bad_content = dict(start=today, content="short") self.assert_invalid( - m(bad_content), - "Value too short (5 versus 15 characters): short" + m(bad_content), "Value too short (5 versus 15 characters): short" ) # Bad start date -- tested at greater length in another test. bad_start_date = dict(start="not-a-date", content=message) self.assert_invalid( - m(bad_start_date), - "Value for start is not a date: not-a-date" + m(bad_start_date), "Value for start is not a date: not-a-date" ) # Bad finish date. @@ -184,27 +175,23 @@ def test_validate_announcement_failure(self): bad_data = dict(start=today, finish=bad_finish_date, content=message) self.assert_invalid( m(bad_data), - "Value for finish must be no earlier than %s" % ( - tomorrow.strftime(validator.DATE_FORMAT) - ) + "Value for finish must be no earlier than %s" + % (tomorrow.strftime(validator.DATE_FORMAT)), ) - def test_validate_length(self): # Test the validate_length helper method in more detail than # it's tested in validate_announcement. m = AnnouncementListValidator.validate_length value = "four" assert value == m(value, 3, 5) - + self.assert_invalid( - m(value, 10, 20), - "Value too short (4 versus 10 characters): four" + m(value, 10, 20), "Value too short (4 versus 10 characters): four" ) self.assert_invalid( - m(value, 1, 3), - "Value too long (4 versus 3 characters): four" + m(value, 1, 3), "Value too long (4 versus 3 characters): four" ) def test_validate_date(self): @@ -220,14 +207,14 @@ def test_validate_date(self): assert february_1 == m("somedate", february_1) assert february_1 == m("somedate", datetime(2020, 2, 1)) - # But if a string is used, it must be in a specific format. + # But if a string is used, it must be in a specific format. self.assert_invalid( m("somedate", "not-a-date"), "Value for somedate is not a date: not-a-date" ) # If a minimum (date or datetime) is provided, the selection # must be on or after that date. - + january_1 = date(2020, 1, 1) january_1_datetime = datetime(2020, 1, 1) assert february_1 == m("somedate", february_1, minimum=january_1) @@ -257,5 +244,3 @@ def test_format(self): # Announcement objects and then back to dictionaries using # Announcement.json_ready. assert [Announcement(**x).json_ready for x in announcements] == as_list - - diff --git a/tests/admin/test_config.py b/tests/admin/test_config.py index 693202f1e5..9c72484c5f 100644 --- a/tests/admin/test_config.py +++ b/tests/admin/test_config.py @@ -8,7 +8,6 @@ class TestAdminUI(object): - @staticmethod def _set_env(monkeypatch, key: str, value: Optional[str]): if value: @@ -17,56 +16,123 @@ def _set_env(monkeypatch, key: str, value: Optional[str]): monkeypatch.delenv(key) @pytest.mark.parametrize( - 'package_name, package_version, mode, expected_result_startswith', + "package_name, package_version, mode, expected_result_startswith", [ - [None, None, OperationalMode.production, - 'https://cdn.jsdelivr.net/npm/@thepalaceproject/circulation-admin@'], - ['@some-scope/some-package', '1.0.0', OperationalMode.production, - 'https://cdn.jsdelivr.net/npm/@some-scope/some-package@1.0.0'], - ['some-package', '1.0.0', OperationalMode.production, - 'https://cdn.jsdelivr.net/npm/some-package@1.0.0'], - [None, None, OperationalMode.development, '/'], - [None, '1.0.0', OperationalMode.development, '/'], - ['some-package', '1.0.0', OperationalMode.development, '/'], - ]) - def test_package_url(self, monkeypatch, package_name: Optional[str], package_version: Optional[str], - mode: OperationalMode, expected_result_startswith: str): - self._set_env(monkeypatch, 'TPP_CIRCULATION_ADMIN_PACKAGE_NAME', package_name) - self._set_env(monkeypatch, 'TPP_CIRCULATION_ADMIN_PACKAGE_VERSION', package_version) + [ + None, + None, + OperationalMode.production, + "https://cdn.jsdelivr.net/npm/@thepalaceproject/circulation-admin@", + ], + [ + "@some-scope/some-package", + "1.0.0", + OperationalMode.production, + "https://cdn.jsdelivr.net/npm/@some-scope/some-package@1.0.0", + ], + [ + "some-package", + "1.0.0", + OperationalMode.production, + "https://cdn.jsdelivr.net/npm/some-package@1.0.0", + ], + [None, None, OperationalMode.development, "/"], + [None, "1.0.0", OperationalMode.development, "/"], + ["some-package", "1.0.0", OperationalMode.development, "/"], + ], + ) + def test_package_url( + self, + monkeypatch, + package_name: Optional[str], + package_version: Optional[str], + mode: OperationalMode, + expected_result_startswith: str, + ): + self._set_env(monkeypatch, "TPP_CIRCULATION_ADMIN_PACKAGE_NAME", package_name) + self._set_env( + monkeypatch, "TPP_CIRCULATION_ADMIN_PACKAGE_VERSION", package_version + ) result = AdminConfig.package_url(_operational_mode=mode) assert result.startswith(expected_result_startswith) @pytest.mark.parametrize( - 'package_name, package_version, expected_result', + "package_name, package_version, expected_result", [ - [None, None, '/my-base-dir/node_modules/@thepalaceproject/circulation-admin'], - [None, '1.0.0', '/my-base-dir/node_modules/@thepalaceproject/circulation-admin'], - ['some-package', '1.0.0', '/my-base-dir/node_modules/some-package'], - ]) - def test_package_development_directory(self, monkeypatch, package_name: Optional[str], - package_version: Optional[str], expected_result: str): - self._set_env(monkeypatch, 'TPP_CIRCULATION_ADMIN_PACKAGE_NAME', package_name) - self._set_env(monkeypatch, 'TPP_CIRCULATION_ADMIN_PACKAGE_VERSION', package_version) - result = AdminConfig.package_development_directory(_base_dir='/my-base-dir') + [ + None, + None, + "/my-base-dir/node_modules/@thepalaceproject/circulation-admin", + ], + [ + None, + "1.0.0", + "/my-base-dir/node_modules/@thepalaceproject/circulation-admin", + ], + ["some-package", "1.0.0", "/my-base-dir/node_modules/some-package"], + ], + ) + def test_package_development_directory( + self, + monkeypatch, + package_name: Optional[str], + package_version: Optional[str], + expected_result: str, + ): + self._set_env(monkeypatch, "TPP_CIRCULATION_ADMIN_PACKAGE_NAME", package_name) + self._set_env( + monkeypatch, "TPP_CIRCULATION_ADMIN_PACKAGE_VERSION", package_version + ) + result = AdminConfig.package_development_directory(_base_dir="/my-base-dir") assert result == expected_result @pytest.mark.parametrize( - 'asset_key, operational_mode, expected_result', + "asset_key, operational_mode, expected_result", [ - ['admin_css', OperationalMode.development, '/admin/static/circulation-admin.css'], - ['admin_css', OperationalMode.production, - 'https://cdn.jsdelivr.net/npm/known-package-name@1.0.0/dist/circulation-admin.css'], - ['admin_js', OperationalMode.development, '/admin/static/circulation-admin.js'], - ['admin_js', OperationalMode.production, - 'https://cdn.jsdelivr.net/npm/known-package-name@1.0.0/dist/circulation-admin.js'], - ['another-asset.jpg', OperationalMode.development, '/admin/static/another-asset.jpg'], - ['another-asset.jpg', OperationalMode.production, - 'https://cdn.jsdelivr.net/npm/known-package-name@1.0.0/dist/another-asset.jpg'], - ] + [ + "admin_css", + OperationalMode.development, + "/admin/static/circulation-admin.css", + ], + [ + "admin_css", + OperationalMode.production, + "https://cdn.jsdelivr.net/npm/known-package-name@1.0.0/dist/circulation-admin.css", + ], + [ + "admin_js", + OperationalMode.development, + "/admin/static/circulation-admin.js", + ], + [ + "admin_js", + OperationalMode.production, + "https://cdn.jsdelivr.net/npm/known-package-name@1.0.0/dist/circulation-admin.js", + ], + [ + "another-asset.jpg", + OperationalMode.development, + "/admin/static/another-asset.jpg", + ], + [ + "another-asset.jpg", + OperationalMode.production, + "https://cdn.jsdelivr.net/npm/known-package-name@1.0.0/dist/another-asset.jpg", + ], + ], ) - def test_lookup_asset_url(self, monkeypatch, asset_key: str, operational_mode: OperationalMode, - expected_result: str): - self._set_env(monkeypatch, 'TPP_CIRCULATION_ADMIN_PACKAGE_NAME', 'known-package-name') - self._set_env(monkeypatch, 'TPP_CIRCULATION_ADMIN_PACKAGE_VERSION', '1.0.0') - result = AdminConfig.lookup_asset_url(key=asset_key, _operational_mode=operational_mode) + def test_lookup_asset_url( + self, + monkeypatch, + asset_key: str, + operational_mode: OperationalMode, + expected_result: str, + ): + self._set_env( + monkeypatch, "TPP_CIRCULATION_ADMIN_PACKAGE_NAME", "known-package-name" + ) + self._set_env(monkeypatch, "TPP_CIRCULATION_ADMIN_PACKAGE_VERSION", "1.0.0") + result = AdminConfig.lookup_asset_url( + key=asset_key, _operational_mode=operational_mode + ) assert result == expected_result diff --git a/tests/admin/test_geographic_validator.py b/tests/admin/test_geographic_validator.py index 62c6136bbb..c522bb835c 100644 --- a/tests/admin/test_geographic_validator.py +++ b/tests/admin/test_geographic_validator.py @@ -1,24 +1,26 @@ +import json +import urllib.error +import urllib.parse +import urllib.request + +import pypostalcode +import uszipcode from api.admin.controller.library_settings import LibrarySettingsController from api.admin.geographic_validator import GeographicValidator from api.admin.problem_details import * from api.config import Configuration from api.registry import RemoteRegistry -from core.model import ( - create, - ExternalIntegration -) +from core.model import ExternalIntegration, create from core.testing import MockRequestsResponse -import json -import pypostalcode from tests.admin.controller.test_controller import SettingsControllerTest -import urllib.request, urllib.parse, urllib.error -import uszipcode + class TestGeographicValidator(SettingsControllerTest): def test_validate_geographic_areas(self): original_validator = GeographicValidator db = self._db + class Mock(GeographicValidator): def __init__(self): self._db = db @@ -26,9 +28,11 @@ def __init__(self): def mock_find_location_through_registry(self, value, db): self.value = value + def mock_find_location_through_registry_with_error(self, value, db): self.value = value return REMOTE_INTEGRATION_FAILED + def mock_find_location_through_registry_success(self, value, db): self.value = value return "CA" @@ -54,7 +58,10 @@ def mock_find_location_through_registry_success(self, value, db): # Invalid 2-letter abbreviation response = mock.validate_geographic_areas('["ZZ"]', self._db) assert response.uri == UNKNOWN_LOCATION.uri - assert response.detail == '"ZZ" is not a valid U.S. state or Canadian province abbreviation.' + assert ( + response.detail + == '"ZZ" is not a valid U.S. state or Canadian province abbreviation.' + ) # The validator should have returned the problem detail without bothering to ask the registry. assert mock.value == None @@ -80,17 +87,24 @@ def mock_find_location_through_registry_success(self, value, db): # The Canadian zip code is valid, but it corresponds to a place too small for the registry to know about it. response = mock.validate_geographic_areas('["J5J"]', self._db) assert response.uri == UNKNOWN_LOCATION.uri - assert response.detail == 'Unable to locate "J5J" (Saint-Sophie, Quebec). Try entering the name of a larger area.' + assert ( + response.detail + == 'Unable to locate "J5J" (Saint-Sophie, Quebec). Try entering the name of a larger area.' + ) assert mock.value == "Saint-Sophie, Quebec" # Can't connect to registry - mock.find_location_through_registry = mock.mock_find_location_through_registry_with_error + mock.find_location_through_registry = ( + mock.mock_find_location_through_registry_with_error + ) response = mock.validate_geographic_areas('["Victoria, BC"]', self._db) # The controller goes ahead and calls find_location_through_registry, but it can't connect to the registry. assert response.uri == REMOTE_INTEGRATION_FAILED.uri # The registry successfully finds the place - mock.find_location_through_registry = mock.mock_find_location_through_registry_success + mock.find_location_through_registry = ( + mock.mock_find_location_through_registry_success + ) response = mock.validate_geographic_areas('["Victoria, BC"]', self._db) assert response == {"CA": ["Victoria, BC"], "US": []} @@ -107,19 +121,30 @@ def test_find_location_through_registry(self): class Mock(GeographicValidator): called_with = [] + def mock_ask_registry(self, service_area_object, db): places = {"US": ["Chicago"], "CA": ["Victoria, BC"]} - service_area_info = json.loads(urllib.parse.unquote(service_area_object)) + service_area_info = json.loads( + urllib.parse.unquote(service_area_object) + ) nation = list(service_area_info.keys())[0] city_or_county = list(service_area_info.values())[0] if city_or_county == "ERROR": test.responses.append(MockRequestsResponse(502)) elif city_or_county in places[nation]: self.called_with.append(service_area_info) - test.responses.append(MockRequestsResponse(200, content=json.dumps(dict(unknown=None, ambiguous=None)))) + test.responses.append( + MockRequestsResponse( + 200, content=json.dumps(dict(unknown=None, ambiguous=None)) + ) + ) else: self.called_with.append(service_area_info) - test.responses.append(MockRequestsResponse(200, content=json.dumps(dict(unknown=[city_or_county])))) + test.responses.append( + MockRequestsResponse( + 200, content=json.dumps(dict(unknown=[city_or_county])) + ) + ) return original_ask_registry(service_area_object, db, get) mock = Mock() @@ -142,14 +167,19 @@ def mock_ask_registry(self, service_area_object, db): mock.called_with = [] - nowhere_response = mock.find_location_through_registry("Not a real place", self._db) + nowhere_response = mock.find_location_through_registry( + "Not a real place", self._db + ) assert len(mock.called_with) == 2 assert {"US": "Not a real place"} == mock.called_with[0] assert {"CA": "Not a real place"} == mock.called_with[1] assert nowhere_response == None error_response = mock.find_location_through_registry("ERROR", self._db) - assert error_response.detail == "Unable to contact the registry at https://registry_url." + assert ( + error_response.detail + == "Unable to contact the registry at https://registry_url." + ) assert error_response.status_code == 502 def test_ask_registry(self, monkeypatch): @@ -167,66 +197,115 @@ def test_ask_registry(self, monkeypatch): # Registry 1 knows about the place self.responses.append(true_response) - response_1 = validator.ask_registry(json.dumps({"CA": "Victoria, BC"}), self._db, self.do_request) + response_1 = validator.ask_registry( + json.dumps({"CA": "Victoria, BC"}), self._db, self.do_request + ) assert response_1 == True assert len(self.requests) == 1 request_1 = self.requests.pop() - assert request_1[0] == 'https://registry_1_url/coverage?coverage={"CA": "Victoria, BC"}' + assert ( + request_1[0] + == 'https://registry_1_url/coverage?coverage={"CA": "Victoria, BC"}' + ) # Registry 1 says the place is unknown, but Registry 2 finds it. self.responses.append(true_response) self.responses.append(unknown_response) - response_2 = validator.ask_registry(json.dumps({"CA": "Victoria, BC"}), self._db, self.do_request) + response_2 = validator.ask_registry( + json.dumps({"CA": "Victoria, BC"}), self._db, self.do_request + ) assert response_2 == True assert len(self.requests) == 2 request_2 = self.requests.pop() - assert request_2[0] == 'https://registry_2_url/coverage?coverage={"CA": "Victoria, BC"}' + assert ( + request_2[0] + == 'https://registry_2_url/coverage?coverage={"CA": "Victoria, BC"}' + ) request_1 = self.requests.pop() - assert request_1[0] == 'https://registry_1_url/coverage?coverage={"CA": "Victoria, BC"}' + assert ( + request_1[0] + == 'https://registry_1_url/coverage?coverage={"CA": "Victoria, BC"}' + ) # Registry_1 says the place is ambiguous and Registry_2 says it's unknown, but Registry_3 finds it. self.responses.append(true_response) self.responses.append(unknown_response) self.responses.append(ambiguous_response) - response_3 = validator.ask_registry(json.dumps({"CA": "Victoria, BC"}), self._db, self.do_request) + response_3 = validator.ask_registry( + json.dumps({"CA": "Victoria, BC"}), self._db, self.do_request + ) assert response_3 == True assert len(self.requests) == 3 request_3 = self.requests.pop() - assert request_3[0] == 'https://registry_3_url/coverage?coverage={"CA": "Victoria, BC"}' + assert ( + request_3[0] + == 'https://registry_3_url/coverage?coverage={"CA": "Victoria, BC"}' + ) request_2 = self.requests.pop() - assert request_2[0] == 'https://registry_2_url/coverage?coverage={"CA": "Victoria, BC"}' + assert ( + request_2[0] + == 'https://registry_2_url/coverage?coverage={"CA": "Victoria, BC"}' + ) request_1 = self.requests.pop() - assert request_1[0] == 'https://registry_1_url/coverage?coverage={"CA": "Victoria, BC"}' + assert ( + request_1[0] + == 'https://registry_1_url/coverage?coverage={"CA": "Victoria, BC"}' + ) # Registry 1 returns a problem detail, but Registry 2 finds the place self.responses.append(true_response) self.responses.append(problem_response) - response_4 = validator.ask_registry(json.dumps({"CA": "Victoria, BC"}), self._db, self.do_request) + response_4 = validator.ask_registry( + json.dumps({"CA": "Victoria, BC"}), self._db, self.do_request + ) assert response_4 == True assert len(self.requests) == 2 request_2 = self.requests.pop() - assert request_2[0] == 'https://registry_2_url/coverage?coverage={"CA": "Victoria, BC"}' + assert ( + request_2[0] + == 'https://registry_2_url/coverage?coverage={"CA": "Victoria, BC"}' + ) request_1 = self.requests.pop() - assert request_1[0] == 'https://registry_1_url/coverage?coverage={"CA": "Victoria, BC"}' + assert ( + request_1[0] + == 'https://registry_1_url/coverage?coverage={"CA": "Victoria, BC"}' + ) # Registry 1 returns a problem detail and the other two registries can't find the place self.responses.append(unknown_response) self.responses.append(ambiguous_response) self.responses.append(problem_response) - response_5 = validator.ask_registry(json.dumps({"CA": "Victoria, BC"}), self._db, self.do_request) + response_5 = validator.ask_registry( + json.dumps({"CA": "Victoria, BC"}), self._db, self.do_request + ) assert response_5.status_code == 502 - assert response_5.detail == "Unable to contact the registry at https://registry_1_url." + assert ( + response_5.detail + == "Unable to contact the registry at https://registry_1_url." + ) assert len(self.requests) == 3 request_3 = self.requests.pop() - assert request_3[0] == 'https://registry_3_url/coverage?coverage={"CA": "Victoria, BC"}' + assert ( + request_3[0] + == 'https://registry_3_url/coverage?coverage={"CA": "Victoria, BC"}' + ) request_2 = self.requests.pop() - assert request_2[0] == 'https://registry_2_url/coverage?coverage={"CA": "Victoria, BC"}' + assert ( + request_2[0] + == 'https://registry_2_url/coverage?coverage={"CA": "Victoria, BC"}' + ) request_1 = self.requests.pop() - assert request_1[0] == 'https://registry_1_url/coverage?coverage={"CA": "Victoria, BC"}' + assert ( + request_1[0] + == 'https://registry_1_url/coverage?coverage={"CA": "Victoria, BC"}' + ) def _registry(self, url): integration, is_new = create( - self._db, ExternalIntegration, protocol=ExternalIntegration.OPDS_REGISTRATION, goal=ExternalIntegration.DISCOVERY_GOAL + self._db, + ExternalIntegration, + protocol=ExternalIntegration.OPDS_REGISTRATION, + goal=ExternalIntegration.DISCOVERY_GOAL, ) integration.url = url return RemoteRegistry(integration) @@ -240,7 +319,10 @@ def _registries(self, urls, monkeypatch): integrations = [] for url in urls: integration, is_new = create( - self._db, ExternalIntegration, protocol=ExternalIntegration.OPDS_REGISTRATION, goal=ExternalIntegration.DISCOVERY_GOAL + self._db, + ExternalIntegration, + protocol=ExternalIntegration.OPDS_REGISTRATION, + goal=ExternalIntegration.DISCOVERY_GOAL, ) integration.url = url integrations.append(integration) @@ -250,12 +332,9 @@ def mock_for_protocol_and_goal(_db, protocol, goal): yield RemoteRegistry(integration) monkeypatch.setattr( - RemoteRegistry, - "for_protocol_and_goal", - mock_for_protocol_and_goal + RemoteRegistry, "for_protocol_and_goal", mock_for_protocol_and_goal ) - def test_is_zip(self): validator = GeographicValidator() assert validator.is_zip("06759", "US") == True @@ -275,9 +354,9 @@ def test_look_up_zip(self): us_zip_unformatted = validator.look_up_zip("06759", "US") assert isinstance(us_zip_unformatted, uszipcode.SimpleZipcode) us_zip_formatted = validator.look_up_zip("06759", "US", True) - assert us_zip_formatted == {'06759': 'Litchfield, CT'} + assert us_zip_formatted == {"06759": "Litchfield, CT"} ca_zip_unformatted = validator.look_up_zip("R2V", "CA") assert isinstance(ca_zip_unformatted, pypostalcode.PostalCode) ca_zip_formatted = validator.look_up_zip("R2V", "CA", True) - assert ca_zip_formatted == {'R2V': 'Winnipeg (Seven Oaks East), Manitoba'} + assert ca_zip_formatted == {"R2V": "Winnipeg (Seven Oaks East), Manitoba"} diff --git a/tests/admin/test_google_oauth_admin_authentication_provider.py b/tests/admin/test_google_oauth_admin_authentication_provider.py index ad048382af..16e1038308 100644 --- a/tests/admin/test_google_oauth_admin_authentication_provider.py +++ b/tests/admin/test_google_oauth_admin_authentication_provider.py @@ -1,12 +1,10 @@ import json -from oauth2client import client as GoogleClient -from core.testing import DatabaseTest -from core.util.problem_detail import ProblemDetail +from oauth2client import client as GoogleClient from api.admin.google_oauth_admin_authentication_provider import ( - GoogleOAuthAdminAuthenticationProvider, DummyGoogleClient, + GoogleOAuthAdminAuthenticationProvider, ) from api.admin.problem_details import INVALID_ADMIN_CREDENTIALS from core.model import ( @@ -16,34 +14,43 @@ ExternalIntegration, create, ) +from core.testing import DatabaseTest +from core.util.problem_detail import ProblemDetail -class TestGoogleOAuthAdminAuthenticationProvider(DatabaseTest): +class TestGoogleOAuthAdminAuthenticationProvider(DatabaseTest): def test_callback(self): super(TestGoogleOAuthAdminAuthenticationProvider, self).setup_method() auth_integration, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, + ) + self.google = GoogleOAuthAdminAuthenticationProvider( + auth_integration, "", test_mode=True ) - self.google = GoogleOAuthAdminAuthenticationProvider(auth_integration, "", test_mode=True) auth_integration.libraries += [self._default_library] ConfigurationSetting.for_library_and_externalintegration( self._db, "domains", self._default_library, auth_integration ).value = json.dumps(["nypl.org"]) # Returns a problem detail when Google returns an error. - error_response, redirect = self.google.callback(self._db, {'error' : 'access_denied'}) + error_response, redirect = self.google.callback( + self._db, {"error": "access_denied"} + ) assert True == isinstance(error_response, ProblemDetail) assert 400 == error_response.status_code - assert True == error_response.detail.endswith('access_denied') + assert True == error_response.detail.endswith("access_denied") assert None == redirect # Successful case creates a dict of admin details - success, redirect = self.google.callback(self._db, {'code' : 'abc'}) - assert 'example@nypl.org' == success['email'] - default_credentials = json.dumps({"id_token": {"hd": "nypl.org", "email": "example@nypl.org"}}) - assert default_credentials == success['credentials'] + success, redirect = self.google.callback(self._db, {"code": "abc"}) + assert "example@nypl.org" == success["email"] + default_credentials = json.dumps( + {"id_token": {"hd": "nypl.org", "email": "example@nypl.org"}} + ) + assert default_credentials == success["credentials"] assert GoogleOAuthAdminAuthenticationProvider.NAME == success["type"] [role] = success.get("roles") assert AdminRole.LIBRARIAN == role.get("role") @@ -51,9 +58,10 @@ def test_callback(self): # If domains are set, the admin's domain must match one of the domains. setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, "domains", self._default_library, auth_integration) + self._db, "domains", self._default_library, auth_integration + ) setting.value = json.dumps(["otherlibrary.org"]) - failure, ignore = self.google.callback(self._db, {'code' : 'abc'}) + failure, ignore = self.google.callback(self._db, {"code": "abc"}) assert INVALID_ADMIN_CREDENTIALS == failure setting.value = json.dumps(["nypl.org"]) @@ -62,26 +70,30 @@ def test_callback(self): class ExceptionRaisingClient(DummyGoogleClient): def step2_exchange(self, auth_code): raise GoogleClient.FlowExchangeError("mock error") + self.google.dummy_client = ExceptionRaisingClient() - error_response, redirect = self.google.callback(self._db, {'code' : 'abc'}) + error_response, redirect = self.google.callback(self._db, {"code": "abc"}) assert True == isinstance(error_response, ProblemDetail) assert 400 == error_response.status_code - assert True == error_response.detail.endswith('mock error') + assert True == error_response.detail.endswith("mock error") assert None == redirect def test_domains(self): super(TestGoogleOAuthAdminAuthenticationProvider, self).setup_method() auth_integration, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) auth_integration.libraries += [self._default_library] ConfigurationSetting.for_library_and_externalintegration( self._db, "domains", self._default_library, auth_integration ).value = json.dumps(["nypl.org"]) - google = GoogleOAuthAdminAuthenticationProvider(auth_integration, "", test_mode=True) + google = GoogleOAuthAdminAuthenticationProvider( + auth_integration, "", test_mode=True + ) assert ["nypl.org"] == list(google.domains.keys()) assert [self._default_library] == google.domains["nypl.org"] @@ -98,9 +110,10 @@ def test_domains(self): def test_staff_email(self): super(TestGoogleOAuthAdminAuthenticationProvider, self).setup_method() auth_integration, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, protocol=ExternalIntegration.GOOGLE_OAUTH, - goal=ExternalIntegration.ADMIN_AUTH_GOAL + goal=ExternalIntegration.ADMIN_AUTH_GOAL, ) nypl_admin = create(self._db, Admin, email="admin@nypl.org") @@ -108,7 +121,9 @@ def test_staff_email(self): # If no domains are set, the admin must already exist in the db # to be considered library staff. - google = GoogleOAuthAdminAuthenticationProvider(auth_integration, "", test_mode=True) + google = GoogleOAuthAdminAuthenticationProvider( + auth_integration, "", test_mode=True + ) assert True == google.staff_email(self._db, "admin@nypl.org") assert True == google.staff_email(self._db, "admin@bklynlibrary.org") @@ -118,7 +133,8 @@ def test_staff_email(self): # if the admin doesn't exist yet. auth_integration.libraries += [self._default_library] setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, "domains", self._default_library, auth_integration) + self._db, "domains", self._default_library, auth_integration + ) setting.value = json.dumps(["nypl.org"]) assert True == google.staff_email(self._db, "admin@nypl.org") assert True == google.staff_email(self._db, "admin@bklynlibrary.org") @@ -130,4 +146,3 @@ def test_staff_email(self): assert True == google.staff_email(self._db, "admin@bklynlibrary.org") assert True == google.staff_email(self._db, "someone@nypl.org") assert True == google.staff_email(self._db, "someone@bklynlibrary.org") - diff --git a/tests/admin/test_opds.py b/tests/admin/test_opds.py index 27a948f5eb..953281abd7 100644 --- a/tests/admin/test_opds.py +++ b/tests/admin/test_opds.py @@ -1,32 +1,26 @@ - - import feedparser from api.admin.opds import AdminAnnotator, AdminFeed from api.opds import AcquisitionFeed -from core.model import ( - Complaint, - DataSource, - ExternalIntegration, - Library, - Measurement, -) -from core.model.configuration import ExternalIntegrationLink from core.lane import Facets, Pagination +from core.model import Complaint, DataSource, ExternalIntegration, Library, Measurement +from core.model.configuration import ExternalIntegrationLink from core.opds import Annotator - from core.testing import DatabaseTest -class TestOPDS(DatabaseTest): +class TestOPDS(DatabaseTest): def links(self, entry, rel=None): - if 'feed' in entry: - entry = entry['feed'] - links = sorted(entry['links'], key=lambda x: (x['rel'], x.get('title'))) + if "feed" in entry: + entry = entry["feed"] + links = sorted(entry["links"], key=lambda x: (x["rel"], x.get("title"))) r = [] for l in links: - if (not rel or l['rel'] == rel or - (isinstance(rel, list) and l['rel'] in rel)): + if ( + not rel + or l["rel"] == rel + or (isinstance(rel, list) and l["rel"] in rel) + ): r.append(l) return r @@ -34,13 +28,21 @@ def test_feed_includes_staff_rating(self): work = self._work(with_open_access_download=True) lp = work.license_pools[0] staff_data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) - lp.identifier.add_measurement(staff_data_source, Measurement.RATING, 3, weight=1000) + lp.identifier.add_measurement( + staff_data_source, Measurement.RATING, 3, weight=1000 + ) - feed = AcquisitionFeed(self._db, "test", "url", [work], AdminAnnotator(None, self._default_library, test_mode=True)) - [entry] = feedparser.parse(str(feed))['entries'] - rating = entry['schema_rating'] - assert 3 == float(rating['schema:ratingvalue']) - assert Measurement.RATING == rating['additionaltype'] + feed = AcquisitionFeed( + self._db, + "test", + "url", + [work], + AdminAnnotator(None, self._default_library, test_mode=True), + ) + [entry] = feedparser.parse(str(feed))["entries"] + rating = entry["schema_rating"] + assert 3 == float(rating["schema:ratingvalue"]) + assert Measurement.RATING == rating["additionaltype"] def test_feed_includes_refresh_link(self): work = self._work(with_open_access_download=True) @@ -49,21 +51,41 @@ def test_feed_includes_refresh_link(self): self._db.commit() # If the metadata wrangler isn't configured, the link is left out. - feed = AcquisitionFeed(self._db, "test", "url", [work], AdminAnnotator(None, self._default_library, test_mode=True)) - [entry] = feedparser.parse(str(feed))['entries'] - assert ([] == - [x for x in entry['links'] if x['rel'] == "http://librarysimplified.org/terms/rel/refresh"]) + feed = AcquisitionFeed( + self._db, + "test", + "url", + [work], + AdminAnnotator(None, self._default_library, test_mode=True), + ) + [entry] = feedparser.parse(str(feed))["entries"] + assert [] == [ + x + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/refresh" + ] # If we configure a metadata wrangler integration, the link appears. integration = self._external_integration( ExternalIntegration.METADATA_WRANGLER, goal=ExternalIntegration.METADATA_GOAL, - settings={ ExternalIntegration.URL: "http://metadata" }, - password="pw") + settings={ExternalIntegration.URL: "http://metadata"}, + password="pw", + ) integration.collections += [self._default_collection] - feed = AcquisitionFeed(self._db, "test", "url", [work], AdminAnnotator(None, self._default_library, test_mode=True)) - [entry] = feedparser.parse(str(feed))['entries'] - [refresh_link] = [x for x in entry['links'] if x['rel'] == "http://librarysimplified.org/terms/rel/refresh"] + feed = AcquisitionFeed( + self._db, + "test", + "url", + [work], + AdminAnnotator(None, self._default_library, test_mode=True), + ) + [entry] = feedparser.parse(str(feed))["entries"] + [refresh_link] = [ + x + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/refresh" + ] assert lp.identifier.identifier in refresh_link["href"] def test_feed_includes_suppress_link(self): @@ -72,30 +94,64 @@ def test_feed_includes_suppress_link(self): lp.suppressed = False self._db.commit() - feed = AcquisitionFeed(self._db, "test", "url", [work], AdminAnnotator(None, self._default_library, test_mode=True)) - [entry] = feedparser.parse(str(feed))['entries'] - [suppress_link] = [x for x in entry['links'] if x['rel'] == "http://librarysimplified.org/terms/rel/hide"] + feed = AcquisitionFeed( + self._db, + "test", + "url", + [work], + AdminAnnotator(None, self._default_library, test_mode=True), + ) + [entry] = feedparser.parse(str(feed))["entries"] + [suppress_link] = [ + x + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/hide" + ] assert lp.identifier.identifier in suppress_link["href"] - unsuppress_links = [x for x in entry['links'] if x['rel'] == "http://librarysimplified.org/terms/rel/restore"] + unsuppress_links = [ + x + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/restore" + ] assert 0 == len(unsuppress_links) lp.suppressed = True self._db.commit() - feed = AcquisitionFeed(self._db, "test", "url", [work], AdminAnnotator(None, self._default_library, test_mode=True)) - [entry] = feedparser.parse(str(feed))['entries'] - [unsuppress_link] = [x for x in entry['links'] if x['rel'] == "http://librarysimplified.org/terms/rel/restore"] + feed = AcquisitionFeed( + self._db, + "test", + "url", + [work], + AdminAnnotator(None, self._default_library, test_mode=True), + ) + [entry] = feedparser.parse(str(feed))["entries"] + [unsuppress_link] = [ + x + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/restore" + ] assert lp.identifier.identifier in unsuppress_link["href"] - suppress_links = [x for x in entry['links'] if x['rel'] == "http://librarysimplified.org/terms/rel/hide"] + suppress_links = [ + x + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/hide" + ] assert 0 == len(suppress_links) def test_feed_includes_edit_link(self): work = self._work(with_open_access_download=True) lp = work.license_pools[0] - feed = AcquisitionFeed(self._db, "test", "url", [work], AdminAnnotator(None, self._default_library, test_mode=True)) - [entry] = feedparser.parse(str(feed))['entries'] - [edit_link] = [x for x in entry['links'] if x['rel'] == "edit"] + feed = AcquisitionFeed( + self._db, + "test", + "url", + [work], + AdminAnnotator(None, self._default_library, test_mode=True), + ) + [entry] = feedparser.parse(str(feed))["entries"] + [edit_link] = [x for x in entry["links"] if x["rel"] == "edit"] assert lp.identifier.identifier in edit_link["href"] def test_feed_includes_change_cover_link(self): @@ -103,16 +159,28 @@ def test_feed_includes_change_cover_link(self): lp = work.license_pools[0] library = self._default_library - feed = AcquisitionFeed(self._db, "test", "url", [work], AdminAnnotator(None, library, test_mode=True)) - [entry] = feedparser.parse(str(feed))['entries'] + feed = AcquisitionFeed( + self._db, + "test", + "url", + [work], + AdminAnnotator(None, library, test_mode=True), + ) + [entry] = feedparser.parse(str(feed))["entries"] # Since there's no storage integration, the change cover link isn't included. - assert [] == [x for x in entry['links'] if x['rel'] == "http://librarysimplified.org/terms/rel/change_cover"] + assert [] == [ + x + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/change_cover" + ] # There is now a covers storage integration that is linked to the external # integration for a collection that the work is in. It will use that # covers mirror and the change cover link is included. - storage = self._external_integration(ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL) + storage = self._external_integration( + ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL + ) storage.username = "user" storage.password = "pass" @@ -121,20 +189,29 @@ def test_feed_includes_change_cover_link(self): external_integration_link = self._external_integration_link( integration=collection._external_integration, other_integration=storage, - purpose=purpose + purpose=purpose, ) library.collections.append(collection) work = self._work(with_open_access_download=True, collection=collection) lp = work.license_pools[0] - feed = AcquisitionFeed(self._db, "test", "url", [work], AdminAnnotator(None, library, test_mode=True)) - [entry] = feedparser.parse(str(feed))['entries'] + feed = AcquisitionFeed( + self._db, + "test", + "url", + [work], + AdminAnnotator(None, library, test_mode=True), + ) + [entry] = feedparser.parse(str(feed))["entries"] - [change_cover_link] = [x for x in entry['links'] if x['rel'] == "http://librarysimplified.org/terms/rel/change_cover"] + [change_cover_link] = [ + x + for x in entry["links"] + if x["rel"] == "http://librarysimplified.org/terms/rel/change_cover" + ] assert lp.identifier.identifier in change_cover_link["href"] def test_complaints_feed(self): - """Test the ability to show a paginated feed of works with complaints. - """ + """Test the ability to show a paginated feed of works with complaints.""" type = iter(Complaint.VALID_TYPES) type1 = next(type) @@ -144,42 +221,50 @@ def test_complaints_feed(self): "fiction work with complaint", language="eng", fiction=True, - with_open_access_download=True) + with_open_access_download=True, + ) work1_complaint1 = self._complaint( work1.license_pools[0], type1, "work1 complaint1 source", - "work1 complaint1 detail") + "work1 complaint1 detail", + ) work1_complaint2 = self._complaint( work1.license_pools[0], type1, "work1 complaint2 source", - "work1 complaint2 detail") + "work1 complaint2 detail", + ) work1_complaint3 = self._complaint( work1.license_pools[0], type2, "work1 complaint3 source", - "work1 complaint3 detail") + "work1 complaint3 detail", + ) work2 = self._work( "nonfiction work with complaint", language="eng", fiction=False, - with_open_access_download=True) + with_open_access_download=True, + ) work2_complaint1 = self._complaint( work2.license_pools[0], type2, "work2 complaint1 source", - "work2 complaint1 detail") + "work2 complaint1 detail", + ) work3 = self._work( "fiction work without complaint", language="eng", fiction=True, - with_open_access_download=True) + with_open_access_download=True, + ) work4 = self._work( "nonfiction work without complaint", language="eng", fiction=False, - with_open_access_download=True) + with_open_access_download=True, + ) facets = Facets.default(self._default_library) pagination = Pagination(size=1) @@ -187,42 +272,50 @@ def test_complaints_feed(self): def make_page(pagination): return AdminFeed.complaints( - library=self._default_library, title="Complaints", - url=self._url, annotator=annotator, - pagination=pagination + library=self._default_library, + title="Complaints", + url=self._url, + annotator=annotator, + pagination=pagination, ) first_page = make_page(pagination) parsed = feedparser.parse(str(first_page)) - assert 1 == len(parsed['entries']) - assert work1.title == parsed['entries'][0]['title'] + assert 1 == len(parsed["entries"]) + assert work1.title == parsed["entries"][0]["title"] # Verify that the entry has acquisition links. - links = parsed['entries'][0]['links'] - open_access_links = [l for l in links if l['rel'] == "http://opds-spec.org/acquisition/open-access"] + links = parsed["entries"][0]["links"] + open_access_links = [ + l + for l in links + if l["rel"] == "http://opds-spec.org/acquisition/open-access" + ] assert 1 == len(open_access_links) # Make sure the links are in place. - [start] = self.links(parsed, 'start') - assert annotator.groups_url(None) == start['href'] - assert annotator.top_level_title() == start['title'] + [start] = self.links(parsed, "start") + assert annotator.groups_url(None) == start["href"] + assert annotator.top_level_title() == start["title"] - [up] = self.links(parsed, 'up') - assert annotator.groups_url(None) == up['href'] - assert annotator.top_level_title() == up['title'] + [up] = self.links(parsed, "up") + assert annotator.groups_url(None) == up["href"] + assert annotator.top_level_title() == up["title"] - [next_link] = self.links(parsed, 'next') - assert annotator.complaints_url(facets, pagination.next_page) == next_link['href'] + [next_link] = self.links(parsed, "next") + assert ( + annotator.complaints_url(facets, pagination.next_page) == next_link["href"] + ) # This was the first page, so no previous link. - assert [] == self.links(parsed, 'previous') + assert [] == self.links(parsed, "previous") # Now get the second page and make sure it has a 'previous' link. second_page = make_page(pagination.next_page) parsed = feedparser.parse(str(second_page)) - [previous] = self.links(parsed, 'previous') - assert annotator.complaints_url(facets, pagination) == previous['href'] - assert 1 == len(parsed['entries']) - assert work2.title == parsed['entries'][0]['title'] + [previous] = self.links(parsed, "previous") + assert annotator.complaints_url(facets, pagination) == previous["href"] + assert 1 == len(parsed["entries"]) + assert work2.title == parsed["entries"][0]["title"] def test_suppressed_feed(self): # Test the ability to show a paginated feed of suppressed works. @@ -245,51 +338,52 @@ def test_suppressed_feed(self): def make_page(pagination): return AdminFeed.suppressed( - _db=self._db, title="Hidden works", - url=self._url, annotator=annotator, - pagination=pagination + _db=self._db, + title="Hidden works", + url=self._url, + annotator=annotator, + pagination=pagination, ) first_page = make_page(pagination) parsed = feedparser.parse(str(first_page)) - assert 1 == len(parsed['entries']) - assert parsed['entries'][0].title in titles - titles.remove(parsed['entries'][0].title) + assert 1 == len(parsed["entries"]) + assert parsed["entries"][0].title in titles + titles.remove(parsed["entries"][0].title) [remaining_title] = titles # Make sure the links are in place. - [start] = self.links(parsed, 'start') - assert annotator.groups_url(None) == start['href'] - assert annotator.top_level_title() == start['title'] + [start] = self.links(parsed, "start") + assert annotator.groups_url(None) == start["href"] + assert annotator.top_level_title() == start["title"] - [up] = self.links(parsed, 'up') - assert annotator.groups_url(None) == up['href'] - assert annotator.top_level_title() == up['title'] + [up] = self.links(parsed, "up") + assert annotator.groups_url(None) == up["href"] + assert annotator.top_level_title() == up["title"] - [next_link] = self.links(parsed, 'next') - assert annotator.suppressed_url(pagination.next_page) == next_link['href'] + [next_link] = self.links(parsed, "next") + assert annotator.suppressed_url(pagination.next_page) == next_link["href"] # This was the first page, so no previous link. - assert [] == self.links(parsed, 'previous') + assert [] == self.links(parsed, "previous") # Now get the second page and make sure it has a 'previous' link. second_page = make_page(pagination.next_page) parsed = feedparser.parse(str(second_page)) - [previous] = self.links(parsed, 'previous') - assert annotator.suppressed_url(pagination) == previous['href'] - assert 1 == len(parsed['entries']) - assert remaining_title == parsed['entries'][0]['title'] + [previous] = self.links(parsed, "previous") + assert annotator.suppressed_url(pagination) == previous["href"] + assert 1 == len(parsed["entries"]) + assert remaining_title == parsed["entries"][0]["title"] # The third page is empty. third_page = make_page(pagination.next_page.next_page) parsed = feedparser.parse(str(third_page)) - [previous] = self.links(parsed, 'previous') - assert annotator.suppressed_url(pagination.next_page) == previous['href'] - assert 0 == len(parsed['entries']) + [previous] = self.links(parsed, "previous") + assert annotator.suppressed_url(pagination.next_page) == previous["href"] + assert 0 == len(parsed["entries"]) class MockAnnotator(AdminAnnotator): - def __init__(self, library): super(MockAnnotator, self).__init__(None, library, test_mode=True) @@ -302,21 +396,20 @@ def groups_url(self, lane): def complaints_url(self, facets, pagination): base = "http://complaints/" - sep = '?' + sep = "?" if facets: base += sep + facets.query_string - sep = '&' + sep = "&" if pagination: base += sep + pagination.query_string return base def suppressed_url(self, pagination): base = "http://complaints/" - sep = '?' + sep = "?" if pagination: base += sep + pagination.query_string return base def annotate_feed(self, feed): super(MockAnnotator, self).annotate_feed(feed) - diff --git a/tests/admin/test_password_admin_authentication_provider.py b/tests/admin/test_password_admin_authentication_provider.py index 66b51dabfb..13c54285a3 100644 --- a/tests/admin/test_password_admin_authentication_provider.py +++ b/tests/admin/test_password_admin_authentication_provider.py @@ -1,13 +1,12 @@ -from core.testing import DatabaseTest -from api.admin.problem_details import * -from api.admin.password_admin_authentication_provider import PasswordAdminAuthenticationProvider -from core.model import ( - Admin, - create, +from api.admin.password_admin_authentication_provider import ( + PasswordAdminAuthenticationProvider, ) +from api.admin.problem_details import * +from core.model import Admin, create +from core.testing import DatabaseTest -class TestPasswordAdminAuthenticationProvider(DatabaseTest): +class TestPasswordAdminAuthenticationProvider(DatabaseTest): def test_sign_in(self): password_auth = PasswordAdminAuthenticationProvider(None) @@ -21,32 +20,45 @@ def test_sign_in(self): admin3, ignore = create(self._db, Admin, email="admin3@nypl.org") # Both admins with passwords can sign in. - admin_details, redirect = password_auth.sign_in(self._db, dict(email="admin1@nypl.org", password="pass1", redirect="foo")) + admin_details, redirect = password_auth.sign_in( + self._db, dict(email="admin1@nypl.org", password="pass1", redirect="foo") + ) assert "admin1@nypl.org" == admin_details.get("email") assert PasswordAdminAuthenticationProvider.NAME == admin_details.get("type") assert "foo" == redirect - admin_details, redirect = password_auth.sign_in(self._db, dict(email="admin2@nypl.org", password="pass2", redirect="foo")) + admin_details, redirect = password_auth.sign_in( + self._db, dict(email="admin2@nypl.org", password="pass2", redirect="foo") + ) assert "admin2@nypl.org" == admin_details.get("email") assert PasswordAdminAuthenticationProvider.NAME == admin_details.get("type") assert "foo" == redirect # An admin can't sign in with an incorrect password.. - admin_details, redirect = password_auth.sign_in(self._db, dict(email="admin1@nypl.org", password="not-the-password", redirect="foo")) + admin_details, redirect = password_auth.sign_in( + self._db, + dict(email="admin1@nypl.org", password="not-the-password", redirect="foo"), + ) assert INVALID_ADMIN_CREDENTIALS == admin_details assert None == redirect # An admin can't sign in with a different admin's password. - admin_details, redirect = password_auth.sign_in(self._db, dict(email="admin1@nypl.org", password="pass2", redirect="foo")) + admin_details, redirect = password_auth.sign_in( + self._db, dict(email="admin1@nypl.org", password="pass2", redirect="foo") + ) assert INVALID_ADMIN_CREDENTIALS == admin_details assert None == redirect # The admin with no password can't sign in. - admin_details, redirect = password_auth.sign_in(self._db, dict(email="admin3@nypl.org", redirect="foo")) + admin_details, redirect = password_auth.sign_in( + self._db, dict(email="admin3@nypl.org", redirect="foo") + ) assert INVALID_ADMIN_CREDENTIALS == admin_details assert None == redirect # An admin email that's not in the db at all can't sign in. - admin_details, redirect = password_auth.sign_in(self._db, dict(email="admin4@nypl.org", password="pass1", redirect="foo")) + admin_details, redirect = password_auth.sign_in( + self._db, dict(email="admin4@nypl.org", password="pass1", redirect="foo") + ) assert INVALID_ADMIN_CREDENTIALS == admin_details assert None == redirect diff --git a/tests/admin/test_routes.py b/tests/admin/test_routes.py index 3599290018..7c0a6faf3a 100644 --- a/tests/admin/test_routes.py +++ b/tests/admin/test_routes.py @@ -1,51 +1,49 @@ import contextlib import logging +import os + import flask from flask import Response from werkzeug.exceptions import MethodNotAllowed -import os - -from core.app_server import ErrorHandler -from core.model import ( - Admin, - ConfigurationSetting, - get_one_or_create -) from api import app from api import routes as api_routes +from api.admin import routes +from api.admin.controller import AdminController, setup_admin_controllers +from api.admin.problem_details import * from api.config import Configuration from api.controller import CirculationManager -from api.admin.controller import AdminController -from api.admin.controller import setup_admin_controllers -from api.admin.problem_details import * -from api.routes import ( - exception_handler, - h as error_handler_object, -) -from api.admin import routes +from api.routes import exception_handler +from api.routes import h as error_handler_object +from core.app_server import ErrorHandler +from core.model import Admin, ConfigurationSetting, get_one_or_create + from ..test_controller import ControllerTest from ..test_routes import ( MockApp, - MockManager, MockController, + MockManager, RouteTest, RouteTestFixtures, ) + class MockAdminApp(object): """Pretends to be a Flask application with a configured CirculationManager and Admin routes. """ + def __init__(self): self.manager = MockAdminManager() + class MockAdminManager(MockManager): def __getattr__(self, controller_name): return self._cache.setdefault( controller_name, MockAdminController(controller_name) ) + class MockAdminController(MockController): AUTHENTICATED_ADMIN = "i am a mock admin" @@ -59,8 +57,7 @@ def authenticated_admin_from_request(self): return INVALID_ADMIN_CREDENTIALS else: return Response( - "authenticated_admin_from_request called without authorizing", - 401 + "authenticated_admin_from_request called without authorizing", 401 ) def get_csrf_token(self): @@ -113,11 +110,11 @@ def setup_method(self): self.manager = app.manager self.original_app = self.routes.app self.original_api_app = self.api_routes.app - self.resolver = self.original_app.url_map.bind('', '/') + self.resolver = self.original_app.url_map.bind("", "/") # For convenience, set self.controller to a specific controller # whose routes are being tested. - controller_name = getattr(self, 'CONTROLLER_NAME', None) + controller_name = getattr(self, "CONTROLLER_NAME", None) if controller_name: self.controller = getattr(self.manager, controller_name) @@ -145,12 +142,14 @@ def assert_authenticated_request_calls(self, url, method, *args, **kwargs): """ authentication_required = kwargs.pop("authentication_required", True) - http_method = kwargs.pop('http_method', 'GET') + http_method = kwargs.pop("http_method", "GET") response = self.request(url, http_method) if authentication_required: assert 401 == response.status_code - assert ("authenticated_admin_from_request called without authorizing" == - response.get_data(as_text=True)) + assert ( + "authenticated_admin_from_request called without authorizing" + == response.get_data(as_text=True) + ) else: assert 200 == response.status_code @@ -158,10 +157,10 @@ def assert_authenticated_request_calls(self, url, method, *args, **kwargs): # will succeed, and try again. self.manager.admin_sign_in_controller.authenticated = True try: - kwargs['http_method'] = http_method + kwargs["http_method"] = http_method # The file response case is specific to the bulk circulation # events route where a CSV file is returned. - if kwargs.get('file_response', None) is not None: + if kwargs.get("file_response", None) is not None: self.assert_file_response(url, *args, **kwargs) else: self.assert_request_calls(url, method, *args, **kwargs) @@ -171,10 +170,10 @@ def assert_authenticated_request_calls(self, url, method, *args, **kwargs): self.manager.admin_sign_in_controller.authenticated = False def assert_file_response(self, url, *args, **kwargs): - http_method = kwargs.pop('http_method', 'GET') + http_method = kwargs.pop("http_method", "GET") response = self.request(url, http_method) - assert response.headers['Content-type'] == 'text/csv' + assert response.headers["Content-type"] == "text/csv" def assert_redirect_call(self, url, *args, **kwargs): @@ -182,7 +181,7 @@ def assert_redirect_call(self, url, *args, **kwargs): # is authenticated and there is a csrf token. self.manager.admin_sign_in_controller.csrf_token = True self.manager.admin_sign_in_controller.authenticated = True - http_method = kwargs.pop('http_method', 'GET') + http_method = kwargs.pop("http_method", "GET") response = self.request(url, http_method) # A Flask template string is returned. @@ -226,41 +225,44 @@ class TestAdminSignIn(AdminRouteTest): CONTROLLER_NAME = "admin_sign_in_controller" def test_google_auth_callback(self): - url = '/admin/GoogleAuth/callback' + url = "/admin/GoogleAuth/callback" self.assert_request_calls(url, self.controller.redirect_after_google_sign_in) def test_sign_in_with_password(self): - url = '/admin/sign_in_with_password' - self.assert_request_calls(url, self.controller.password_sign_in, http_method='POST') + url = "/admin/sign_in_with_password" + self.assert_request_calls( + url, self.controller.password_sign_in, http_method="POST" + ) - self.assert_supported_methods(url, 'POST') + self.assert_supported_methods(url, "POST") def test_sign_in(self): - url = '/admin/sign_in' + url = "/admin/sign_in" self.assert_request_calls(url, self.controller.sign_in) def test_sign_out(self): - url = '/admin/sign_out' + url = "/admin/sign_out" self.assert_authenticated_request_calls(url, self.controller.sign_out) def test_change_password(self): - url = '/admin/change_password' + url = "/admin/change_password" self.assert_authenticated_request_calls( - url, self.controller.change_password, http_method='POST' + url, self.controller.change_password, http_method="POST" ) - self.assert_supported_methods(url, 'POST') + self.assert_supported_methods(url, "POST") def test_sign_in_again(self): - url = '/admin/sign_in_again' + url = "/admin/sign_in_again" self.assert_redirect_call(url) def test_redirect(self): - url = '/admin' + url = "/admin" response = self.request(url) assert 302 == response.status_code assert "Redirecting..." in response.get_data(as_text=True) + class TestAdminWork(AdminRouteTest): CONTROLLER_NAME = "admin_work_controller" @@ -268,86 +270,113 @@ class TestAdminWork(AdminRouteTest): def test_details(self): url = "/admin/works//an/identifier" self.assert_authenticated_request_calls( - url, self.controller.details, '', 'an/identifier' + url, self.controller.details, "", "an/identifier" ) - self.assert_supported_methods(url, 'GET') + self.assert_supported_methods(url, "GET") def test_classifications(self): url = "/admin/works//an/identifier/classifications" self.assert_authenticated_request_calls( - url, self.controller.classifications, '', 'an/identifier' + url, self.controller.classifications, "", "an/identifier" ) - self.assert_supported_methods(url, 'GET') + self.assert_supported_methods(url, "GET") def test_preview_book_cover(self): url = "/admin/works//an/identifier/preview_book_cover" self.assert_authenticated_request_calls( - url, self.controller.preview_book_cover, '', 'an/identifier', - http_method='POST' + url, + self.controller.preview_book_cover, + "", + "an/identifier", + http_method="POST", ) def test_change_book_cover(self): url = "/admin/works//an/identifier/change_book_cover" self.assert_authenticated_request_calls( - url, self.controller.change_book_cover, '', - 'an/identifier', http_method='POST' + url, + self.controller.change_book_cover, + "", + "an/identifier", + http_method="POST", ) def test_complaints(self): url = "/admin/works//an/identifier/complaints" self.assert_authenticated_request_calls( - url, self.controller.complaints, '', 'an/identifier' + url, self.controller.complaints, "", "an/identifier" ) - self.assert_supported_methods(url, 'GET') + self.assert_supported_methods(url, "GET") def test_custom_lists(self): url = "/admin/works//an/identifier/lists" self.assert_authenticated_request_calls( - url, self.controller.custom_lists, '', 'an/identifier', - http_method='POST' + url, + self.controller.custom_lists, + "", + "an/identifier", + http_method="POST", ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_edit(self): url = "/admin/works//an/identifier/edit" self.assert_authenticated_request_calls( - url, self.controller.edit, '', 'an/identifier', - http_method='POST' + url, + self.controller.edit, + "", + "an/identifier", + http_method="POST", ) def test_suppress(self): url = "/admin/works//an/identifier/suppress" self.assert_authenticated_request_calls( - url, self.controller.suppress, '', 'an/identifier', - http_method='POST' + url, + self.controller.suppress, + "", + "an/identifier", + http_method="POST", ) def test_unsuppress(self): url = "/admin/works//an/identifier/unsuppress" self.assert_authenticated_request_calls( - url, self.controller.unsuppress, '', 'an/identifier', - http_method='POST' + url, + self.controller.unsuppress, + "", + "an/identifier", + http_method="POST", ) def test_refresh_metadata(self): url = "/admin/works//an/identifier/refresh" self.assert_authenticated_request_calls( - url, self.controller.refresh_metadata, '', 'an/identifier', - http_method='POST' + url, + self.controller.refresh_metadata, + "", + "an/identifier", + http_method="POST", ) def test_resolve_complaints(self): url = "/admin/works//an/identifier/resolve_complaints" self.assert_authenticated_request_calls( - url, self.controller.resolve_complaints, '', 'an/identifier', - http_method='POST' + url, + self.controller.resolve_complaints, + "", + "an/identifier", + http_method="POST", ) def test_edit_classifications(self): url = "/admin/works//an/identifier/edit_classifications" self.assert_authenticated_request_calls( - url, self.controller.edit_classifications, '', 'an/identifier', - http_method='POST' + url, + self.controller.edit_classifications, + "", + "an/identifier", + http_method="POST", ) def test_roles(self): @@ -366,6 +395,7 @@ def test_right_status(self): url = "/admin/rights_status" self.assert_request_calls(url, self.controller.rights_status) + class TestAdminFeed(AdminRouteTest): CONTROLLER_NAME = "admin_feed_controller" @@ -382,6 +412,7 @@ def test_genres(self): url = "/admin/genres" self.assert_authenticated_request_calls(url, self.controller.genres) + class TestAdminDashboard(AdminRouteTest): CONTROLLER_NAME = "admin_dashboard_controller" @@ -389,8 +420,7 @@ class TestAdminDashboard(AdminRouteTest): def test_bulk_circulation_events(self): url = "/admin/bulk_circulation_events" self.assert_authenticated_request_calls( - url, self.controller.bulk_circulation_events, - file_response=True + url, self.controller.bulk_circulation_events, file_response=True ) def test_circulation_events(self): @@ -401,23 +431,23 @@ def test_stats(self): url = "/admin/stats" self.assert_authenticated_request_calls(url, self.controller.stats) + class TestAdminLibrarySettings(AdminRouteTest): CONTROLLER_NAME = "admin_library_settings_controller" def test_process_libraries(self): url = "/admin/libraries" - self.assert_authenticated_request_calls( - url, self.controller.process_libraries - ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_authenticated_request_calls(url, self.controller.process_libraries) + self.assert_supported_methods(url, "GET", "POST") def test_delete(self): url = "/admin/library/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminCollectionSettings(AdminRouteTest): @@ -428,14 +458,15 @@ def test_process_get(self): self.assert_authenticated_request_calls( url, self.controller.process_collections ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_process_post(self): url = "/admin/collection/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminCollectionSelfTests(AdminRouteTest): @@ -444,9 +475,10 @@ class TestAdminCollectionSelfTests(AdminRouteTest): def test_process_collection_self_tests(self): url = "/admin/collection_self_tests/" self.assert_authenticated_request_calls( - url, self.controller.process_collection_self_tests, '' + url, self.controller.process_collection_self_tests, "" ) + class TestAdminCollectionLibraryRegistrations(AdminRouteTest): CONTROLLER_NAME = "admin_collection_library_registrations_controller" @@ -456,7 +488,8 @@ def test_process_collection_library_registrations(self): self.assert_authenticated_request_calls( url, self.controller.process_collection_library_registrations ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") + class TestAdminAuthServices(AdminRouteTest): @@ -467,14 +500,15 @@ def test_process_admin_auth_services(self): self.assert_authenticated_request_calls( url, self.controller.process_admin_auth_services ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/admin_auth_service/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminIndividualAdminSettings(AdminRouteTest): @@ -485,14 +519,15 @@ def test_process_individual_admins(self): self.assert_authenticated_request_calls( url, self.controller.process_individual_admins ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/individual_admin/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminPatronAuthServices(AdminRouteTest): @@ -503,14 +538,15 @@ def test_process_patron_auth_services(self): self.assert_authenticated_request_calls( url, self.controller.process_patron_auth_services ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/patron_auth_service/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminPatronAuthServicesSelfTests(AdminRouteTest): @@ -519,9 +555,10 @@ class TestAdminPatronAuthServicesSelfTests(AdminRouteTest): def test_process_patron_auth_service_self_tests(self): url = "/admin/patron_auth_service_self_tests/" self.assert_authenticated_request_calls( - url, self.controller.process_patron_auth_service_self_tests, '' + url, self.controller.process_patron_auth_service_self_tests, "" ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") + class TestAdminPatron(AdminRouteTest): @@ -530,16 +567,16 @@ class TestAdminPatron(AdminRouteTest): def test_lookup_patron(self): url = "/admin/manage_patrons" self.assert_authenticated_request_calls( - url, self.controller.lookup_patron, http_method='POST' + url, self.controller.lookup_patron, http_method="POST" ) - self.assert_supported_methods(url, 'POST') + self.assert_supported_methods(url, "POST") def test_reset_adobe_id(self): url = "/admin/manage_patrons/reset_adobe_id" self.assert_authenticated_request_calls( - url, self.controller.reset_adobe_id, http_method='POST' + url, self.controller.reset_adobe_id, http_method="POST" ) - self.assert_supported_methods(url, 'POST') + self.assert_supported_methods(url, "POST") class TestAdminMetadataServices(AdminRouteTest): @@ -551,14 +588,15 @@ def test_process_metadata_services(self): self.assert_authenticated_request_calls( url, self.controller.process_metadata_services ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/metadata_service/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminAnalyticsServices(AdminRouteTest): @@ -569,14 +607,15 @@ def test_process_analytics_services(self): self.assert_authenticated_request_calls( url, self.controller.process_analytics_services ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/analytics_service/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminCDNServices(AdminRouteTest): @@ -587,14 +626,15 @@ def test_process_cdn_services(self): self.assert_authenticated_request_calls( url, self.controller.process_cdn_services ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/cdn_service/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminSearchServices(AdminRouteTest): @@ -602,17 +642,16 @@ class TestAdminSearchServices(AdminRouteTest): def test_process_services(self): url = "/admin/search_services" - self.assert_authenticated_request_calls( - url, self.controller.process_services - ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_authenticated_request_calls(url, self.controller.process_services) + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/search_service/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminSearchServicesSelfTests(AdminRouteTest): @@ -621,9 +660,10 @@ class TestAdminSearchServicesSelfTests(AdminRouteTest): def test_process_search_service_self_tests(self): url = "/admin/search_service_self_tests/" self.assert_authenticated_request_calls( - url, self.controller.process_search_service_self_tests, '' + url, self.controller.process_search_service_self_tests, "" ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") + class TestAdminStorageServices(AdminRouteTest): @@ -631,17 +671,16 @@ class TestAdminStorageServices(AdminRouteTest): def test_process_services(self): url = "/admin/storage_services" - self.assert_authenticated_request_calls( - url, self.controller.process_services - ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_authenticated_request_calls(url, self.controller.process_services) + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/storage_service/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminCatalogServices(AdminRouteTest): @@ -652,14 +691,15 @@ def test_process_catalog_services(self): self.assert_authenticated_request_calls( url, self.controller.process_catalog_services ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/catalog_service/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminDiscoveryServices(AdminRouteTest): @@ -670,14 +710,15 @@ def test_process_discovery_services(self): self.assert_authenticated_request_calls( url, self.controller.process_discovery_services ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/discovery_service/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminSitewideServices(AdminRouteTest): @@ -685,17 +726,16 @@ class TestAdminSitewideServices(AdminRouteTest): def test_process_services(self): url = "/admin/sitewide_settings" - self.assert_authenticated_request_calls( - url, self.controller.process_services - ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_authenticated_request_calls(url, self.controller.process_services) + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/sitewide_setting/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminLoggingServices(AdminRouteTest): @@ -703,17 +743,16 @@ class TestAdminLoggingServices(AdminRouteTest): def test_process_services(self): url = "/admin/logging_services" - self.assert_authenticated_request_calls( - url, self.controller.process_services - ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_authenticated_request_calls(url, self.controller.process_services) + self.assert_supported_methods(url, "GET", "POST") def test_process_delete(self): url = "/admin/logging_service/" self.assert_authenticated_request_calls( - url, self.controller.process_delete, '', http_method='DELETE' + url, self.controller.process_delete, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") + class TestAdminDiscoveryServiceLibraryRegistrations(AdminRouteTest): @@ -724,7 +763,8 @@ def test_process_discovery_service_library_registrations(self): self.assert_authenticated_request_calls( url, self.controller.process_discovery_service_library_registrations ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") + class TestAdminCustomListsServices(AdminRouteTest): @@ -732,17 +772,15 @@ class TestAdminCustomListsServices(AdminRouteTest): def test_custom_lists(self): url = "/admin/custom_lists" - self.assert_authenticated_request_calls( - url, self.controller.custom_lists - ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_authenticated_request_calls(url, self.controller.custom_lists) + self.assert_supported_methods(url, "GET", "POST") def test_custom_list(self): url = "/admin/custom_list/" self.assert_authenticated_request_calls( - url, self.controller.custom_list, '' + url, self.controller.custom_list, "" ) - self.assert_supported_methods(url, 'GET', 'POST', 'DELETE') + self.assert_supported_methods(url, "GET", "POST", "DELETE") class TestAdminLanes(AdminRouteTest): @@ -752,42 +790,43 @@ class TestAdminLanes(AdminRouteTest): def test_lanes(self): url = "/admin/lanes" self.assert_authenticated_request_calls(url, self.controller.lanes) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_lane(self): url = "/admin/lane/" self.assert_authenticated_request_calls( - url, self.controller.lane, '', http_method='DELETE' + url, self.controller.lane, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") def test_show_lane(self): url = "/admin/lane//show" self.assert_authenticated_request_calls( - url, self.controller.show_lane, '', http_method='POST' + url, self.controller.show_lane, "", http_method="POST" ) - self.assert_supported_methods(url, 'POST') + self.assert_supported_methods(url, "POST") def test_hide_lane(self): url = "/admin/lane//hide" self.assert_authenticated_request_calls( - url, self.controller.hide_lane, '', http_method='POST' + url, self.controller.hide_lane, "", http_method="POST" ) - self.assert_supported_methods(url, 'POST') + self.assert_supported_methods(url, "POST") def test_reset(self): url = "/admin/lanes/reset" self.assert_authenticated_request_calls( - url, self.controller.reset, http_method='POST' + url, self.controller.reset, http_method="POST" ) - self.assert_supported_methods(url, 'POST') + self.assert_supported_methods(url, "POST") def test_change_order(self): url = "/admin/lanes/change_order" self.assert_authenticated_request_calls( - url, self.controller.change_order, http_method='POST' + url, self.controller.change_order, http_method="POST" ) - self.assert_supported_methods(url, 'POST') + self.assert_supported_methods(url, "POST") + class TestTimestamps(AdminRouteTest): @@ -797,15 +836,14 @@ def test_diagnostics(self): url = "/admin/diagnostics" self.assert_authenticated_request_calls(url, self.controller.diagnostics) + class TestAdminView(AdminRouteTest): CONTROLLER_NAME = "admin_view_controller" def test_admin_view(self): url = "/admin/web/" - self.assert_request_calls( - url, self.controller, None, None, path=None - ) + self.assert_request_calls(url, self.controller, None, None, path=None) url = "/admin/web/collection/a/collection/book/a/book" self.assert_request_calls( @@ -813,19 +851,14 @@ def test_admin_view(self): ) url = "/admin/web/collection/a/collection" - self.assert_request_calls( - url, self.controller, "a/collection", None, path=None - ) + self.assert_request_calls(url, self.controller, "a/collection", None, path=None) url = "/admin/web/book/a/book" - self.assert_request_calls( - url, self.controller, None, "a/book", path=None - ) + self.assert_request_calls(url, self.controller, None, "a/book", path=None) url = "/admin/web/a/path" - self.assert_request_calls( - url, self.controller, None, None, path="a/path" - ) + self.assert_request_calls(url, self.controller, None, None, path="a/path") + class TestAdminStatic(AdminRouteTest): @@ -840,7 +873,7 @@ def test_static_file(self): os.path.join( os.path.abspath(os.path.dirname(__file__)), "../..", - "api/admin/node_modules/@thepalaceproject/circulation-admin/dist" + "api/admin/node_modules/@thepalaceproject/circulation-admin/dist", ) ) diff --git a/tests/admin/test_validator.py b/tests/admin/test_validator.py index 3a8638aa1d..b1957cd74b 100644 --- a/tests/admin/test_validator.py +++ b/tests/admin/test_validator.py @@ -1,10 +1,9 @@ from io import StringIO - from parameterized import parameterized from werkzeug.datastructures import MultiDict -from api.admin.validator import Validator, PatronAuthenticationValidatorFactory +from api.admin.validator import PatronAuthenticationValidatorFactory, Validator from api.config import Configuration from api.shared_collection import BaseSharedCollectionAPI from tests.admin.fixtures.dummy_validator import DummyAuthenticationProviderValidator @@ -17,18 +16,26 @@ def test_validate_email(self): # One valid input from form form = MultiDict([("help-email", valid)]) - response = Validator().validate_email(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_email( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response == None # One invalid input from form form = MultiDict([("help-email", invalid)]) - response = Validator().validate_email(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_email( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response.detail == '"invalid_format" is not a valid email address.' assert response.status_code == 400 # One valid and one invalid input from form - form = MultiDict([("help-email", valid), ("configuration_contact_email_address", invalid)]) - response = Validator().validate_email(Configuration.LIBRARY_SETTINGS, {"form": form}) + form = MultiDict( + [("help-email", valid), ("configuration_contact_email_address", invalid)] + ) + response = Validator().validate_email( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response.detail == '"invalid_format" is not a valid email address.' assert response.status_code == 400 @@ -42,18 +49,24 @@ def test_validate_email(self): assert response.status_code == 400 # Two valid in a list - form = MultiDict([('help-email', valid), ('help-email', 'valid2@email.com')]) - response = Validator().validate_email(Configuration.LIBRARY_SETTINGS, {"form": form}) + form = MultiDict([("help-email", valid), ("help-email", "valid2@email.com")]) + response = Validator().validate_email( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response == None # One valid and one empty in a list - form = MultiDict([('help-email', valid), ('help-email', '')]) - response = Validator().validate_email(Configuration.LIBRARY_SETTINGS, {"form": form}) + form = MultiDict([("help-email", valid), ("help-email", "")]) + response = Validator().validate_email( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response == None # One valid and one invalid in a list - form = MultiDict([('help-email', valid), ('help-email', invalid)]) - response = Validator().validate_email(Configuration.LIBRARY_SETTINGS, {"form": form}) + form = MultiDict([("help-email", valid), ("help-email", invalid)]) + response = Validator().validate_email( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response.detail == '"invalid_format" is not a valid email address.' assert response.status_code == 400 @@ -63,34 +76,61 @@ def test_validate_url(self): # Valid form = MultiDict([("help-web", valid)]) - response = Validator().validate_url(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_url( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response == None # Invalid form = MultiDict([("help-web", invalid)]) - response = Validator().validate_url(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_url( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response.detail == '"invalid_url" is not a valid URL.' assert response.status_code == 400 # One valid, one invalid form = MultiDict([("help-web", valid), ("terms-of-service", invalid)]) - response = Validator().validate_url(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_url( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response.detail == '"invalid_url" is not a valid URL.' assert response.status_code == 400 # Two valid in a list - form = MultiDict([(BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, "http://library1.com"), (BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, "http://library2.com")]) - response = Validator().validate_url(BaseSharedCollectionAPI.SETTINGS, {"form": form}) + form = MultiDict( + [ + (BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, "http://library1.com"), + (BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, "http://library2.com"), + ] + ) + response = Validator().validate_url( + BaseSharedCollectionAPI.SETTINGS, {"form": form} + ) assert response == None # One valid and one empty in a list - form = MultiDict([(BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, "http://library1.com"), (BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, "")]) - response = Validator().validate_url(BaseSharedCollectionAPI.SETTINGS, {"form": form}) + form = MultiDict( + [ + (BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, "http://library1.com"), + (BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, ""), + ] + ) + response = Validator().validate_url( + BaseSharedCollectionAPI.SETTINGS, {"form": form} + ) assert response == None # One valid and one invalid in a list - form = MultiDict([(BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, "http://library1.com"), (BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, invalid)]) - response = Validator().validate_url(BaseSharedCollectionAPI.SETTINGS, {"form": form}) + form = MultiDict( + [ + (BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, "http://library1.com"), + (BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, invalid), + ] + ) + response = Validator().validate_url( + BaseSharedCollectionAPI.SETTINGS, {"form": form} + ) assert response.detail == '"invalid_url" is not a valid URL.' assert response.status_code == 400 @@ -100,36 +140,54 @@ def test_validate_number(self): # Valid form = MultiDict([("hold_limit", valid)]) - response = Validator().validate_number(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_number( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response == None # Invalid form = MultiDict([("hold_limit", invalid)]) - response = Validator().validate_number(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_number( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response.detail == '"ten" is not a number.' assert response.status_code == 400 # One valid, one invalid form = MultiDict([("hold_limit", valid), ("loan_limit", invalid)]) - response = Validator().validate_number(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_number( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response.detail == '"ten" is not a number.' assert response.status_code == 400 # Invalid: below minimum form = MultiDict([("hold_limit", -5)]) - response = Validator().validate_number(Configuration.LIBRARY_SETTINGS, {"form": form}) - assert response.detail == 'Maximum number of books a patron can have on hold at once must be greater than 0.' + response = Validator().validate_number( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) + assert ( + response.detail + == "Maximum number of books a patron can have on hold at once must be greater than 0." + ) assert response.status_code == 400 # Valid: below maximum form = MultiDict([("minimum_featured_quality", ".9")]) - response = Validator().validate_number(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_number( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response == None # Invalid: above maximum form = MultiDict([("minimum_featured_quality", "2")]) - response = Validator().validate_number(Configuration.LIBRARY_SETTINGS, {"form": form}) - assert response.detail == "Minimum quality for books that show up in 'featured' lanes cannot be greater than 1." + response = Validator().validate_number( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) + assert ( + response.detail + == "Minimum quality for books that show up in 'featured' lanes cannot be greater than 1." + ) assert response.status_code == 400 def test_validate_language_code(self): @@ -138,67 +196,101 @@ def test_validate_language_code(self): mixed = ["eng", "abc", "spa"] form = MultiDict([("large_collections", all_valid)]) - response = Validator().validate_language_code(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_language_code( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response == None form = MultiDict([("large_collections", all_invalid)]) - response = Validator().validate_language_code(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_language_code( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response.detail == '"abc" is not a valid language code.' assert response.status_code == 400 form = MultiDict([("large_collections", mixed)]) - response = Validator().validate_language_code(Configuration.LIBRARY_SETTINGS, {"form": form}) + response = Validator().validate_language_code( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response.detail == '"abc" is not a valid language code.' assert response.status_code == 400 - form = MultiDict([("large_collections", all_valid), ("small_collections", all_valid), ("tiny_collections", mixed)]) - response = Validator().validate_language_code(Configuration.LIBRARY_SETTINGS, {"form": form}) + form = MultiDict( + [ + ("large_collections", all_valid), + ("small_collections", all_valid), + ("tiny_collections", mixed), + ] + ) + response = Validator().validate_language_code( + Configuration.LIBRARY_SETTINGS, {"form": form} + ) assert response.detail == '"abc" is not a valid language code.' assert response.status_code == 400 def test_validate_image(self): def create_image_file(format_string): - image_data = '\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x01\x03\x00\x00\x00%\xdbV\xca\x00\x00\x00\x06PLTE\xffM\x00\x01\x01\x01\x8e\x1e\xe5\x1b\x00\x00\x00\x01tRNS\xcc\xd24V\xfd\x00\x00\x00\nIDATx\x9cc`\x00\x00\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82' + image_data = "\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x01\x03\x00\x00\x00%\xdbV\xca\x00\x00\x00\x06PLTE\xffM\x00\x01\x01\x01\x8e\x1e\xe5\x1b\x00\x00\x00\x01tRNS\xcc\xd24V\xfd\x00\x00\x00\nIDATx\x9cc`\x00\x00\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82" + class TestImageFile(StringIO): - headers = { "Content-Type": "image/" + format_string } + headers = {"Content-Type": "image/" + format_string} + return TestImageFile(image_data) return result [png, jpeg, gif, invalid] = [ - MultiDict([(Configuration.LOGO, create_image_file(x))]) for x in ["png", "jpeg", "gif", "abc"] + MultiDict([(Configuration.LOGO, create_image_file(x))]) + for x in ["png", "jpeg", "gif", "abc"] ] - png_response = Validator().validate_image(Configuration.LIBRARY_SETTINGS, {"files": png}) + png_response = Validator().validate_image( + Configuration.LIBRARY_SETTINGS, {"files": png} + ) assert png_response == None - jpeg_response = Validator().validate_image(Configuration.LIBRARY_SETTINGS, {"files": jpeg}) + jpeg_response = Validator().validate_image( + Configuration.LIBRARY_SETTINGS, {"files": jpeg} + ) assert jpeg_response == None - gif_response = Validator().validate_image(Configuration.LIBRARY_SETTINGS, {"files": gif}) + gif_response = Validator().validate_image( + Configuration.LIBRARY_SETTINGS, {"files": gif} + ) assert gif_response == None - abc_response = Validator().validate_image(Configuration.LIBRARY_SETTINGS, {"files": invalid}) - assert abc_response.detail == 'Upload for Logo image must be in GIF, PNG, or JPG format. (Upload was image/abc.)' + abc_response = Validator().validate_image( + Configuration.LIBRARY_SETTINGS, {"files": invalid} + ) + assert ( + abc_response.detail + == "Upload for Logo image must be in GIF, PNG, or JPG format. (Upload was image/abc.)" + ) assert abc_response.status_code == 400 def test_validate(self): called = [] + class Mock(Validator): def validate_email(self, settings, content): called.append("validate_email") + def validate_url(self, settings, content): called.append("validate_url") + def validate_number(self, settings, content): called.append("validate_number") + def validate_language_code(self, settings, content): called.append("validate_language_code") + def validate_image(self, settings, content): called.append("validate_image") + Mock().validate(Configuration.LIBRARY_SETTINGS, {}) assert called == [ - 'validate_email', - 'validate_url', - 'validate_number', - 'validate_language_code', - 'validate_image' + "validate_email", + "validate_url", + "validate_number", + "validate_language_code", + "validate_image", ] def test__is_url(self): @@ -220,10 +312,15 @@ def test__is_url(self): class PatronAuthenticationValidatorFactoryTest(object): - @parameterized.expand([ - ('validator_using_class_name', 'tests.admin.fixtures.dummy_validator'), - ('validator_using_factory_method', 'tests.admin.fixtures.dummy_validator_factory') - ]) + @parameterized.expand( + [ + ("validator_using_class_name", "tests.admin.fixtures.dummy_validator"), + ( + "validator_using_factory_method", + "tests.admin.fixtures.dummy_validator_factory", + ), + ] + ) def test_create_can_create(self, name, protocol): # Arrange factory = PatronAuthenticationValidatorFactory() diff --git a/tests/clever/test_clever.py b/tests/clever/test_clever.py index e29d84358a..d46a829b8f 100644 --- a/tests/clever/test_clever.py +++ b/tests/clever/test_clever.py @@ -1,20 +1,20 @@ -import os import datetime +import os from flask import request, url_for from api.clever import ( - CleverAuthenticationAPI, - UNSUPPORTED_CLEVER_USER_TYPE, CLEVER_NOT_ELIGIBLE, CLEVER_UNKNOWN_SCHOOL, + UNSUPPORTED_CLEVER_USER_TYPE, + CleverAuthenticationAPI, external_type_from_clever_grade, ) from api.problem_details import INVALID_CREDENTIALS from core.model import ExternalIntegration +from core.testing import DatabaseTest from core.util.datetime_helpers import utc_now from core.util.problem_detail import ProblemDetail -from core.testing import DatabaseTest class MockAPI(CleverAuthenticationAPI): @@ -46,7 +46,14 @@ def test_external_type_from_clever_grade(self): THEN: The matching external_type value should be returned, or None if the match fails """ for e_grade in [ - "InfantToddler", "Preschool", "PreKindergarten", "TransitionalKindergarten", "Kindergarten", "1", "2", "3" + "InfantToddler", + "Preschool", + "PreKindergarten", + "TransitionalKindergarten", + "Kindergarten", + "1", + "2", + "3", ]: assert external_type_from_clever_grade(e_grade) == "E" @@ -61,20 +68,24 @@ def test_external_type_from_clever_grade(self): class TestCleverAuthenticationAPI(DatabaseTest): - def setup_method(self): super(TestCleverAuthenticationAPI, self).setup_method() self.api = MockAPI(self._default_library, self.mock_integration) - os.environ['AUTOINITIALIZE'] = "False" + os.environ["AUTOINITIALIZE"] = "False" from api.app import app - del os.environ['AUTOINITIALIZE'] + + del os.environ["AUTOINITIALIZE"] self.app = app @property def mock_integration(self): """Make a fake ExternalIntegration that can be used to configure a CleverAuthenticationAPI""" - integration = self._external_integration(protocol="OAuth", goal=ExternalIntegration.PATRON_AUTH_GOAL, - username="fake_client_id", password="fake_client_secret") + integration = self._external_integration( + protocol="OAuth", + goal=ExternalIntegration.PATRON_AUTH_GOAL, + username="fake_client_id", + password="fake_client_secret", + ) integration.setting(MockAPI.OAUTH_TOKEN_EXPIRATION_DAYS).value = 20 return integration @@ -95,7 +106,10 @@ def test_remote_exchange_code_for_bearer_token(self): # Test success. self.api.queue_response(dict(access_token="a token")) with self.app.test_request_context("/"): - assert self.api.remote_exchange_code_for_bearer_token(self._db, "code") == "a token" + assert ( + self.api.remote_exchange_code_for_bearer_token(self._db, "code") + == "a token" + ) # Test failure. self.api.queue_response(None) @@ -113,46 +127,76 @@ def test_remote_exchange_payload(self): with self.app.test_request_context("/"): payload = self.api._remote_exchange_payload(self._db, "a code") - expect_uri = url_for("oauth_callback", - library_short_name=self._default_library.name, - _external=True) - assert 'authorization_code' == payload['grant_type'] - assert expect_uri == payload['redirect_uri'] - assert 'a code' == payload['code'] + expect_uri = url_for( + "oauth_callback", + library_short_name=self._default_library.name, + _external=True, + ) + assert "authorization_code" == payload["grant_type"] + assert expect_uri == payload["redirect_uri"] + assert "a code" == payload["code"] def test_remote_patron_lookup_unsupported_user_type(self): - self.api.queue_response(dict(type='district_admin', data=dict(id='1234'))) + self.api.queue_response(dict(type="district_admin", data=dict(id="1234"))) token = self.api.remote_patron_lookup("token") assert UNSUPPORTED_CLEVER_USER_TYPE == token def test_remote_patron_lookup_ineligible(self): - self.api.queue_response(dict(type='student', data=dict(id='1234'), links=[dict(rel='canonical', uri='test')])) - self.api.queue_response(dict(data=dict(school='1234', district='1234'))) - self.api.queue_response(dict(data=dict(nces_id='I am not Title I'))) + self.api.queue_response( + dict( + type="student", + data=dict(id="1234"), + links=[dict(rel="canonical", uri="test")], + ) + ) + self.api.queue_response(dict(data=dict(school="1234", district="1234"))) + self.api.queue_response(dict(data=dict(nces_id="I am not Title I"))) token = self.api.remote_patron_lookup("") assert CLEVER_NOT_ELIGIBLE == token def test_remote_patron_lookup_missing_nces_id(self): - self.api.queue_response(dict(type='student', data=dict(id='1234'), links=[dict(rel='canonical', uri='test')])) - self.api.queue_response(dict(data=dict(school='1234', district='1234'))) + self.api.queue_response( + dict( + type="student", + data=dict(id="1234"), + links=[dict(rel="canonical", uri="test")], + ) + ) + self.api.queue_response(dict(data=dict(school="1234", district="1234"))) self.api.queue_response(dict(data=dict())) token = self.api.remote_patron_lookup("") assert CLEVER_UNKNOWN_SCHOOL == token def test_remote_patron_unknown_student_grade(self): - self.api.queue_response(dict(type='student', data=dict(id='2'), links=[dict(rel='canonical', uri='test')])) - self.api.queue_response(dict(data=dict(school='1234', district='1234', name='Abcd', grade=""))) - self.api.queue_response(dict(data=dict(nces_id='44270647'))) + self.api.queue_response( + dict( + type="student", + data=dict(id="2"), + links=[dict(rel="canonical", uri="test")], + ) + ) + self.api.queue_response( + dict(data=dict(school="1234", district="1234", name="Abcd", grade="")) + ) + self.api.queue_response(dict(data=dict(nces_id="44270647"))) patrondata = self.api.remote_patron_lookup("token") assert patrondata.external_type is None def test_remote_patron_lookup_title_i(self): - self.api.queue_response(dict(type='student', data=dict(id='5678'), links=[dict(rel='canonical', uri='test')])) - self.api.queue_response(dict(data=dict(school='1234', district='1234', name='Abcd', grade="10"))) - self.api.queue_response(dict(data=dict(nces_id='44270647'))) + self.api.queue_response( + dict( + type="student", + data=dict(id="5678"), + links=[dict(rel="canonical", uri="test")], + ) + ) + self.api.queue_response( + dict(data=dict(school="1234", district="1234", name="Abcd", grade="10")) + ) + self.api.queue_response(dict(data=dict(nces_id="44270647"))) patrondata = self.api.remote_patron_lookup("token") assert patrondata.personal_name is None @@ -164,18 +208,36 @@ def test_remote_patron_lookup_free_lunch_status(self): def test_remote_patron_lookup_external_type(self): # Teachers have an external type of 'A' indicating all access. - self.api.queue_response(dict(type='teacher', data=dict(id='1'), links=[dict(rel='canonical', uri='test')])) - self.api.queue_response(dict(data=dict(school='1234', district='1234', name='Abcd'))) - self.api.queue_response(dict(data=dict(nces_id='44270647'))) + self.api.queue_response( + dict( + type="teacher", + data=dict(id="1"), + links=[dict(rel="canonical", uri="test")], + ) + ) + self.api.queue_response( + dict(data=dict(school="1234", district="1234", name="Abcd")) + ) + self.api.queue_response(dict(data=dict(nces_id="44270647"))) patrondata = self.api.remote_patron_lookup("teacher token") assert "A" == patrondata.external_type # Student type is based on grade def queue_student(grade): - self.api.queue_response(dict(type='student', data=dict(id='2'), links=[dict(rel='canonical', uri='test')])) - self.api.queue_response(dict(data=dict(school='1234', district='1234', name='Abcd', grade=grade))) - self.api.queue_response(dict(data=dict(nces_id='44270647'))) + self.api.queue_response( + dict( + type="student", + data=dict(id="2"), + links=[dict(rel="canonical", uri="test")], + ) + ) + self.api.queue_response( + dict( + data=dict(school="1234", district="1234", name="Abcd", grade=grade) + ) + ) + self.api.queue_response(dict(data=dict(nces_id="44270647"))) queue_student(grade="1") patrondata = self.api.remote_patron_lookup("token") @@ -192,16 +254,26 @@ def queue_student(grade): def test_oauth_callback_creates_patron(self): """Test a successful run of oauth_callback.""" self.api.queue_response(dict(access_token="bearer token")) - self.api.queue_response(dict(type='teacher', data=dict(id='1'), links=[dict(rel='canonical', uri='test')])) - self.api.queue_response(dict(data=dict(school='1234', district='1234', name='Abcd'))) - self.api.queue_response(dict(data=dict(nces_id='44270647'))) + self.api.queue_response( + dict( + type="teacher", + data=dict(id="1"), + links=[dict(rel="canonical", uri="test")], + ) + ) + self.api.queue_response( + dict(data=dict(school="1234", district="1234", name="Abcd")) + ) + self.api.queue_response(dict(data=dict(nces_id="44270647"))) with self.app.test_request_context("/"): response = self.api.oauth_callback(self._db, dict(code="teacher code")) credential, patron, patrondata = response # The bearer token was turned into a Credential. - expect_credential, ignore = self.api.create_token(self._db, patron, "bearer token") + expect_credential, ignore = self.api.create_token( + self._db, patron, "bearer token" + ) assert credential == expect_credential # Since the patron is a teacher, their external_type was set to 'A'. @@ -235,10 +307,13 @@ def test_external_authenticate_url(self): with self.app.test_request_context("/"): request.library = self._default_library params = my_api.external_authenticate_url("state", self._db) - expected_redirect_uri = url_for("oauth_callback", library_short_name=self._default_library.short_name, - _external=True) + expected_redirect_uri = url_for( + "oauth_callback", + library_short_name=self._default_library.short_name, + _external=True, + ) expected = ( - 'https://clever.com/oauth/authorize' - '?response_type=code&client_id=fake_client_id&redirect_uri=%s&state=state' + "https://clever.com/oauth/authorize" + "?response_type=code&client_id=fake_client_id&redirect_uri=%s&state=state" ) % expected_redirect_uri assert params == expected diff --git a/tests/conftest.py b/tests/conftest.py index 478630b649..ab585547ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ - # Pull in the session_fixture defined in core/testing.py # which does the database setup and initialization pytest_plugins = ["core.testing"] diff --git a/tests/lcp/database_test.py b/tests/lcp/database_test.py index 1e7117674b..3db2cfdddb 100644 --- a/tests/lcp/database_test.py +++ b/tests/lcp/database_test.py @@ -11,6 +11,5 @@ def setup_method(self): super(DatabaseTest, self).setup_method() self._integration = self._external_integration( - protocol=LCPAPI.NAME, - goal=ExternalIntegration.LICENSE_GOAL + protocol=LCPAPI.NAME, goal=ExternalIntegration.LICENSE_GOAL ) diff --git a/tests/lcp/fixtures.py b/tests/lcp/fixtures.py index 8f70ec75c5..31f3e80442 100644 --- a/tests/lcp/fixtures.py +++ b/tests/lcp/fixtures.py @@ -1,23 +1,27 @@ -EXISTING_BOOK_FILE_PATH = '/books/ebook.epub' -NOT_EXISTING_BOOK_FILE_PATH = '/books/notexistingbook.epub' +EXISTING_BOOK_FILE_PATH = "/books/ebook.epub" +NOT_EXISTING_BOOK_FILE_PATH = "/books/notexistingbook.epub" -BOOK_IDENTIFIER = 'EBOOK' +BOOK_IDENTIFIER = "EBOOK" -CONTENT_ENCRYPTION_KEY = '+RulyN2G8MfAahNEO/Xz0TwBT5xMzvbFFHqqWGPrO3M=' -PROTECTED_CONTENT_LOCATION = '/opt/readium/files/encrypted/1f162bc2-be6f-42a9-8153-96d675418ff1.epub' -PROTECTED_CONTENT_DISPOSITION = '1f162bc2-be6f-42a9-8153-96d675418ff1.epub' -PROTECTED_CONTENT_TYPE = 'application/epub+zip' +CONTENT_ENCRYPTION_KEY = "+RulyN2G8MfAahNEO/Xz0TwBT5xMzvbFFHqqWGPrO3M=" +PROTECTED_CONTENT_LOCATION = ( + "/opt/readium/files/encrypted/1f162bc2-be6f-42a9-8153-96d675418ff1.epub" +) +PROTECTED_CONTENT_DISPOSITION = "1f162bc2-be6f-42a9-8153-96d675418ff1.epub" +PROTECTED_CONTENT_TYPE = "application/epub+zip" PROTECTED_CONTENT_LENGTH = 798385 -PROTECTED_CONTENT_SHA256 = 'e058281cbc11bae29451e5e2c8003efa1164c3f6dde6dcc003c8bb79e2acb88f' +PROTECTED_CONTENT_SHA256 = ( + "e058281cbc11bae29451e5e2c8003efa1164c3f6dde6dcc003c8bb79e2acb88f" +) -LCPENCRYPT_NOT_EXISTING_DIRECTORY_RESULT = \ - '''Error opening input file, for more information type 'lcpencrypt -help' ; level 30 +LCPENCRYPT_NOT_EXISTING_DIRECTORY_RESULT = """Error opening input file, for more information type 'lcpencrypt -help' ; level 30 open {0}: no such file or directory -'''.format(NOT_EXISTING_BOOK_FILE_PATH) +""".format( + NOT_EXISTING_BOOK_FILE_PATH +) -LCPENCRYPT_FAILED_ENCRYPTION_RESULT = \ - '''{{ +LCPENCRYPT_FAILED_ENCRYPTION_RESULT = """{{ "content-id": "{0}", "content-encryption-key": null, "protected-content-location": "{1}", @@ -26,14 +30,11 @@ "protected-content-disposition": "{2}" }} Encryption was successful -'''.format( - BOOK_IDENTIFIER, - PROTECTED_CONTENT_LOCATION, - NOT_EXISTING_BOOK_FILE_PATH - ) +""".format( + BOOK_IDENTIFIER, PROTECTED_CONTENT_LOCATION, NOT_EXISTING_BOOK_FILE_PATH +) -LCPENCRYPT_SUCCESSFUL_ENCRYPTION_RESULT = \ - '''{{ +LCPENCRYPT_SUCCESSFUL_ENCRYPTION_RESULT = """{{ "content-id": "{0}", "content-encryption-key": "{1}", "protected-content-location": "{2}", @@ -43,22 +44,20 @@ "protected-content-type": "{6}" }} Encryption was successful -'''.format( - BOOK_IDENTIFIER, - CONTENT_ENCRYPTION_KEY, - PROTECTED_CONTENT_LOCATION, - PROTECTED_CONTENT_LENGTH, - PROTECTED_CONTENT_SHA256, - PROTECTED_CONTENT_DISPOSITION, - PROTECTED_CONTENT_TYPE - ) +""".format( + BOOK_IDENTIFIER, + CONTENT_ENCRYPTION_KEY, + PROTECTED_CONTENT_LOCATION, + PROTECTED_CONTENT_LENGTH, + PROTECTED_CONTENT_SHA256, + PROTECTED_CONTENT_DISPOSITION, + PROTECTED_CONTENT_TYPE, +) -LCPENCRYPT_FAILED_LCPSERVER_NOTIFICATION = \ - '''Error notifying the License Server; level 60 -lcp server error 401''' +LCPENCRYPT_FAILED_LCPSERVER_NOTIFICATION = """Error notifying the License Server; level 60 +lcp server error 401""" -LCPENCRYPT_SUCCESSFUL_NOTIFICATION_RESULT = \ - '''License Server was notified +LCPENCRYPT_SUCCESSFUL_NOTIFICATION_RESULT = """License Server was notified {{ "content-id": "{0}", "content-encryption-key": "{1}", @@ -69,18 +68,18 @@ "protected-content-type": "{6}" }} Encryption was successful -'''.format( - BOOK_IDENTIFIER, - CONTENT_ENCRYPTION_KEY, - PROTECTED_CONTENT_LOCATION, - PROTECTED_CONTENT_LENGTH, - PROTECTED_CONTENT_SHA256, - PROTECTED_CONTENT_DISPOSITION, - PROTECTED_CONTENT_TYPE - ) +""".format( + BOOK_IDENTIFIER, + CONTENT_ENCRYPTION_KEY, + PROTECTED_CONTENT_LOCATION, + PROTECTED_CONTENT_LENGTH, + PROTECTED_CONTENT_SHA256, + PROTECTED_CONTENT_DISPOSITION, + PROTECTED_CONTENT_TYPE, +) -LCPSERVER_LICENSE = ''' +LCPSERVER_LICENSE = """ { "provider": "http://circulation.manager", "id": "e99be177-4902-426a-9b96-0872ae877e2f", @@ -127,13 +126,13 @@ "algorithm": "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" } } -''' +""" -LCPSERVER_URL = 'http://localhost:8989' -LCPSERVER_USER = 'lcp' -LCPSERVER_PASSWORD = 'secretpassword' -LCPSERVER_INPUT_DIRECTORY = '/opt/readium/encrypted' +LCPSERVER_URL = "http://localhost:8989" +LCPSERVER_USER = "lcp" +LCPSERVER_PASSWORD = "secretpassword" +LCPSERVER_INPUT_DIRECTORY = "/opt/readium/encrypted" -CONTENT_ID = '1' -TEXT_HINT = 'Not very helpful hint' -PROVIDER_NAME = 'http://circulation.manager' +CONTENT_ID = "1" +TEXT_HINT = "Not very helpful hint" +PROVIDER_NAME = "http://circulation.manager" diff --git a/tests/lcp/test_collection.py b/tests/lcp/test_collection.py index a9523c8bfd..b2c42ba00d 100644 --- a/tests/lcp/test_collection.py +++ b/tests/lcp/test_collection.py @@ -2,14 +2,18 @@ import json from freezegun import freeze_time -from mock import create_autospec, MagicMock, patch +from mock import MagicMock, create_autospec, patch from api.lcp.collection import LCPAPI, LCPFulfilmentInfo from api.lcp.encrypt import LCPEncryptionConfiguration -from api.lcp.server import LCPServerConfiguration, LCPServer -from core.model import ExternalIntegration, DataSource -from core.model.configuration import HasExternalIntegration, ConfigurationStorage, ConfigurationAttribute, \ - ConfigurationFactory +from api.lcp.server import LCPServer, LCPServerConfiguration +from core.model import DataSource, ExternalIntegration +from core.model.configuration import ( + ConfigurationAttribute, + ConfigurationFactory, + ConfigurationStorage, + HasExternalIntegration, +) from core.util.datetime_helpers import utc_now from tests.lcp import fixtures from tests.lcp.database_test import DatabaseTest @@ -22,7 +26,9 @@ def setup_method(self, mock_search=True): self._lcp_collection = self._collection(protocol=ExternalIntegration.LCP) self._integration = self._lcp_collection.external_integration integration_association = create_autospec(spec=HasExternalIntegration) - integration_association.external_integration = MagicMock(return_value=self._integration) + integration_association.external_integration = MagicMock( + return_value=self._integration + ) self._configuration_storage = ConfigurationStorage(integration_association) self._configuration_factory = ConfigurationFactory() @@ -32,279 +38,330 @@ def test_settings(self): # lcpserver_url assert ( - LCPAPI.SETTINGS[0][ConfigurationAttribute.KEY.value] == - LCPServerConfiguration.lcpserver_url.key) + LCPAPI.SETTINGS[0][ConfigurationAttribute.KEY.value] + == LCPServerConfiguration.lcpserver_url.key + ) assert ( - LCPAPI.SETTINGS[0][ConfigurationAttribute.LABEL.value] == - LCPServerConfiguration.lcpserver_url.label) + LCPAPI.SETTINGS[0][ConfigurationAttribute.LABEL.value] + == LCPServerConfiguration.lcpserver_url.label + ) assert ( - LCPAPI.SETTINGS[0][ConfigurationAttribute.DESCRIPTION.value] == - LCPServerConfiguration.lcpserver_url.description) + LCPAPI.SETTINGS[0][ConfigurationAttribute.DESCRIPTION.value] + == LCPServerConfiguration.lcpserver_url.description + ) + assert LCPAPI.SETTINGS[0][ConfigurationAttribute.TYPE.value] == None assert ( - LCPAPI.SETTINGS[0][ConfigurationAttribute.TYPE.value] == - None) + LCPAPI.SETTINGS[0][ConfigurationAttribute.REQUIRED.value] + == LCPServerConfiguration.lcpserver_url.required + ) assert ( - LCPAPI.SETTINGS[0][ConfigurationAttribute.REQUIRED.value] == - LCPServerConfiguration.lcpserver_url.required) + LCPAPI.SETTINGS[0][ConfigurationAttribute.DEFAULT.value] + == LCPServerConfiguration.lcpserver_url.default + ) assert ( - LCPAPI.SETTINGS[0][ConfigurationAttribute.DEFAULT.value] == - LCPServerConfiguration.lcpserver_url.default) - assert ( - LCPAPI.SETTINGS[0][ConfigurationAttribute.CATEGORY.value] == - LCPServerConfiguration.lcpserver_url.category) + LCPAPI.SETTINGS[0][ConfigurationAttribute.CATEGORY.value] + == LCPServerConfiguration.lcpserver_url.category + ) # lcpserver_user assert ( - LCPAPI.SETTINGS[1][ConfigurationAttribute.KEY.value] == - LCPServerConfiguration.lcpserver_user.key) - assert ( - LCPAPI.SETTINGS[1][ConfigurationAttribute.LABEL.value] == - LCPServerConfiguration.lcpserver_user.label) + LCPAPI.SETTINGS[1][ConfigurationAttribute.KEY.value] + == LCPServerConfiguration.lcpserver_user.key + ) assert ( - LCPAPI.SETTINGS[1][ConfigurationAttribute.DESCRIPTION.value] == - LCPServerConfiguration.lcpserver_user.description) + LCPAPI.SETTINGS[1][ConfigurationAttribute.LABEL.value] + == LCPServerConfiguration.lcpserver_user.label + ) assert ( - LCPAPI.SETTINGS[1][ConfigurationAttribute.TYPE.value] == - None) + LCPAPI.SETTINGS[1][ConfigurationAttribute.DESCRIPTION.value] + == LCPServerConfiguration.lcpserver_user.description + ) + assert LCPAPI.SETTINGS[1][ConfigurationAttribute.TYPE.value] == None assert ( - LCPAPI.SETTINGS[1][ConfigurationAttribute.REQUIRED.value] == - LCPServerConfiguration.lcpserver_user.required) + LCPAPI.SETTINGS[1][ConfigurationAttribute.REQUIRED.value] + == LCPServerConfiguration.lcpserver_user.required + ) assert ( - LCPAPI.SETTINGS[1][ConfigurationAttribute.DEFAULT.value] == - LCPServerConfiguration.lcpserver_user.default) + LCPAPI.SETTINGS[1][ConfigurationAttribute.DEFAULT.value] + == LCPServerConfiguration.lcpserver_user.default + ) assert ( - LCPAPI.SETTINGS[1][ConfigurationAttribute.CATEGORY.value] == - LCPServerConfiguration.lcpserver_user.category) + LCPAPI.SETTINGS[1][ConfigurationAttribute.CATEGORY.value] + == LCPServerConfiguration.lcpserver_user.category + ) # lcpserver_password assert ( - LCPAPI.SETTINGS[2][ConfigurationAttribute.KEY.value] == - LCPServerConfiguration.lcpserver_password.key) + LCPAPI.SETTINGS[2][ConfigurationAttribute.KEY.value] + == LCPServerConfiguration.lcpserver_password.key + ) assert ( - LCPAPI.SETTINGS[2][ConfigurationAttribute.LABEL.value] == - LCPServerConfiguration.lcpserver_password.label) + LCPAPI.SETTINGS[2][ConfigurationAttribute.LABEL.value] + == LCPServerConfiguration.lcpserver_password.label + ) assert ( - LCPAPI.SETTINGS[2][ConfigurationAttribute.DESCRIPTION.value] == - LCPServerConfiguration.lcpserver_password.description) + LCPAPI.SETTINGS[2][ConfigurationAttribute.DESCRIPTION.value] + == LCPServerConfiguration.lcpserver_password.description + ) + assert LCPAPI.SETTINGS[2][ConfigurationAttribute.TYPE.value] == None assert ( - LCPAPI.SETTINGS[2][ConfigurationAttribute.TYPE.value] == - None) + LCPAPI.SETTINGS[2][ConfigurationAttribute.REQUIRED.value] + == LCPServerConfiguration.lcpserver_password.required + ) assert ( - LCPAPI.SETTINGS[2][ConfigurationAttribute.REQUIRED.value] == - LCPServerConfiguration.lcpserver_password.required) + LCPAPI.SETTINGS[2][ConfigurationAttribute.DEFAULT.value] + == LCPServerConfiguration.lcpserver_password.default + ) assert ( - LCPAPI.SETTINGS[2][ConfigurationAttribute.DEFAULT.value] == - LCPServerConfiguration.lcpserver_password.default) - assert ( - LCPAPI.SETTINGS[2][ConfigurationAttribute.CATEGORY.value] == - LCPServerConfiguration.lcpserver_password.category) + LCPAPI.SETTINGS[2][ConfigurationAttribute.CATEGORY.value] + == LCPServerConfiguration.lcpserver_password.category + ) # lcpserver_input_directory assert ( - LCPAPI.SETTINGS[3][ConfigurationAttribute.KEY.value] == - LCPServerConfiguration.lcpserver_input_directory.key) - assert ( - LCPAPI.SETTINGS[3][ConfigurationAttribute.LABEL.value] == - LCPServerConfiguration.lcpserver_input_directory.label) + LCPAPI.SETTINGS[3][ConfigurationAttribute.KEY.value] + == LCPServerConfiguration.lcpserver_input_directory.key + ) assert ( - LCPAPI.SETTINGS[3][ConfigurationAttribute.DESCRIPTION.value] == - LCPServerConfiguration.lcpserver_input_directory.description) + LCPAPI.SETTINGS[3][ConfigurationAttribute.LABEL.value] + == LCPServerConfiguration.lcpserver_input_directory.label + ) assert ( - LCPAPI.SETTINGS[3][ConfigurationAttribute.TYPE.value] == - None) + LCPAPI.SETTINGS[3][ConfigurationAttribute.DESCRIPTION.value] + == LCPServerConfiguration.lcpserver_input_directory.description + ) + assert LCPAPI.SETTINGS[3][ConfigurationAttribute.TYPE.value] == None assert ( - LCPAPI.SETTINGS[3][ConfigurationAttribute.REQUIRED.value] == - LCPServerConfiguration.lcpserver_input_directory.required) + LCPAPI.SETTINGS[3][ConfigurationAttribute.REQUIRED.value] + == LCPServerConfiguration.lcpserver_input_directory.required + ) assert ( - LCPAPI.SETTINGS[3][ConfigurationAttribute.DEFAULT.value] == - LCPServerConfiguration.lcpserver_input_directory.default) + LCPAPI.SETTINGS[3][ConfigurationAttribute.DEFAULT.value] + == LCPServerConfiguration.lcpserver_input_directory.default + ) assert ( - LCPAPI.SETTINGS[3][ConfigurationAttribute.CATEGORY.value] == - LCPServerConfiguration.lcpserver_input_directory.category) + LCPAPI.SETTINGS[3][ConfigurationAttribute.CATEGORY.value] + == LCPServerConfiguration.lcpserver_input_directory.category + ) # lcpserver_page_size assert ( - LCPAPI.SETTINGS[4][ConfigurationAttribute.KEY.value] == - LCPServerConfiguration.lcpserver_page_size.key) + LCPAPI.SETTINGS[4][ConfigurationAttribute.KEY.value] + == LCPServerConfiguration.lcpserver_page_size.key + ) assert ( - LCPAPI.SETTINGS[4][ConfigurationAttribute.LABEL.value] == - LCPServerConfiguration.lcpserver_page_size.label) + LCPAPI.SETTINGS[4][ConfigurationAttribute.LABEL.value] + == LCPServerConfiguration.lcpserver_page_size.label + ) assert ( - LCPAPI.SETTINGS[4][ConfigurationAttribute.DESCRIPTION.value] == - LCPServerConfiguration.lcpserver_page_size.description) + LCPAPI.SETTINGS[4][ConfigurationAttribute.DESCRIPTION.value] + == LCPServerConfiguration.lcpserver_page_size.description + ) + assert LCPAPI.SETTINGS[4][ConfigurationAttribute.TYPE.value] == "number" assert ( - LCPAPI.SETTINGS[4][ConfigurationAttribute.TYPE.value] == - 'number') + LCPAPI.SETTINGS[4][ConfigurationAttribute.REQUIRED.value] + == LCPServerConfiguration.lcpserver_page_size.required + ) assert ( - LCPAPI.SETTINGS[4][ConfigurationAttribute.REQUIRED.value] == - LCPServerConfiguration.lcpserver_page_size.required) + LCPAPI.SETTINGS[4][ConfigurationAttribute.DEFAULT.value] + == LCPServerConfiguration.lcpserver_page_size.default + ) assert ( - LCPAPI.SETTINGS[4][ConfigurationAttribute.DEFAULT.value] == - LCPServerConfiguration.lcpserver_page_size.default) - assert ( - LCPAPI.SETTINGS[4][ConfigurationAttribute.CATEGORY.value] == - LCPServerConfiguration.lcpserver_page_size.category) + LCPAPI.SETTINGS[4][ConfigurationAttribute.CATEGORY.value] + == LCPServerConfiguration.lcpserver_page_size.category + ) # provider_name assert ( - LCPAPI.SETTINGS[5][ConfigurationAttribute.KEY.value] == - LCPServerConfiguration.provider_name.key) - assert ( - LCPAPI.SETTINGS[5][ConfigurationAttribute.LABEL.value] == - LCPServerConfiguration.provider_name.label) + LCPAPI.SETTINGS[5][ConfigurationAttribute.KEY.value] + == LCPServerConfiguration.provider_name.key + ) assert ( - LCPAPI.SETTINGS[5][ConfigurationAttribute.DESCRIPTION.value] == - LCPServerConfiguration.provider_name.description) + LCPAPI.SETTINGS[5][ConfigurationAttribute.LABEL.value] + == LCPServerConfiguration.provider_name.label + ) assert ( - LCPAPI.SETTINGS[5][ConfigurationAttribute.TYPE.value] == - None) + LCPAPI.SETTINGS[5][ConfigurationAttribute.DESCRIPTION.value] + == LCPServerConfiguration.provider_name.description + ) + assert LCPAPI.SETTINGS[5][ConfigurationAttribute.TYPE.value] == None assert ( - LCPAPI.SETTINGS[5][ConfigurationAttribute.REQUIRED.value] == - LCPServerConfiguration.provider_name.required) + LCPAPI.SETTINGS[5][ConfigurationAttribute.REQUIRED.value] + == LCPServerConfiguration.provider_name.required + ) assert ( - LCPAPI.SETTINGS[5][ConfigurationAttribute.DEFAULT.value] == - LCPServerConfiguration.provider_name.default) + LCPAPI.SETTINGS[5][ConfigurationAttribute.DEFAULT.value] + == LCPServerConfiguration.provider_name.default + ) assert ( - LCPAPI.SETTINGS[5][ConfigurationAttribute.CATEGORY.value] == - LCPServerConfiguration.provider_name.category) + LCPAPI.SETTINGS[5][ConfigurationAttribute.CATEGORY.value] + == LCPServerConfiguration.provider_name.category + ) # passphrase_hint assert ( - LCPAPI.SETTINGS[6][ConfigurationAttribute.KEY.value] == - LCPServerConfiguration.passphrase_hint.key) - assert ( - LCPAPI.SETTINGS[6][ConfigurationAttribute.LABEL.value] == - LCPServerConfiguration.passphrase_hint.label) + LCPAPI.SETTINGS[6][ConfigurationAttribute.KEY.value] + == LCPServerConfiguration.passphrase_hint.key + ) assert ( - LCPAPI.SETTINGS[6][ConfigurationAttribute.DESCRIPTION.value] == - LCPServerConfiguration.passphrase_hint.description) + LCPAPI.SETTINGS[6][ConfigurationAttribute.LABEL.value] + == LCPServerConfiguration.passphrase_hint.label + ) assert ( - LCPAPI.SETTINGS[6][ConfigurationAttribute.TYPE.value] == - None) + LCPAPI.SETTINGS[6][ConfigurationAttribute.DESCRIPTION.value] + == LCPServerConfiguration.passphrase_hint.description + ) + assert LCPAPI.SETTINGS[6][ConfigurationAttribute.TYPE.value] == None assert ( - LCPAPI.SETTINGS[6][ConfigurationAttribute.REQUIRED.value] == - LCPServerConfiguration.passphrase_hint.required) + LCPAPI.SETTINGS[6][ConfigurationAttribute.REQUIRED.value] + == LCPServerConfiguration.passphrase_hint.required + ) assert ( - LCPAPI.SETTINGS[6][ConfigurationAttribute.DEFAULT.value] == - LCPServerConfiguration.passphrase_hint.default) + LCPAPI.SETTINGS[6][ConfigurationAttribute.DEFAULT.value] + == LCPServerConfiguration.passphrase_hint.default + ) assert ( - LCPAPI.SETTINGS[6][ConfigurationAttribute.CATEGORY.value] == - LCPServerConfiguration.passphrase_hint.category) + LCPAPI.SETTINGS[6][ConfigurationAttribute.CATEGORY.value] + == LCPServerConfiguration.passphrase_hint.category + ) # encryption_algorithm assert ( - LCPAPI.SETTINGS[7][ConfigurationAttribute.KEY.value] == - LCPServerConfiguration.encryption_algorithm.key) + LCPAPI.SETTINGS[7][ConfigurationAttribute.KEY.value] + == LCPServerConfiguration.encryption_algorithm.key + ) assert ( - LCPAPI.SETTINGS[7][ConfigurationAttribute.LABEL.value] == - LCPServerConfiguration.encryption_algorithm.label) + LCPAPI.SETTINGS[7][ConfigurationAttribute.LABEL.value] + == LCPServerConfiguration.encryption_algorithm.label + ) assert ( - LCPAPI.SETTINGS[7][ConfigurationAttribute.DESCRIPTION.value] == - LCPServerConfiguration.encryption_algorithm.description) + LCPAPI.SETTINGS[7][ConfigurationAttribute.DESCRIPTION.value] + == LCPServerConfiguration.encryption_algorithm.description + ) assert ( - LCPAPI.SETTINGS[7][ConfigurationAttribute.TYPE.value] == - LCPServerConfiguration.encryption_algorithm.type.value) + LCPAPI.SETTINGS[7][ConfigurationAttribute.TYPE.value] + == LCPServerConfiguration.encryption_algorithm.type.value + ) assert ( - LCPAPI.SETTINGS[7][ConfigurationAttribute.REQUIRED.value] == - LCPServerConfiguration.encryption_algorithm.required) + LCPAPI.SETTINGS[7][ConfigurationAttribute.REQUIRED.value] + == LCPServerConfiguration.encryption_algorithm.required + ) assert ( - LCPAPI.SETTINGS[7][ConfigurationAttribute.DEFAULT.value] == - LCPServerConfiguration.encryption_algorithm.default) + LCPAPI.SETTINGS[7][ConfigurationAttribute.DEFAULT.value] + == LCPServerConfiguration.encryption_algorithm.default + ) assert ( - LCPAPI.SETTINGS[7][ConfigurationAttribute.CATEGORY.value] == - LCPServerConfiguration.encryption_algorithm.category) + LCPAPI.SETTINGS[7][ConfigurationAttribute.CATEGORY.value] + == LCPServerConfiguration.encryption_algorithm.category + ) # max_printable_pages assert ( - LCPAPI.SETTINGS[8][ConfigurationAttribute.KEY.value] == - LCPServerConfiguration.max_printable_pages.key) + LCPAPI.SETTINGS[8][ConfigurationAttribute.KEY.value] + == LCPServerConfiguration.max_printable_pages.key + ) assert ( - LCPAPI.SETTINGS[8][ConfigurationAttribute.LABEL.value] == - LCPServerConfiguration.max_printable_pages.label) + LCPAPI.SETTINGS[8][ConfigurationAttribute.LABEL.value] + == LCPServerConfiguration.max_printable_pages.label + ) assert ( - LCPAPI.SETTINGS[8][ConfigurationAttribute.DESCRIPTION.value] == - LCPServerConfiguration.max_printable_pages.description) + LCPAPI.SETTINGS[8][ConfigurationAttribute.DESCRIPTION.value] + == LCPServerConfiguration.max_printable_pages.description + ) + assert LCPAPI.SETTINGS[8][ConfigurationAttribute.TYPE.value] == "number" assert ( - LCPAPI.SETTINGS[8][ConfigurationAttribute.TYPE.value] == - 'number') + LCPAPI.SETTINGS[8][ConfigurationAttribute.REQUIRED.value] + == LCPServerConfiguration.max_printable_pages.required + ) assert ( - LCPAPI.SETTINGS[8][ConfigurationAttribute.REQUIRED.value] == - LCPServerConfiguration.max_printable_pages.required) + LCPAPI.SETTINGS[8][ConfigurationAttribute.DEFAULT.value] + == LCPServerConfiguration.max_printable_pages.default + ) assert ( - LCPAPI.SETTINGS[8][ConfigurationAttribute.DEFAULT.value] == - LCPServerConfiguration.max_printable_pages.default) - assert ( - LCPAPI.SETTINGS[8][ConfigurationAttribute.CATEGORY.value] == - LCPServerConfiguration.max_printable_pages.category) + LCPAPI.SETTINGS[8][ConfigurationAttribute.CATEGORY.value] + == LCPServerConfiguration.max_printable_pages.category + ) # max_copiable_pages assert ( - LCPAPI.SETTINGS[9][ConfigurationAttribute.KEY.value] == - LCPServerConfiguration.max_copiable_pages.key) - assert ( - LCPAPI.SETTINGS[9][ConfigurationAttribute.LABEL.value] == - LCPServerConfiguration.max_copiable_pages.label) + LCPAPI.SETTINGS[9][ConfigurationAttribute.KEY.value] + == LCPServerConfiguration.max_copiable_pages.key + ) assert ( - LCPAPI.SETTINGS[9][ConfigurationAttribute.DESCRIPTION.value] == - LCPServerConfiguration.max_copiable_pages.description) + LCPAPI.SETTINGS[9][ConfigurationAttribute.LABEL.value] + == LCPServerConfiguration.max_copiable_pages.label + ) assert ( - LCPAPI.SETTINGS[9][ConfigurationAttribute.TYPE.value] == - 'number') + LCPAPI.SETTINGS[9][ConfigurationAttribute.DESCRIPTION.value] + == LCPServerConfiguration.max_copiable_pages.description + ) + assert LCPAPI.SETTINGS[9][ConfigurationAttribute.TYPE.value] == "number" assert ( - LCPAPI.SETTINGS[9][ConfigurationAttribute.REQUIRED.value] == - LCPServerConfiguration.max_copiable_pages.required) + LCPAPI.SETTINGS[9][ConfigurationAttribute.REQUIRED.value] + == LCPServerConfiguration.max_copiable_pages.required + ) assert ( - LCPAPI.SETTINGS[9][ConfigurationAttribute.DEFAULT.value] == - LCPServerConfiguration.max_copiable_pages.default) + LCPAPI.SETTINGS[9][ConfigurationAttribute.DEFAULT.value] + == LCPServerConfiguration.max_copiable_pages.default + ) assert ( - LCPAPI.SETTINGS[9][ConfigurationAttribute.CATEGORY.value] == - LCPServerConfiguration.max_copiable_pages.category) + LCPAPI.SETTINGS[9][ConfigurationAttribute.CATEGORY.value] + == LCPServerConfiguration.max_copiable_pages.category + ) # lcpencrypt_location assert ( - LCPAPI.SETTINGS[10][ConfigurationAttribute.KEY.value] == - LCPEncryptionConfiguration.lcpencrypt_location.key) - assert ( - LCPAPI.SETTINGS[10][ConfigurationAttribute.LABEL.value] == - LCPEncryptionConfiguration.lcpencrypt_location.label) + LCPAPI.SETTINGS[10][ConfigurationAttribute.KEY.value] + == LCPEncryptionConfiguration.lcpencrypt_location.key + ) assert ( - LCPAPI.SETTINGS[10][ConfigurationAttribute.DESCRIPTION.value] == - LCPEncryptionConfiguration.lcpencrypt_location.description) + LCPAPI.SETTINGS[10][ConfigurationAttribute.LABEL.value] + == LCPEncryptionConfiguration.lcpencrypt_location.label + ) assert ( - LCPAPI.SETTINGS[10][ConfigurationAttribute.TYPE.value] == - None) + LCPAPI.SETTINGS[10][ConfigurationAttribute.DESCRIPTION.value] + == LCPEncryptionConfiguration.lcpencrypt_location.description + ) + assert LCPAPI.SETTINGS[10][ConfigurationAttribute.TYPE.value] == None assert ( - LCPAPI.SETTINGS[10][ConfigurationAttribute.REQUIRED.value] == - LCPEncryptionConfiguration.lcpencrypt_location.required) + LCPAPI.SETTINGS[10][ConfigurationAttribute.REQUIRED.value] + == LCPEncryptionConfiguration.lcpencrypt_location.required + ) assert ( - LCPAPI.SETTINGS[10][ConfigurationAttribute.DEFAULT.value] == - LCPEncryptionConfiguration.lcpencrypt_location.default) + LCPAPI.SETTINGS[10][ConfigurationAttribute.DEFAULT.value] + == LCPEncryptionConfiguration.lcpencrypt_location.default + ) assert ( - LCPAPI.SETTINGS[10][ConfigurationAttribute.CATEGORY.value] == - LCPEncryptionConfiguration.lcpencrypt_location.category) + LCPAPI.SETTINGS[10][ConfigurationAttribute.CATEGORY.value] + == LCPEncryptionConfiguration.lcpencrypt_location.category + ) # lcpencrypt_output_directory assert ( - LCPAPI.SETTINGS[11][ConfigurationAttribute.KEY.value] == - LCPEncryptionConfiguration.lcpencrypt_output_directory.key) + LCPAPI.SETTINGS[11][ConfigurationAttribute.KEY.value] + == LCPEncryptionConfiguration.lcpencrypt_output_directory.key + ) assert ( - LCPAPI.SETTINGS[11][ConfigurationAttribute.LABEL.value] == - LCPEncryptionConfiguration.lcpencrypt_output_directory.label) + LCPAPI.SETTINGS[11][ConfigurationAttribute.LABEL.value] + == LCPEncryptionConfiguration.lcpencrypt_output_directory.label + ) assert ( - LCPAPI.SETTINGS[11][ConfigurationAttribute.DESCRIPTION.value] == - LCPEncryptionConfiguration.lcpencrypt_output_directory.description) + LCPAPI.SETTINGS[11][ConfigurationAttribute.DESCRIPTION.value] + == LCPEncryptionConfiguration.lcpencrypt_output_directory.description + ) + assert LCPAPI.SETTINGS[11][ConfigurationAttribute.TYPE.value] == None assert ( - LCPAPI.SETTINGS[11][ConfigurationAttribute.TYPE.value] == - None) + LCPAPI.SETTINGS[11][ConfigurationAttribute.REQUIRED.value] + == LCPEncryptionConfiguration.lcpencrypt_output_directory.required + ) assert ( - LCPAPI.SETTINGS[11][ConfigurationAttribute.REQUIRED.value] == - LCPEncryptionConfiguration.lcpencrypt_output_directory.required) + LCPAPI.SETTINGS[11][ConfigurationAttribute.DEFAULT.value] + == LCPEncryptionConfiguration.lcpencrypt_output_directory.default + ) assert ( - LCPAPI.SETTINGS[11][ConfigurationAttribute.DEFAULT.value] == - LCPEncryptionConfiguration.lcpencrypt_output_directory.default) - assert ( - LCPAPI.SETTINGS[11][ConfigurationAttribute.CATEGORY.value] == - LCPEncryptionConfiguration.lcpencrypt_output_directory.category) + LCPAPI.SETTINGS[11][ConfigurationAttribute.CATEGORY.value] + == LCPEncryptionConfiguration.lcpencrypt_output_directory.category + ) @freeze_time("2020-01-01 00:00:00") def test_checkout_without_existing_loan(self): @@ -316,29 +373,39 @@ def test_checkout_without_existing_loan(self): end_date = start_date + datetime.timedelta(days=days) data_source = DataSource.lookup(self._db, DataSource.LCP, autocreate=True) data_source_name = data_source.name - edition = self._edition(data_source_name=data_source_name, identifier_id=fixtures.CONTENT_ID) + edition = self._edition( + data_source_name=data_source_name, identifier_id=fixtures.CONTENT_ID + ) license_pool = self._licensepool( - edition=edition, data_source_name=data_source_name, collection=self._lcp_collection) + edition=edition, + data_source_name=data_source_name, + collection=self._lcp_collection, + ) lcp_license = json.loads(fixtures.LCPSERVER_LICENSE) lcp_server_mock = create_autospec(spec=LCPServer) lcp_server_mock.generate_license = MagicMock(return_value=lcp_license) with self._configuration_factory.create( - self._configuration_storage, self._db, LCPServerConfiguration) as configuration: + self._configuration_storage, self._db, LCPServerConfiguration + ) as configuration: - with patch('api.lcp.collection.LCPServer') as lcp_server_constructor: + with patch("api.lcp.collection.LCPServer") as lcp_server_constructor: lcp_server_constructor.return_value = lcp_server_mock configuration.lcpserver_url = fixtures.LCPSERVER_URL configuration.lcpserver_user = fixtures.LCPSERVER_USER configuration.lcpserver_password = fixtures.LCPSERVER_PASSWORD - configuration.lcpserver_input_directory = fixtures.LCPSERVER_INPUT_DIRECTORY + configuration.lcpserver_input_directory = ( + fixtures.LCPSERVER_INPUT_DIRECTORY + ) configuration.provider_name = fixtures.PROVIDER_NAME configuration.passphrase_hint = fixtures.TEXT_HINT - configuration.encryption_algorithm = LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + configuration.encryption_algorithm = ( + LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + ) # Act - loan = lcp_api.checkout(patron, 'pin', license_pool, 'internal format') + loan = lcp_api.checkout(patron, "pin", license_pool, "internal format") # Assert assert loan.collection_id == self._lcp_collection.id @@ -346,12 +413,13 @@ def test_checkout_without_existing_loan(self): assert loan.license_pool(self._db) == license_pool assert loan.data_source_name == data_source_name assert loan.identifier_type == license_pool.identifier.type - assert loan.external_identifier == lcp_license['id'] + assert loan.external_identifier == lcp_license["id"] assert loan.start_date == start_date assert loan.end_date == end_date lcp_server_mock.generate_license.assert_called_once_with( - self._db, fixtures.CONTENT_ID, patron, start_date, end_date) + self._db, fixtures.CONTENT_ID, patron, start_date, end_date + ) @freeze_time("2020-01-01 00:00:00") def test_checkout_with_existing_loan(self): @@ -363,31 +431,41 @@ def test_checkout_with_existing_loan(self): end_date = start_date + datetime.timedelta(days=days) data_source = DataSource.lookup(self._db, DataSource.LCP, autocreate=True) data_source_name = data_source.name - edition = self._edition(data_source_name=data_source_name, identifier_id=fixtures.CONTENT_ID) + edition = self._edition( + data_source_name=data_source_name, identifier_id=fixtures.CONTENT_ID + ) license_pool = self._licensepool( - edition=edition, data_source_name=data_source_name, collection=self._lcp_collection) + edition=edition, + data_source_name=data_source_name, + collection=self._lcp_collection, + ) lcp_license = json.loads(fixtures.LCPSERVER_LICENSE) lcp_server_mock = create_autospec(spec=LCPServer) lcp_server_mock.get_license = MagicMock(return_value=lcp_license) - loan_identifier = 'e99be177-4902-426a-9b96-0872ae877e2f' + loan_identifier = "e99be177-4902-426a-9b96-0872ae877e2f" license_pool.loan_to(patron, external_identifier=loan_identifier) with self._configuration_factory.create( - self._configuration_storage, self._db, LCPServerConfiguration) as configuration: - with patch('api.lcp.collection.LCPServer') as lcp_server_constructor: + self._configuration_storage, self._db, LCPServerConfiguration + ) as configuration: + with patch("api.lcp.collection.LCPServer") as lcp_server_constructor: lcp_server_constructor.return_value = lcp_server_mock configuration.lcpserver_url = fixtures.LCPSERVER_URL configuration.lcpserver_user = fixtures.LCPSERVER_USER configuration.lcpserver_password = fixtures.LCPSERVER_PASSWORD - configuration.lcpserver_input_directory = fixtures.LCPSERVER_INPUT_DIRECTORY + configuration.lcpserver_input_directory = ( + fixtures.LCPSERVER_INPUT_DIRECTORY + ) configuration.provider_name = fixtures.PROVIDER_NAME configuration.passphrase_hint = fixtures.TEXT_HINT - configuration.encryption_algorithm = LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + configuration.encryption_algorithm = ( + LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + ) # Act - loan = lcp_api.checkout(patron, 'pin', license_pool, 'internal format') + loan = lcp_api.checkout(patron, "pin", license_pool, "internal format") # Assert assert loan.collection_id == self._lcp_collection.id @@ -400,7 +478,8 @@ def test_checkout_with_existing_loan(self): assert loan.end_date == end_date lcp_server_mock.get_license.assert_called_once_with( - self._db, loan_identifier, patron) + self._db, loan_identifier, patron + ) @freeze_time("2020-01-01 00:00:00") def test_fulfil(self): @@ -413,28 +492,43 @@ def test_fulfil(self): data_source = DataSource.lookup(self._db, DataSource.LCP, autocreate=True) data_source_name = data_source.name license_pool = self._licensepool( - edition=None, data_source_name=data_source_name, collection=self._lcp_collection) + edition=None, + data_source_name=data_source_name, + collection=self._lcp_collection, + ) lcp_license = json.loads(fixtures.LCPSERVER_LICENSE) lcp_server_mock = create_autospec(spec=LCPServer) lcp_server_mock.get_license = MagicMock(return_value=lcp_license) with self._configuration_factory.create( - self._configuration_storage, self._db, LCPServerConfiguration) as configuration: - with patch('api.lcp.collection.LCPServer') as lcp_server_constructor: + self._configuration_storage, self._db, LCPServerConfiguration + ) as configuration: + with patch("api.lcp.collection.LCPServer") as lcp_server_constructor: lcp_server_constructor.return_value = lcp_server_mock configuration.lcpserver_url = fixtures.LCPSERVER_URL configuration.lcpserver_user = fixtures.LCPSERVER_USER configuration.lcpserver_password = fixtures.LCPSERVER_PASSWORD - configuration.lcpserver_input_directory = fixtures.LCPSERVER_INPUT_DIRECTORY + configuration.lcpserver_input_directory = ( + fixtures.LCPSERVER_INPUT_DIRECTORY + ) configuration.provider_name = fixtures.PROVIDER_NAME configuration.passphrase_hint = fixtures.TEXT_HINT - configuration.encryption_algorithm = LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + configuration.encryption_algorithm = ( + LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + ) # Act - license_pool.loan_to(patron, start=today, end=expires, external_identifier=lcp_license['id']) - fulfilment_info = lcp_api.fulfill(patron, 'pin', license_pool, 'internal format') + license_pool.loan_to( + patron, + start=today, + end=expires, + external_identifier=lcp_license["id"], + ) + fulfilment_info = lcp_api.fulfill( + patron, "pin", license_pool, "internal format" + ) # Assert assert isinstance(fulfilment_info, LCPFulfilmentInfo) == True @@ -445,7 +539,8 @@ def test_fulfil(self): assert fulfilment_info.identifier_type == license_pool.identifier.type lcp_server_mock.get_license.assert_called_once_with( - self._db, lcp_license['id'], patron) + self._db, lcp_license["id"], patron + ) def test_patron_activity_returns_correct_result(self): # Arrange @@ -458,37 +553,60 @@ def test_patron_activity_returns_correct_result(self): expires = today + datetime.timedelta(days=days) data_source = DataSource.lookup(self._db, DataSource.LCP, autocreate=True) data_source_name = data_source.name - external_identifier = '1' + external_identifier = "1" license_pool = self._licensepool( - edition=None, data_source_name=data_source_name, collection=self._lcp_collection) - license_pool.loan_to(patron, start=today, end=expires, external_identifier=external_identifier) + edition=None, + data_source_name=data_source_name, + collection=self._lcp_collection, + ) + license_pool.loan_to( + patron, start=today, end=expires, external_identifier=external_identifier + ) # 2. Loan from a different collection other_collection = self._collection(protocol=ExternalIntegration.MANUAL) - other_external_identifier = '2' + other_external_identifier = "2" other_license_pool = self._licensepool( - edition=None, data_source_name=data_source_name, collection=other_collection) - other_license_pool.loan_to(patron, start=today, end=expires, external_identifier=other_external_identifier) + edition=None, data_source_name=data_source_name, collection=other_collection + ) + other_license_pool.loan_to( + patron, + start=today, + end=expires, + external_identifier=other_external_identifier, + ) # 3. Other patron's loan other_patron = self._patron() other_license_pool = self._licensepool( - edition=None, data_source_name=data_source_name, collection=other_collection) + edition=None, data_source_name=data_source_name, collection=other_collection + ) other_license_pool.loan_to(other_patron, start=today, end=expires) # 4. Expired loan other_license_pool = self._licensepool( - edition=None, data_source_name=data_source_name, collection=self._lcp_collection) - other_license_pool.loan_to(patron, start=today, end=today - datetime.timedelta(days=1)) + edition=None, + data_source_name=data_source_name, + collection=self._lcp_collection, + ) + other_license_pool.loan_to( + patron, start=today, end=today - datetime.timedelta(days=1) + ) # 5. Not started loan other_license_pool = self._licensepool( - edition=None, data_source_name=data_source_name, collection=self._lcp_collection) + edition=None, + data_source_name=data_source_name, + collection=self._lcp_collection, + ) other_license_pool.loan_to( - patron, start=today + datetime.timedelta(days=1), end=today + datetime.timedelta(days=2)) + patron, + start=today + datetime.timedelta(days=1), + end=today + datetime.timedelta(days=2), + ) # Act - loans = lcp_api.patron_activity(patron, 'pin') + loans = lcp_api.patron_activity(patron, "pin") # Assert assert len(loans) == 1 diff --git a/tests/lcp/test_controller.py b/tests/lcp/test_controller.py index a5d025995b..e62e9a7425 100644 --- a/tests/lcp/test_controller.py +++ b/tests/lcp/test_controller.py @@ -1,7 +1,7 @@ import json from flask import request -from mock import MagicMock, create_autospec, patch, call +from mock import MagicMock, call, create_autospec, patch from api.controller import CirculationManager from api.lcp.collection import LCPAPI @@ -15,21 +15,29 @@ class TestLCPController(ControllerTest): - def test_get_lcp_passphrase_returns_the_same_passphrase_for_authenticated_patron(self): + def test_get_lcp_passphrase_returns_the_same_passphrase_for_authenticated_patron( + self, + ): # Arrange - expected_passphrase = '1cde00b4-bea9-48fc-819b-bd17c578a22c' + expected_passphrase = "1cde00b4-bea9-48fc-819b-bd17c578a22c" - with patch('api.lcp.controller.LCPCredentialFactory') as credential_factory_constructor_mock: + with patch( + "api.lcp.controller.LCPCredentialFactory" + ) as credential_factory_constructor_mock: credential_factory = create_autospec(spec=LCPCredentialFactory) - credential_factory.get_patron_passphrase = MagicMock(return_value=expected_passphrase) + credential_factory.get_patron_passphrase = MagicMock( + return_value=expected_passphrase + ) credential_factory_constructor_mock.return_value = credential_factory patron = self.default_patron manager = CirculationManager(self._db, testing=True) controller = LCPController(manager) - controller.authenticated_patron_from_request = MagicMock(return_value=patron) + controller.authenticated_patron_from_request = MagicMock( + return_value=patron + ) - url = 'http://circulationmanager.org/lcp/hint' + url = "http://circulationmanager.org/lcp/hint" with self.app.test_request_context(url): request.library = self._default_library @@ -41,20 +49,17 @@ def test_get_lcp_passphrase_returns_the_same_passphrase_for_authenticated_patron # Assert for result in [result1, result2]: assert result.status_code == 200 - assert ('passphrase' in result.json) == True - assert result.json['passphrase'] == expected_passphrase + assert ("passphrase" in result.json) == True + assert result.json["passphrase"] == expected_passphrase credential_factory.get_patron_passphrase.assert_has_calls( - [ - call(self._db, patron), - call(self._db, patron) - ] + [call(self._db, patron), call(self._db, patron)] ) def test_get_lcp_license_returns_problem_detail_when_collection_is_missing(self): # Arrange - missing_collection_name = 'missing-collection' - license_id = 'e99be177-4902-426a-9b96-0872ae877e2f' + missing_collection_name = "missing-collection" + license_id = "e99be177-4902-426a-9b96-0872ae877e2f" expected_license = json.loads(fixtures.LCPSERVER_LICENSE) lcp_server = create_autospec(spec=LCPServer) lcp_server.get_license = MagicMock(return_value=expected_license) @@ -62,7 +67,9 @@ def test_get_lcp_license_returns_problem_detail_when_collection_is_missing(self) lcp_collection = self._collection(LCPAPI.NAME, ExternalIntegration.LCP) library.collections.append(lcp_collection) - with patch('api.lcp.controller.LCPServerFactory') as lcp_server_factory_constructor_mock: + with patch( + "api.lcp.controller.LCPServerFactory" + ) as lcp_server_factory_constructor_mock: lcp_server_factory = create_autospec(spec=LCPServerFactory) lcp_server_factory.create = MagicMock(return_value=lcp_server) lcp_server_factory_constructor_mock.return_value = lcp_server_factory @@ -70,10 +77,13 @@ def test_get_lcp_license_returns_problem_detail_when_collection_is_missing(self) patron = self.default_patron manager = CirculationManager(self._db, testing=True) controller = LCPController(manager) - controller.authenticated_patron_from_request = MagicMock(return_value=patron) + controller.authenticated_patron_from_request = MagicMock( + return_value=patron + ) - url = 'http://circulationmanager.org/{0}/licenses{1}'.format( - missing_collection_name, license_id) + url = "http://circulationmanager.org/{0}/licenses{1}".format( + missing_collection_name, license_id + ) with self.app.test_request_context(url): request.library = self._default_library @@ -86,7 +96,7 @@ def test_get_lcp_license_returns_problem_detail_when_collection_is_missing(self) def test_get_lcp_license_returns_the_same_license_for_authenticated_patron(self): # Arrange - license_id = 'e99be177-4902-426a-9b96-0872ae877e2f' + license_id = "e99be177-4902-426a-9b96-0872ae877e2f" expected_license = json.loads(fixtures.LCPSERVER_LICENSE) lcp_server = create_autospec(spec=LCPServer) lcp_server.get_license = MagicMock(return_value=expected_license) @@ -94,7 +104,9 @@ def test_get_lcp_license_returns_the_same_license_for_authenticated_patron(self) lcp_collection = self._collection(LCPAPI.NAME, ExternalIntegration.LCP) library.collections.append(lcp_collection) - with patch('api.lcp.controller.LCPServerFactory') as lcp_server_factory_constructor_mock: + with patch( + "api.lcp.controller.LCPServerFactory" + ) as lcp_server_factory_constructor_mock: lcp_server_factory = create_autospec(spec=LCPServerFactory) lcp_server_factory.create = MagicMock(return_value=lcp_server) lcp_server_factory_constructor_mock.return_value = lcp_server_factory @@ -102,10 +114,13 @@ def test_get_lcp_license_returns_the_same_license_for_authenticated_patron(self) patron = self.default_patron manager = CirculationManager(self._db, testing=True) controller = LCPController(manager) - controller.authenticated_patron_from_request = MagicMock(return_value=patron) + controller.authenticated_patron_from_request = MagicMock( + return_value=patron + ) - url = 'http://circulationmanager.org/{0}/licenses{1}'.format( - LCPAPI.NAME, license_id) + url = "http://circulationmanager.org/{0}/licenses{1}".format( + LCPAPI.NAME, license_id + ) with self.app.test_request_context(url): request.library = self._default_library diff --git a/tests/lcp/test_encrypt.py b/tests/lcp/test_encrypt.py index 8543d4d56e..424b34f296 100644 --- a/tests/lcp/test_encrypt.py +++ b/tests/lcp/test_encrypt.py @@ -1,103 +1,133 @@ import pytest -from mock import patch, create_autospec, MagicMock +from mock import MagicMock, create_autospec, patch from parameterized import parameterized from pyfakefs.fake_filesystem_unittest import Patcher -from api.lcp.encrypt import LCPEncryptor, LCPEncryptionException, LCPEncryptionConfiguration, LCPEncryptionResult +from api.lcp.encrypt import ( + LCPEncryptionConfiguration, + LCPEncryptionException, + LCPEncryptionResult, + LCPEncryptor, +) from core.model import Identifier -from core.model.configuration import HasExternalIntegration, ConfigurationStorage, ConfigurationFactory +from core.model.configuration import ( + ConfigurationFactory, + ConfigurationStorage, + HasExternalIntegration, +) from tests.lcp import fixtures from tests.lcp.database_test import DatabaseTest class TestLCPEncryptor(DatabaseTest): - @parameterized.expand([ - ( - 'non_existing_directory', - fixtures.NOT_EXISTING_BOOK_FILE_PATH, - fixtures.LCPENCRYPT_NOT_EXISTING_DIRECTORY_RESULT, - None, - LCPEncryptionException(fixtures.LCPENCRYPT_NOT_EXISTING_DIRECTORY_RESULT.strip()), - False - ), - ( - 'failed_encryption', - fixtures.NOT_EXISTING_BOOK_FILE_PATH, - fixtures.LCPENCRYPT_FAILED_ENCRYPTION_RESULT, - None, - LCPEncryptionException('Encryption failed') - ), - ( - 'successful_encryption', - fixtures.EXISTING_BOOK_FILE_PATH, - fixtures.LCPENCRYPT_SUCCESSFUL_ENCRYPTION_RESULT, - LCPEncryptionResult( - content_id=fixtures.BOOK_IDENTIFIER, - content_encryption_key=fixtures.CONTENT_ENCRYPTION_KEY, - protected_content_location=fixtures.PROTECTED_CONTENT_LOCATION, - protected_content_disposition=fixtures.PROTECTED_CONTENT_DISPOSITION, - protected_content_type=fixtures.PROTECTED_CONTENT_TYPE, - protected_content_length=fixtures.PROTECTED_CONTENT_LENGTH, - protected_content_sha256=fixtures.PROTECTED_CONTENT_SHA256 - ) - ), - ( - 'failed_lcp_server_notification', - fixtures.EXISTING_BOOK_FILE_PATH, - fixtures.LCPENCRYPT_FAILED_LCPSERVER_NOTIFICATION, - None, - LCPEncryptionException(fixtures.LCPENCRYPT_FAILED_LCPSERVER_NOTIFICATION.strip()) - ), - ( - 'successful_lcp_server_notification', - fixtures.EXISTING_BOOK_FILE_PATH, - fixtures.LCPENCRYPT_SUCCESSFUL_NOTIFICATION_RESULT, - LCPEncryptionResult( - content_id=fixtures.BOOK_IDENTIFIER, - content_encryption_key=fixtures.CONTENT_ENCRYPTION_KEY, - protected_content_location=fixtures.PROTECTED_CONTENT_LOCATION, - protected_content_disposition=fixtures.PROTECTED_CONTENT_DISPOSITION, - protected_content_type=fixtures.PROTECTED_CONTENT_TYPE, - protected_content_length=fixtures.PROTECTED_CONTENT_LENGTH, - protected_content_sha256=fixtures.PROTECTED_CONTENT_SHA256 - ) - ), - ]) + @parameterized.expand( + [ + ( + "non_existing_directory", + fixtures.NOT_EXISTING_BOOK_FILE_PATH, + fixtures.LCPENCRYPT_NOT_EXISTING_DIRECTORY_RESULT, + None, + LCPEncryptionException( + fixtures.LCPENCRYPT_NOT_EXISTING_DIRECTORY_RESULT.strip() + ), + False, + ), + ( + "failed_encryption", + fixtures.NOT_EXISTING_BOOK_FILE_PATH, + fixtures.LCPENCRYPT_FAILED_ENCRYPTION_RESULT, + None, + LCPEncryptionException("Encryption failed"), + ), + ( + "successful_encryption", + fixtures.EXISTING_BOOK_FILE_PATH, + fixtures.LCPENCRYPT_SUCCESSFUL_ENCRYPTION_RESULT, + LCPEncryptionResult( + content_id=fixtures.BOOK_IDENTIFIER, + content_encryption_key=fixtures.CONTENT_ENCRYPTION_KEY, + protected_content_location=fixtures.PROTECTED_CONTENT_LOCATION, + protected_content_disposition=fixtures.PROTECTED_CONTENT_DISPOSITION, + protected_content_type=fixtures.PROTECTED_CONTENT_TYPE, + protected_content_length=fixtures.PROTECTED_CONTENT_LENGTH, + protected_content_sha256=fixtures.PROTECTED_CONTENT_SHA256, + ), + ), + ( + "failed_lcp_server_notification", + fixtures.EXISTING_BOOK_FILE_PATH, + fixtures.LCPENCRYPT_FAILED_LCPSERVER_NOTIFICATION, + None, + LCPEncryptionException( + fixtures.LCPENCRYPT_FAILED_LCPSERVER_NOTIFICATION.strip() + ), + ), + ( + "successful_lcp_server_notification", + fixtures.EXISTING_BOOK_FILE_PATH, + fixtures.LCPENCRYPT_SUCCESSFUL_NOTIFICATION_RESULT, + LCPEncryptionResult( + content_id=fixtures.BOOK_IDENTIFIER, + content_encryption_key=fixtures.CONTENT_ENCRYPTION_KEY, + protected_content_location=fixtures.PROTECTED_CONTENT_LOCATION, + protected_content_disposition=fixtures.PROTECTED_CONTENT_DISPOSITION, + protected_content_type=fixtures.PROTECTED_CONTENT_TYPE, + protected_content_length=fixtures.PROTECTED_CONTENT_LENGTH, + protected_content_sha256=fixtures.PROTECTED_CONTENT_SHA256, + ), + ), + ] + ) def test_local_lcpencrypt( - self, - _, - file_path, - lcpencrypt_output, - expected_result, - expected_exception=None, - create_file=True): + self, + _, + file_path, + lcpencrypt_output, + expected_result, + expected_exception=None, + create_file=True, + ): # Arrange integration_owner = create_autospec(spec=HasExternalIntegration) - integration_owner.external_integration = MagicMock(return_value=self._integration) + integration_owner.external_integration = MagicMock( + return_value=self._integration + ) configuration_storage = ConfigurationStorage(integration_owner) configuration_factory = ConfigurationFactory() encryptor = LCPEncryptor(configuration_storage, configuration_factory) identifier = Identifier(identifier=fixtures.BOOK_IDENTIFIER) - with configuration_factory.create(configuration_storage, self._db, LCPEncryptionConfiguration) as configuration: - configuration.lcpencrypt_location = LCPEncryptionConfiguration.DEFAULT_LCPENCRYPT_LOCATION + with configuration_factory.create( + configuration_storage, self._db, LCPEncryptionConfiguration + ) as configuration: + configuration.lcpencrypt_location = ( + LCPEncryptionConfiguration.DEFAULT_LCPENCRYPT_LOCATION + ) with Patcher() as patcher: - patcher.fs.create_file(LCPEncryptionConfiguration.DEFAULT_LCPENCRYPT_LOCATION) + patcher.fs.create_file( + LCPEncryptionConfiguration.DEFAULT_LCPENCRYPT_LOCATION + ) if create_file: patcher.fs.create_file(file_path) - with patch('subprocess.check_output') as subprocess_check_output_mock: + with patch("subprocess.check_output") as subprocess_check_output_mock: subprocess_check_output_mock.return_value = lcpencrypt_output if expected_exception: - with pytest.raises(expected_exception.__class__) as exception_metadata: - encryptor.encrypt(self._db, file_path, identifier.identifier) + with pytest.raises( + expected_exception.__class__ + ) as exception_metadata: + encryptor.encrypt( + self._db, file_path, identifier.identifier + ) # Assert assert exception_metadata.value == expected_exception else: # Assert - result = encryptor.encrypt(self._db, file_path, identifier.identifier) + result = encryptor.encrypt( + self._db, file_path, identifier.identifier + ) assert result == expected_result diff --git a/tests/lcp/test_hash.py b/tests/lcp/test_hash.py index bf8d1a5a32..e2d89ceb8e 100644 --- a/tests/lcp/test_hash.py +++ b/tests/lcp/test_hash.py @@ -4,32 +4,34 @@ class TestHasherFactory(object): - @parameterized.expand([ - ( - 'sha256', + @parameterized.expand( + [ + ( + "sha256", HashingAlgorithm.SHA256, - '12345', - '5994471abb01112afcc18159f6cc74b4f511b99806da59b3caf5a9c173cacfc5' - ), - ( - 'sha256_value', + "12345", + "5994471abb01112afcc18159f6cc74b4f511b99806da59b3caf5a9c173cacfc5", + ), + ( + "sha256_value", HashingAlgorithm.SHA256.value, - '12345', - '5994471abb01112afcc18159f6cc74b4f511b99806da59b3caf5a9c173cacfc5' - ), - ( - 'sha512', + "12345", + "5994471abb01112afcc18159f6cc74b4f511b99806da59b3caf5a9c173cacfc5", + ), + ( + "sha512", HashingAlgorithm.SHA512, - '12345', - '3627909a29c31381a071ec27f7c9ca97726182aed29a7ddd2e54353322cfb30abb9e3a6df2ac2c20fe23436311d678564d0c8d305930575f60e2d3d048184d79' - ), - ( - 'sha512_value', + "12345", + "3627909a29c31381a071ec27f7c9ca97726182aed29a7ddd2e54353322cfb30abb9e3a6df2ac2c20fe23436311d678564d0c8d305930575f60e2d3d048184d79", + ), + ( + "sha512_value", HashingAlgorithm.SHA512.value, - '12345', - '3627909a29c31381a071ec27f7c9ca97726182aed29a7ddd2e54353322cfb30abb9e3a6df2ac2c20fe23436311d678564d0c8d305930575f60e2d3d048184d79' - ) - ]) + "12345", + "3627909a29c31381a071ec27f7c9ca97726182aed29a7ddd2e54353322cfb30abb9e3a6df2ac2c20fe23436311d678564d0c8d305930575f60e2d3d048184d79", + ), + ] + ) def test_create(self, _, hashing_algorithm, value, expected_value): # hasher_factory = HasherFactory() diff --git a/tests/lcp/test_importer.py b/tests/lcp/test_importer.py index 8c3db92306..d24fba0d61 100644 --- a/tests/lcp/test_importer.py +++ b/tests/lcp/test_importer.py @@ -9,16 +9,16 @@ class TestLCPImporter(object): def test_import_book(self): # Arrange - file_path = '/opt/readium/raw_books/book.epub' - identifier = '123456789' + file_path = "/opt/readium/raw_books/book.epub" + identifier = "123456789" encrypted_content = LCPEncryptionResult( - content_id='1', - content_encryption_key='12345', - protected_content_location='/opt/readium/files/encrypted', - protected_content_disposition='encrypted_book', - protected_content_type='application/epub+zip', + content_id="1", + content_encryption_key="12345", + protected_content_location="/opt/readium/files/encrypted", + protected_content_disposition="encrypted_book", + protected_content_type="application/epub+zip", protected_content_length=12345, - protected_content_sha256='12345' + protected_content_sha256="12345", ) lcp_encryptor = create_autospec(spec=LCPEncryptor) lcp_encryptor.encrypt = MagicMock(return_value=encrypted_content) @@ -33,4 +33,3 @@ def test_import_book(self): # Assert lcp_encryptor.encrypt.assert_called_once_with(db, file_path, identifier) lcp_server.add_content.assert_called_once_with(db, encrypted_content) - diff --git a/tests/lcp/test_mirror.py b/tests/lcp/test_mirror.py index 8661505713..7f96975ad4 100644 --- a/tests/lcp/test_mirror.py +++ b/tests/lcp/test_mirror.py @@ -1,9 +1,9 @@ -from mock import create_autospec, patch, ANY +from mock import ANY, create_autospec, patch from api.lcp.importer import LCPImporter from api.lcp.mirror import LCPMirror -from core.model import ExternalIntegration, Identifier, DataSource, Representation -from core.s3 import S3UploaderConfiguration, MinIOUploaderConfiguration +from core.model import DataSource, ExternalIntegration, Identifier, Representation +from core.s3 import MinIOUploaderConfiguration, S3UploaderConfiguration from tests.lcp.database_test import DatabaseTest @@ -12,11 +12,13 @@ def setup_method(self): super(TestLCPMirror, self).setup_method() settings = { - S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: 'encrypted-books', - MinIOUploaderConfiguration.ENDPOINT_URL: 'http://minio' + S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: "encrypted-books", + MinIOUploaderConfiguration.ENDPOINT_URL: "http://minio", } integration = self._external_integration( - ExternalIntegration.LCP, goal=ExternalIntegration.STORAGE_GOAL, settings=settings + ExternalIntegration.LCP, + goal=ExternalIntegration.STORAGE_GOAL, + settings=settings, ) self._lcp_collection = self._collection(protocol=ExternalIntegration.LCP) self._lcp_mirror = LCPMirror(integration) @@ -24,25 +26,31 @@ def setup_method(self): def test_book_url(self): # Arrange data_source = DataSource.lookup(self._db, DataSource.LCP, autocreate=True) - identifier = Identifier(identifier='12345', type=Identifier.ISBN) + identifier = Identifier(identifier="12345", type=Identifier.ISBN) # Act result = self._lcp_mirror.book_url(identifier, data_source=data_source) # Assert - assert result == 'http://encrypted-books.minio/12345' + assert result == "http://encrypted-books.minio/12345" def test_mirror_one(self): # Arrange - expected_identifier = '12345' - mirror_url = 'http://encrypted-books.minio/' + expected_identifier + expected_identifier = "12345" + mirror_url = "http://encrypted-books.minio/" + expected_identifier lcp_importer = create_autospec(spec=LCPImporter) - representation, _ = self._representation(media_type=Representation.EPUB_MEDIA_TYPE, content='12345') + representation, _ = self._representation( + media_type=Representation.EPUB_MEDIA_TYPE, content="12345" + ) # Act - with patch('api.lcp.mirror.LCPImporter') as lcp_importer_constructor: + with patch("api.lcp.mirror.LCPImporter") as lcp_importer_constructor: lcp_importer_constructor.return_value = lcp_importer - self._lcp_mirror.mirror_one(representation, mirror_to=mirror_url, collection=self._lcp_collection) + self._lcp_mirror.mirror_one( + representation, mirror_to=mirror_url, collection=self._lcp_collection + ) # Assert - lcp_importer.import_book.assert_called_once_with(self._db, ANY, expected_identifier) + lcp_importer.import_book.assert_called_once_with( + self._db, ANY, expected_identifier + ) diff --git a/tests/lcp/test_server.py b/tests/lcp/test_server.py index 51eef2eba2..d40f84449b 100644 --- a/tests/lcp/test_server.py +++ b/tests/lcp/test_server.py @@ -4,7 +4,7 @@ import urllib.parse import requests_mock -from mock import create_autospec, MagicMock +from mock import MagicMock, create_autospec from parameterized import parameterized from api.lcp import utils @@ -12,8 +12,12 @@ from api.lcp.hash import HasherFactory from api.lcp.server import LCPServer, LCPServerConfiguration from core.lcp.credential import LCPCredentialFactory -from core.model.configuration import HasExternalIntegration, ConfigurationStorage, ConfigurationFactory, \ - ExternalIntegration +from core.model.configuration import ( + ConfigurationFactory, + ConfigurationStorage, + ExternalIntegration, + HasExternalIntegration, +) from tests.lcp import fixtures from tests.lcp.database_test import DatabaseTest @@ -25,49 +29,64 @@ def setup_method(self): self._lcp_collection = self._collection(protocol=ExternalIntegration.LCP) self._integration = self._lcp_collection.external_integration integration_owner = create_autospec(spec=HasExternalIntegration) - integration_owner.external_integration = MagicMock(return_value=self._integration) + integration_owner.external_integration = MagicMock( + return_value=self._integration + ) self._configuration_storage = ConfigurationStorage(integration_owner) self._configuration_factory = ConfigurationFactory() self._hasher_factory = HasherFactory() self._credential_factory = LCPCredentialFactory() self._lcp_server = LCPServer( - self._configuration_storage, self._configuration_factory, self._hasher_factory, self._credential_factory + self._configuration_storage, + self._configuration_factory, + self._hasher_factory, + self._credential_factory, ) @parameterized.expand( [ - ('empty_input_directory', ''), - ('non_empty_input_directory', '/tmp/encrypted_books') + ("empty_input_directory", ""), + ("non_empty_input_directory", "/tmp/encrypted_books"), ] ) def test_add_content(self, _, input_directory): # Arrange lcp_server = LCPServer( - self._configuration_storage, self._configuration_factory, self._hasher_factory, self._credential_factory) + self._configuration_storage, + self._configuration_factory, + self._hasher_factory, + self._credential_factory, + ) encrypted_content = LCPEncryptionResult( content_id=fixtures.CONTENT_ID, - content_encryption_key='12345', - protected_content_location='/opt/readium/files/encrypted', - protected_content_disposition='encrypted_book', - protected_content_type='application/epub+zip', + content_encryption_key="12345", + protected_content_location="/opt/readium/files/encrypted", + protected_content_disposition="encrypted_book", + protected_content_type="application/epub+zip", protected_content_length=12345, - protected_content_sha256='12345' + protected_content_sha256="12345", ) expected_protected_content_disposition = os.path.join( - input_directory, encrypted_content.protected_content_disposition) + input_directory, encrypted_content.protected_content_disposition + ) with self._configuration_factory.create( - self._configuration_storage, self._db, LCPServerConfiguration) as configuration: + self._configuration_storage, self._db, LCPServerConfiguration + ) as configuration: configuration.lcpserver_url = fixtures.LCPSERVER_URL configuration.lcpserver_user = fixtures.LCPSERVER_USER configuration.lcpserver_password = fixtures.LCPSERVER_PASSWORD configuration.lcpserver_input_directory = input_directory configuration.provider_name = fixtures.PROVIDER_NAME configuration.passphrase_hint = fixtures.TEXT_HINT - configuration.encryption_algorithm = LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + configuration.encryption_algorithm = ( + LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + ) with requests_mock.Mocker() as request_mock: - url = urllib.parse.urljoin(fixtures.LCPSERVER_URL, '/contents/{0}'.format(fixtures.CONTENT_ID)) + url = urllib.parse.urljoin( + fixtures.LCPSERVER_URL, "/contents/{0}".format(fixtures.CONTENT_ID) + ) request_mock.put(url) # Act @@ -77,142 +96,161 @@ def test_add_content(self, _, input_directory): assert request_mock.called == True json_request = json.loads(request_mock.last_request.text) - assert json_request['content-id'] == encrypted_content.content_id - assert json_request['content-encryption-key'] == encrypted_content.content_encryption_key - assert json_request['protected-content-location'] == expected_protected_content_disposition - assert json_request['protected-content-disposition'] == encrypted_content.protected_content_disposition - assert json_request['protected-content-type'] == encrypted_content.protected_content_type - assert json_request['protected-content-length'] == encrypted_content.protected_content_length - assert json_request['protected-content-sha256'] == encrypted_content.protected_content_sha256 - - @parameterized.expand([ - ('none_rights', None, None, None, None), - ( - 'license_start', + assert json_request["content-id"] == encrypted_content.content_id + assert ( + json_request["content-encryption-key"] + == encrypted_content.content_encryption_key + ) + assert ( + json_request["protected-content-location"] + == expected_protected_content_disposition + ) + assert ( + json_request["protected-content-disposition"] + == encrypted_content.protected_content_disposition + ) + assert ( + json_request["protected-content-type"] + == encrypted_content.protected_content_type + ) + assert ( + json_request["protected-content-length"] + == encrypted_content.protected_content_length + ) + assert ( + json_request["protected-content-sha256"] + == encrypted_content.protected_content_sha256 + ) + + @parameterized.expand( + [ + ("none_rights", None, None, None, None), + ( + "license_start", datetime.datetime(2020, 1, 1, 00, 00, 00), None, None, - None - ), - ( - 'license_end', - None, - datetime.datetime(2020, 12, 31, 23, 59, 59), - None, - None - ), - ( - 'max_printable_pages', - None, - None, - 10, - None - ), - ( - 'max_printable_pages_empty_max_copiable_pages', - None, - None, - 10, - '' - ), - ( - 'empty_max_printable_pages', - None, - None, - '', - None - ), - ( - 'max_copiable_pages', None, + ), + ( + "license_end", None, - None, - 1024 - ), - ( - 'empty_max_printable_pages_max_copiable_pages', - None, - None, - '', - 1024 - ), - ( - 'empty_max_copiable_pages', - None, + datetime.datetime(2020, 12, 31, 23, 59, 59), None, None, - '' - ), - ( - 'dates', + ), + ("max_printable_pages", None, None, 10, None), + ("max_printable_pages_empty_max_copiable_pages", None, None, 10, ""), + ("empty_max_printable_pages", None, None, "", None), + ("max_copiable_pages", None, None, None, 1024), + ("empty_max_printable_pages_max_copiable_pages", None, None, "", 1024), + ("empty_max_copiable_pages", None, None, None, ""), + ( + "dates", datetime.datetime(2020, 1, 1, 00, 00, 00), datetime.datetime(2020, 12, 31, 23, 59, 59), None, - None - ), - ( - 'full_rights', + None, + ), + ( + "full_rights", datetime.datetime(2020, 1, 1, 00, 00, 00), datetime.datetime(2020, 12, 31, 23, 59, 59), 10, - 1024 - ), - ]) - def test_generate_license(self, _, license_start, license_end, max_printable_pages, max_copiable_pages): + 1024, + ), + ] + ) + def test_generate_license( + self, _, license_start, license_end, max_printable_pages, max_copiable_pages + ): # Arrange patron = self._patron() - expected_patron_id = '52a190d1-cd69-4794-9d7a-1ec50392697f' - expected_patron_passphrase = '52a190d1-cd69-4794-9d7a-1ec50392697a' - expected_patron_key = self._hasher_factory \ - .create(LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM) \ - .hash(expected_patron_passphrase) + expected_patron_id = "52a190d1-cd69-4794-9d7a-1ec50392697f" + expected_patron_passphrase = "52a190d1-cd69-4794-9d7a-1ec50392697a" + expected_patron_key = self._hasher_factory.create( + LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + ).hash(expected_patron_passphrase) with self._configuration_factory.create( - self._configuration_storage, self._db, LCPServerConfiguration) as configuration: + self._configuration_storage, self._db, LCPServerConfiguration + ) as configuration: configuration.lcpserver_url = fixtures.LCPSERVER_URL configuration.lcpserver_user = fixtures.LCPSERVER_USER configuration.lcpserver_password = fixtures.LCPSERVER_PASSWORD configuration.provider_name = fixtures.PROVIDER_NAME configuration.passphrase_hint = fixtures.TEXT_HINT - configuration.encryption_algorithm = LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + configuration.encryption_algorithm = ( + LCPServerConfiguration.DEFAULT_ENCRYPTION_ALGORITHM + ) configuration.max_printable_pages = max_printable_pages configuration.max_copiable_pages = max_copiable_pages - self._credential_factory.get_patron_id = MagicMock(return_value=expected_patron_id) - self._credential_factory.get_patron_passphrase = MagicMock(return_value=expected_patron_passphrase) + self._credential_factory.get_patron_id = MagicMock( + return_value=expected_patron_id + ) + self._credential_factory.get_patron_passphrase = MagicMock( + return_value=expected_patron_passphrase + ) with requests_mock.Mocker() as request_mock: - url = urllib.parse.urljoin(fixtures.LCPSERVER_URL, '/contents/{0}/license'.format(fixtures.CONTENT_ID)) + url = urllib.parse.urljoin( + fixtures.LCPSERVER_URL, + "/contents/{0}/license".format(fixtures.CONTENT_ID), + ) request_mock.post(url, json=fixtures.LCPSERVER_LICENSE) # Act license = self._lcp_server.generate_license( - self._db, fixtures.CONTENT_ID, patron, license_start, license_end) + self._db, fixtures.CONTENT_ID, patron, license_start, license_end + ) # Assert assert request_mock.called == True assert license == fixtures.LCPSERVER_LICENSE json_request = json.loads(request_mock.last_request.text) - assert json_request['provider'] == fixtures.PROVIDER_NAME - assert json_request['user']['id'] == expected_patron_id - assert json_request['encryption']['user_key']['text_hint'] == fixtures.TEXT_HINT - assert json_request['encryption']['user_key']['hex_value'] == expected_patron_key + assert json_request["provider"] == fixtures.PROVIDER_NAME + assert json_request["user"]["id"] == expected_patron_id + assert ( + json_request["encryption"]["user_key"]["text_hint"] + == fixtures.TEXT_HINT + ) + assert ( + json_request["encryption"]["user_key"]["hex_value"] + == expected_patron_key + ) if license_start is not None: - assert json_request['rights']['start'] == utils.format_datetime(license_start) + assert json_request["rights"]["start"] == utils.format_datetime( + license_start + ) if license_end is not None: - assert json_request['rights']['end'] == utils.format_datetime(license_end) - if max_printable_pages is not None and max_printable_pages != '': - assert json_request['rights']['print'] == max_printable_pages - if max_copiable_pages is not None and max_copiable_pages != '': - assert json_request['rights']['copy'] == max_copiable_pages + assert json_request["rights"]["end"] == utils.format_datetime( + license_end + ) + if max_printable_pages is not None and max_printable_pages != "": + assert json_request["rights"]["print"] == max_printable_pages + if max_copiable_pages is not None and max_copiable_pages != "": + assert json_request["rights"]["copy"] == max_copiable_pages all_rights_fields_are_empty = all( - [rights_field is None or rights_field == '' for rights_field in [license_start, license_end, max_printable_pages, max_copiable_pages]] + [ + rights_field is None or rights_field == "" + for rights_field in [ + license_start, + license_end, + max_printable_pages, + max_copiable_pages, + ] + ] ) if all_rights_fields_are_empty: - assert ('rights' in json_request) == False + assert ("rights" in json_request) == False - self._credential_factory.get_patron_id.assert_called_once_with(self._db, patron) - self._credential_factory.get_patron_passphrase.assert_called_once_with(self._db, patron) + self._credential_factory.get_patron_id.assert_called_once_with( + self._db, patron + ) + self._credential_factory.get_patron_passphrase.assert_called_once_with( + self._db, patron + ) diff --git a/tests/mock_authentication_provider.py b/tests/mock_authentication_provider.py index 02e3922d62..24649597a5 100644 --- a/tests/mock_authentication_provider.py +++ b/tests/mock_authentication_provider.py @@ -3,7 +3,16 @@ class MockExplodingAuthenticationProvider(BasicAuthenticationProvider): - def __init__(self, library, integration, analytics=None, patron=None, patrondata=None, *args, **kwargs): + def __init__( + self, + library, + integration, + analytics=None, + patron=None, + patrondata=None, + *args, + **kwargs + ): raise RemoteIntegrationException("Mock", "Mock exploded.") def authenticate(self, _db, header): diff --git a/tests/proquest/test_client.py b/tests/proquest/test_client.py index 64f3f1812a..bd2c102bf9 100644 --- a/tests/proquest/test_client.py +++ b/tests/proquest/test_client.py @@ -344,7 +344,8 @@ def test_get_book_correctly_extracts_acsm_books( """ download_link = "https://proquest.com/fulfill?documentID=12345" expected_acsm_book = ProQuestBook( - content=acsm_file_content.encode("utf-8"), content_type=DeliveryMechanism.ADOBE_DRM + content=acsm_file_content.encode("utf-8"), + content_type=DeliveryMechanism.ADOBE_DRM, ) first_response_arguments = { diff --git a/tests/proquest/test_importer.py b/tests/proquest/test_importer.py index d386aef6c1..15c3604619 100644 --- a/tests/proquest/test_importer.py +++ b/tests/proquest/test_importer.py @@ -53,7 +53,7 @@ ) from core.opds2_import import RWPMManifestParser from core.testing import DatabaseTest -from core.util.datetime_helpers import utc_now, datetime_utc +from core.util.datetime_helpers import datetime_utc, utc_now from tests.proquest import fixtures @@ -641,9 +641,7 @@ def test_fulfil_lookups_for_existing_token(self): # Arrange proquest_token = "1234567890" - proquest_token_expires_in = utc_now() + datetime.timedelta( - hours=1 - ) + proquest_token_expires_in = utc_now() + datetime.timedelta(hours=1) proquest_credential = Credential( credential=proquest_token, expires=proquest_token_expires_in ) @@ -956,16 +954,12 @@ def test_fulfil_refreshes_expired_token(self): # Arrange affiliation_id = "12345" expired_proquest_token = "1234567890" - expired_proquest_token_expired_in = ( - utc_now() - datetime.timedelta(minutes=1) - ) + expired_proquest_token_expired_in = utc_now() - datetime.timedelta(minutes=1) expired_proquest_token_credential = Credential( credential=expired_proquest_token, expires=expired_proquest_token_expired_in ) new_proquest_token = "1234567890_" - new_proquest_token_expires_in = utc_now() + datetime.timedelta( - hours=1 - ) + new_proquest_token_expires_in = utc_now() + datetime.timedelta(hours=1) new_proquest_token_credential = Credential( credential=new_proquest_token, expires=new_proquest_token_expires_in ) @@ -1176,7 +1170,9 @@ def test_monitor_correctly_processes_pages(self, _, feeds, expected_calls): ProQuestOPDS2Importer, RWPMManifestParser(OPDS2FeedParserFactory()), ) - monitor._get_feeds = MagicMock(return_value=list(zip([None] * len(feeds), feeds))) + monitor._get_feeds = MagicMock( + return_value=list(zip([None] * len(feeds), feeds)) + ) monitor.import_one_feed = MagicMock(return_value=([], [])) # Act @@ -1428,7 +1424,9 @@ def test_monitor_correctly_does_not_process_already_processed_pages(self): ProQuestOPDS2Importer, RWPMManifestParser(OPDS2FeedParserFactory()), ) - monitor._get_feeds = MagicMock(return_value=list(zip([None] * len(feeds), feeds))) + monitor._get_feeds = MagicMock( + return_value=list(zip([None] * len(feeds), feeds)) + ) monitor.import_one_feed = MagicMock(return_value=([], [])) # Act diff --git a/tests/saml/configuration/test_model.py b/tests/saml/configuration/test_model.py index 6036ec411e..666cc45256 100644 --- a/tests/saml/configuration/test_model.py +++ b/tests/saml/configuration/test_model.py @@ -140,10 +140,14 @@ def test_get_identity_providers_returns_non_federated_idps(self): # Assert assert 2 == len(identity_providers) - assert True == isinstance(identity_providers[0], SAMLIdentityProviderMetadata) + assert True == isinstance( + identity_providers[0], SAMLIdentityProviderMetadata + ) assert fixtures.IDP_1_ENTITY_ID == identity_providers[0].entity_id - assert True == isinstance(identity_providers[1], SAMLIdentityProviderMetadata) + assert True == isinstance( + identity_providers[1], SAMLIdentityProviderMetadata + ) assert fixtures.IDP_2_ENTITY_ID == identity_providers[1].entity_id configuration_storage.load.assert_has_calls( @@ -202,10 +206,14 @@ def test_get_identity_providers_returns_federated_idps(self): # Assert assert 2 == len(identity_providers) - assert True == isinstance(identity_providers[0], SAMLIdentityProviderMetadata) + assert True == isinstance( + identity_providers[0], SAMLIdentityProviderMetadata + ) assert fixtures.IDP_1_ENTITY_ID == identity_providers[0].entity_id - assert True == isinstance(identity_providers[1], SAMLIdentityProviderMetadata) + assert True == isinstance( + identity_providers[1], SAMLIdentityProviderMetadata + ) assert fixtures.IDP_2_ENTITY_ID == identity_providers[1].entity_id configuration_storage.load.assert_has_calls( @@ -273,16 +281,24 @@ def test_get_identity_providers_returns_both_non_federated_and_federated_idps(se # Assert assert 4 == len(identity_providers) - assert True == isinstance(identity_providers[0], SAMLIdentityProviderMetadata) + assert True == isinstance( + identity_providers[0], SAMLIdentityProviderMetadata + ) assert fixtures.IDP_1_ENTITY_ID == identity_providers[0].entity_id - assert True == isinstance(identity_providers[1], SAMLIdentityProviderMetadata) + assert True == isinstance( + identity_providers[1], SAMLIdentityProviderMetadata + ) assert fixtures.IDP_2_ENTITY_ID == identity_providers[1].entity_id - assert True == isinstance(identity_providers[2], SAMLIdentityProviderMetadata) + assert True == isinstance( + identity_providers[2], SAMLIdentityProviderMetadata + ) assert fixtures.IDP_1_ENTITY_ID == identity_providers[2].entity_id - assert True == isinstance(identity_providers[3], SAMLIdentityProviderMetadata) + assert True == isinstance( + identity_providers[3], SAMLIdentityProviderMetadata + ) assert fixtures.IDP_2_ENTITY_ID == identity_providers[3].entity_id configuration_storage.load.assert_has_calls( @@ -515,7 +531,7 @@ def test_get_settings_returns_correct_result(self): "wantAssertionsEncrypted": False, "nameIdEncrypted": False, "signatureAlgorithm": "http://www.w3.org/2000/09/xmldsig#rsa-sha1", - "allowRepeatAttributeName": False + "allowRepeatAttributeName": False, }, } db = create_autospec(spec=sqlalchemy.orm.session.Session) diff --git a/tests/saml/configuration/test_validator.py b/tests/saml/configuration/test_validator.py index b27b15fffb..011fc36cbd 100644 --- a/tests/saml/configuration/test_validator.py +++ b/tests/saml/configuration/test_validator.py @@ -107,7 +107,12 @@ def setup_class(cls): ] ) def test_validate( - self, _, sp_xml_metadata, idp_xml_metadata, patron_id_regular_expression, expected_validation_result + self, + _, + sp_xml_metadata, + idp_xml_metadata, + patron_id_regular_expression, + expected_validation_result, ): """Ensure that SAMLSettingsValidator correctly validates the input data. diff --git a/tests/saml/controller_test.py b/tests/saml/controller_test.py index b664de5003..ba2d94bcec 100644 --- a/tests/saml/controller_test.py +++ b/tests/saml/controller_test.py @@ -4,7 +4,6 @@ class ControllerTest(BaseControllerTest): - def setup_method(self): self._integration = None super(ControllerTest, self).setup_method() diff --git a/tests/saml/metadata/test_model.py b/tests/saml/metadata/test_model.py index 100ccbb533..3b35ac7586 100644 --- a/tests/saml/metadata/test_model.py +++ b/tests/saml/metadata/test_model.py @@ -27,18 +27,19 @@ def test_init_accepts_list_of_attributes(self): # Assert assert True == (SAMLAttributeType.uid.name in attribute_statement.attributes) assert ( - attributes[0].values == - attribute_statement.attributes[SAMLAttributeType.uid.name].values) + attributes[0].values + == attribute_statement.attributes[SAMLAttributeType.uid.name].values + ) + assert True == ( + SAMLAttributeType.eduPersonTargetedID.name in attribute_statement.attributes + ) assert ( - True == - (SAMLAttributeType.eduPersonTargetedID.name - in attribute_statement.attributes)) - assert ( - attributes[1].values == - attribute_statement.attributes[ + attributes[1].values + == attribute_statement.attributes[ SAMLAttributeType.eduPersonTargetedID.name - ].values) + ].values + ) class TestSAMLSubjectPatronIDExtractor(object): diff --git a/tests/saml/metadata/test_parser.py b/tests/saml/metadata/test_parser.py index 014f4b1fd4..7d1c65275f 100644 --- a/tests/saml/metadata/test_parser.py +++ b/tests/saml/metadata/test_parser.py @@ -77,8 +77,9 @@ def test_parse_does_not_raise_exception_when_xml_metadata_does_not_have_display_ [parsing_result] = parsing_results assert True == isinstance(parsing_result, SAMLMetadataParsingResult) assert True == isinstance(parsing_result.provider, SAMLIdentityProviderMetadata) - assert ( - True == isinstance(parsing_result.xml_node, onelogin.saml2.xmlparser.RestrictedElement)) + assert True == isinstance( + parsing_result.xml_node, onelogin.saml2.xmlparser.RestrictedElement + ) assert ( SAMLIdentityProviderMetadata( entity_id=fixtures.IDP_1_ENTITY_ID, @@ -95,8 +96,9 @@ def test_parse_does_not_raise_exception_when_xml_metadata_does_not_have_display_ encryption_certificates=[ fixtures.strip_certificate(fixtures.ENCRYPTION_CERTIFICATE) ], - ) == - parsing_result.provider) + ) + == parsing_result.provider + ) def test_parse_correctly_parses_one_idp_metadata(self): # Arrange @@ -111,8 +113,9 @@ def test_parse_correctly_parses_one_idp_metadata(self): [parsing_result] = parsing_results assert True == isinstance(parsing_result, SAMLMetadataParsingResult) assert True == isinstance(parsing_result.provider, SAMLIdentityProviderMetadata) - assert ( - True == isinstance(parsing_result.xml_node, onelogin.saml2.xmlparser.RestrictedElement)) + assert True == isinstance( + parsing_result.xml_node, onelogin.saml2.xmlparser.RestrictedElement + ) assert ( SAMLIdentityProviderMetadata( entity_id=fixtures.IDP_1_ENTITY_ID, @@ -181,8 +184,9 @@ def test_parse_correctly_parses_one_idp_metadata(self): encryption_certificates=[ fixtures.strip_certificate(fixtures.ENCRYPTION_CERTIFICATE) ], - ) == - parsing_result.provider) + ) + == parsing_result.provider + ) def test_parse_correctly_parses_idp_metadata_without_name_id_format(self): # Arrange @@ -197,8 +201,9 @@ def test_parse_correctly_parses_idp_metadata_without_name_id_format(self): [parsing_result] = parsing_results assert True == isinstance(parsing_result, SAMLMetadataParsingResult) assert True == isinstance(parsing_result.provider, SAMLIdentityProviderMetadata) - assert ( - True == isinstance(parsing_result.xml_node, onelogin.saml2.xmlparser.RestrictedElement)) + assert True == isinstance( + parsing_result.xml_node, onelogin.saml2.xmlparser.RestrictedElement + ) assert ( SAMLIdentityProviderMetadata( entity_id=fixtures.IDP_1_ENTITY_ID, @@ -267,8 +272,9 @@ def test_parse_correctly_parses_idp_metadata_without_name_id_format(self): encryption_certificates=[ fixtures.strip_certificate(fixtures.ENCRYPTION_CERTIFICATE) ], - ) == - parsing_result.provider) + ) + == parsing_result.provider + ) def test_parse_correctly_parses_idp_metadata_with_one_certificate(self): # Arrange @@ -285,8 +291,9 @@ def test_parse_correctly_parses_idp_metadata_with_one_certificate(self): assert True == isinstance(parsing_result, SAMLMetadataParsingResult) assert True == isinstance(parsing_result.provider, SAMLIdentityProviderMetadata) - assert ( - True == isinstance(parsing_result.xml_node, onelogin.saml2.xmlparser.RestrictedElement)) + assert True == isinstance( + parsing_result.xml_node, onelogin.saml2.xmlparser.RestrictedElement + ) assert ( SAMLIdentityProviderMetadata( entity_id=fixtures.IDP_1_ENTITY_ID, @@ -355,8 +362,9 @@ def test_parse_correctly_parses_idp_metadata_with_one_certificate(self): encryption_certificates=[ fixtures.strip_certificate(fixtures.SIGNING_CERTIFICATE) ], - ) == - parsing_result.provider) + ) + == parsing_result.provider + ) def test_parse_correctly_parses_metadata_with_multiple_descriptors(self): # Arrange @@ -368,10 +376,12 @@ def test_parse_correctly_parses_metadata_with_multiple_descriptors(self): # Assert assert 2 == len(parsing_results) assert True == isinstance(parsing_results[0], SAMLMetadataParsingResult) - assert True == isinstance(parsing_results[0].provider, SAMLIdentityProviderMetadata) - assert ( - True == - isinstance(parsing_results[0].xml_node, onelogin.saml2.xmlparser.RestrictedElement)) + assert True == isinstance( + parsing_results[0].provider, SAMLIdentityProviderMetadata + ) + assert True == isinstance( + parsing_results[0].xml_node, onelogin.saml2.xmlparser.RestrictedElement + ) assert ( SAMLIdentityProviderMetadata( entity_id=fixtures.IDP_1_ENTITY_ID, @@ -424,14 +434,17 @@ def test_parse_correctly_parses_metadata_with_multiple_descriptors(self): encryption_certificates=[ fixtures.strip_certificate(fixtures.ENCRYPTION_CERTIFICATE) ], - ) == - parsing_results[0].provider) + ) + == parsing_results[0].provider + ) assert True == isinstance(parsing_results[1], SAMLMetadataParsingResult) - assert True == isinstance(parsing_results[1].provider, SAMLIdentityProviderMetadata) - assert ( - True == - isinstance(parsing_results[1].xml_node, onelogin.saml2.xmlparser.RestrictedElement)) + assert True == isinstance( + parsing_results[1].provider, SAMLIdentityProviderMetadata + ) + assert True == isinstance( + parsing_results[1].xml_node, onelogin.saml2.xmlparser.RestrictedElement + ) assert ( SAMLIdentityProviderMetadata( entity_id=fixtures.IDP_2_ENTITY_ID, @@ -484,8 +497,9 @@ def test_parse_correctly_parses_metadata_with_multiple_descriptors(self): encryption_certificates=[ fixtures.strip_certificate(fixtures.ENCRYPTION_CERTIFICATE) ], - ) == - parsing_results[1].provider) + ) + == parsing_results[1].provider + ) def test_parse_raises_exception_when_sp_metadata_does_not_contain_acs_service(self): # Arrange @@ -510,8 +524,9 @@ def test_parse_correctly_parses_one_sp_metadata(self): [parsing_result] = parsing_results assert True == isinstance(parsing_result, SAMLMetadataParsingResult) assert True == isinstance(parsing_result.provider, SAMLServiceProviderMetadata) - assert ( - True == isinstance(parsing_result.xml_node, onelogin.saml2.xmlparser.RestrictedElement)) + assert True == isinstance( + parsing_result.xml_node, onelogin.saml2.xmlparser.RestrictedElement + ) assert ( SAMLServiceProviderMetadata( @@ -569,8 +584,9 @@ def test_parse_correctly_parses_one_sp_metadata(self): authn_requests_signed=False, want_assertions_signed=False, certificate=fixtures.strip_certificate(fixtures.SIGNING_CERTIFICATE), - ) == - parsing_result.provider) + ) + == parsing_result.provider + ) class TestSAMLSubjectParser(object): diff --git a/tests/saml/test_auth.py b/tests/saml/test_auth.py index 21ec617ad3..c09943068a 100644 --- a/tests/saml/test_auth.py +++ b/tests/saml/test_auth.py @@ -220,8 +220,9 @@ def test_start_authentication(self, _, service_provider, identity_providers): acs_binding = saml_request_dom.get("ProtocolBinding") assert ( - acs_binding == - SERVICE_PROVIDER_WITH_UNSIGNED_REQUESTS.acs_service.binding.value) + acs_binding + == SERVICE_PROVIDER_WITH_UNSIGNED_REQUESTS.acs_service.binding.value + ) sso_url = saml_request_dom.get("Destination") assert sso_url == IDENTITY_PROVIDERS[0].sso_service.url @@ -236,7 +237,9 @@ def test_start_authentication(self, _, service_provider, identity_providers): name_id_policy_node = name_id_policy_nodes[0] name_id_format = name_id_policy_node.get("Format") - assert name_id_format == SERVICE_PROVIDER_WITH_UNSIGNED_REQUESTS.name_id_format + assert ( + name_id_format == SERVICE_PROVIDER_WITH_UNSIGNED_REQUESTS.name_id_format + ) @parameterized.expand( [ diff --git a/tests/saml/test_controller.py b/tests/saml/test_controller.py index 046b890101..a802d25de2 100644 --- a/tests/saml/test_controller.py +++ b/tests/saml/test_controller.py @@ -224,8 +224,9 @@ def test_saml_authentication_redirect( assert result.response == expected_problem.response else: assert 302 == result.status_code - assert ( - expected_authentication_redirect_uri == result.headers.get("Location")) + assert expected_authentication_redirect_uri == result.headers.get( + "Location" + ) authentication_manager.start_authentication.assert_called_once_with( self._db, idp_entity_id, expected_relay_state @@ -427,7 +428,9 @@ def test_saml_authentication_callback( else: assert result.status_code == 302 assert ( - result.headers.get("Location") == expected_authentication_redirect_uri) + result.headers.get("Location") + == expected_authentication_redirect_uri + ) authentication_manager.finish_authentication.assert_called_once_with( self._db, IDENTITY_PROVIDERS[0].entity_id diff --git a/tests/sip/test_authentication_provider.py b/tests/sip/test_authentication_provider.py index 89fb1015db..0e15eb1f1d 100644 --- a/tests/sip/test_authentication_provider.py +++ b/tests/sip/test_authentication_provider.py @@ -1,18 +1,15 @@ +import json from datetime import datetime import pytest -from api.sip.client import ( - MockSIPClient, - MockSIPClientFactory, -) -from api.sip import SIP2AuthenticationProvider -from core.util.http import RemoteIntegrationException from api.authenticator import PatronData -import json +from api.sip import SIP2AuthenticationProvider +from api.sip.client import MockSIPClient, MockSIPClientFactory from core.config import CannotLoadConfiguration - from core.testing import DatabaseTest +from core.util.http import RemoteIntegrationException + class TestSIP2AuthenticationProvider(DatabaseTest): @@ -57,7 +54,7 @@ def test_initialize_from_integration(self): integration.setting(p.FIELD_SEPARATOR).value = "\t" integration.setting(p.INSTITUTION_ID).value = "MAIN" provider = p(self._default_library, integration) - + # A SIP2AuthenticationProvider was initialized based on the # integration values. assert "user1" == provider.login_user_id @@ -170,8 +167,7 @@ def test_remote_authenticate(self): assert "foo@bar.com" == patrondata.email_address assert 9.25 == patrondata.fines assert "Falk, Jen" == patrondata.personal_name - assert (datetime(2018, 6, 9, 23, 59, 59) == - patrondata.authorization_expires) + assert datetime(2018, 6, 9, 23, 59, 59) == patrondata.authorization_expires client.queue_response(self.polaris_wrong_pin) client.queue_response(self.end_session_response) @@ -181,8 +177,7 @@ def test_remote_authenticate(self): client.queue_response(self.polaris_expired_card) client.queue_response(self.end_session_response) patrondata = auth.remote_authenticate("user", "pass") - assert (datetime(2016, 10, 25, 23, 59, 59) == - patrondata.authorization_expires) + assert datetime(2016, 10, 25, 23, 59, 59) == patrondata.authorization_expires client.queue_response(self.polaris_excess_fines) client.queue_response(self.end_session_response) @@ -209,9 +204,7 @@ def test_remote_authenticate_no_password(self): integration = self._external_integration(self._str) p = SIP2AuthenticationProvider integration.setting(p.PASSWORD_KEYBOARD).value = p.NULL_KEYBOARD - auth = p( - self._default_library, integration, client=MockSIPClientFactory() - ) + auth = p(self._default_library, integration, client=MockSIPClientFactory()) client = auth._client # This Evergreen instance doesn't use passwords. client.queue_response(self.evergreen_active_user) @@ -230,8 +223,8 @@ def test_remote_authenticate_no_password(self): patrondata = auth.remote_authenticate("user2", "some password") assert "12345" == patrondata.authorization_identifier request = client.requests[-1] - assert b'user2' in request - assert b'some password' not in request + assert b"user2" in request + assert b"some password" not in request def test_encoding(self): # It's possible to specify an encoding other than CP850 @@ -241,9 +234,7 @@ def test_encoding(self): p = SIP2AuthenticationProvider integration = self._external_integration(self._str) integration.setting(p.ENCODING).value = "utf-8" - auth = p( - self._default_library, integration, client=MockSIPClientFactory() - ) + auth = p(self._default_library, integration, client=MockSIPClientFactory()) # Queue the UTF-8 version of the patron information # as opposed to the CP850 version. @@ -267,6 +258,7 @@ def test_ioerror_during_connect_becomes_remoteintegrationexception(self): """If the IP of the circulation manager has not been whitelisted, we generally can't even connect to the server. """ + class CannotConnect(MockSIPClient): def connect(self): raise IOError("Doom!") @@ -277,24 +269,31 @@ def connect(self): ) with pytest.raises(RemoteIntegrationException) as excinfo: - provider.remote_authenticate("username", "password",) + provider.remote_authenticate( + "username", + "password", + ) assert "Error accessing unknown server: Doom!" in str(excinfo.value) def test_ioerror_during_send_becomes_remoteintegrationexception(self): """If there's an IOError communicating with the server, it becomes a RemoteIntegrationException. """ + class CannotSend(MockSIPClient): def do_send(self, data): raise IOError("Doom!") integration = self._external_integration(self._str) - integration.url = 'server.local' + integration.url = "server.local" provider = SIP2AuthenticationProvider( self._default_library, integration, client=CannotSend ) with pytest.raises(RemoteIntegrationException) as excinfo: - provider.remote_authenticate("username", "password",) + provider.remote_authenticate( + "username", + "password", + ) assert "Error accessing server.local: Doom!" in str(excinfo.value) def test_parse_date(self): @@ -304,17 +303,20 @@ def test_parse_date(self): assert datetime(2011, 1, 2, 10, 20, 30) == parse("20110102UTC102030") def test_remote_patron_lookup(self): - #When the SIP authentication provider needs to look up a patron, - #it calls patron_information on its SIP client and passes in None - #for the password. + # When the SIP authentication provider needs to look up a patron, + # it calls patron_information on its SIP client and passes in None + # for the password. patron = self._patron() patron.authorization_identifier = "1234" integration = self._external_integration(self._str) + class Mock(MockSIPClient): def patron_information(self, identifier, password): self.patron_information = identifier self.password = password - return self.patron_information_parser(TestSIP2AuthenticationProvider.polaris_wrong_pin) + return self.patron_information_parser( + TestSIP2AuthenticationProvider.polaris_wrong_pin + ) client = Mock() client.queue_response(self.end_session_response) @@ -333,14 +335,16 @@ def patron_information(self, identifier, password): def test_info_to_patrondata_validate_password(self): integration = self._external_integration(self._str) - integration.url = 'server.local' + integration.url = "server.local" provider = SIP2AuthenticationProvider( self._default_library, integration, client=MockSIPClientFactory() ) client = provider._client # Test with valid login, should return PatronData - info = client.patron_information_parser(TestSIP2AuthenticationProvider.sierra_valid_login) + info = client.patron_information_parser( + TestSIP2AuthenticationProvider.sierra_valid_login + ) patron = provider.info_to_patrondata(info) assert patron.__class__ == PatronData assert "12345" == patron.authorization_identifier @@ -352,20 +356,24 @@ def test_info_to_patrondata_validate_password(self): assert PatronData.NO_VALUE == patron.block_reason # Test with invalid login, should return None - info = client.patron_information_parser(TestSIP2AuthenticationProvider.sierra_invalid_login) + info = client.patron_information_parser( + TestSIP2AuthenticationProvider.sierra_invalid_login + ) patron = provider.info_to_patrondata(info) assert None == patron def test_info_to_patrondata_no_validate_password(self): integration = self._external_integration(self._str) - integration.url = 'server.local' + integration.url = "server.local" provider = SIP2AuthenticationProvider( self._default_library, integration, client=MockSIPClientFactory() ) client = provider._client # Test with valid login, should return PatronData - info = client.patron_information_parser(TestSIP2AuthenticationProvider.sierra_valid_login) + info = client.patron_information_parser( + TestSIP2AuthenticationProvider.sierra_valid_login + ) patron = provider.info_to_patrondata(info, validate_password=False) assert patron.__class__ == PatronData assert "12345" == patron.authorization_identifier @@ -377,7 +385,9 @@ def test_info_to_patrondata_no_validate_password(self): assert PatronData.NO_VALUE == patron.block_reason # Test with invalid login, should return PatronData - info = client.patron_information_parser(TestSIP2AuthenticationProvider.sierra_invalid_login) + info = client.patron_information_parser( + TestSIP2AuthenticationProvider.sierra_invalid_login + ) patron = provider.info_to_patrondata(info, validate_password=False) assert patron.__class__ == PatronData assert "12345" == patron.authorization_identifier @@ -386,16 +396,25 @@ def test_info_to_patrondata_no_validate_password(self): assert 0 == patron.fines assert None == patron.authorization_expires assert None == patron.external_type - assert 'no borrowing privileges' == patron.block_reason + assert "no borrowing privileges" == patron.block_reason def test_patron_block_setting(self): - integration_block = self._external_integration(self._str, settings={SIP2AuthenticationProvider.PATRON_STATUS_BLOCK: "true"}) - integration_noblock = self._external_integration(self._str, settings={SIP2AuthenticationProvider.PATRON_STATUS_BLOCK: "false"}) + integration_block = self._external_integration( + self._str, settings={SIP2AuthenticationProvider.PATRON_STATUS_BLOCK: "true"} + ) + integration_noblock = self._external_integration( + self._str, + settings={SIP2AuthenticationProvider.PATRON_STATUS_BLOCK: "false"}, + ) # Test with blocked patron, block should be set - p = SIP2AuthenticationProvider(self._default_library, integration_block, client=MockSIPClientFactory()) + p = SIP2AuthenticationProvider( + self._default_library, integration_block, client=MockSIPClientFactory() + ) client = p._client - info = client.patron_information_parser(TestSIP2AuthenticationProvider.evergreen_expired_card) + info = client.patron_information_parser( + TestSIP2AuthenticationProvider.evergreen_expired_card + ) patron = p.info_to_patrondata(info) assert patron.__class__ == PatronData assert "12345" == patron.authorization_identifier @@ -407,11 +426,12 @@ def test_patron_block_setting(self): # Test with blocked patron, block should not be set p = SIP2AuthenticationProvider( - self._default_library, integration_noblock, - client=MockSIPClientFactory() + self._default_library, integration_noblock, client=MockSIPClientFactory() ) client = p._client - info = client.patron_information_parser(TestSIP2AuthenticationProvider.evergreen_expired_card) + info = client.patron_information_parser( + TestSIP2AuthenticationProvider.evergreen_expired_card + ) patron = p.info_to_patrondata(info) assert patron.__class__ == PatronData assert "12345" == patron.authorization_identifier @@ -433,11 +453,15 @@ def connect(self): class MockSIPLogin(MockSIPClient): def now(self): return datetime(2019, 1, 1).strftime("%Y%m%d0000%H%M%S") + def login(self): if not self.login_user_id and not self.login_password: raise IOError("Error logging in") + def patron_information(self, username, password): - return self.patron_information_parser(TestSIP2AuthenticationProvider.sierra_valid_login) + return self.patron_information_parser( + TestSIP2AuthenticationProvider.sierra_valid_login + ) auth = SIP2AuthenticationProvider( self._default_library, integration, client=MockBadConnection @@ -448,7 +472,7 @@ def patron_information(self, username, password): assert len(results) == 1 assert results[0].name == "Test Connection" assert results[0].success == False - assert(results[0].exception, IOError("Could not connect")) + assert (results[0].exception, IOError("Could not connect")) auth = SIP2AuthenticationProvider( self._default_library, integration, client=MockSIPLogin @@ -461,7 +485,7 @@ def patron_information(self, username, password): assert results[1].name == "Test Login with username 'None' and password 'None'" assert results[1].success == False - assert(results[1].exception, IOError("Error logging in")) + assert (results[1].exception, IOError("Error logging in")) # Set the log in username and password integration.username = "user1" @@ -476,17 +500,24 @@ def patron_information(self, username, password): assert results[0].name == "Test Connection" assert results[0].success == True - assert results[1].name == "Test Login with username 'user1' and password 'pass1'" + assert ( + results[1].name == "Test Login with username 'user1' and password 'pass1'" + ) assert results[1].success == True assert results[2].name == "Authenticating test patron" assert results[2].success == False - assert(results[2].exception, CannotLoadConfiguration("No test patron identifier is configured.")) - + assert ( + results[2].exception, + CannotLoadConfiguration("No test patron identifier is configured."), + ) # Now add the test patron credentials into the mocked client and SIP2 authenticator provider patronDataClient = MockSIPLogin(login_user_id="user1", login_password="pass1") - valid_login_patron = patronDataClient.patron_information_parser(TestSIP2AuthenticationProvider.sierra_valid_login) + valid_login_patron = patronDataClient.patron_information_parser( + TestSIP2AuthenticationProvider.sierra_valid_login + ) + class MockSIP2PatronInformation(SIP2AuthenticationProvider): def patron_information(self, username, password): return valid_login_patron @@ -503,7 +534,9 @@ def patron_information(self, username, password): assert results[0].name == "Test Connection" assert results[0].success == True - assert results[1].name == "Test Login with username 'user1' and password 'pass1'" + assert ( + results[1].name == "Test Login with username 'user1' and password 'pass1'" + ) assert results[1].success == True assert results[2].name == "Authenticating test patron" @@ -516,9 +549,10 @@ def patron_information(self, username, password): assert results[4].name == "Patron information request" assert results[4].success == True - assert results[4].result == patronDataClient.patron_information_request("usertest1", "userpassword1") + assert results[4].result == patronDataClient.patron_information_request( + "usertest1", "userpassword1" + ) assert results[5].name == "Raw test patron information" assert results[5].success == True assert results[5].result == json.dumps(valid_login_patron, indent=1) - diff --git a/tests/sip/test_client.py b/tests/sip/test_client.py index 5a968923c3..147f7b6824 100644 --- a/tests/sip/test_client.py +++ b/tests/sip/test_client.py @@ -1,20 +1,17 @@ """Standalone tests of the SIP2 client.""" -import pytest import os import socket import ssl -from api.sip.client import ( - MockSIPClient, - SIPClient, -) -from api.sip.dialect import ( - GenericILS, - AutoGraphicsVerso -) + +import pytest + +from api.sip.client import MockSIPClient, SIPClient +from api.sip.dialect import AutoGraphicsVerso, GenericILS + class MockSocket(object): def __init__(self, *args, **kwargs): - self.data = b'' + self.data = b"" self.args = args self.kwargs = kwargs self.timeout = None @@ -73,9 +70,7 @@ def test_secure_connect(self): target_server = object() insecure = SIPClient(target_server, 999, use_ssl=False) no_cert = SIPClient(target_server, 999, use_ssl=True) - with_cert = SIPClient( - target_server, 999, ssl_cert="cert", ssl_key="key" - ) + with_cert = SIPClient(target_server, 999, ssl_cert="cert", ssl_key="key") # Mock the socket.socket function. old_socket = socket.socket @@ -109,7 +104,7 @@ def test_secure_connect(self): with_cert.connect() connection, kwargs = wrap_socket.called_with assert isinstance(connection, MockSocket) - assert set(['keyfile', 'certfile']) == set(kwargs.keys()) + assert set(["keyfile", "certfile"]) == set(kwargs.keys()) for tmpfile in list(kwargs.values()): tmpfile = os.path.abspath(tmpfile) assert os.path.basename(tmpfile).startswith("tmp") @@ -140,10 +135,8 @@ def test_read_message(self): for data in ( # Simple message. b"abcd\n", - # Message that contains non-ASCII characters. "LE CARRÉ, JOHN\r".encode("cp850"), - # Message that spans multiple blocks. (b"a" * 4097) + b"\n", ): @@ -163,39 +156,39 @@ def test_read_message(self): # Un-mock the socket.socket function socket.socket = old_socket -class TestBasicProtocol(object): +class TestBasicProtocol(object): def test_login_message(self): sip = MockSIPClient() - message = sip.login_message('user_id', 'password') - assert '9300CNuser_id|COpassword' == message + message = sip.login_message("user_id", "password") + assert "9300CNuser_id|COpassword" == message def test_append_checksum(self): sip = MockSIPClient() - sip.sequence_number=7 + sip.sequence_number = 7 data = "some data" new_data = sip.append_checksum(data) assert "some data|AY7AZFAAA" == new_data def test_sequence_number_increment(self): - sip = MockSIPClient(login_user_id='user_id', login_password='password') - sip.sequence_number=0 - sip.queue_response('941') + sip = MockSIPClient(login_user_id="user_id", login_password="password") + sip.sequence_number = 0 + sip.queue_response("941") response = sip.login() assert 1 == sip.sequence_number # Test wraparound from 9 to 0 - sip.sequence_number=9 - sip.queue_response('941') + sip.sequence_number = 9 + sip.queue_response("941") response = sip.login() assert 0 == sip.sequence_number def test_resend(self): - sip = MockSIPClient(login_user_id='user_id', login_password='password') + sip = MockSIPClient(login_user_id="user_id", login_password="password") # The first response will be a request to resend the original message. - sip.queue_response('96') + sip.queue_response("96") # The second response will indicate a successful login. - sip.queue_response('941') + sip.queue_response("941") response = sip.login() @@ -203,24 +196,24 @@ def test_resend(self): req1, req2 = sip.requests # The first request includes a sequence ID field, "AY", with # the value "0". - assert b'9300CNuser_id|COpassword|AY0AZF556\r' == req1 + assert b"9300CNuser_id|COpassword|AY0AZF556\r" == req1 # The second request does not include a sequence ID field. As # a consequence its checksum is different. - assert b'9300CNuser_id|COpassword|AZF620\r' == req2 + assert b"9300CNuser_id|COpassword|AZF620\r" == req2 # The login request eventually succeeded. - assert {'login_ok': '1', '_status': '94'} == response + assert {"login_ok": "1", "_status": "94"} == response def test_maximum_resend(self): - sip = MockSIPClient(login_user_id='user_id', login_password='password') + sip = MockSIPClient(login_user_id="user_id", login_password="password") # We will keep sending retry messages until we reach the maximum - sip.queue_response('96') - sip.queue_response('96') - sip.queue_response('96') - sip.queue_response('96') - sip.queue_response('96') + sip.queue_response("96") + sip.queue_response("96") + sip.queue_response("96") + sip.queue_response("96") + sip.queue_response("96") # After reaching the maximum the client should give an IOError pytest.raises(IOError, sip.login) @@ -228,133 +221,149 @@ def test_maximum_resend(self): # We should send as many requests as we are allowed retries assert sip.MAXIMUM_RETRIES == len(sip.requests) -class TestLogin(object): +class TestLogin(object): def test_login_success(self): - sip = MockSIPClient(login_user_id='user_id', login_password='password') - sip.queue_response('941') + sip = MockSIPClient(login_user_id="user_id", login_password="password") + sip.queue_response("941") response = sip.login() - assert {'login_ok': '1', '_status': '94'} == response + assert {"login_ok": "1", "_status": "94"} == response def test_login_password_is_optional(self): """You can specify a login_id without specifying a login_password.""" - sip = MockSIPClient(login_user_id='user_id') - sip.queue_response('941') + sip = MockSIPClient(login_user_id="user_id") + sip.queue_response("941") response = sip.login() - assert {'login_ok': '1', '_status': '94'} == response + assert {"login_ok": "1", "_status": "94"} == response def test_login_failure(self): - sip = MockSIPClient(login_user_id='user_id', login_password='password') - sip.queue_response('940') + sip = MockSIPClient(login_user_id="user_id", login_password="password") + sip.queue_response("940") pytest.raises(IOError, sip.login) def test_login_happens_when_user_id_and_password_specified(self): - sip = MockSIPClient(login_user_id='user_id', login_password='password') + sip = MockSIPClient(login_user_id="user_id", login_password="password") # We're not logged in, and we must log in before sending a real # message. assert True == sip.must_log_in - sip.queue_response('941') - sip.queue_response('64Y 201610050000114734 AOnypl |AA12345|AENo Name|BLN|AFYour library card number cannot be located. Please see a staff member for assistance.|AY1AZC9DE') + sip.queue_response("941") + sip.queue_response( + "64Y 201610050000114734 AOnypl |AA12345|AENo Name|BLN|AFYour library card number cannot be located. Please see a staff member for assistance.|AY1AZC9DE" + ) sip.login() - response = sip.patron_information('patron_identifier') + response = sip.patron_information("patron_identifier") # Two requests were made. assert 2 == len(sip.requests) assert 2 == sip.sequence_number # We ended up with the right data. - assert '12345' == response['patron_identifier'] + assert "12345" == response["patron_identifier"] def test_no_login_when_user_id_and_password_not_specified(self): sip = MockSIPClient() assert False == sip.must_log_in - sip.queue_response('64Y 201610050000114734 AOnypl |AA12345|AENo Name|BLN|AFYour library card number cannot be located. Please see a staff member for assistance.|AY1AZC9DE') + sip.queue_response( + "64Y 201610050000114734 AOnypl |AA12345|AENo Name|BLN|AFYour library card number cannot be located. Please see a staff member for assistance.|AY1AZC9DE" + ) sip.login() # Zero requests made assert 0 == len(sip.requests) assert 0 == sip.sequence_number - response = sip.patron_information('patron_identifier') + response = sip.patron_information("patron_identifier") # One request made. assert 1 == len(sip.requests) assert 1 == sip.sequence_number # We ended up with the right data. - assert '12345' == response['patron_identifier'] + assert "12345" == response["patron_identifier"] def test_login_failure_interrupts_other_request(self): - sip = MockSIPClient(login_user_id='user_id', login_password='password') - sip.queue_response('940') + sip = MockSIPClient(login_user_id="user_id", login_password="password") + sip.queue_response("940") # We don't even get a chance to make the patron information request # because our login attempt fails. - pytest.raises(IOError, sip.patron_information, 'patron_identifier') + pytest.raises(IOError, sip.patron_information, "patron_identifier") - def test_login_does_not_happen_implicitly_when_user_id_and_password_not_specified(self): + def test_login_does_not_happen_implicitly_when_user_id_and_password_not_specified( + self, + ): sip = MockSIPClient() # We're implicitly logged in. assert False == sip.must_log_in - sip.queue_response('64Y 201610050000114734 AOnypl |AA12345|AENo Name|BLN|AFYour library card number cannot be located. Please see a staff member for assistance.|AY1AZC9DE') - response = sip.patron_information('patron_identifier') + sip.queue_response( + "64Y 201610050000114734 AOnypl |AA12345|AENo Name|BLN|AFYour library card number cannot be located. Please see a staff member for assistance.|AY1AZC9DE" + ) + response = sip.patron_information("patron_identifier") # One request was made. assert 1 == len(sip.requests) assert 1 == sip.sequence_number # We ended up with the right data. - assert '12345' == response['patron_identifier'] + assert "12345" == response["patron_identifier"] class TestPatronResponse(object): - def setup_method(self): self.sip = MockSIPClient() def test_incorrect_card_number(self): - self.sip.queue_response("64Y 201610050000114734 AOnypl |AA240|AENo Name|BLN|AFYour library card number cannot be located.|AY1AZC9DE") - response = self.sip.patron_information('identifier') + self.sip.queue_response( + "64Y 201610050000114734 AOnypl |AA240|AENo Name|BLN|AFYour library card number cannot be located.|AY1AZC9DE" + ) + response = self.sip.patron_information("identifier") # Test some of the basic fields. - assert response['institution_id'] == 'nypl ' - assert response['personal_name'] == 'No Name' - assert response['screen_message'] == ['Your library card number cannot be located.'] - assert response['valid_patron'] == 'N' - assert response['patron_status'] == 'Y ' - parsed = response['patron_status_parsed'] - assert True == parsed['charge privileges denied'] - assert False == parsed['too many items charged'] + assert response["institution_id"] == "nypl " + assert response["personal_name"] == "No Name" + assert response["screen_message"] == [ + "Your library card number cannot be located." + ] + assert response["valid_patron"] == "N" + assert response["patron_status"] == "Y " + parsed = response["patron_status_parsed"] + assert True == parsed["charge privileges denied"] + assert False == parsed["too many items charged"] def test_hold_items(self): "A patron has multiple items on hold." - self.sip.queue_response("64 000201610050000114837000300020002000000000000AOnypl |AA233|AEBAR, FOO|BZ0030|CA0050|CB0050|BLY|CQY|BV0|CC15.00|AS123|AS456|AS789|BEFOO@BAR.COM|AY1AZC848") - response = self.sip.patron_information('identifier') - assert '0003' == response['hold_items_count'] - assert ['123', '456', '789'] == response['hold_items'] + self.sip.queue_response( + "64 000201610050000114837000300020002000000000000AOnypl |AA233|AEBAR, FOO|BZ0030|CA0050|CB0050|BLY|CQY|BV0|CC15.00|AS123|AS456|AS789|BEFOO@BAR.COM|AY1AZC848" + ) + response = self.sip.patron_information("identifier") + assert "0003" == response["hold_items_count"] + assert ["123", "456", "789"] == response["hold_items"] def test_multiple_screen_messages(self): - self.sip.queue_response("64Y YYYYYYYYYYY000201610050000115040000000000000000000000000AOnypl |AA233|AESHELDON, ALICE|BZ0030|CA0050|CB0050|BLY|CQN|BV0|CC15.00|AFInvalid PIN entered. Please try again or see a staff member for assistance.|AFThere are unresolved issues with your account. Please see a staff member for assistance.|AY2AZ9B64") - response = self.sip.patron_information('identifier') - assert 2 == len(response['screen_message']) + self.sip.queue_response( + "64Y YYYYYYYYYYY000201610050000115040000000000000000000000000AOnypl |AA233|AESHELDON, ALICE|BZ0030|CA0050|CB0050|BLY|CQN|BV0|CC15.00|AFInvalid PIN entered. Please try again or see a staff member for assistance.|AFThere are unresolved issues with your account. Please see a staff member for assistance.|AY2AZ9B64" + ) + response = self.sip.patron_information("identifier") + assert 2 == len(response["screen_message"]) def test_extension_field_captured(self): - """This SIP2 message includes an extension field with the code XI. - """ - self.sip.queue_response("64 Y 00020161005 122942000000000000000000000000AA240|AEBooth Active Test|BHUSD|BDAdult Circ Desk 1 Newtown, CT USA 06470|AQNEWTWN|BLY|CQN|PA20191004|PCAdult|PIAllowed|XI86371|AOBiblioTest|ZZfoo|AY2AZ0000") - response = self.sip.patron_information('identifier') + """This SIP2 message includes an extension field with the code XI.""" + self.sip.queue_response( + "64 Y 00020161005 122942000000000000000000000000AA240|AEBooth Active Test|BHUSD|BDAdult Circ Desk 1 Newtown, CT USA 06470|AQNEWTWN|BLY|CQN|PA20191004|PCAdult|PIAllowed|XI86371|AOBiblioTest|ZZfoo|AY2AZ0000" + ) + response = self.sip.patron_information("identifier") # The Evergreen XI field is a known extension and is picked up # as sipserver_internal_id. - assert "86371" == response['sipserver_internal_id'] + assert "86371" == response["sipserver_internal_id"] # The ZZ field is an unknown extension and is captured under # its SIP code. - assert ["foo"] == response['ZZ'] + assert ["foo"] == response["ZZ"] def test_variant_encoding(self): response_unicode = "64 000201610210000142637000000000000000000000000AOnypl |AA12345|AELE CARRÉ, JOHN|BZ0030|CA0050|CB0050|BLY|CQY|BV0|CC15.00|BEfoo@example.com|AY1AZD1B7\r" @@ -363,46 +372,48 @@ def test_variant_encoding(self): # as CP850. assert "cp850" == self.sip.encoding self.sip.queue_response(response_unicode.encode("cp850")) - response = self.sip.patron_information('identifier') - assert "LE CARRÉ, JOHN" == response['personal_name'] + response = self.sip.patron_information("identifier") + assert "LE CARRÉ, JOHN" == response["personal_name"] # But a SIP2 server may send some other encoding, such as # UTF-8. This can cause odd results if the circulation manager # tries to parse the data as CP850. self.sip.queue_response(response_unicode.encode("utf-8")) - response = self.sip.patron_information('identifier') - assert "LE CARRÉ, JOHN" == response['personal_name'] + response = self.sip.patron_information("identifier") + assert "LE CARRÉ, JOHN" == response["personal_name"] # Giving SIPClient the right encoding means the data is # converted correctly. sip = MockSIPClient(encoding="utf-8") assert "utf-8" == sip.encoding sip.queue_response(response_unicode.encode("utf-8")) - response = sip.patron_information('identifier') - assert "LE CARRÉ, JOHN" == response['personal_name'] + response = sip.patron_information("identifier") + assert "LE CARRÉ, JOHN" == response["personal_name"] def test_embedded_pipe(self): """In most cases we can handle data even if it contains embedded instances of the separator character. """ - self.sip.queue_response('64 000201610050000134405000000000000000000000000AOnypl |AA12345|AERICHARDSON, LEONARD|BZ0030|CA0050|CB0050|BLY|CQY|BV0|CC15.00|BEleona|rdr@|bar.com|AY1AZD1BB\r') - response = self.sip.patron_information('identifier') - assert "leona|rdr@|bar.com" == response['email_address'] + self.sip.queue_response( + "64 000201610050000134405000000000000000000000000AOnypl |AA12345|AERICHARDSON, LEONARD|BZ0030|CA0050|CB0050|BLY|CQY|BV0|CC15.00|BEleona|rdr@|bar.com|AY1AZD1BB\r" + ) + response = self.sip.patron_information("identifier") + assert "leona|rdr@|bar.com" == response["email_address"] def test_different_separator(self): """When you create the SIPClient you get to specify which character to use as the field separator. """ - sip = MockSIPClient(separator='^') - sip.queue_response("64Y 201610050000114734 AOnypl ^AA240^AENo Name^BLN^AFYour library card number cannot be located.^AY1AZC9DE") - response = sip.patron_information('identifier') - assert '240' == response['patron_identifier'] + sip = MockSIPClient(separator="^") + sip.queue_response( + "64Y 201610050000114734 AOnypl ^AA240^AENo Name^BLN^AFYour library card number cannot be located.^AY1AZC9DE" + ) + response = sip.patron_information("identifier") + assert "240" == response["patron_identifier"] def test_location_code_is_optional(self): """You can specify a location_code when logging in, or not.""" - without_code = self.sip.login_message( - "login_id", "login_password" - ) + without_code = self.sip.login_message("login_id", "login_password") assert without_code.endswith("COlogin_password") with_code = self.sip.login_message( "login_id", "login_password", "location_code" @@ -413,27 +424,23 @@ def test_institution_id_field_is_always_provided(self): without_institution_arg = self.sip.patron_information_request( "patron_identifier", "patron_password" ) - assert without_institution_arg.startswith('AO|', 33) + assert without_institution_arg.startswith("AO|", 33) def test_institution_id_field_value_provided(self): # Fake value retrieved from DB - sip = MockSIPClient(institution_id='MAIN') + sip = MockSIPClient(institution_id="MAIN") with_institution_provided = sip.patron_information_request( "patron_identifier", "patron_password" ) - assert with_institution_provided.startswith('AOMAIN|', 33) + assert with_institution_provided.startswith("AOMAIN|", 33) def test_patron_password_is_optional(self): - without_password = self.sip.patron_information_request( - "patron_identifier" - ) - assert without_password.endswith('AApatron_identifier|AC') + without_password = self.sip.patron_information_request("patron_identifier") + assert without_password.endswith("AApatron_identifier|AC") with_password = self.sip.patron_information_request( "patron_identifier", "patron_password" ) - assert with_password.endswith( - 'AApatron_identifier|AC|ADpatron_password' - ) + assert with_password.endswith("AApatron_identifier|AC|ADpatron_password") def test_parse_patron_status(self): m = MockSIPClient.parse_patron_status @@ -442,43 +449,43 @@ def test_parse_patron_status(self): pytest.raises(ValueError, m, " " * 20) parsed = m("Y Y Y Y Y Y Y ") for yes in [ - 'charge privileges denied', - #'renewal privileges denied', - 'recall privileges denied', - #'hold privileges denied', - 'card reported lost', - #'too many items charged', - 'too many items overdue', - #'too many renewals', - 'too many claims of items returned', - #'too many items lost', - 'excessive outstanding fines', - #'excessive outstanding fees', - 'recall overdue', - #'too many items billed', + "charge privileges denied", + #'renewal privileges denied', + "recall privileges denied", + #'hold privileges denied', + "card reported lost", + #'too many items charged', + "too many items overdue", + #'too many renewals', + "too many claims of items returned", + #'too many items lost', + "excessive outstanding fines", + #'excessive outstanding fees', + "recall overdue", + #'too many items billed', ]: assert parsed[yes] == True for no in [ - #'charge privileges denied', - 'renewal privileges denied', - #'recall privileges denied', - 'hold privileges denied', - #'card reported lost', - 'too many items charged', - #'too many items overdue', - 'too many renewals', - #'too many claims of items returned', - 'too many items lost', - #'excessive outstanding fines', - 'excessive outstanding fees', - #'recall overdue', - 'too many items billed', + #'charge privileges denied', + "renewal privileges denied", + #'recall privileges denied', + "hold privileges denied", + #'card reported lost', + "too many items charged", + #'too many items overdue', + "too many renewals", + #'too many claims of items returned', + "too many items lost", + #'excessive outstanding fines', + "excessive outstanding fees", + #'recall overdue', + "too many items billed", ]: assert parsed[no] == False -class TestClientDialects(object): +class TestClientDialects(object): def setup_method(self): self.sip = MockSIPClient() @@ -486,13 +493,13 @@ def test_generic_dialect(self): # Generic ILS should send end_session message self.sip.dialect = GenericILS self.sip.queue_response("36Y201610210000142637AO3|AA25891000331441|AF|AG") - self.sip.end_session('username', 'password') + self.sip.end_session("username", "password") assert self.sip.read_count == 1 assert self.sip.write_count == 1 def test_ag_dialect(self): # AG VERSO ILS shouldn't end_session message self.sip.dialect = AutoGraphicsVerso - self.sip.end_session('username', 'password') + self.sip.end_session("username", "password") assert self.sip.read_count == 0 assert self.sip.write_count == 0 diff --git a/tests/test_adobe_vendor_id.py b/tests/test_adobe_vendor_id.py index 7dee01d094..65caaf3d0c 100644 --- a/tests/test_adobe_vendor_id.py +++ b/tests/test_adobe_vendor_id.py @@ -1,30 +1,26 @@ +import base64 +import datetime import json - -import pytest +import re import jwt -from jwt.exceptions import ( - DecodeError, - ExpiredSignatureError, - InvalidIssuedAtError -) -import re -import datetime +import pytest +from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidIssuedAtError -from api.problem_details import * from api.adobe_vendor_id import ( - AdobeSignInRequestParser, AdobeAccountInfoRequestParser, + AdobeSignInRequestParser, AdobeVendorIDController, - AdobeVendorIDRequestHandler, AdobeVendorIDModel, + AdobeVendorIDRequestHandler, AuthdataUtility, DeviceManagementRequestHandler, ) - +from api.config import CannotLoadConfiguration, Configuration, temp_config from api.opds import CirculationManagerAnnotator +from api.problem_details import * +from api.simple_authentication import SimpleAuthenticationProvider from api.testing import VendorIDTest - from core.model import ( ConfigurationSetting, Credential, @@ -33,20 +29,8 @@ ExternalIntegration, Library, ) -from core.util.datetime_helpers import ( - datetime_utc, - utc_now, -) +from core.util.datetime_helpers import datetime_utc, utc_now from core.util.problem_detail import ProblemDetail -import base64 - -from api.config import ( - CannotLoadConfiguration, - Configuration, - temp_config, -) - -from api.simple_authentication import SimpleAuthenticationProvider class TestVendorIDModel(VendorIDTest): @@ -60,15 +44,11 @@ def setup_method(self): self.vendor_id_library = self._default_library # This library can create Short Client Tokens that the Vendor # ID server will recognize. - self.short_client_token_library = self._library( - short_name="shortclienttoken" - ) + self.short_client_token_library = self._library(short_name="shortclienttoken") # Initialize the Adobe-specific ExternalIntegrations for both # libraries. - self.initialize_adobe( - self.vendor_id_library, [self.short_client_token_library] - ) + self.initialize_adobe(self.vendor_id_library, [self.short_client_token_library]) # Set up a simple authentication provider that validates # one specific patron. @@ -81,20 +61,20 @@ def setup_method(self): ) self.model = AdobeVendorIDModel( - self._db, self._default_library, self.authenticator, - self.TEST_NODE_VALUE + self._db, self._default_library, self.authenticator, self.TEST_NODE_VALUE ) self.data_source = DataSource.lookup(self._db, DataSource.ADOBE) self.bob_patron = self.authenticator.authenticated_patron( - self._db, dict(username="validpatron", password="password")) + self._db, dict(username="validpatron", password="password") + ) def test_uuid(self): u = self.model.uuid() # All UUIDs need to start with a 0 and end with the same node # value. - assert u.startswith('urn:uuid:0') - assert u.endswith('685b35c00f05') + assert u.startswith("urn:uuid:0") + assert u.endswith("685b35c00f05") def test_uuid_and_label_respects_existing_id(self): uuid, label = self.model.uuid_and_label(self.bob_patron) @@ -109,11 +89,17 @@ def test_uuid_and_label_creates_delegatedpatronid_from_credential(self): # reason, the migration script did not give them a # DelegatedPatronIdentifier. adobe = self.data_source + def set_value(credential): credential.credential = "A dummy value" + old_style_credential = Credential.lookup( - self._db, adobe, self.model.VENDOR_ID_UUID_TOKEN_TYPE, - self.bob_patron, set_value, True + self._db, + adobe, + self.model.VENDOR_ID_UUID_TOKEN_TYPE, + self.bob_patron, + set_value, + True, ) # Now uuid_and_label works. @@ -125,21 +111,25 @@ def set_value(credential): # patron account. internal = DataSource.lookup(self._db, DataSource.INTERNAL_PROCESSING) bob_anonymized_identifier = Credential.lookup( - self._db, internal, + self._db, + internal, AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, - self.bob_patron, None + self.bob_patron, + None, ) # That anonymized identifier is associated with a # DelegatedPatronIdentifier whose delegated_identifier is # taken from the old-style Credential. - [bob_delegated_patron_identifier] = self._db.query( - DelegatedPatronIdentifier).filter( + [bob_delegated_patron_identifier] = ( + self._db.query(DelegatedPatronIdentifier) + .filter( DelegatedPatronIdentifier.patron_identifier - ==bob_anonymized_identifier.credential - ).all() - assert ("A dummy value" == - bob_delegated_patron_identifier.delegated_identifier) + == bob_anonymized_identifier.credential + ) + .all() + ) + assert "A dummy value" == bob_delegated_patron_identifier.delegated_identifier # If the DelegatedPatronIdentifier and the Credential # have different values, the DelegatedPatronIdentifier wins. @@ -155,15 +145,18 @@ def set_value(credential): uuid, label = self.model.uuid_and_label(self.bob_patron) assert "A dummy value" == uuid - def test_create_authdata(self): credential = self.model.create_authdata(self.bob_patron) # There's now a persistent token associated with Bob's # patron account, and that's the token returned by create_authdata() bob_authdata = Credential.lookup( - self._db, self.data_source, self.model.AUTHDATA_TOKEN_TYPE, - self.bob_patron, None) + self._db, + self.data_source, + self.model.AUTHDATA_TOKEN_TYPE, + self.bob_patron, + None, + ) assert credential.credential == bob_authdata.credential def test_to_delegated_patron_identifier_uuid(self): @@ -172,12 +165,12 @@ def test_to_delegated_patron_identifier_uuid(self): foreign_identifier = "foreign ID" # Pass in nothing and you get nothing. - assert ((None, None) == - self.model.to_delegated_patron_identifier_uuid(foreign_uri, None)) - assert ((None, None) == - self.model.to_delegated_patron_identifier_uuid( - None, foreign_identifier - )) + assert (None, None) == self.model.to_delegated_patron_identifier_uuid( + foreign_uri, None + ) + assert (None, None) == self.model.to_delegated_patron_identifier_uuid( + None, foreign_identifier + ) # Pass in a URI and identifier and you get a UUID and a label. uuid, label = self.model.to_delegated_patron_identifier_uuid( @@ -190,10 +183,12 @@ def test_to_delegated_patron_identifier_uuid(self): # And we can verify that a DelegatedPatronIdentifier was # created for the URI+identifier, and that it contains the # UUID. - [dpi] = self._db.query(DelegatedPatronIdentifier).filter( - DelegatedPatronIdentifier.library_uri==foreign_uri).filter( - DelegatedPatronIdentifier.patron_identifier==foreign_identifier - ).all() + [dpi] = ( + self._db.query(DelegatedPatronIdentifier) + .filter(DelegatedPatronIdentifier.library_uri == foreign_uri) + .filter(DelegatedPatronIdentifier.patron_identifier == foreign_identifier) + .all() + ) assert uuid == dpi.delegated_identifier def test_authdata_lookup_delegated_patron_identifier_success(self): @@ -213,8 +208,10 @@ def test_authdata_lookup_delegated_patron_identifier_success(self): # The Vendor ID library knows the secret it shares with the # other library -- initialize_adobe() took care of that. sct_library_uri = sct_library.setting(Configuration.WEBSITE_URL).value - assert ("%s token secret" % sct_library.short_name == - vendor_id_utility.secrets_by_library_uri[sct_library_uri]) + assert ( + "%s token secret" % sct_library.short_name + == vendor_id_utility.secrets_by_library_uri[sct_library_uri] + ) # Because this library shares the other library's secret, # it can decode a JWT issued by the other library, and @@ -230,10 +227,12 @@ def test_authdata_lookup_delegated_patron_identifier_success(self): # The UUID corresponds to a DelegatedPatronIdentifier, # associated with the foreign library and the patron # identifier that library encoded in its JWT. - [dpi] = self._db.query(DelegatedPatronIdentifier).filter( - DelegatedPatronIdentifier.library_uri==sct_library_uri).filter( - DelegatedPatronIdentifier.patron_identifier=="Foreign patron" - ).all() + [dpi] = ( + self._db.query(DelegatedPatronIdentifier) + .filter(DelegatedPatronIdentifier.library_uri == sct_library_uri) + .filter(DelegatedPatronIdentifier.patron_identifier == "Foreign patron") + .all() + ) assert uuid == dpi.delegated_identifier assert "Delegated account ID %s" % uuid == label @@ -256,24 +255,26 @@ def test_short_client_token_lookup_delegated_patron_identifier_success(self): # The Vendor ID library knows the secret it shares with the # other library -- initialize_adobe() took care of that. sct_library_url = sct_library.setting(Configuration.WEBSITE_URL).value - assert ("%s token secret" % sct_library.short_name == - vendor_id_utility.secrets_by_library_uri[sct_library_url]) + assert ( + "%s token secret" % sct_library.short_name + == vendor_id_utility.secrets_by_library_uri[sct_library_url] + ) # Because the Vendor ID library shares the Short Client Token # library's secret, it can decode a short client token issued # by that library, and issue an Adobe ID (UUID). token, signature = short_client_token.rsplit("|", 1) - uuid, label = self.model.short_client_token_lookup( - token, signature - ) + uuid, label = self.model.short_client_token_lookup(token, signature) # The UUID corresponds to a DelegatedPatronIdentifier, # associated with the foreign library and the patron # identifier that library encoded in its JWT. - [dpi] = self._db.query(DelegatedPatronIdentifier).filter( - DelegatedPatronIdentifier.library_uri==sct_library_url).filter( - DelegatedPatronIdentifier.patron_identifier=="Foreign patron" - ).all() + [dpi] = ( + self._db.query(DelegatedPatronIdentifier) + .filter(DelegatedPatronIdentifier.library_uri == sct_library_url) + .filter(DelegatedPatronIdentifier.patron_identifier == "Foreign patron") + .all() + ) assert uuid == dpi.delegated_identifier assert "Delegated account ID %s" % uuid == label @@ -287,9 +288,7 @@ def test_short_client_token_lookup_delegated_patron_identifier_success(self): assert new_label == label def test_short_client_token_lookup_delegated_patron_identifier_failure(self): - uuid, label = self.model.short_client_token_lookup( - "bad token", "bad signature" - ) + uuid, label = self.model.short_client_token_lookup("bad token", "bad signature") assert None == uuid assert None == label @@ -300,32 +299,36 @@ def test_username_password_lookup_success(self): # patron account. internal = DataSource.lookup(self._db, DataSource.INTERNAL_PROCESSING) bob_anonymized_identifier = Credential.lookup( - self._db, internal, + self._db, + internal, AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, - self.bob_patron, None + self.bob_patron, + None, ) # That anonymized identifier is associated with a # DelegatedPatronIdentifier whose delegated_identifier is a # UUID. - [bob_delegated_patron_identifier] = self._db.query( - DelegatedPatronIdentifier).filter( + [bob_delegated_patron_identifier] = ( + self._db.query(DelegatedPatronIdentifier) + .filter( DelegatedPatronIdentifier.patron_identifier - ==bob_anonymized_identifier.credential - ).all() + == bob_anonymized_identifier.credential + ) + .all() + ) assert "Delegated account ID %s" % urn == label assert urn == bob_delegated_patron_identifier.delegated_identifier assert urn.startswith("urn:uuid:0") - assert urn.endswith('685b35c00f05') + assert urn.endswith("685b35c00f05") def test_authdata_token_credential_lookup_success(self): # Create an authdata token Credential for Bob. now = utc_now() token, ignore = Credential.persistent_token_create( - self._db, self.data_source, self.model.AUTHDATA_TOKEN_TYPE, - self.bob_patron + self._db, self.data_source, self.model.AUTHDATA_TOKEN_TYPE, self.bob_patron ) # The token is persistent. @@ -339,19 +342,24 @@ def test_authdata_token_credential_lookup_success(self): # patron account. internal = DataSource.lookup(self._db, DataSource.INTERNAL_PROCESSING) bob_anonymized_identifier = Credential.lookup( - self._db, internal, + self._db, + internal, AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, - self.bob_patron, None + self.bob_patron, + None, ) # That anonymized identifier is associated with a # DelegatedPatronIdentifier whose delegated_identifier is a # UUID. - [bob_delegated_patron_identifier] = self._db.query( - DelegatedPatronIdentifier).filter( + [bob_delegated_patron_identifier] = ( + self._db.query(DelegatedPatronIdentifier) + .filter( DelegatedPatronIdentifier.patron_identifier - ==bob_anonymized_identifier.credential - ).all() + == bob_anonymized_identifier.credential + ) + .all() + ) # That UUID is the one returned by authdata_lookup. assert urn == bob_delegated_patron_identifier.delegated_identifier @@ -360,43 +368,43 @@ def test_smuggled_authdata_credential_success(self): # Bob's client has created a persistent token to authenticate him. now = utc_now() token, ignore = Credential.persistent_token_create( - self._db, self.data_source, self.model.AUTHDATA_TOKEN_TYPE, - self.bob_patron + self._db, self.data_source, self.model.AUTHDATA_TOKEN_TYPE, self.bob_patron ) # But Bob's client can't trigger the operation that will cause # Adobe to authenticate him via that token, so it passes in # the token credential as the 'username' and leaves the # password blank. - urn, label = self.model.standard_lookup( - dict(username=token.credential) - ) + urn, label = self.model.standard_lookup(dict(username=token.credential)) # There is now an anonymized identifier associated with Bob's # patron account. internal = DataSource.lookup(self._db, DataSource.INTERNAL_PROCESSING) bob_anonymized_identifier = Credential.lookup( - self._db, internal, + self._db, + internal, AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, - self.bob_patron, None + self.bob_patron, + None, ) # That anonymized identifier is associated with a # DelegatedPatronIdentifier whose delegated_identifier is a # UUID. - [bob_delegated_patron_identifier] = self._db.query( - DelegatedPatronIdentifier).filter( + [bob_delegated_patron_identifier] = ( + self._db.query(DelegatedPatronIdentifier) + .filter( DelegatedPatronIdentifier.patron_identifier - ==bob_anonymized_identifier.credential - ).all() + == bob_anonymized_identifier.credential + ) + .all() + ) # That UUID is the one returned by standard_lookup. assert urn == bob_delegated_patron_identifier.delegated_identifier # A future attempt to authenticate with the token will succeed. - urn, label = self.model.standard_lookup( - dict(username=token.credential) - ) + urn, label = self.model.standard_lookup(dict(username=token.credential)) assert urn == bob_delegated_patron_identifier.delegated_identifier def test_authdata_lookup_failure_no_token(self): @@ -407,8 +415,7 @@ def test_authdata_lookup_failure_no_token(self): def test_authdata_lookup_failure_wrong_token(self): # Bob has an authdata token. token, ignore = Credential.persistent_token_create( - self._db, self.data_source, self.model.AUTHDATA_TOKEN_TYPE, - self.bob_patron + self._db, self.data_source, self.model.AUTHDATA_TOKEN_TYPE, self.bob_patron ) # But we look up a different token and get nothing. @@ -441,21 +448,28 @@ class TestVendorIDRequestParsers(object): def test_username_sign_in_request(self): parser = AdobeSignInRequestParser() data = parser.process(self.username_sign_in_request) - assert {'username': 'Vendor username', - 'password': 'Vendor password', 'method': 'standard'} == data + assert { + "username": "Vendor username", + "password": "Vendor password", + "method": "standard", + } == data def test_authdata_sign_in_request(self): parser = AdobeSignInRequestParser() data = parser.process(self.authdata_sign_in_request) - assert ({'authData': 'this data was base64 encoded', 'method': 'authData'} == - data) + assert { + "authData": "this data was base64 encoded", + "method": "authData", + } == data def test_accountinfo_request(self): parser = AdobeAccountInfoRequestParser() data = parser.process(self.accountinfo_request) - assert ({'method': 'standard', - 'user': 'urn:uuid:0xxxxxxx-xxxx-1xxx-xxxx-yyyyyyyyyyyy'} == - data) + assert { + "method": "standard", + "user": "urn:uuid:0xxxxxxx-xxxx-1xxx-xxxx-yyyyyyyyyyyy", + } == data + class TestVendorIDRequestHandler(object): @@ -476,25 +490,21 @@ class TestVendorIDRequestHandler(object): user1_uuid = "test-uuid" user1_label = "Human-readable label for user1" - username_password_lookup = { - ("user1", "pass1") : (user1_uuid, user1_label) - } + username_password_lookup = {("user1", "pass1"): (user1_uuid, user1_label)} - authdata_lookup = { - "The secret token" : (user1_uuid, user1_label) - } + authdata_lookup = {"The secret token": (user1_uuid, user1_label)} - userinfo_lookup = { user1_uuid : user1_label } + userinfo_lookup = {user1_uuid: user1_label} @property def _handler(self): - return AdobeVendorIDRequestHandler( - self.TEST_VENDOR_ID) + return AdobeVendorIDRequestHandler(self.TEST_VENDOR_ID) @classmethod def _standard_login(cls, data): return cls.username_password_lookup.get( - (data.get('username'), data.get('password')), (None, None)) + (data.get("username"), data.get("password")), (None, None) + ) @classmethod def _authdata_login(cls, authdata): @@ -505,84 +515,111 @@ def _userinfo(cls, uuid): return cls.userinfo_lookup.get(uuid) def test_error_document(self): - doc = self._handler.error_document( - "VENDORID", "Some random error") - assert '' == doc + doc = self._handler.error_document("VENDORID", "Some random error") + assert ( + '' + == doc + ) def test_handle_username_sign_in_request_success(self): - doc = self.username_sign_in_request % dict( - username="user1", password="pass1") + doc = self.username_sign_in_request % dict(username="user1", password="pass1") result = self._handler.handle_signin_request( - doc, self._standard_login, self._authdata_login) - assert result.startswith('\ntest-uuid\n\n') + doc, self._standard_login, self._authdata_login + ) + assert result.startswith( + '\ntest-uuid\n\n' + ) def test_handle_username_sign_in_request_failure(self): doc = self.username_sign_in_request % dict( - username="user1", password="wrongpass") + username="user1", password="wrongpass" + ) result = self._handler.handle_signin_request( - doc, self._standard_login, self._authdata_login) - assert '' == result + doc, self._standard_login, self._authdata_login + ) + assert ( + '' + == result + ) def test_handle_username_authdata_request_success(self): doc = self.authdata_sign_in_request % dict( - authdata=base64.b64encode(b"The secret token").decode("utf-8")) + authdata=base64.b64encode(b"The secret token").decode("utf-8") + ) result = self._handler.handle_signin_request( - doc, self._standard_login, self._authdata_login) - assert result.startswith('\ntest-uuid\n\n') + doc, self._standard_login, self._authdata_login + ) + assert result.startswith( + '\ntest-uuid\n\n' + ) def test_handle_username_authdata_request_invalid(self): - doc = self.authdata_sign_in_request % dict( - authdata="incorrect") + doc = self.authdata_sign_in_request % dict(authdata="incorrect") result = self._handler.handle_signin_request( - doc, self._standard_login, self._authdata_login) - assert result.startswith('' == result + doc, self._standard_login, self._authdata_login + ) + assert ( + '' + == result + ) def test_failure_send_login_request_to_accountinfo(self): doc = self.authdata_sign_in_request % dict( - authdata=base64.b64encode(b"incorrect")) - result = self._handler.handle_accountinfo_request( - doc, self._userinfo) - assert '' == result + authdata=base64.b64encode(b"incorrect") + ) + result = self._handler.handle_accountinfo_request(doc, self._userinfo) + assert ( + '' + == result + ) def test_failure_send_accountinfo_request_to_login(self): - doc = self.accountinfo_request % dict( - uuid=self.user1_uuid) + doc = self.accountinfo_request % dict(uuid=self.user1_uuid) result = self._handler.handle_signin_request( - doc, self._standard_login, self._authdata_login) - assert '' == result + doc, self._standard_login, self._authdata_login + ) + assert ( + '' + == result + ) def test_handle_accountinfo_success(self): - doc = self.accountinfo_request % dict( - uuid=self.user1_uuid) - result = self._handler.handle_accountinfo_request( - doc, self._userinfo) - assert '\n\n' == result + doc = self.accountinfo_request % dict(uuid=self.user1_uuid) + result = self._handler.handle_accountinfo_request(doc, self._userinfo) + assert ( + '\n\n' + == result + ) def test_handle_accountinfo_failure(self): - doc = self.accountinfo_request % dict( - uuid="not the uuid") - result = self._handler.handle_accountinfo_request( - doc, self._userinfo) - assert '' == result + doc = self.accountinfo_request % dict(uuid="not the uuid") + result = self._handler.handle_accountinfo_request(doc, self._userinfo) + assert ( + '' + == result + ) class TestAuthdataUtility(VendorIDTest): - def setup_method(self): super(TestAuthdataUtility, self).setup_method() self.authdata = AuthdataUtility( - vendor_id = "The Vendor ID", - library_uri = "http://my-library.org/", - library_short_name = "MyLibrary", - secret = "My library secret", - other_libraries = { + vendor_id="The Vendor ID", + library_uri="http://my-library.org/", + library_short_name="MyLibrary", + secret="My library secret", + other_libraries={ "http://your-library.org/": ("you", "Your library secret") }, ) @@ -597,27 +634,35 @@ def test_from_config(self): utility = AuthdataUtility.from_config(library) registry = ExternalIntegration.lookup( - self._db, ExternalIntegration.OPDS_REGISTRATION, - ExternalIntegration.DISCOVERY_GOAL, library=library + self._db, + ExternalIntegration.OPDS_REGISTRATION, + ExternalIntegration.DISCOVERY_GOAL, + library=library, + ) + assert ( + library.short_name + "token" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, ExternalIntegration.USERNAME, library, registry + ).value + ) + assert ( + library.short_name + " token secret" + == ConfigurationSetting.for_library_and_externalintegration( + self._db, ExternalIntegration.PASSWORD, library, registry + ).value ) - assert (library.short_name + "token" == - ConfigurationSetting.for_library_and_externalintegration( - self._db, ExternalIntegration.USERNAME, library, registry).value) - assert (library.short_name + " token secret" == - ConfigurationSetting.for_library_and_externalintegration( - self._db, ExternalIntegration.PASSWORD, library, registry).value) assert self.TEST_VENDOR_ID == utility.vendor_id assert library_url == utility.library_uri - assert ( - {library2_url : "%s token secret" % library2.short_name, - library_url : "%s token secret" % library.short_name} == - utility.secrets_by_library_uri) + assert { + library2_url: "%s token secret" % library2.short_name, + library_url: "%s token secret" % library.short_name, + } == utility.secrets_by_library_uri - assert ( - {"%sTOKEN" % library.short_name.upper() : library_url, - "%sTOKEN" % library2.short_name.upper() : library2_url } == - utility.library_uris_by_short_name) + assert { + "%sTOKEN" % library.short_name.upper(): library_url, + "%sTOKEN" % library2.short_name.upper(): library2_url, + } == utility.library_uris_by_short_name # If the Library object is disconnected from its database # session, as may happen in production... @@ -627,66 +672,66 @@ def test_from_config(self): # will fail... with pytest.raises(ValueError) as excinfo: AuthdataUtility.from_config(library) - assert "No database connection provided and could not derive one from Library object!" in str(excinfo.value) + assert ( + "No database connection provided and could not derive one from Library object!" + in str(excinfo.value) + ) # ...unless a database session is provided in the constructor. authdata = AuthdataUtility.from_config(library, self._db) - assert ( - {"%sTOKEN" % library.short_name.upper() : library_url, - "%sTOKEN" % library2.short_name.upper() : library2_url } == - authdata.library_uris_by_short_name) + assert { + "%sTOKEN" % library.short_name.upper(): library_url, + "%sTOKEN" % library2.short_name.upper(): library2_url, + } == authdata.library_uris_by_short_name library = self._db.merge(library) self._db.commit() # If an integration is set up but incomplete, from_config # raises CannotLoadConfiguration. setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, ExternalIntegration.USERNAME, library, registry) + self._db, ExternalIntegration.USERNAME, library, registry + ) old_short_name = setting.value setting.value = None - pytest.raises( - CannotLoadConfiguration, AuthdataUtility.from_config, - library - ) + pytest.raises(CannotLoadConfiguration, AuthdataUtility.from_config, library) setting.value = old_short_name setting = library.setting(Configuration.WEBSITE_URL) old_value = setting.value setting.value = None - pytest.raises( - CannotLoadConfiguration, AuthdataUtility.from_config, library - ) + pytest.raises(CannotLoadConfiguration, AuthdataUtility.from_config, library) setting.value = old_value setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, ExternalIntegration.PASSWORD, library, registry) + self._db, ExternalIntegration.PASSWORD, library, registry + ) old_secret = setting.value setting.value = None - pytest.raises( - CannotLoadConfiguration, AuthdataUtility.from_config, library - ) + pytest.raises(CannotLoadConfiguration, AuthdataUtility.from_config, library) setting.value = old_secret # If other libraries are not configured, that's fine. We'll # only have a configuration for ourselves. - self.adobe_vendor_id.set_setting( - AuthdataUtility.OTHER_LIBRARIES_KEY, None - ) + self.adobe_vendor_id.set_setting(AuthdataUtility.OTHER_LIBRARIES_KEY, None) authdata = AuthdataUtility.from_config(library) - assert ({library_url : "%s token secret" % library.short_name} == - authdata.secrets_by_library_uri) - assert ({"%sTOKEN" % library.short_name.upper(): library_url} == - authdata.library_uris_by_short_name) + assert { + library_url: "%s token secret" % library.short_name + } == authdata.secrets_by_library_uri + assert { + "%sTOKEN" % library.short_name.upper(): library_url + } == authdata.library_uris_by_short_name # Short library names are case-insensitive. If the # configuration has the same library short name twice, you # can't create an AuthdataUtility. self.adobe_vendor_id.set_setting( AuthdataUtility.OTHER_LIBRARIES_KEY, - json.dumps({ - "http://a/" : ("a", "secret1"), - "http://b/" : ("A", "secret2"), - }) + json.dumps( + { + "http://a/": ("a", "secret1"), + "http://b/": ("A", "secret2"), + } + ), ) pytest.raises(ValueError, AuthdataUtility.from_config, library) @@ -699,12 +744,15 @@ def test_short_client_token_for_patron(self): class MockAuthdataUtility(AuthdataUtility): def __init__(self): pass + def encode_short_client_token(self, patron_identifier): self.encode_sct_called_with = patron_identifier return "a", "b" + def _adobe_patron_identifier(self, patron_information): self.patron_identifier_called_with = patron_information return "patron identifier" + # A patron is passed in; we get their identifier for Adobe ID purposes, # and generate a short client token based on it patron = self._patron() @@ -749,17 +797,20 @@ def test_encode(self): self.authdata.library_uri, patron_identifier, now, expires ) assert ( - base64.encodebytes(b'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJodHRwOi8vbXktbGlicmFyeS5vcmcvIiwic3ViIjoiUGF0cm9uIGlkZW50aWZpZXIiLCJpYXQiOjE0NTE2NDk2MDAuMCwiZXhwIjoxNTE0ODA4MDAwLjB9.Ua11tFCpC4XAgwhR6jFyoxfHy4s1zt2Owg4dOoCefYA') == - authdata) + base64.encodebytes( + b"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJodHRwOi8vbXktbGlicmFyeS5vcmcvIiwic3ViIjoiUGF0cm9uIGlkZW50aWZpZXIiLCJpYXQiOjE0NTE2NDk2MDAuMCwiZXhwIjoxNTE0ODA4MDAwLjB9.Ua11tFCpC4XAgwhR6jFyoxfHy4s1zt2Owg4dOoCefYA" + ) + == authdata + ) def test_decode_from_another_library(self): # Here's the AuthdataUtility used by another library. foreign_authdata = AuthdataUtility( - vendor_id = "The Vendor ID", - library_uri = "http://your-library.org/", - library_short_name = "you", - secret = "Your library secret", + vendor_id="The Vendor ID", + library_uri="http://your-library.org/", + library_short_name="you", + secret="Your library secret", ) patron_identifier = "Patron identifier" @@ -772,7 +823,7 @@ def test_decode_from_another_library(self): # If our secret doesn't match the other library's secret, # we can't decode the authdata - foreign_authdata.secret = 'A new secret' + foreign_authdata.secret = "A new secret" vendor_id, authdata = foreign_authdata.encode(patron_identifier) with pytest.raises(DecodeError) as excinfo: self.authdata.decode(authdata) @@ -783,10 +834,10 @@ def test_decode_from_unknown_library_fails(self): # Here's the AuthdataUtility used by a library we don't know # about. foreign_authdata = AuthdataUtility( - vendor_id = "The Vendor ID", - library_uri = "http://some-other-library.org/", - library_short_name = "SomeOther", - secret = "Some other library secret", + vendor_id="The Vendor ID", + library_uri="http://some-other-library.org/", + library_short_name="SomeOther", + secret="Some other library secret", ) vendor_id, authdata = foreign_authdata.encode("A patron") # They can encode, but we cna't decode. @@ -796,21 +847,13 @@ def test_decode_from_unknown_library_fails(self): def test_cannot_decode_token_from_future(self): future = utc_now() + datetime.timedelta(days=365) - authdata = self.authdata._encode( - "Patron identifier", iat=future - ) - pytest.raises( - InvalidIssuedAtError, self.authdata.decode, authdata - ) + authdata = self.authdata._encode("Patron identifier", iat=future) + pytest.raises(InvalidIssuedAtError, self.authdata.decode, authdata) def test_cannot_decode_expired_token(self): expires = datetime_utc(2016, 1, 1, 12, 0, 0) - authdata = self.authdata._encode( - "Patron identifier", exp=expires - ) - pytest.raises( - ExpiredSignatureError, self.authdata.decode, authdata - ) + authdata = self.authdata._encode("Patron identifier", exp=expires) + pytest.raises(ExpiredSignatureError, self.authdata.decode, authdata) def test_cannot_encode_null_patron_identifier(self): with pytest.raises(ValueError) as excinfo: @@ -820,7 +863,8 @@ def test_cannot_encode_null_patron_identifier(self): def test_cannot_decode_null_patron_identifier(self): authdata = self.authdata._encode( - self.authdata.library_uri, None, + self.authdata.library_uri, + None, ) with pytest.raises(DecodeError) as excinfo: self.authdata.decode(authdata) @@ -847,8 +891,10 @@ def test_short_client_token_encode_known_value(self): # what would otherwise be normal base64 text. Similarly for # the semicolon which replaced the slash, and the at sign which # replaced the equals sign. - assert ('a library|1234.5|a patron identifier|YoNGn7f38mF531KSWJ;o1H0Z3chbC:uTE:t7pAwqYxM@' == - value) + assert ( + "a library|1234.5|a patron identifier|YoNGn7f38mF531KSWJ;o1H0Z3chbC:uTE:t7pAwqYxM@" + == value + ) # Dissect the known value to show how it works. token, signature = value.rsplit("|", 1) @@ -870,10 +916,10 @@ def test_short_client_token_encode_known_value(self): def test_encode_short_client_token_expiry(self, monkeypatch): authdata = AuthdataUtility( - vendor_id = "The Vendor ID", - library_uri = "http://your-library.org/", - library_short_name = "you", - secret = "Your library secret", + vendor_id="The Vendor ID", + library_uri="http://your-library.org/", + library_short_name="you", + secret="Your library secret", ) test_date = datetime_utc(2021, 5, 5) monkeypatch.setattr(authdata, "_now", lambda: test_date) @@ -883,33 +929,37 @@ def test_encode_short_client_token_expiry(self, monkeypatch): # Test with no expiry set vendor_id, token = authdata.encode_short_client_token(patron_identifier) - assert token.split('|')[0:-1] == ['YOU', '1620176400', 'Patron identifier'] + assert token.split("|")[0:-1] == ["YOU", "1620176400", "Patron identifier"] # Test with expiry set to 20 min - vendor_id, token = authdata.encode_short_client_token(patron_identifier, {'minutes': 20}) - assert token.split('|')[0:-1] == ['YOU', '1620174000', 'Patron identifier'] + vendor_id, token = authdata.encode_short_client_token( + patron_identifier, {"minutes": 20} + ) + assert token.split("|")[0:-1] == ["YOU", "1620174000", "Patron identifier"] # Test with expiry set to 2 days - vendor_id, token = authdata.encode_short_client_token(patron_identifier, {'days': 2}) - assert token.split('|')[0:-1] == ['YOU', '1620345600', 'Patron identifier'] + vendor_id, token = authdata.encode_short_client_token( + patron_identifier, {"days": 2} + ) + assert token.split("|")[0:-1] == ["YOU", "1620345600", "Patron identifier"] # Test with expiry set to 4 hours - vendor_id, token = authdata.encode_short_client_token(patron_identifier, {'hours': 4}) - assert token.split('|')[0:-1] == ['YOU', '1620187200', 'Patron identifier'] + vendor_id, token = authdata.encode_short_client_token( + patron_identifier, {"hours": 4} + ) + assert token.split("|")[0:-1] == ["YOU", "1620187200", "Patron identifier"] def test_decode_short_client_token_from_another_library(self): # Here's the AuthdataUtility used by another library. foreign_authdata = AuthdataUtility( - vendor_id = "The Vendor ID", - library_uri = "http://your-library.org/", - library_short_name = "you", - secret = "Your library secret", + vendor_id="The Vendor ID", + library_uri="http://your-library.org/", + library_short_name="you", + secret="Your library secret", ) patron_identifier = "Patron identifier" - vendor_id, token = foreign_authdata.encode_short_client_token( - patron_identifier - ) + vendor_id, token = foreign_authdata.encode_short_client_token(patron_identifier) # Because we know the other library's secret, we're able to # decode the authdata. @@ -919,10 +969,8 @@ def test_decode_short_client_token_from_another_library(self): # If our secret for a library doesn't match the other # library's short token signing key, we can't decode the # authdata. - foreign_authdata.short_token_signing_key = b'A new secret' - vendor_id, token = foreign_authdata.encode_short_client_token( - patron_identifier - ) + foreign_authdata.short_token_signing_key = b"A new secret" + vendor_id, token = foreign_authdata.encode_short_client_token(patron_identifier) with pytest.raises(ValueError) as excinfo: self.authdata.decode_short_client_token(token) assert "Invalid signature for" in str(excinfo.value) @@ -944,28 +992,34 @@ def test_decode_client_token_errors(self): # The patron identifier must not be blank. with pytest.raises(ValueError) as excinfo: m("library|1234|", "signature") - assert 'Token library|1234| has empty patron identifier' in str(excinfo.value) + assert "Token library|1234| has empty patron identifier" in str(excinfo.value) # The library must be a known one. with pytest.raises(ValueError) as excinfo: m("library|1234|patron", "signature") - assert 'I don\'t know how to handle tokens from library "LIBRARY"' in str(excinfo.value) + assert 'I don\'t know how to handle tokens from library "LIBRARY"' in str( + excinfo.value + ) # We must have the shared secret for the given library. - self.authdata.library_uris_by_short_name['LIBRARY'] = 'http://a-library.com/' + self.authdata.library_uris_by_short_name["LIBRARY"] = "http://a-library.com/" with pytest.raises(ValueError) as excinfo: m("library|1234|patron", "signature") - assert 'I don\'t know the secret for library http://a-library.com/' in str(excinfo.value) + assert "I don't know the secret for library http://a-library.com/" in str( + excinfo.value + ) # The token must not have expired. with pytest.raises(ValueError) as excinfo: m("mylibrary|1234|patron", "signature") - assert 'Token mylibrary|1234|patron expired at 1970-01-01 00:20:34' in str(excinfo.value) + assert "Token mylibrary|1234|patron expired at 1970-01-01 00:20:34" in str( + excinfo.value + ) # Finally, the signature must be valid. with pytest.raises(ValueError) as excinfo: m("mylibrary|99999999999|patron", "signature") - assert 'Invalid signature for' in str(excinfo.value) + assert "Invalid signature for" in str(excinfo.value) def test_adobe_base64_encode_decode(self): # Test our special variant of base64 encoding designed to avoid @@ -973,15 +1027,15 @@ def test_adobe_base64_encode_decode(self): value = "!\tFN6~'Es52?X!#)Z*_S" encoded = AuthdataUtility.adobe_base64_encode(value) - assert 'IQlGTjZ:J0VzNTI;WCEjKVoqX1M@' == encoded + assert "IQlGTjZ:J0VzNTI;WCEjKVoqX1M@" == encoded # This is like normal base64 encoding, but with a colon # replacing the plus character, a semicolon replacing the # slash, an at sign replacing the equal sign and the final # newline stripped. - assert ( - encoded.replace(":", "+").replace(";", "/").replace("@", "=") + "\n" == - base64.encodebytes(value.encode("utf-8")).decode("utf-8")) + assert encoded.replace(":", "+").replace(";", "/").replace( + "@", "=" + ) + "\n" == base64.encodebytes(value.encode("utf-8")).decode("utf-8") # We can reverse the encoding to get the original value. assert value == AuthdataUtility.adobe_base64_decode(encoded).decode("utf-8") @@ -993,22 +1047,23 @@ def sign(self, value, key): plus sign, a slash and an equal sign when base64-encoded. """ return "!\tFN6~'Es52?X!#)Z*_S" + self.authdata.short_token_signer = MockSigner() token = self.authdata._encode_short_client_token("lib", "1234", 0) # The signature part of the token has been encoded with our # custom encoding, not vanilla base64. - assert 'lib|0|1234|IQlGTjZ:J0VzNTI;WCEjKVoqX1M@' == token + assert "lib|0|1234|IQlGTjZ:J0VzNTI;WCEjKVoqX1M@" == token def test_decode_two_part_short_client_token_uses_adobe_base64_encoding(self): # The base64 encoding of this signature has a plus sign in it. - signature = 'LbU}66%\\-4zt>R>_)\n2Q' + signature = "LbU}66%\\-4zt>R>_)\n2Q" encoded_signature = AuthdataUtility.adobe_base64_encode(signature) # We replace the plus sign with a colon. - assert ':' in encoded_signature - assert '+' not in encoded_signature + assert ":" in encoded_signature + assert "+" not in encoded_signature # Make sure that decode_two_part_short_client_token properly # reverses that change when decoding the 'password'. @@ -1017,22 +1072,19 @@ def _decode_short_client_token(self, token, supposed_signature): assert supposed_signature.decode("utf-8") == signature self.test_code_ran = True - utility = MockAuthdataUtility( - vendor_id = "The Vendor ID", - library_uri = "http://your-library.org/", - library_short_name = "you", - secret = "Your library secret", + utility = MockAuthdataUtility( + vendor_id="The Vendor ID", + library_uri="http://your-library.org/", + library_short_name="you", + secret="Your library secret", ) utility.test_code_ran = False - utility.decode_two_part_short_client_token( - "username", encoded_signature - ) + utility.decode_two_part_short_client_token("username", encoded_signature) # The code in _decode_short_client_token ran. Since there was no # test failure, it ran successfully. assert True == utility.test_code_ran - # Tests of code that is used only in a migration script. This can # be deleted once # 20161102-adobe-id-is-delegated-patron-identifier.py is run on @@ -1047,14 +1099,16 @@ def test_migrate_adobe_id_noop(self): def test_migrate_adobe_id_success(self): from api.opds import CirculationManagerAnnotator + patron = self._patron() # This patron has a Credential containing their Adobe ID data_source = DataSource.lookup(self._db, DataSource.ADOBE) adobe_id = Credential( - patron=patron, data_source=data_source, + patron=patron, + data_source=data_source, type=AdobeVendorIDModel.VENDOR_ID_UUID_TOKEN_TYPE, - credential="My Adobe ID" + credential="My Adobe ID", ) # Run the migration. @@ -1066,8 +1120,7 @@ def test_migrate_adobe_id_success(self): # The new credential contains an anonymized patron identifier # used solely to connect the patron to their Adobe ID. - assert (AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER == - new_credential.type) + assert AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER == new_credential.type # We can use that identifier to look up a DelegatedPatronIdentifier # @@ -1075,9 +1128,13 @@ def explode(): # This method won't be called because the # DelegatedPatronIdentifier already exists. raise Exception() + identifier, is_new = DelegatedPatronIdentifier.get_one_or_create( - self._db, self.authdata.library_uri, new_credential.credential, - DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID, explode + self._db, + self.authdata.library_uri, + new_credential.credential, + DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID, + explode, ) assert delegated_identifier == identifier assert False == is_new @@ -1091,10 +1148,12 @@ def explode(): self.authdata.library_uri, new_credential.credential ) assert "My Adobe ID" == uuid - assert 'Delegated account ID My Adobe ID' == label + assert "Delegated account ID My Adobe ID" == label # If we run the migration again, nothing new happens. - new_credential_2, delegated_identifier_2 = self.authdata.migrate_adobe_id(patron) + new_credential_2, delegated_identifier_2 = self.authdata.migrate_adobe_id( + patron + ) assert new_credential == new_credential_2 assert delegated_identifier == delegated_identifier_2 assert 2 == len(patron.credentials) @@ -1102,18 +1161,17 @@ def explode(): self.authdata.library_uri, new_credential.credential ) assert "My Adobe ID" == uuid - assert 'Delegated account ID My Adobe ID' == label + assert "Delegated account ID My Adobe ID" == label class TestDeviceManagementRequestHandler(VendorIDTest): - def test_register_drm_device_identifier(self): credential = self._credential() handler = DeviceManagementRequestHandler(credential) handler.register_device("device1") - assert ( - ['device1'] == - [x.device_identifier for x in credential.drm_device_identifiers]) + assert ["device1"] == [ + x.device_identifier for x in credential.drm_device_identifiers + ] def test_register_drm_device_identifier_does_nothing_on_no_input(self): credential = self._credential() @@ -1154,12 +1212,14 @@ def test_device_list(self): class TestAdobeVendorIDController(VendorIDTest): - def test_create_authdata_handler(self): controller = AdobeVendorIDController( - self._db, self._default_library, self.TEST_VENDOR_ID, - self.TEST_NODE_VALUE, object() + self._db, + self._default_library, + self.TEST_VENDOR_ID, + self.TEST_NODE_VALUE, + object(), ) patron = self._patron() response = controller.create_authdata_handler(patron) diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 1a300ce0d6..d64259491c 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -1,24 +1,18 @@ -import json import datetime +import json + from pyld import jsonld +from api.annotations import AnnotationParser, AnnotationWriter +from api.problem_details import * +from core.model import Annotation, create from core.testing import DatabaseTest from core.util.datetime_helpers import utc_now -from .test_controller import ControllerTest -from core.model import ( - Annotation, - create, -) +from .test_controller import ControllerTest -from api.annotations import ( - AnnotationWriter, - AnnotationParser, -) -from api.problem_details import * class AnnotationTest(DatabaseTest): - def _patron(self): """Create a test patron who has opted in to annotation sync.""" patron = super(AnnotationTest, self)._patron() @@ -27,7 +21,6 @@ def _patron(self): class TestAnnotationWriter(AnnotationTest, ControllerTest): - def test_annotations_for(self): patron = self._patron() @@ -36,7 +29,8 @@ def test_annotations_for(self): identifier = self._identifier() annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=patron, identifier=identifier, motivation=Annotation.IDLING, @@ -48,14 +42,17 @@ def test_annotations_for(self): identifier2 = self._identifier() annotation2, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=patron, identifier=identifier2, motivation=Annotation.IDLING, ) # The patron has two annotations for different identifiers. - assert set([annotation, annotation2]) == set(AnnotationWriter.annotations_for(patron)) + assert set([annotation, annotation2]) == set( + AnnotationWriter.annotations_for(patron) + ) assert [annotation] == AnnotationWriter.annotations_for(patron, identifier) assert [annotation2] == AnnotationWriter.annotations_for(patron, identifier2) @@ -65,20 +62,23 @@ def test_annotation_container_for(self): with self.app.test_request_context("/"): container, timestamp = AnnotationWriter.annotation_container_for(patron) - assert (set([AnnotationWriter.JSONLD_CONTEXT, AnnotationWriter.LDP_CONTEXT]) == - set(container['@context'])) + assert set( + [AnnotationWriter.JSONLD_CONTEXT, AnnotationWriter.LDP_CONTEXT] + ) == set(container["@context"]) assert "annotations" in container["id"] - assert set(["BasicContainer", "AnnotationCollection"]) == set(container["type"]) + assert set(["BasicContainer", "AnnotationCollection"]) == set( + container["type"] + ) assert 0 == container["total"] first_page = container["first"] assert "AnnotationPage" == first_page["type"] # The page doesn't have a context, since it's in the container. - assert None == first_page.get('@context') + assert None == first_page.get("@context") # The patron doesn't have any annotations yet. - assert 0 == container['total'] + assert 0 == container["total"] # There's no timestamp since the container is empty. assert None == timestamp @@ -86,7 +86,8 @@ def test_annotation_container_for(self): # Now, add an annotation. identifier = self._identifier() annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=patron, identifier=identifier, motivation=Annotation.IDLING, @@ -96,22 +97,25 @@ def test_annotation_container_for(self): container, timestamp = AnnotationWriter.annotation_container_for(patron) # The context, type, and id stay the same. - assert (set([AnnotationWriter.JSONLD_CONTEXT, AnnotationWriter.LDP_CONTEXT]) == - set(container['@context'])) + assert set( + [AnnotationWriter.JSONLD_CONTEXT, AnnotationWriter.LDP_CONTEXT] + ) == set(container["@context"]) assert "annotations" in container["id"] assert identifier.identifier not in container["id"] - assert set(["BasicContainer", "AnnotationCollection"]) == set(container["type"]) + assert set(["BasicContainer", "AnnotationCollection"]) == set( + container["type"] + ) # But now there is one item. - assert 1 == container['total'] + assert 1 == container["total"] first_page = container["first"] - assert 1 == len(first_page['items']) + assert 1 == len(first_page["items"]) # The item doesn't have a context, since it's in the container. - first_item = first_page['items'][0] - assert None == first_item.get('@context') + first_item = first_page["items"][0] + assert None == first_item.get("@context") # The timestamp is the annotation's timestamp. assert annotation.timestamp == timestamp @@ -120,7 +124,7 @@ def test_annotation_container_for(self): annotation.active = False container, timestamp = AnnotationWriter.annotation_container_for(patron) - assert 0 == container['total'] + assert 0 == container["total"] assert None == timestamp def test_annotation_container_for_with_identifier(self): @@ -128,30 +132,36 @@ def test_annotation_container_for_with_identifier(self): identifier = self._identifier() with self.app.test_request_context("/"): - container, timestamp = AnnotationWriter.annotation_container_for(patron, identifier) + container, timestamp = AnnotationWriter.annotation_container_for( + patron, identifier + ) - assert (set([AnnotationWriter.JSONLD_CONTEXT, AnnotationWriter.LDP_CONTEXT]) == - set(container['@context'])) + assert set( + [AnnotationWriter.JSONLD_CONTEXT, AnnotationWriter.LDP_CONTEXT] + ) == set(container["@context"]) assert "annotations" in container["id"] assert identifier.identifier in container["id"] - assert set(["BasicContainer", "AnnotationCollection"]) == set(container["type"]) + assert set(["BasicContainer", "AnnotationCollection"]) == set( + container["type"] + ) assert 0 == container["total"] first_page = container["first"] assert "AnnotationPage" == first_page["type"] # The page doesn't have a context, since it's in the container. - assert None == first_page.get('@context') + assert None == first_page.get("@context") # The patron doesn't have any annotations yet. - assert 0 == container['total'] + assert 0 == container["total"] # There's no timestamp since the container is empty. assert None == timestamp # Now, add an annotation for this identifier, and one for a different identifier. annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=patron, identifier=identifier, motivation=Annotation.IDLING, @@ -159,31 +169,37 @@ def test_annotation_container_for_with_identifier(self): annotation.timestamp = utc_now() other_annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=patron, identifier=self._identifier(), motivation=Annotation.IDLING, ) - container, timestamp = AnnotationWriter.annotation_container_for(patron, identifier) + container, timestamp = AnnotationWriter.annotation_container_for( + patron, identifier + ) # The context, type, and id stay the same. - assert (set([AnnotationWriter.JSONLD_CONTEXT, AnnotationWriter.LDP_CONTEXT]) == - set(container['@context'])) + assert set( + [AnnotationWriter.JSONLD_CONTEXT, AnnotationWriter.LDP_CONTEXT] + ) == set(container["@context"]) assert "annotations" in container["id"] assert identifier.identifier in container["id"] - assert set(["BasicContainer", "AnnotationCollection"]) == set(container["type"]) + assert set(["BasicContainer", "AnnotationCollection"]) == set( + container["type"] + ) # But now there is one item. - assert 1 == container['total'] + assert 1 == container["total"] first_page = container["first"] - assert 1 == len(first_page['items']) + assert 1 == len(first_page["items"]) # The item doesn't have a context, since it's in the container. - first_item = first_page['items'][0] - assert None == first_item.get('@context') + first_item = first_page["items"][0] + assert None == first_item.get("@context") # The timestamp is the annotation's timestamp. assert annotation.timestamp == timestamp @@ -191,8 +207,10 @@ def test_annotation_container_for_with_identifier(self): # If the annotation is deleted, the container will be empty again. annotation.active = False - container, timestamp = AnnotationWriter.annotation_container_for(patron, identifier) - assert 0 == container['total'] + container, timestamp = AnnotationWriter.annotation_container_for( + patron, identifier + ) + assert 0 == container["total"] assert None == timestamp def test_annotation_page_for(self): @@ -202,15 +220,16 @@ def test_annotation_page_for(self): page = AnnotationWriter.annotation_page_for(patron) # The patron doesn't have any annotations, so the page is empty. - assert AnnotationWriter.JSONLD_CONTEXT == page['@context'] - assert 'annotations' in page['id'] - assert 'AnnotationPage' == page['type'] - assert 0 == len(page['items']) + assert AnnotationWriter.JSONLD_CONTEXT == page["@context"] + assert "annotations" in page["id"] + assert "AnnotationPage" == page["type"] + assert 0 == len(page["items"]) # If we add an annotation, the page will have an item. identifier = self._identifier() annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=patron, identifier=identifier, motivation=Annotation.IDLING, @@ -218,14 +237,14 @@ def test_annotation_page_for(self): page = AnnotationWriter.annotation_page_for(patron) - assert 1 == len(page['items']) + assert 1 == len(page["items"]) # But if the annotation is deleted, the page will be empty again. annotation.active = False page = AnnotationWriter.annotation_page_for(patron) - assert 0 == len(page['items']) + assert 0 == len(page["items"]) def test_annotation_page_for_with_identifier(self): patron = self._patron() @@ -235,55 +254,56 @@ def test_annotation_page_for_with_identifier(self): page = AnnotationWriter.annotation_page_for(patron, identifier) # The patron doesn't have any annotations, so the page is empty. - assert AnnotationWriter.JSONLD_CONTEXT == page['@context'] - assert 'annotations' in page['id'] - assert identifier.identifier in page['id'] - assert 'AnnotationPage' == page['type'] - assert 0 == len(page['items']) + assert AnnotationWriter.JSONLD_CONTEXT == page["@context"] + assert "annotations" in page["id"] + assert identifier.identifier in page["id"] + assert "AnnotationPage" == page["type"] + assert 0 == len(page["items"]) # If we add an annotation, the page will have an item. annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=patron, identifier=identifier, motivation=Annotation.IDLING, ) page = AnnotationWriter.annotation_page_for(patron, identifier) - assert 1 == len(page['items']) + assert 1 == len(page["items"]) # If a different identifier has an annotation, the page will still have one item. other_annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=patron, identifier=self._identifier(), motivation=Annotation.IDLING, ) page = AnnotationWriter.annotation_page_for(patron, identifier) - assert 1 == len(page['items']) + assert 1 == len(page["items"]) # But if the annotation is deleted, the page will be empty again. annotation.active = False page = AnnotationWriter.annotation_page_for(patron, identifier) - assert 0 == len(page['items']) + assert 0 == len(page["items"]) def test_detail_target(self): patron = self._patron() identifier = self._identifier() target = { - "http://www.w3.org/ns/oa#hasSource": { - "@id": identifier.urn - }, + "http://www.w3.org/ns/oa#hasSource": {"@id": identifier.urn}, "http://www.w3.org/ns/oa#hasSelector": { "@type": "http://www.w3.org/ns/oa#FragmentSelector", - "http://www.w3.org/1999/02/22-rdf-syntax-ns#value": "epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)" - } + "http://www.w3.org/1999/02/22-rdf-syntax-ns#value": "epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)", + }, } annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=patron, identifier=identifier, motivation=Annotation.IDLING, @@ -294,14 +314,14 @@ def test_detail_target(self): detail = AnnotationWriter.detail(annotation) assert "annotations/%i" % annotation.id in detail["id"] - assert "Annotation" == detail['type'] - assert Annotation.IDLING == detail['motivation'] + assert "Annotation" == detail["type"] + assert Annotation.IDLING == detail["motivation"] compacted_target = { "source": identifier.urn, "selector": { "type": "FragmentSelector", - "value": "epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)" - } + "value": "epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)", + }, } assert compacted_target == detail["target"] @@ -313,11 +333,12 @@ def test_detail_body(self): "http://www.w3.org/ns/oa#bodyValue": "A good description of the topic that bears further investigation", "http://www.w3.org/ns/oa#hasPurpose": { "@id": "http://www.w3.org/ns/oa#describing" - } + }, } annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=patron, identifier=identifier, motivation=Annotation.IDLING, @@ -328,12 +349,12 @@ def test_detail_body(self): detail = AnnotationWriter.detail(annotation) assert "annotations/%i" % annotation.id in detail["id"] - assert "Annotation" == detail['type'] - assert Annotation.IDLING == detail['motivation'] + assert "Annotation" == detail["type"] + assert Annotation.IDLING == detail["motivation"] compacted_body = { "type": "TextualBody", "bodyValue": "A good description of the topic that bears further investigation", - "purpose": "describing" + "purpose": "describing", } assert compacted_body == detail["body"] @@ -347,23 +368,25 @@ def setup_method(self): def _sample_jsonld(self, motivation=Annotation.IDLING): data = dict() - data["@context"] = [AnnotationWriter.JSONLD_CONTEXT, - {'ls': Annotation.LS_NAMESPACE}] + data["@context"] = [ + AnnotationWriter.JSONLD_CONTEXT, + {"ls": Annotation.LS_NAMESPACE}, + ] data["type"] = "Annotation" - motivation = motivation.replace(Annotation.LS_NAMESPACE, 'ls:') - motivation = motivation.replace(Annotation.OA_NAMESPACE, 'oa:') + motivation = motivation.replace(Annotation.LS_NAMESPACE, "ls:") + motivation = motivation.replace(Annotation.OA_NAMESPACE, "oa:") data["motivation"] = motivation data["body"] = { "type": "TextualBody", "bodyValue": "A good description of the topic that bears further investigation", - "purpose": "describing" + "purpose": "describing", } data["target"] = { "source": self.identifier.urn, "selector": { "type": "oa:FragmentSelector", - "value": "epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)" - } + "value": "epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)", + }, } return data @@ -375,10 +398,8 @@ def test_invalid_identifier(self): # If the target source can't be parsed as a URN we send # INVALID_ANNOTATION_TARGET data = self._sample_jsonld() - data['target']['source'] = 'not a URN' - annotation = AnnotationParser.parse( - self._db, json.dumps(data), self.patron - ) + data["target"]["source"] = "not a URN" + annotation = AnnotationParser.parse(self._db, json.dumps(data), self.patron) assert INVALID_ANNOTATION_TARGET == annotation def test_null_id(self): @@ -387,40 +408,44 @@ def test_null_id(self): # can't handle this, so we need to test it specially. self.pool.loan_to(self.patron) data = self._sample_jsonld() - data['id'] = None - annotation = AnnotationParser.parse( - self._db, json.dumps(data), self.patron - ) + data["id"] = None + annotation = AnnotationParser.parse(self._db, json.dumps(data), self.patron) assert isinstance(annotation, Annotation) def test_parse_expanded_jsonld(self): self.pool.loan_to(self.patron) data = dict() - data['@type'] = ["http://www.w3.org/ns/oa#Annotation"] - data["http://www.w3.org/ns/oa#motivatedBy"] = [{ - "@id": Annotation.IDLING - }] - data["http://www.w3.org/ns/oa#hasBody"] = [{ - "@type" : ["http://www.w3.org/ns/oa#TextualBody"], - "http://www.w3.org/ns/oa#bodyValue": [{ - "@value": "A good description of the topic that bears further investigation" - }], - "http://www.w3.org/ns/oa#hasPurpose": [{ - "@id": "http://www.w3.org/ns/oa#describing" - }] - }] - data["http://www.w3.org/ns/oa#hasTarget"] = [{ - "http://www.w3.org/ns/oa#hasSelector": [{ - "@type": ["http://www.w3.org/ns/oa#FragmentSelector"], - "http://www.w3.org/1999/02/22-rdf-syntax-ns#value": [{ - "@value": "epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)" - }] - }], - "http://www.w3.org/ns/oa#hasSource": [{ - "@id": self.identifier.urn - }], - }] + data["@type"] = ["http://www.w3.org/ns/oa#Annotation"] + data["http://www.w3.org/ns/oa#motivatedBy"] = [{"@id": Annotation.IDLING}] + data["http://www.w3.org/ns/oa#hasBody"] = [ + { + "@type": ["http://www.w3.org/ns/oa#TextualBody"], + "http://www.w3.org/ns/oa#bodyValue": [ + { + "@value": "A good description of the topic that bears further investigation" + } + ], + "http://www.w3.org/ns/oa#hasPurpose": [ + {"@id": "http://www.w3.org/ns/oa#describing"} + ], + } + ] + data["http://www.w3.org/ns/oa#hasTarget"] = [ + { + "http://www.w3.org/ns/oa#hasSelector": [ + { + "@type": ["http://www.w3.org/ns/oa#FragmentSelector"], + "http://www.w3.org/1999/02/22-rdf-syntax-ns#value": [ + { + "@value": "epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)" + } + ], + } + ], + "http://www.w3.org/ns/oa#hasSource": [{"@id": self.identifier.urn}], + } + ] data_json = json.dumps(data) @@ -429,32 +454,33 @@ def test_parse_expanded_jsonld(self): assert self.identifier.id == annotation.identifier_id assert Annotation.IDLING == annotation.motivation assert True == annotation.active - assert json.dumps(data["http://www.w3.org/ns/oa#hasTarget"][0]) == annotation.target - assert json.dumps(data["http://www.w3.org/ns/oa#hasBody"][0]) == annotation.content + assert ( + json.dumps(data["http://www.w3.org/ns/oa#hasTarget"][0]) + == annotation.target + ) + assert ( + json.dumps(data["http://www.w3.org/ns/oa#hasBody"][0]) == annotation.content + ) def test_parse_compacted_jsonld(self): self.pool.loan_to(self.patron) data = dict() data["@type"] = "http://www.w3.org/ns/oa#Annotation" - data["http://www.w3.org/ns/oa#motivatedBy"] = { - "@id": Annotation.IDLING - } + data["http://www.w3.org/ns/oa#motivatedBy"] = {"@id": Annotation.IDLING} data["http://www.w3.org/ns/oa#hasBody"] = { "@type": "http://www.w3.org/ns/oa#TextualBody", "http://www.w3.org/ns/oa#bodyValue": "A good description of the topic that bears further investigation", "http://www.w3.org/ns/oa#hasPurpose": { "@id": "http://www.w3.org/ns/oa#describing" - } + }, } data["http://www.w3.org/ns/oa#hasTarget"] = { - "http://www.w3.org/ns/oa#hasSource": { - "@id": self.identifier.urn - }, + "http://www.w3.org/ns/oa#hasSource": {"@id": self.identifier.urn}, "http://www.w3.org/ns/oa#hasSelector": { "@type": "http://www.w3.org/ns/oa#FragmentSelector", - "http://www.w3.org/1999/02/22-rdf-syntax-ns#value": "epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)" - } + "http://www.w3.org/1999/02/22-rdf-syntax-ns#value": "epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)", + }, } data_json = json.dumps(data) @@ -465,8 +491,14 @@ def test_parse_compacted_jsonld(self): assert self.identifier.id == annotation.identifier_id assert Annotation.IDLING == annotation.motivation assert True == annotation.active - assert json.dumps(expanded["http://www.w3.org/ns/oa#hasTarget"][0]) == annotation.target - assert json.dumps(expanded["http://www.w3.org/ns/oa#hasBody"][0]) == annotation.content + assert ( + json.dumps(expanded["http://www.w3.org/ns/oa#hasTarget"][0]) + == annotation.target + ) + assert ( + json.dumps(expanded["http://www.w3.org/ns/oa#hasBody"][0]) + == annotation.content + ) def test_parse_jsonld_with_context(self): self.pool.loan_to(self.patron) @@ -481,8 +513,14 @@ def test_parse_jsonld_with_context(self): assert self.identifier.id == annotation.identifier_id assert Annotation.IDLING == annotation.motivation assert True == annotation.active - assert json.dumps(expanded["http://www.w3.org/ns/oa#hasTarget"][0]) == annotation.target - assert json.dumps(expanded["http://www.w3.org/ns/oa#hasBody"][0]) == annotation.content + assert ( + json.dumps(expanded["http://www.w3.org/ns/oa#hasTarget"][0]) + == annotation.target + ) + assert ( + json.dumps(expanded["http://www.w3.org/ns/oa#hasBody"][0]) + == annotation.content + ) def test_parse_jsonld_with_bookmarking_motivation(self): """You can create multiple bookmarks in a single book.""" @@ -501,7 +539,9 @@ def test_parse_jsonld_with_bookmarking_motivation(self): # But unlike with IDLING, you _can_ create multiple bookmarks # for the same identifier, so long as the selector value # (ie. the location within the book) is different. - data['target']['selector']['value'] = 'epubcfi(/3/4[chap01ref]!/4[body01]/15[para05]/3:10)' + data["target"]["selector"][ + "value" + ] = "epubcfi(/3/4[chap01ref]!/4[body01]/15[para05]/3:10)" data_json = json.dumps(data) annotation3 = AnnotationParser.parse(self._db, data_json, self.patron) assert annotation3 != annotation @@ -528,7 +568,7 @@ def test_parse_jsonld_with_no_loan(self): def test_parse_jsonld_with_no_target(self): data = self._sample_jsonld() - del data['target'] + del data["target"] data_json = json.dumps(data) annotation = AnnotationParser.parse(self._db, data_json, self.patron) @@ -539,7 +579,8 @@ def test_parse_updates_existing_annotation(self): self.pool.loan_to(self.patron) original_annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron_id=self.patron.id, identifier_id=self.identifier.id, motivation=Annotation.IDLING, @@ -563,14 +604,16 @@ def test_parse_treats_duplicates_as_interchangeable(self): # Due to an earlier race condition, two duplicate annotations # were put in the database. a1, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron_id=self.patron.id, identifier_id=self.identifier.id, motivation=Annotation.IDLING, ) a2, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron_id=self.patron.id, identifier_id=self.identifier.id, motivation=Annotation.IDLING, @@ -591,8 +634,6 @@ def test_parse_jsonld_with_patron_opt_out(self): data = self._sample_jsonld() data_json = json.dumps(data) - self.patron.synchronize_annotations=False - annotation = AnnotationParser.parse( - self._db, data_json, self.patron - ) + self.patron.synchronize_annotations = False + annotation = AnnotationParser.parse(self._db, data_json, self.patron) assert PATRON_NOT_OPTED_IN_TO_ANNOTATION_SYNC == annotation diff --git a/tests/test_announcements.py b/tests/test_announcements.py index f62469fd4f..329e0d938f 100644 --- a/tests/test_announcements.py +++ b/tests/test_announcements.py @@ -1,13 +1,11 @@ import json -from api.testing import AnnouncementTest from api.admin.announcement_list_validator import AnnouncementListValidator -from api.announcements import ( - Announcements, - Announcement -) +from api.announcements import Announcement, Announcements +from api.testing import AnnouncementTest from core.testing import DatabaseTest + class TestAnnouncements(AnnouncementTest, DatabaseTest): """Test the Announcements object.""" @@ -35,8 +33,8 @@ def test_for_library(self): # treated as an empty list. In real life this would only # happen due to a bug or a bad bit of manually entered SQL. invalid = dict(self.active) - invalid['id'] = 'Another ID' - invalid['finish'] = 'Not a date' + invalid["id"] = "Another ID" + invalid["finish"] = "Not a date" setting.value = json.dumps([self.active, invalid, self.expired]) assert [] == Announcements.for_library(l).announcements @@ -57,7 +55,7 @@ def test_is_active(self): # An announcement that ends today is still active. expires_today = dict(self.active) - expires_today['finish'] = self.today + expires_today["finish"] = self.today assert True == Announcement(**self.active).is_active def test_for_authentication_document(self): @@ -66,8 +64,10 @@ def test_for_authentication_document(self): # 'start' and 'finish' will be ignored, as will the extra value # that has no meaning within Announcement. announcement = Announcement(extra="extra value", **self.active) - assert (dict(id="active", content="A sample announcement.") == - announcement.for_authentication_document) + assert ( + dict(id="active", content="A sample announcement.") + == announcement.for_authentication_document + ) def test_json_ready(self): # Demonstrate the form of an Announcement used to store in the database. @@ -79,7 +79,12 @@ def test_json_ready(self): dict( id="active", content="A sample announcement.", - start=announcement.start.strftime(AnnouncementListValidator.DATE_FORMAT), - finish=announcement.finish.strftime(AnnouncementListValidator.DATE_FORMAT), - ) == - announcement.json_ready) + start=announcement.start.strftime( + AnnouncementListValidator.DATE_FORMAT + ), + finish=announcement.finish.strftime( + AnnouncementListValidator.DATE_FORMAT + ), + ) + == announcement.json_ready + ) diff --git a/tests/test_authenticator.py b/tests/test_authenticator.py index 984c12520f..7435d11d9b 100644 --- a/tests/test_authenticator.py +++ b/tests/test_authenticator.py @@ -1,80 +1,72 @@ """Test the base authentication framework: that is, the classes that don't interact with any particular source of truth. """ -import pytest -from flask_babel import lazy_gettext as _ import datetime -from decimal import Decimal import json import os -from money import Money import re -import urllib.request, urllib.parse, urllib.error +import urllib.error import urllib.parse -import flask -from flask import url_for, Flask -from core.opds import OPDSFeed -from core.user_profile import ProfileController -from core.model import ( - CirculationEvent, - ConfigurationSetting, - Credential, - DataSource, - ExternalIntegration, - Library, - Patron, - create, - Session, -) +import urllib.request +from decimal import Decimal -from core.util.datetime_helpers import utc_now -from core.util.problem_detail import ( - ProblemDetail, -) -from core.util.authentication_for_opds import ( - AuthenticationForOPDSDocument, -) -from core.util.http import IntegrationException -from core.mock_analytics_provider import MockAnalyticsProvider +import flask +import pytest +from flask import Flask, url_for +from flask_babel import lazy_gettext as _ +from money import Money -from api.announcements import Announcements -from api.millenium_patron import MilleniumPatronAPI -from api.firstbook import FirstBookAuthenticationAPI -from api.clever import CleverAuthenticationAPI -from api.util.patron import PatronUtility from api.annotations import AnnotationWriter +from api.announcements import Announcements from api.authenticator import ( + AuthenticationProvider, Authenticator, + BasicAuthenticationProvider, CirculationPatronProfileStorage, LibraryAuthenticator, - AuthenticationProvider, - BasicAuthenticationProvider, - OAuthController, OAuthAuthenticationProvider, + OAuthController, PatronData, ) -from api.problem_details import PATRON_OF_ANOTHER_LIBRARY -from api.simple_authentication import SimpleAuthenticationProvider +from api.clever import CleverAuthenticationAPI +from api.config import CannotLoadConfiguration, Configuration, temp_config +from api.firstbook import FirstBookAuthenticationAPI from api.millenium_patron import MilleniumPatronAPI from api.opds import LibraryAnnotator - -from api.config import ( - CannotLoadConfiguration, - Configuration, - temp_config, -) - from api.problem_details import * +from api.problem_details import PATRON_OF_ANOTHER_LIBRARY +from api.simple_authentication import SimpleAuthenticationProvider from api.testing import VendorIDTest - +from api.util.patron import PatronUtility +from core.mock_analytics_provider import MockAnalyticsProvider +from core.model import ( + CirculationEvent, + ConfigurationSetting, + Credential, + DataSource, + ExternalIntegration, + Library, + Patron, + Session, + create, +) +from core.opds import OPDSFeed from core.testing import DatabaseTest +from core.user_profile import ProfileController +from core.util.authentication_for_opds import AuthenticationForOPDSDocument +from core.util.datetime_helpers import utc_now +from core.util.http import IntegrationException +from core.util.problem_detail import ProblemDetail + from .test_controller import ControllerTest + class MockAuthenticationProvider(object): """An AuthenticationProvider that always authenticates requests for the given Patron and always returns the given PatronData when asked to look up data. """ + def __init__(self, patron=None, patrondata=None): self.patron = patron self.patrondata = patrondata @@ -84,15 +76,25 @@ def authenticate(self, _db, header): class MockBasicAuthenticationProvider( - BasicAuthenticationProvider, - MockAuthenticationProvider + BasicAuthenticationProvider, MockAuthenticationProvider ): """A mock basic authentication provider for use in testing the overall authentication process. """ - def __init__(self, library, integration, analytics=None, patron=None, patrondata=None, *args, **kwargs): + + def __init__( + self, + library, + integration, + analytics=None, + patron=None, + patrondata=None, + *args, + **kwargs + ): super(MockBasicAuthenticationProvider, self).__init__( - library, integration, analytics, *args, **kwargs) + library, integration, analytics, *args, **kwargs + ) self.patron = patron self.patrondata = patrondata @@ -105,15 +107,25 @@ def remote_authenticate(self, username, password): def remote_patron_lookup(self, patrondata): return self.patrondata + class MockBasic(BasicAuthenticationProvider): """A second mock basic authentication provider for use in testing the workflow around Basic Auth. """ - NAME = 'Mock Basic Auth provider' + + NAME = "Mock Basic Auth provider" LOGIN_BUTTON_IMAGE = "BasicButton.png" - def __init__(self, library, integration, analytics=None, patrondata=None, - remote_patron_lookup_patrondata=None, - *args, **kwargs): + + def __init__( + self, + library, + integration, + analytics=None, + patrondata=None, + remote_patron_lookup_patrondata=None, + *args, + **kwargs + ): super(MockBasic, self).__init__(library, integration, analytics) self.patrondata = patrondata self.remote_patron_lookup_patrondata = remote_patron_lookup_patrondata @@ -126,12 +138,12 @@ def remote_patron_lookup(self, patrondata): class MockOAuthAuthenticationProvider( - OAuthAuthenticationProvider, - MockAuthenticationProvider + OAuthAuthenticationProvider, MockAuthenticationProvider ): """A mock OAuth authentication provider for use in testing the overall authentication process. """ + def __init__(self, library, provider_name, patron=None, patrondata=None): self.library_id = library.id self.NAME = provider_name @@ -146,6 +158,7 @@ class MockOAuth(OAuthAuthenticationProvider): """A second mock basic authentication provider for use in testing the workflow around OAuth. """ + URI = "http://example.org/" NAME = "Mock provider" TOKEN_TYPE = "test token" @@ -160,7 +173,9 @@ def __init__(self, library, name="Mock OAuth", integration=None, analytics=None) @classmethod def _mock_integration(self, _db, name): integration, ignore = create( - _db, ExternalIntegration, protocol="OAuth", + _db, + ExternalIntegration, + protocol="OAuth", goal=ExternalIntegration.PATRON_AUTH_GOAL, ) integration.username = name @@ -170,7 +185,6 @@ def _mock_integration(self, _db, name): class AuthenticatorTest(DatabaseTest): - def mock_basic(self, *args, **kwargs): """Convenience method to instantiate a MockBasic object with the default library. @@ -184,7 +198,6 @@ def mock_basic(self, *args, **kwargs): class TestPatronData(AuthenticatorTest): - def setup_method(self): super(TestPatronData, self).setup_method() self.expiration_time = utc_now() @@ -211,26 +224,26 @@ def test_to_dict(self): email_address="5", authorization_expires=self.expiration_time.strftime("%Y-%m-%d"), fines="6", - block_reason=None + block_reason=None, ) assert data == expect # Test with an empty fines field self.data.fines = PatronData.NO_VALUE data = self.data.to_dict - expect['fines'] = None + expect["fines"] = None assert data == expect # Test with a zeroed-out fines field self.data.fines = Decimal(0.0) data = self.data.to_dict - expect['fines'] = '0' + expect["fines"] = "0" assert data == expect # Test with an empty expiration time self.data.authorization_expires = PatronData.NO_VALUE data = self.data.to_dict - expect['authorization_expires'] = None + expect["authorization_expires"] = None assert data == expect def test_apply(self): @@ -248,9 +261,9 @@ def test_apply(self): # This data is stored in PatronData but not applied to Patron. assert "4" == self.data.personal_name - assert False == hasattr(patron, 'personal_name') + assert False == hasattr(patron, "personal_name") assert "5" == self.data.email_address - assert False == hasattr(patron, 'email_address') + assert False == hasattr(patron, "email_address") # This data is stored on the Patron object as a convenience, # but it's not stored in the database. @@ -271,28 +284,19 @@ def test_apply_multiple_authorization_identifiers(self): """ patron = self._patron() patron.authorization_identifier = None - data = PatronData( - authorization_identifier=["2", "3"], - complete=True - ) + data = PatronData(authorization_identifier=["2", "3"], complete=True) data.apply(patron) assert "2" == patron.authorization_identifier # If Patron.authorization_identifier is already set, it will # not be changed, so long as its current value is acceptable. - data = PatronData( - authorization_identifier=["3", "2"], - complete=True - ) + data = PatronData(authorization_identifier=["3", "2"], complete=True) data.apply(patron) assert "2" == patron.authorization_identifier # If Patron.authorization_identifier ever turns out not to be # an acceptable value, it will be changed. - data = PatronData( - authorization_identifier=["3", "4"], - complete=True - ) + data = PatronData(authorization_identifier=["3", "4"], complete=True) data.apply(patron) assert "3" == patron.authorization_identifier @@ -327,9 +331,9 @@ def test_apply_leaves_valid_authorization_identifier_alone(self): """ patron = self._patron() patron.authorization_identifier = "old identifier" - self.data.set_authorization_identifier([ - "new identifier", patron.authorization_identifier - ]) + self.data.set_authorization_identifier( + ["new identifier", patron.authorization_identifier] + ) self.data.apply(patron) assert "old identifier" == patron.authorization_identifier @@ -339,9 +343,7 @@ def test_apply_overwrites_invalid_authorization_identifier(self): authorization identifier that no longer works, we change it. """ patron = self._patron() - self.data.set_authorization_identifier([ - "identifier 1", "identifier 2" - ]) + self.data.set_authorization_identifier(["identifier 1", "identifier 2"]) self.data.apply(patron) assert "identifier 1" == patron.authorization_identifier @@ -358,9 +360,7 @@ def test_apply_on_incomplete_information(self): # Patron.authorization_identifier to that string but we also # indicate that we need to perform an external sync on them # ASAP. - authenticated = PatronData( - authorization_identifier="1234", complete=False - ) + authenticated = PatronData(authorization_identifier="1234", complete=False) patron = self._patron() patron.authorization_identifier = None patron.last_external_sync = now @@ -401,7 +401,7 @@ def test_get_or_create_patron(self): patron, is_new = self.data.get_or_create_patron( self._db, self._default_library.id, analytics ) - assert '2' == patron.authorization_identifier + assert "2" == patron.authorization_identifier assert self._default_library == patron.library assert True == is_new assert CirculationEvent.NEW_PATRON == analytics.event_type @@ -419,7 +419,7 @@ def test_get_or_create_patron(self): patron, is_new = self.data.get_or_create_patron( self._db, self._default_library.id, analytics ) - assert '2' == patron.authorization_identifier + assert "2" == patron.authorization_identifier assert False == is_new assert "Achewood" == patron.neighborhood assert 1 == analytics.count @@ -435,38 +435,54 @@ def test_to_response_parameters(self): class TestCirculationPatronProfileStorage(ControllerTest): - def test_profile_document(self): def mock_url_for(endpoint, library_short_name, _external=True): - return "http://host/" + endpoint + "?" + "library_short_name=" + library_short_name + return ( + "http://host/" + + endpoint + + "?" + + "library_short_name=" + + library_short_name + ) patron = self._patron() storage = CirculationPatronProfileStorage(patron, mock_url_for) doc = storage.profile_document - assert 'settings' in doc - #Since there's no authdata configured, the DRM fields are not present - assert 'drm:vendor' not in doc - assert 'drm:clientToken' not in doc - assert 'drm:scheme' not in doc - assert 'links' not in doc - - #Now there's authdata configured, and the DRM fields are populated with - #the vendor ID and a short client token + assert "settings" in doc + # Since there's no authdata configured, the DRM fields are not present + assert "drm:vendor" not in doc + assert "drm:clientToken" not in doc + assert "drm:scheme" not in doc + assert "links" not in doc + + # Now there's authdata configured, and the DRM fields are populated with + # the vendor ID and a short client token self.initialize_adobe(patron.library) doc = storage.profile_document - [adobe] = doc['drm'] + [adobe] = doc["drm"] assert adobe["drm:vendor"] == "vendor id" assert adobe["drm:clientToken"].startswith( patron.library.short_name.upper() + "TOKEN" ) - assert adobe["drm:scheme"] == "http://librarysimplified.org/terms/drm/scheme/ACS" - [device_link, annotations_link] = doc['links'] - assert device_link['rel'] == "http://librarysimplified.org/terms/drm/rel/devices" - assert device_link['href'] == "http://host/adobe_drm_devices?library_short_name=default" - assert annotations_link['rel'] == "http://www.w3.org/ns/oa#annotationService" - assert annotations_link['href'] == "http://host/annotations?library_short_name=default" - assert annotations_link['type'] == AnnotationWriter.CONTENT_TYPE + assert ( + adobe["drm:scheme"] == "http://librarysimplified.org/terms/drm/scheme/ACS" + ) + [device_link, annotations_link] = doc["links"] + assert ( + device_link["rel"] == "http://librarysimplified.org/terms/drm/rel/devices" + ) + assert ( + device_link["href"] + == "http://host/adobe_drm_devices?library_short_name=default" + ) + assert annotations_link["rel"] == "http://www.w3.org/ns/oa#annotationService" + assert ( + annotations_link["href"] + == "http://host/annotations?library_short_name=default" + ) + assert annotations_link["type"] == AnnotationWriter.CONTENT_TYPE + class MockAuthenticator(Authenticator): """Allows testing Authenticator methods outside of a request context.""" @@ -486,12 +502,11 @@ def current_library_short_name(self): class TestAuthenticator(ControllerTest): - def test_init(self): # The default library has already been configured to use the # SimpleAuthenticationProvider for its basic auth. l1 = self._default_library - l1.short_name = 'l1' + l1.short_name = "l1" # This library uses Millenium Patron. l2, ignore = create(self._db, Library, short_name="l2") @@ -508,53 +523,62 @@ def test_init(self): auth = Authenticator(self._db, analytics) # A LibraryAuthenticator has been created for each Library. - assert 'l1' in auth.library_authenticators - assert 'l2' in auth.library_authenticators - assert isinstance(auth.library_authenticators['l1'], LibraryAuthenticator) - assert isinstance(auth.library_authenticators['l2'], LibraryAuthenticator) + assert "l1" in auth.library_authenticators + assert "l2" in auth.library_authenticators + assert isinstance(auth.library_authenticators["l1"], LibraryAuthenticator) + assert isinstance(auth.library_authenticators["l2"], LibraryAuthenticator) # Each LibraryAuthenticator has been associated with an # appropriate AuthenticationProvider. assert isinstance( - auth.library_authenticators['l1'].basic_auth_provider, - SimpleAuthenticationProvider + auth.library_authenticators["l1"].basic_auth_provider, + SimpleAuthenticationProvider, ) assert isinstance( - auth.library_authenticators['l2'].basic_auth_provider, - MilleniumPatronAPI + auth.library_authenticators["l2"].basic_auth_provider, MilleniumPatronAPI ) # Each provider has the analytics set. - assert analytics == auth.library_authenticators['l1'].basic_auth_provider.analytics - assert analytics == auth.library_authenticators['l2'].basic_auth_provider.analytics + assert ( + analytics == auth.library_authenticators["l1"].basic_auth_provider.analytics + ) + assert ( + analytics == auth.library_authenticators["l2"].basic_auth_provider.analytics + ) def test_methods_call_library_authenticators(self): class MockLibraryAuthenticator(LibraryAuthenticator): def __init__(self, name): self.name = name + def authenticated_patron(self, _db, header): return "authenticated patron for %s" % self.name + def create_authentication_document(self): return "authentication document for %s" % self.name + def create_authentication_headers(self): return "authentication headers for %s" % self.name + def get_credential_from_header(self, header): return "credential for %s" % self.name + def create_bearer_token(self, *args, **kwargs): return "bearer token for %s" % self.name + def oauth_provider_lookup(self, *args, **kwargs): return "oauth provider for %s" % self.name + def decode_bearer_token(self, *args, **kwargs): return "decoded bearer token for %s" % self.name - l1, ignore = create(self._db, Library, short_name="l1") l2, ignore = create(self._db, Library, short_name="l2") auth = Authenticator(self._db) - auth.library_authenticators['l1'] = MockLibraryAuthenticator("l1") - auth.library_authenticators['l2'] = MockLibraryAuthenticator("l2") + auth.library_authenticators["l1"] = MockLibraryAuthenticator("l1") + auth.library_authenticators["l2"] = MockLibraryAuthenticator("l2") # This new library isn't in the authenticator. l3, ignore = create(self._db, Library, short_name="l3") @@ -571,9 +595,16 @@ def decode_bearer_token(self, *args, **kwargs): # The other libraries are in the authenticator. with self.app.test_request_context("/"): flask.request.library = l1 - assert "authenticated patron for l1" == auth.authenticated_patron(self._db, {}) - assert "authentication document for l1" == auth.create_authentication_document() - assert "authentication headers for l1" == auth.create_authentication_headers() + assert "authenticated patron for l1" == auth.authenticated_patron( + self._db, {} + ) + assert ( + "authentication document for l1" + == auth.create_authentication_document() + ) + assert ( + "authentication headers for l1" == auth.create_authentication_headers() + ) assert "credential for l1" == auth.get_credential_from_header({}) assert "bearer token for l1" == auth.create_bearer_token() assert "oauth provider for l1" == auth.oauth_provider_lookup() @@ -581,9 +612,16 @@ def decode_bearer_token(self, *args, **kwargs): with self.app.test_request_context("/"): flask.request.library = l2 - assert "authenticated patron for l2" == auth.authenticated_patron(self._db, {}) - assert "authentication document for l2" == auth.create_authentication_document() - assert "authentication headers for l2" == auth.create_authentication_headers() + assert "authenticated patron for l2" == auth.authenticated_patron( + self._db, {} + ) + assert ( + "authentication document for l2" + == auth.create_authentication_document() + ) + assert ( + "authentication headers for l2" == auth.create_authentication_headers() + ) assert "credential for l2" == auth.get_credential_from_header({}) assert "bearer token for l2" == auth.create_bearer_token() assert "oauth provider for l2" == auth.oauth_provider_lookup() @@ -591,12 +629,12 @@ def decode_bearer_token(self, *args, **kwargs): class TestLibraryAuthenticator(AuthenticatorTest): - def test_from_config_basic_auth_only(self): # Only a basic auth provider. millenium = self._external_integration( - "api.millenium_patron", ExternalIntegration.PATRON_AUTH_GOAL, - libraries=[self._default_library] + "api.millenium_patron", + ExternalIntegration.PATRON_AUTH_GOAL, + libraries=[self._default_library], ) millenium.url = "http://url/" auth = LibraryAuthenticator.from_config(self._db, self._default_library) @@ -609,14 +647,16 @@ def test_from_config_basic_auth_and_oauth(self): library = self._default_library # A basic auth provider and an oauth provider. firstbook = self._external_integration( - "api.firstbook", ExternalIntegration.PATRON_AUTH_GOAL, + "api.firstbook", + ExternalIntegration.PATRON_AUTH_GOAL, ) firstbook.url = "http://url/" firstbook.password = "secret" library.integrations.append(firstbook) oauth = self._external_integration( - "api.clever", ExternalIntegration.PATRON_AUTH_GOAL, + "api.clever", + ExternalIntegration.PATRON_AUTH_GOAL, ) oauth.username = "client_id" oauth.password = "client_secret" @@ -626,14 +666,11 @@ def test_from_config_basic_auth_and_oauth(self): auth = LibraryAuthenticator.from_config(self._db, library, analytics) assert auth.basic_auth_provider != None - assert isinstance(auth.basic_auth_provider, - FirstBookAuthenticationAPI) + assert isinstance(auth.basic_auth_provider, FirstBookAuthenticationAPI) assert analytics == auth.basic_auth_provider.analytics assert 1 == len(auth.oauth_providers_by_name) - clever = auth.oauth_providers_by_name[ - CleverAuthenticationAPI.NAME - ] + clever = auth.oauth_providers_by_name[CleverAuthenticationAPI.NAME] assert isinstance(clever, CleverAuthenticationAPI) assert analytics == clever.analytics @@ -642,6 +679,7 @@ def test_with_custom_patron_catalog(self): include instantiation of a CustomPatronCatalog. """ mock_catalog = object() + class MockCustomPatronCatalog(object): @classmethod def for_library(self, library): @@ -649,8 +687,9 @@ def for_library(self, library): return mock_catalog authenticator = LibraryAuthenticator.from_config( - self._db, self._default_library, - custom_catalog_source=MockCustomPatronCatalog + self._db, + self._default_library, + custom_catalog_source=MockCustomPatronCatalog, ) assert self._default_library == MockCustomPatronCatalog.called_with @@ -679,7 +718,8 @@ def test_configuration_exception_during_from_config_stored(self): # Create an integration destined to raise CannotLoadConfiguration.. misconfigured = self._external_integration( - "api.firstbook", ExternalIntegration.PATRON_AUTH_GOAL, + "api.firstbook", + ExternalIntegration.PATRON_AUTH_GOAL, ) # ... and one destined to raise ImportError. @@ -698,20 +738,21 @@ def test_configuration_exception_during_from_config_stored(self): # initialization_exceptions. not_configured = auth.initialization_exceptions[misconfigured.id] assert isinstance(not_configured, CannotLoadConfiguration) - assert 'First Book server not configured.' == str(not_configured) + assert "First Book server not configured." == str(not_configured) not_found = auth.initialization_exceptions[unknown.id] assert isinstance(not_found, ImportError) assert "No module named 'unknown protocol'" == str(not_found) def test_register_fails_when_integration_has_wrong_goal(self): - integration = self._external_integration( - "protocol", "some other goal" - ) + integration = self._external_integration("protocol", "some other goal") auth = LibraryAuthenticator(_db=self._db, library=self._default_library) with pytest.raises(CannotLoadConfiguration) as excinfo: auth.register_provider(integration) - assert "Was asked to register an integration with goal=some other goal as though it were a way of authenticating patrons." in str(excinfo.value) + assert ( + "Was asked to register an integration with goal=some other goal as though it were a way of authenticating patrons." + in str(excinfo.value) + ) def test_register_fails_when_integration_not_associated_with_library(self): integration = self._external_integration( @@ -720,10 +761,15 @@ def test_register_fails_when_integration_not_associated_with_library(self): auth = LibraryAuthenticator(_db=self._db, library=self._default_library) with pytest.raises(CannotLoadConfiguration) as excinfo: auth.register_provider(integration) - assert "Was asked to register an integration with library {}, which doesn't use it."\ - .format(self._default_library.name) in str(excinfo.value) + assert "Was asked to register an integration with library {}, which doesn't use it.".format( + self._default_library.name + ) in str( + excinfo.value + ) - def test_register_fails_when_integration_module_does_not_contain_provider_class(self): + def test_register_fails_when_integration_module_does_not_contain_provider_class( + self, + ): library = self._default_library integration = self._external_integration( "api.lanes", ExternalIntegration.PATRON_AUTH_GOAL @@ -732,9 +778,14 @@ def test_register_fails_when_integration_module_does_not_contain_provider_class( auth = LibraryAuthenticator(_db=self._db, library=library) with pytest.raises(CannotLoadConfiguration) as excinfo: auth.register_provider(integration) - assert "Loaded module api.lanes but could not find a class called AuthenticationProvider inside." in str(excinfo.value) + assert ( + "Loaded module api.lanes but could not find a class called AuthenticationProvider inside." + in str(excinfo.value) + ) - def test_register_provider_fails_but_does_not_explode_on_remote_integration_error(self): + def test_register_provider_fails_but_does_not_explode_on_remote_integration_error( + self, + ): library = self._default_library # We're going to instantiate the a mock authentication provider that # immediately raises a RemoteIntegrationException, which will become @@ -747,25 +798,28 @@ def test_register_provider_fails_but_does_not_explode_on_remote_integration_erro with pytest.raises(CannotLoadConfiguration) as excinfo: auth.register_provider(integration) assert "Could not instantiate" in str(excinfo.value) - assert "authentication provider for library {}, possibly due to a network connection problem."\ - .format(self._default_library.name) in str(excinfo.value) + assert "authentication provider for library {}, possibly due to a network connection problem.".format( + self._default_library.name + ) in str( + excinfo.value + ) def test_register_provider_basic_auth(self): firstbook = self._external_integration( - "api.firstbook", ExternalIntegration.PATRON_AUTH_GOAL, + "api.firstbook", + ExternalIntegration.PATRON_AUTH_GOAL, ) firstbook.url = "http://url/" firstbook.password = "secret" self._default_library.integrations.append(firstbook) auth = LibraryAuthenticator(_db=self._db, library=self._default_library) auth.register_provider(firstbook) - assert isinstance( - auth.basic_auth_provider, FirstBookAuthenticationAPI - ) + assert isinstance(auth.basic_auth_provider, FirstBookAuthenticationAPI) def test_register_oauth_provider(self): oauth = self._external_integration( - "api.clever", ExternalIntegration.PATRON_AUTH_GOAL, + "api.clever", + ExternalIntegration.PATRON_AUTH_GOAL, ) oauth.username = "client_id" oauth.password = "client_secret" @@ -773,27 +827,19 @@ def test_register_oauth_provider(self): auth = LibraryAuthenticator(_db=self._db, library=self._default_library) auth.register_provider(oauth) assert 1 == len(auth.oauth_providers_by_name) - clever = auth.oauth_providers_by_name[ - CleverAuthenticationAPI.NAME - ] + clever = auth.oauth_providers_by_name[CleverAuthenticationAPI.NAME] assert isinstance(clever, CleverAuthenticationAPI) def test_oauth_provider_requires_secret(self): integration = self._external_integration(self._str) - basic = MockBasicAuthenticationProvider( - self._default_library, integration - ) - oauth = MockOAuthAuthenticationProvider( - self._default_library, "provider1" - ) + basic = MockBasicAuthenticationProvider(self._default_library, integration) + oauth = MockOAuthAuthenticationProvider(self._default_library, "provider1") # You can create an Authenticator that only uses Basic Auth # without providing a secret. LibraryAuthenticator( - _db=self._db, - library=self._default_library, - basic_auth_provider=basic + _db=self._db, library=self._default_library, basic_auth_provider=basic ) # You can create an Authenticator that uses OAuth if you @@ -801,14 +847,20 @@ def test_oauth_provider_requires_secret(self): LibraryAuthenticator( _db=self._db, library=self._default_library, - oauth_providers=[oauth], bearer_token_signing_secret="foo" + oauth_providers=[oauth], + bearer_token_signing_secret="foo", ) # But you can't create an Authenticator that uses OAuth # without providing a secret. with pytest.raises(CannotLoadConfiguration) as excinfo: - LibraryAuthenticator(_db=self._db, library=self._default_library, oauth_providers=[oauth]) - assert "OAuth providers are configured, but secret for signing bearer tokens is not." in str(excinfo.value) + LibraryAuthenticator( + _db=self._db, library=self._default_library, oauth_providers=[oauth] + ) + assert ( + "OAuth providers are configured, but secret for signing bearer tokens is not." + in str(excinfo.value) + ) def test_supports_patron_authentication(self): authenticator = LibraryAuthenticator.from_config( @@ -836,7 +888,8 @@ def test_identifies_individuals(self): # This LibraryAuthenticator does not authenticate patrons at # all, so it does not identify patrons as individuals. authenticator = LibraryAuthenticator( - _db=self._db, library=self._default_library, + _db=self._db, + library=self._default_library, ) # This LibraryAuthenticator has two Authenticators, but @@ -844,12 +897,15 @@ def test_identifies_individuals(self): class MockAuthenticator(object): NAME = "mock" IDENTIFIES_INDIVIDUALS = False + basic = MockAuthenticator() oauth = MockAuthenticator() authenticator = LibraryAuthenticator( - _db=self._db, library=self._default_library, - basic_auth_provider=basic, oauth_providers=[oauth], - bearer_token_signing_secret=self._str + _db=self._db, + library=self._default_library, + basic_auth_provider=basic, + oauth_providers=[oauth], + bearer_token_signing_secret=self._str, ) assert False == authenticator.identifies_individuals @@ -864,20 +920,18 @@ class MockAuthenticator(object): oauth.IDENTIFIES_INDIVIDUALS = True assert True == authenticator.identifies_individuals - def test_providers(self): integration = self._external_integration(self._str) - basic = MockBasicAuthenticationProvider( - self._default_library, integration - ) + basic = MockBasicAuthenticationProvider(self._default_library, integration) oauth1 = MockOAuthAuthenticationProvider(self._default_library, "provider1") oauth2 = MockOAuthAuthenticationProvider(self._default_library, "provider2") authenticator = LibraryAuthenticator( _db=self._db, library=self._default_library, - basic_auth_provider=basic, oauth_providers=[oauth1, oauth2], - bearer_token_signing_secret='foo' + basic_auth_provider=basic, + oauth_providers=[oauth1, oauth2], + bearer_token_signing_secret="foo", ) assert [basic, oauth1, oauth2] == list(authenticator.providers) @@ -890,18 +944,16 @@ def test_provider_registration(self): authenticator = LibraryAuthenticator( _db=self._db, library=self._default_library, - bearer_token_signing_secret='foo' + bearer_token_signing_secret="foo", ) integration = self._external_integration(self._str) - basic1 = MockBasicAuthenticationProvider( - self._default_library, integration - ) - basic2 = MockBasicAuthenticationProvider( - self._default_library, integration - ) + basic1 = MockBasicAuthenticationProvider(self._default_library, integration) + basic2 = MockBasicAuthenticationProvider(self._default_library, integration) oauth1 = MockOAuthAuthenticationProvider(self._default_library, "provider1") oauth2 = MockOAuthAuthenticationProvider(self._default_library, "provider2") - oauth1_dupe = MockOAuthAuthenticationProvider(self._default_library, "provider1") + oauth1_dupe = MockOAuthAuthenticationProvider( + self._default_library, "provider1" + ) authenticator.register_basic_auth_provider(basic1) authenticator.register_basic_auth_provider(basic1) @@ -916,19 +968,17 @@ def test_provider_registration(self): with pytest.raises(CannotLoadConfiguration) as excinfo: authenticator.register_oauth_provider(oauth1_dupe) - assert 'Two different OAuth providers claim the name "provider1"' in str(excinfo.value) + assert 'Two different OAuth providers claim the name "provider1"' in str( + excinfo.value + ) def test_oauth_provider_lookup(self): # If there are no OAuth providers we cannot look one up. integration = self._external_integration(self._str) - basic = MockBasicAuthenticationProvider( - self._default_library, integration - ) + basic = MockBasicAuthenticationProvider(self._default_library, integration) authenticator = LibraryAuthenticator( - _db=self._db, - library=self._default_library, - basic_auth_provider=basic + _db=self._db, library=self._default_library, basic_auth_provider=basic ) problem = authenticator.oauth_provider_lookup("provider1") assert problem.uri == UNKNOWN_OAUTH_PROVIDER.uri @@ -942,7 +992,7 @@ def test_oauth_provider_lookup(self): _db=self._db, library=self._default_library, oauth_providers=[oauth1, oauth2], - bearer_token_signing_secret='foo' + bearer_token_signing_secret="foo", ) provider = authenticator.oauth_provider_lookup("provider1") @@ -951,31 +1001,30 @@ def test_oauth_provider_lookup(self): problem = authenticator.oauth_provider_lookup("provider3") assert problem.uri == UNKNOWN_OAUTH_PROVIDER.uri assert ( - _("The specified OAuth provider name isn't one of the known providers. The known providers are: provider1, provider2") == - problem.detail) + _( + "The specified OAuth provider name isn't one of the known providers. The known providers are: provider1, provider2" + ) + == problem.detail + ) def test_authenticated_patron_basic(self): patron = self._patron() patrondata = PatronData( permanent_id=patron.external_identifier, authorization_identifier=patron.authorization_identifier, - username=patron.username, neighborhood="Achewood" + username=patron.username, + neighborhood="Achewood", ) integration = self._external_integration(self._str) basic = MockBasicAuthenticationProvider( - self._default_library, integration, patron=patron, - patrondata=patrondata + self._default_library, integration, patron=patron, patrondata=patrondata ) authenticator = LibraryAuthenticator( - _db=self._db, - library=self._default_library, - basic_auth_provider=basic + _db=self._db, library=self._default_library, basic_auth_provider=basic + ) + assert patron == authenticator.authenticated_patron( + self._db, dict(username="foo", password="bar") ) - assert ( - patron == - authenticator.authenticated_patron( - self._db, dict(username="foo", password="bar") - )) # Neighborhood information is being temporarily stored in the # Patron object for use elsewhere in request processing. It @@ -984,27 +1033,27 @@ def test_authenticated_patron_basic(self): assert "Achewood" == patron.neighborhood # OAuth doesn't work. - problem = authenticator.authenticated_patron( - self._db, "Bearer abcd" - ) + problem = authenticator.authenticated_patron(self._db, "Bearer abcd") assert UNSUPPORTED_AUTHENTICATION_MECHANISM == problem def test_authenticated_patron_oauth(self): patron1 = self._patron() patron2 = self._patron() - oauth1 = MockOAuthAuthenticationProvider(self._default_library, "oauth1", patron=patron1) - oauth2 = MockOAuthAuthenticationProvider(self._default_library, "oauth2", patron=patron2) + oauth1 = MockOAuthAuthenticationProvider( + self._default_library, "oauth1", patron=patron1 + ) + oauth2 = MockOAuthAuthenticationProvider( + self._default_library, "oauth2", patron=patron2 + ) authenticator = LibraryAuthenticator( _db=self._db, library=self._default_library, oauth_providers=[oauth1, oauth2], - bearer_token_signing_secret='foo' + bearer_token_signing_secret="foo", ) # Ask oauth1 to create a bearer token. - token = authenticator.create_bearer_token( - oauth1.NAME, "some token" - ) + token = authenticator.create_bearer_token(oauth1.NAME, "some token") # The authenticator will decode the bearer token into a # provider and a provider token. It will look up the oauth1 @@ -1012,9 +1061,7 @@ def test_authenticated_patron_oauth(self): # the provider token. # # This gives us patron1, as opposed to patron2. - authenticated = authenticator.authenticated_patron( - self._db, "Bearer " + token - ) + authenticated = authenticator.authenticated_patron(self._db, "Bearer " + token) assert patron1 == authenticated # Basic auth doesn't work. @@ -1028,9 +1075,7 @@ def test_authenticated_patron_unsupported_mechanism(self): _db=self._db, library=self._default_library, ) - problem = authenticator.authenticated_patron( - self._db, object() - ) + problem = authenticator.authenticated_patron(self._db, object()) assert UNSUPPORTED_AUTHENTICATION_MECHANISM == problem def test_get_credential_from_header(self): @@ -1043,24 +1088,23 @@ def test_get_credential_from_header(self): authenticator = LibraryAuthenticator( _db=self._db, library=self._default_library, - basic_auth_provider=basic, oauth_providers=[oauth], - bearer_token_signing_secret="secret" + basic_auth_provider=basic, + oauth_providers=[oauth], + bearer_token_signing_secret="secret", ) credential = dict(password="foo") - assert ("foo" == - authenticator.get_credential_from_header(credential)) + assert "foo" == authenticator.get_credential_from_header(credential) # We can't pull the password out if only OAuth authentication # providers are configured. authenticator = LibraryAuthenticator( _db=self._db, library=self._default_library, - basic_auth_provider=None, oauth_providers=[oauth], - bearer_token_signing_secret="secret" + basic_auth_provider=None, + oauth_providers=[oauth], + bearer_token_signing_secret="secret", ) - assert (None == - authenticator.get_credential_from_header(credential)) - + assert None == authenticator.get_credential_from_header(credential) def test_create_bearer_token(self): oauth1 = MockOAuthAuthenticationProvider(self._default_library, "oauth1") @@ -1069,13 +1113,15 @@ def test_create_bearer_token(self): _db=self._db, library=self._default_library, oauth_providers=[oauth1, oauth2], - bearer_token_signing_secret='foo' + bearer_token_signing_secret="foo", ) # A token is created and signed with the bearer token. token1 = authenticator.create_bearer_token(oauth1.NAME, "some token") - assert ("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ0b2tlbiI6InNvbWUgdG9rZW4iLCJpc3MiOiJvYXV0aDEifQ.toy4qdoziL99SN4q9DRMdN-3a0v81CfVjwJVFNUt_mk" == - token1) + assert ( + "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ0b2tlbiI6InNvbWUgdG9rZW4iLCJpc3MiOiJvYXV0aDEifQ.toy4qdoziL99SN4q9DRMdN-3a0v81CfVjwJVFNUt_mk" + == token1 + ) # Varying the name of the OAuth provider varies the bearer # token. @@ -1084,9 +1130,7 @@ def test_create_bearer_token(self): # Varying the token sent by the OAuth provider varies the # bearer token. - token3 = authenticator.create_bearer_token( - oauth1.NAME, "some other token" - ) + token3 = authenticator.create_bearer_token(oauth1.NAME, "some other token") assert token3 != token1 # Varying the secret used to sign the token varies the bearer @@ -1101,7 +1145,7 @@ def test_decode_bearer_token(self): _db=self._db, library=self._default_library, oauth_providers=[oauth], - bearer_token_signing_secret='secret' + bearer_token_signing_secret="secret", ) # A token is created and signed with the secret. @@ -1110,15 +1154,13 @@ def test_decode_bearer_token(self): decoded = authenticator.decode_bearer_token(encoded) assert token_value == decoded - decoded = authenticator.decode_bearer_token_from_header( - "Bearer " + encoded - ) + decoded = authenticator.decode_bearer_token_from_header("Bearer " + encoded) assert token_value == decoded def test_create_authentication_document(self): - class MockAuthenticator(LibraryAuthenticator): """Mock the _geographic_areas method.""" + AREAS = ["focus area", "service area"] @classmethod @@ -1133,17 +1175,16 @@ def _geographic_areas(cls, library): library.name = "A Fabulous Library" authenticator = MockAuthenticator( _db=self._db, - library = library, - basic_auth_provider=basic, oauth_providers=[oauth], - bearer_token_signing_secret='secret' + library=library, + basic_auth_provider=basic, + oauth_providers=[oauth], + bearer_token_signing_secret="secret", ) class MockAuthenticationDocumentAnnotator(object): - def annotate_authentication_document( - self, library, doc, url_for - ): + def annotate_authentication_document(self, library, doc, url_for): self.called_with = library, doc, url_for - doc['modified'] = 'Kilroy was here' + doc["modified"] = "Kilroy was here" return doc annotator = MockAuthenticationDocumentAnnotator() @@ -1151,10 +1192,11 @@ def annotate_authentication_document( # We're about to call url_for, so we must create an # application context. - os.environ['AUTOINITIALIZE'] = "False" + os.environ["AUTOINITIALIZE"] = "False" from api.app import app + self.app = app - del os.environ['AUTOINITIALIZE'] + del os.environ["AUTOINITIALIZE"] # Set up configuration settings for links. link_config = { @@ -1177,32 +1219,39 @@ def annotate_authentication_document( # Set the URL to the library's web page. ConfigurationSetting.for_library( - Configuration.WEBSITE_URL, library).value = "http://library/" + Configuration.WEBSITE_URL, library + ).value = "http://library/" # Set the color scheme a mobile client should use. ConfigurationSetting.for_library( - Configuration.COLOR_SCHEME, library).value = "plaid" + Configuration.COLOR_SCHEME, library + ).value = "plaid" # Set the colors a web client should use. ConfigurationSetting.for_library( - Configuration.WEB_PRIMARY_COLOR, library).value = "#012345" + Configuration.WEB_PRIMARY_COLOR, library + ).value = "#012345" ConfigurationSetting.for_library( - Configuration.WEB_SECONDARY_COLOR, library).value = "#abcdef" + Configuration.WEB_SECONDARY_COLOR, library + ).value = "#abcdef" # Configure the various ways a patron can get help. ConfigurationSetting.for_library( - Configuration.HELP_EMAIL, library).value = "help@library" + Configuration.HELP_EMAIL, library + ).value = "help@library" ConfigurationSetting.for_library( - Configuration.HELP_WEB, library).value = "http://library.help/" + Configuration.HELP_WEB, library + ).value = "http://library.help/" ConfigurationSetting.for_library( - Configuration.HELP_URI, library).value = "custom:uri" + Configuration.HELP_URI, library + ).value = "custom:uri" base_url = ConfigurationSetting.sitewide(self._db, Configuration.BASE_URL_KEY) - base_url.value = 'http://circulation-manager/' + base_url.value = "http://circulation-manager/" # Configure three announcements: two active and one # inactive. - format = '%Y-%m-%d' + format = "%Y-%m-%d" today = datetime.date.today() tomorrow = (today + datetime.timedelta(days=1)).strftime(format) yesterday = (today - datetime.timedelta(days=1)).strftime(format) @@ -1210,16 +1259,22 @@ def annotate_authentication_document( today = today.strftime(format) announcements = [ dict( - id='a1', content='this is announcement 1', - start=yesterday, finish=today, + id="a1", + content="this is announcement 1", + start=yesterday, + finish=today, ), dict( - id='a2', content='this is announcement 2', - start=two_days_ago, finish=yesterday, + id="a2", + content="this is announcement 2", + start=two_days_ago, + finish=yesterday, ), dict( - id='a3', content='this is announcement 3', - start=yesterday, finish=today, + id="a3", + content="this is announcement 3", + start=yesterday, + finish=today, ), ] announcement_setting = ConfigurationSetting.for_library( @@ -1229,16 +1284,14 @@ def annotate_authentication_document( with self.app.test_request_context("/"): url = authenticator.authentication_document_url(library) - assert url.endswith( - "/%s/authentication_document" % library.short_name - ) + assert url.endswith("/%s/authentication_document" % library.short_name) doc = json.loads(authenticator.create_authentication_document()) # The main thing we need to test is that the # authentication sub-documents are assembled properly and # placed in the right position. - flows = doc['authentication'] - oauth_doc, basic_doc = sorted(flows, key=lambda x: x['type']) + flows = doc["authentication"] + oauth_doc, basic_doc = sorted(flows, key=lambda x: x["type"]) expect_basic = basic.authentication_flow_document(self._db) assert expect_basic == basic_doc @@ -1248,14 +1301,14 @@ def annotate_authentication_document( # We also need to test that the library's name and ID # were placed in the document. - assert "A Fabulous Library" == doc['title'] - assert "Just the best." == doc['service_description'] - assert url == doc['id'] + assert "A Fabulous Library" == doc["title"] + assert "Just the best." == doc["service_description"] + assert url == doc["id"] # The mobile color scheme and web colors are correctly reported. - assert "plaid" == doc['color_scheme'] - assert "#012345" == doc['web_color_scheme']['primary'] - assert "#abcdef" == doc['web_color_scheme']['secondary'] + assert "plaid" == doc["color_scheme"] + assert "#012345" == doc["web_color_scheme"]["primary"] + assert "#abcdef" == doc["web_color_scheme"]["secondary"] # _geographic_areas was called and provided the library's # focus area and service area. @@ -1264,112 +1317,125 @@ def annotate_authentication_document( # We also need to test that the links got pulled in # from the configuration. - (about, alternate, copyright, help_uri, help_web, help_email, - copyright_agent, profile, loans, license, logo, privacy_policy, register, start, - stylesheet, terms_of_service) = sorted( - doc['links'], key=lambda x: (x['rel'], x['href']) - ) - assert "http://terms" == terms_of_service['href'] - assert "http://privacy" == privacy_policy['href'] - assert "http://copyright" == copyright['href'] - assert "http://about" == about['href'] - assert "http://license/" == license['href'] - assert "image data" == logo['href'] - assert "http://style.css" == stylesheet['href'] - - assert ("/loans" in loans['href']) - assert "http://opds-spec.org/shelf" == loans['rel'] - assert OPDSFeed.ACQUISITION_FEED_TYPE == loans['type'] - - assert ("/patrons/me" in profile['href']) - assert ProfileController.LINK_RELATION == profile['rel'] - assert ProfileController.MEDIA_TYPE == profile['type'] + ( + about, + alternate, + copyright, + help_uri, + help_web, + help_email, + copyright_agent, + profile, + loans, + license, + logo, + privacy_policy, + register, + start, + stylesheet, + terms_of_service, + ) = sorted(doc["links"], key=lambda x: (x["rel"], x["href"])) + assert "http://terms" == terms_of_service["href"] + assert "http://privacy" == privacy_policy["href"] + assert "http://copyright" == copyright["href"] + assert "http://about" == about["href"] + assert "http://license/" == license["href"] + assert "image data" == logo["href"] + assert "http://style.css" == stylesheet["href"] + + assert "/loans" in loans["href"] + assert "http://opds-spec.org/shelf" == loans["rel"] + assert OPDSFeed.ACQUISITION_FEED_TYPE == loans["type"] + + assert "/patrons/me" in profile["href"] + assert ProfileController.LINK_RELATION == profile["rel"] + assert ProfileController.MEDIA_TYPE == profile["type"] expect_start = url_for( - "index", library_short_name=self._default_library.short_name, - _external=True + "index", + library_short_name=self._default_library.short_name, + _external=True, ) - assert expect_start == start['href'] + assert expect_start == start["href"] # The start link points to an OPDS feed. - assert OPDSFeed.ACQUISITION_FEED_TYPE == start['type'] + assert OPDSFeed.ACQUISITION_FEED_TYPE == start["type"] # Most of the other links have type='text/html' - assert "text/html" == about['type'] + assert "text/html" == about["type"] # The registration link doesn't have a type, because it # uses a non-HTTP URI scheme. - assert 'type' not in register - assert 'custom-registration-hook://library/' == register['href'] + assert "type" not in register + assert "custom-registration-hook://library/" == register["href"] # The logo link has type "image/png". assert "image/png" == logo["type"] # We have three help links. - assert "custom:uri" == help_uri['href'] - assert "http://library.help/" == help_web['href'] - assert "text/html" == help_web['type'] - assert "mailto:help@library" == help_email['href'] + assert "custom:uri" == help_uri["href"] + assert "http://library.help/" == help_web["href"] + assert "text/html" == help_web["type"] + assert "mailto:help@library" == help_email["href"] # Since no special address was given for the copyright # designated agent, the help address was reused. - copyright_rel = "http://librarysimplified.org/rel/designated-agent/copyright" - assert copyright_rel == copyright_agent['rel'] - assert "mailto:help@library" == copyright_agent['href'] + copyright_rel = ( + "http://librarysimplified.org/rel/designated-agent/copyright" + ) + assert copyright_rel == copyright_agent["rel"] + assert "mailto:help@library" == copyright_agent["href"] # The public key is correct. - assert authenticator.public_key == doc['public_key']['value'] - assert "RSA" == doc['public_key']['type'] - + assert authenticator.public_key == doc["public_key"]["value"] + assert "RSA" == doc["public_key"]["type"] # The library's web page shows up as an HTML alternate # to the OPDS server. assert ( - dict(rel="alternate", type="text/html", href="http://library/") == - alternate) + dict(rel="alternate", type="text/html", href="http://library/") + == alternate + ) # Active announcements are published; inactive announcements are not. - a1, a3 = doc['announcements'] - assert ( - dict(id='a1', content='this is announcement 1') == - a1) - assert ( - dict(id='a3', content='this is announcement 3') == - a3) + a1, a3 = doc["announcements"] + assert dict(id="a1", content="this is announcement 1") == a1 + assert dict(id="a3", content="this is announcement 3") == a3 # Features that are enabled for this library are communicated # through the 'features' item. - features = doc['features'] - assert [] == features['disabled'] - assert [Configuration.RESERVATIONS_FEATURE] == features['enabled'] + features = doc["features"] + assert [] == features["disabled"] + assert [Configuration.RESERVATIONS_FEATURE] == features["enabled"] # If a separate copyright designated agent is configured, # that email address is used instead of the default # patron support address. ConfigurationSetting.for_library( - Configuration.COPYRIGHT_DESIGNATED_AGENT_EMAIL, library).value = "mailto:dmca@library.org" + Configuration.COPYRIGHT_DESIGNATED_AGENT_EMAIL, library + ).value = "mailto:dmca@library.org" doc = json.loads(authenticator.create_authentication_document()) - [agent] = [x for x in doc['links'] if x['rel'] == copyright_rel] + [agent] = [x for x in doc["links"] if x["rel"] == copyright_rel] assert "mailto:dmca@library.org" == agent["href"] # If no focus area or service area are provided, those fields # are not added to the document. MockAuthenticator.AREAS = [None, None] doc = json.loads(authenticator.create_authentication_document()) - for key in ('focus_area', 'service_area'): + for key in ("focus_area", "service_area"): assert key not in doc # If there are no announcements, the list of announcements is present # but empty. announcement_setting.value = None doc = json.loads(authenticator.create_authentication_document()) - assert [] == doc['announcements'] + assert [] == doc["announcements"] # The annotator's annotate_authentication_document method # was called and successfully modified the authentication # document. assert (library, doc, url_for) == annotator.called_with - assert 'Kilroy was here' == doc['modified'] + assert "Kilroy was here" == doc["modified"] # While we're in this context, let's also test # create_authentication_headers. @@ -1378,17 +1444,16 @@ def annotate_authentication_document( # provider, that provider's .authentication_header is used # for WWW-Authenticate. headers = authenticator.create_authentication_headers() - assert AuthenticationForOPDSDocument.MEDIA_TYPE == headers['Content-Type'] - assert basic.authentication_header == headers['WWW-Authenticate'] + assert AuthenticationForOPDSDocument.MEDIA_TYPE == headers["Content-Type"] + assert basic.authentication_header == headers["WWW-Authenticate"] # The response contains a Link header pointing to the authentication # document expect = "<%s>; rel=%s" % ( authenticator.authentication_document_url(self._default_library), - AuthenticationForOPDSDocument.LINK_RELATION + AuthenticationForOPDSDocument.LINK_RELATION, ) - assert expect == headers['Link'] - + assert expect == headers["Link"] # If the authenticator does not include a basic auth provider, # no WWW-Authenticate header is provided. @@ -1396,10 +1461,10 @@ def annotate_authentication_document( _db=self._db, library=library, oauth_providers=[oauth], - bearer_token_signing_secret='secret' + bearer_token_signing_secret="secret", ) headers = authenticator.create_authentication_headers() - assert 'WWW-Authenticate' not in headers + assert "WWW-Authenticate" not in headers def test_key_pair(self): """Test the public/private key pair associated with a library.""" @@ -1411,14 +1476,15 @@ def keys(): return ConfigurationSetting.for_library( Configuration.KEY_PAIR, library ).json_value + assert None == keys() # Instantiating a LibraryAuthenticator for a library automatically # generates a public/private key pair. auth = LibraryAuthenticator.from_config(self._db, library) public, private = keys() - assert 'BEGIN PUBLIC KEY' in public - assert 'BEGIN RSA PRIVATE KEY' in private + assert "BEGIN PUBLIC KEY" in public + assert "BEGIN RSA PRIVATE KEY" in private # The public key is stored in the # LibraryAuthenticator.public_key property. @@ -1427,7 +1493,7 @@ def keys(): # The private key is not stored in the LibraryAuthenticator # object, but it can be obtained from the database by # using the key_pair property. - assert not hasattr(auth, 'private_key') + assert not hasattr(auth, "private_key") assert (public, private) == auth.key_pair # Each library has its own key pair. @@ -1437,11 +1503,13 @@ def keys(): def test__geographic_areas(self): """Test the _geographic_areas helper method.""" + class Mock(LibraryAuthenticator): values = { - Configuration.LIBRARY_FOCUS_AREA : "focus", - Configuration.LIBRARY_SERVICE_AREA : "service", + Configuration.LIBRARY_FOCUS_AREA: "focus", + Configuration.LIBRARY_SERVICE_AREA: "service", } + @classmethod def _geographic_area(cls, key, library): cls.called_with = library @@ -1493,27 +1561,20 @@ def m(): class TestAuthenticationProvider(AuthenticatorTest): - credentials = dict(username='user', password='') + credentials = dict(username="user", password="") def test_external_integration(self): provider = self.mock_basic(patrondata=None) - assert (self.mock_basic_integration == - provider.external_integration(self._db)) + assert self.mock_basic_integration == provider.external_integration(self._db) def test_authenticated_patron_passes_on_none(self): provider = self.mock_basic(patrondata=None) - patron = provider.authenticated_patron( - self._db, self.credentials - ) + patron = provider.authenticated_patron(self._db, self.credentials) assert None == patron def test_authenticated_patron_passes_on_problem_detail(self): - provider = self.mock_basic( - patrondata=UNSUPPORTED_AUTHENTICATION_MECHANISM - ) - patron = provider.authenticated_patron( - self._db, self.credentials - ) + provider = self.mock_basic(patrondata=UNSUPPORTED_AUTHENTICATION_MECHANISM) + patron = provider.authenticated_patron(self._db, self.credentials) assert UNSUPPORTED_AUTHENTICATION_MECHANISM == patron def test_authenticated_patron_allows_access_to_expired_credentials(self): @@ -1522,15 +1583,15 @@ def test_authenticated_patron_allows_access_to_expired_credentials(self): """ yesterday = utc_now() - datetime.timedelta(days=1) - expired = PatronData(permanent_id="1", authorization_identifier="2", - authorization_expires=yesterday) - provider = self.mock_basic( - patrondata=expired, - remote_patron_lookup_patrondata=expired + expired = PatronData( + permanent_id="1", + authorization_identifier="2", + authorization_expires=yesterday, ) - patron = provider.authenticated_patron( - self._db, self.credentials + provider = self.mock_basic( + patrondata=expired, remote_patron_lookup_patrondata=expired ) + patron = provider.authenticated_patron(self._db, self.credentials) assert "1" == patron.external_identifier assert "2" == patron.authorization_identifier @@ -1545,7 +1606,7 @@ def test_authenticated_patron_updates_metadata_if_necessary(self): incomplete_data = PatronData( permanent_id=patron.external_identifier, authorization_identifier=username, - complete=False + complete=False, ) # If we do a lookup for this patron we will get more complete @@ -1553,17 +1614,15 @@ def test_authenticated_patron_updates_metadata_if_necessary(self): complete_data = PatronData( permanent_id=patron.external_identifier, authorization_identifier=barcode, - username=username, cached_neighborhood="Little Homeworld", - complete=True + username=username, + cached_neighborhood="Little Homeworld", + complete=True, ) provider = self.mock_basic( - patrondata=incomplete_data, - remote_patron_lookup_patrondata=complete_data - ) - patron2 = provider.authenticated_patron( - self._db, self.credentials + patrondata=incomplete_data, remote_patron_lookup_patrondata=complete_data ) + patron2 = provider.authenticated_patron(self._db, self.credentials) # We found the right patron. assert patron == patron2 @@ -1588,9 +1647,7 @@ def test_authenticated_patron_updates_metadata_if_necessary(self): # patron has borrowing privileges. last_sync = patron.last_external_sync assert False == PatronUtility.needs_external_sync(patron) - patron = provider.authenticated_patron( - self._db, dict(username=username) - ) + patron = provider.authenticated_patron(self._db, dict(username=username)) assert last_sync == patron.last_external_sync assert barcode == patron.authorization_identifier assert username == patron.username @@ -1606,7 +1663,7 @@ def test_authenticated_patron_updates_metadata_if_necessary(self): incomplete_data = PatronData( permanent_id=patron.external_identifier, authorization_identifier="some other identifier", - complete=False + complete=False, ) provider.patrondata = incomplete_data patron = provider.authenticated_patron( @@ -1622,13 +1679,11 @@ def test_authenticated_patron_updates_metadata_if_necessary(self): def test_update_patron_metadata(self): patron = self._patron() - patron.authorization_identifier="2345" + patron.authorization_identifier = "2345" assert None == patron.last_external_sync assert None == patron.username - patrondata = PatronData( - username="user", neighborhood="Little Homeworld" - ) + patrondata = PatronData(username="user", neighborhood="Little Homeworld") provider = self.mock_basic(remote_patron_lookup_patrondata=patrondata) provider.external_type_regular_expression = re.compile("^(.)") provider.update_patron_metadata(patron) @@ -1680,16 +1735,16 @@ class MockProvider(AuthenticationProvider): NAME = "Just a mock" setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, MockProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, - library, integration + self._db, + MockProvider.EXTERNAL_TYPE_REGULAR_EXPRESSION, + library, + integration, ) setting.value = None # If there is no EXTERNAL_TYPE_REGULAR_EXPRESSION, calling # update_patron_external_type does nothing. - MockProvider(library, integration).update_patron_external_type( - patron - ) + MockProvider(library, integration).update_patron_external_type(patron) assert "old value" == patron.external_type setting.value = "([A-Z])" @@ -1716,36 +1771,114 @@ def test_restriction_matches(self): m = AuthenticationProvider._restriction_matches # If restriction is none, we always return True. - assert True == m("123", None, AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX) - assert True == m("123", None, AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING) - assert True == m("123", None, AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX) - assert True == m("123", None, AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST) + assert True == m( + "123", + None, + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX, + ) + assert True == m( + "123", + None, + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING, + ) + assert True == m( + "123", + None, + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX, + ) + assert True == m( + "123", None, AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST + ) # If field is None we always return False. - assert False == m(None, "1234", AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX) - assert False == m(None, "1234", AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING) - assert False == m(None, re.compile(".*"), AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX) - assert False == m(None, ['1','2'], AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST) + assert False == m( + None, + "1234", + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX, + ) + assert False == m( + None, + "1234", + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING, + ) + assert False == m( + None, + re.compile(".*"), + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX, + ) + assert False == m( + None, + ["1", "2"], + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST, + ) # Test prefix - assert True == m("12345a", "1234", AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX) - assert False == m("a1234", "1234", AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX) + assert True == m( + "12345a", + "1234", + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX, + ) + assert False == m( + "a1234", + "1234", + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX, + ) # Test string - assert False == m("12345a", "1234", AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING) - assert False == m("a1234", "1234", AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING) - assert True == m("1234", "1234", AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING) + assert False == m( + "12345a", + "1234", + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING, + ) + assert False == m( + "a1234", + "1234", + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING, + ) + assert True == m( + "1234", + "1234", + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING, + ) # Test list - assert True == m("1234", ["1234","4321"], AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST) - assert True == m("4321", ["1234", "4321"], AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST) - assert False == m("12345", ["1234", "4321"], AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST) - assert False == m("54321", ["1234", "4321"], AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST) + assert True == m( + "1234", + ["1234", "4321"], + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST, + ) + assert True == m( + "4321", + ["1234", "4321"], + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST, + ) + assert False == m( + "12345", + ["1234", "4321"], + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST, + ) + assert False == m( + "54321", + ["1234", "4321"], + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST, + ) # Test Regex - assert True == m("123", re.compile("^(12|34)"), AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX) - assert True == m("345", re.compile("^(12|34)"), AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX) - assert False == m("abc", re.compile("^bc"), AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX) + assert True == m( + "123", + re.compile("^(12|34)"), + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX, + ) + assert True == m( + "345", + re.compile("^(12|34)"), + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX, + ) + assert False == m( + "abc", + re.compile("^bc"), + AuthenticationProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX, + ) def test_enforce_library_identifier_restriction(self): """Test the enforce_library_identifier_restriction method.""" @@ -1754,41 +1887,59 @@ def test_enforce_library_identifier_restriction(self): patron = self._patron() patrondata = PatronData() - #Test with patron rather than patrondata as argument + # Test with patron rather than patrondata as argument assert patron == m(object(), patron) patron.library_id = -1 assert False == m(object(), patron) # Test no restriction - provider.library_identifier_restriction_type = MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE + provider.library_identifier_restriction_type = ( + MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE + ) provider.library_identifier_restriction = "2345" - provider.library_identifier_field = MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE + provider.library_identifier_field = ( + MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE + ) assert patrondata == m("12365", patrondata) # Test regex against barcode - provider.library_identifier_restriction_type = MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX + provider.library_identifier_restriction_type = ( + MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_REGEX + ) provider.library_identifier_restriction = re.compile("23[46]5") - provider.library_identifier_field = MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE + provider.library_identifier_field = ( + MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE + ) assert patrondata == m("23456", patrondata) assert patrondata == m("2365", patrondata) assert False == m("2375", provider.patrondata) # Test prefix against barcode - provider.library_identifier_restriction_type = MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX + provider.library_identifier_restriction_type = ( + MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX + ) provider.library_identifier_restriction = "2345" - provider.library_identifier_field = MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE + provider.library_identifier_field = ( + MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE + ) assert patrondata == m("23456", patrondata) assert False == m("123456", patrondata) # Test string against barcode - provider.library_identifier_restriction_type = MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING + provider.library_identifier_restriction_type = ( + MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING + ) provider.library_identifier_restriction = "2345" - provider.library_identifier_field = MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE + provider.library_identifier_field = ( + MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE + ) assert False == m("123456", patrondata) assert patrondata == m("2345", patrondata) # Test match applied to field on patrondata not barcode - provider.library_identifier_restriction_type = MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING + provider.library_identifier_restriction_type = ( + MockBasic.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING + ) provider.library_identifier_restriction = "2345" provider.library_identifier_field = "agent" patrondata.library_identifier = "2345" @@ -1804,13 +1955,14 @@ class MockProvider(AuthenticationProvider): NAME = "Just a mock" string_setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, MockProvider.LIBRARY_IDENTIFIER_RESTRICTION, - library, integration + self._db, MockProvider.LIBRARY_IDENTIFIER_RESTRICTION, library, integration ) type_setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, MockProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE, - library, integration + self._db, + MockProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE, + library, + integration, ) # If the type is regex its converted into a regular expression. @@ -1823,19 +1975,19 @@ class MockProvider(AuthenticationProvider): type_setting.value = MockProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_LIST string_setting.value = "a,b,c" provider = MockProvider(library, integration) - assert ['a', 'b', 'c'] == provider.library_identifier_restriction + assert ["a", "b", "c"] == provider.library_identifier_restriction # If its type is prefix make sure its a string type_setting.value = MockProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX string_setting.value = "abc" provider = MockProvider(library, integration) - assert 'abc' == provider.library_identifier_restriction + assert "abc" == provider.library_identifier_restriction # If its type is string make sure its a string type_setting.value = MockProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_STRING string_setting.value = "abc" provider = MockProvider(library, integration) - assert 'abc' == provider.library_identifier_restriction + assert "abc" == provider.library_identifier_restriction # If its type is none make sure its actually None type_setting.value = MockProvider.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_NONE @@ -1845,7 +1997,6 @@ class MockProvider(AuthenticationProvider): class TestBasicAuthenticationProvider(AuthenticatorTest): - def test_constructor(self): b = BasicAuthenticationProvider @@ -1853,7 +2004,6 @@ def test_constructor(self): class ConfigAuthenticationProvider(b): NAME = "Config loading test" - integration = self._external_integration( self._str, goal=ExternalIntegration.PATRON_AUTH_GOAL ) @@ -1863,9 +2013,7 @@ class ConfigAuthenticationProvider(b): integration.setting(b.TEST_IDENTIFIER).value = "username" integration.setting(b.TEST_PASSWORD).value = "pw" - provider = ConfigAuthenticationProvider( - self._default_library, integration - ) + provider = ConfigAuthenticationProvider(self._default_library, integration) assert "idre" == provider.identifier_re.pattern assert "pwre" == provider.password_re.pattern assert "username" == provider.test_username @@ -1875,19 +2023,16 @@ class ConfigAuthenticationProvider(b): integration = self._external_integration( self._str, goal=ExternalIntegration.PATRON_AUTH_GOAL ) - provider = ConfigAuthenticationProvider( - self._default_library, integration - ) - assert (b.DEFAULT_IDENTIFIER_REGULAR_EXPRESSION == - provider.identifier_re) + provider = ConfigAuthenticationProvider(self._default_library, integration) + assert b.DEFAULT_IDENTIFIER_REGULAR_EXPRESSION == provider.identifier_re assert None == provider.password_re - def test_testing_patron(self): - class MockAuthenticatedPatron(MockBasicAuthenticationProvider): def __init__(self, *args, **kwargs): - self._authenticated_patron_returns = kwargs.pop("_authenticated_patron_returns", None) + self._authenticated_patron_returns = kwargs.pop( + "_authenticated_patron_returns", None + ) super(MockAuthenticatedPatron, self).__init__(*args, **kwargs) def authenticated_patron(self, *args, **kwargs): @@ -1911,8 +2056,8 @@ def authenticated_patron(self, *args, **kwargs): # but we can't look up the testing patron either. b = BasicAuthenticationProvider integration = self._external_integration(self._str) - integration.setting(b.TEST_IDENTIFIER).value = '1' - integration.setting(b.TEST_PASSWORD).value = '2' + integration.setting(b.TEST_IDENTIFIER).value = "1" + integration.setting(b.TEST_PASSWORD).value = "2" missing_patron = MockBasicAuthenticationProvider( self._default_library, integration, patron=None ) @@ -1929,11 +2074,13 @@ def authenticated_patron(self, *args, **kwargs): b = BasicAuthenticationProvider patron = self._patron() integration = self._external_integration(self._str) - integration.setting(b.TEST_IDENTIFIER).value = '1' - integration.setting(b.TEST_PASSWORD).value = '2' + integration.setting(b.TEST_IDENTIFIER).value = "1" + integration.setting(b.TEST_PASSWORD).value = "2" problem_patron = MockAuthenticatedPatron( - self._default_library, integration, patron=patron, - _authenticated_patron_returns=PATRON_OF_ANOTHER_LIBRARY + self._default_library, + integration, + patron=patron, + _authenticated_patron_returns=PATRON_OF_ANOTHER_LIBRARY, ) value = problem_patron.testing_patron(self._db) assert patron != PATRON_OF_ANOTHER_LIBRARY @@ -1951,11 +2098,13 @@ def authenticated_patron(self, *args, **kwargs): b = BasicAuthenticationProvider patron = self._patron() integration = self._external_integration(self._str) - integration.setting(b.TEST_IDENTIFIER).value = '1' - integration.setting(b.TEST_PASSWORD).value = '2' + integration.setting(b.TEST_IDENTIFIER).value = "1" + integration.setting(b.TEST_PASSWORD).value = "2" problem_patron = MockAuthenticatedPatron( - self._default_library, integration, patron=patron, - _authenticated_patron_returns=not_a_patron + self._default_library, + integration, + patron=patron, + _authenticated_patron_returns=not_a_patron, ) value = problem_patron.testing_patron(self._db) assert patron != not_a_patron @@ -1964,14 +2113,15 @@ def authenticated_patron(self, *args, **kwargs): # And testing_patron_or_bust() still doesn't work. with pytest.raises(IntegrationException) as excinfo: problem_patron.testing_patron_or_bust(self._db) - assert "Test patron lookup returned invalid value for patron" in str(excinfo.value) + assert "Test patron lookup returned invalid value for patron" in str( + excinfo.value + ) # Here, we configure a testing patron who is authenticated by # their username and password. patron = self._patron() present_patron = MockBasicAuthenticationProvider( - self._default_library, integration, - patron=patron + self._default_library, integration, patron=patron ) value = present_patron.testing_patron(self._db) assert (patron, "2") == value @@ -1980,12 +2130,13 @@ def authenticated_patron(self, *args, **kwargs): # value as testing_patron() assert value == present_patron.testing_patron_or_bust(self._db) - def test__run_self_tests(self): _db = object() + class CantAuthenticateTestPatron(BasicAuthenticationProvider): def __init__(self): pass + def testing_patron_or_bust(self, _db): self.called_with = _db raise Exception("Nope") @@ -2031,12 +2182,11 @@ def test_client_configuration(self): """ b = BasicAuthenticationProvider integration = self._external_integration(self._str) - integration.setting( - b.IDENTIFIER_KEYBOARD).value = b.EMAIL_ADDRESS_KEYBOARD + integration.setting(b.IDENTIFIER_KEYBOARD).value = b.EMAIL_ADDRESS_KEYBOARD integration.setting(b.PASSWORD_KEYBOARD).value = b.NUMBER_PAD integration.setting(b.IDENTIFIER_LABEL).value = "Your Library Card" - integration.setting(b.PASSWORD_LABEL).value = 'Password' - integration.setting(b.IDENTIFIER_BARCODE_FORMAT).value = 'some barcode' + integration.setting(b.PASSWORD_LABEL).value = "Password" + integration.setting(b.IDENTIFIER_BARCODE_FORMAT).value = "some barcode" provider = b(self._default_library, integration) @@ -2049,8 +2199,8 @@ def test_client_configuration(self): def test_server_side_validation(self): b = BasicAuthenticationProvider integration = self._external_integration(self._str) - integration.setting(b.IDENTIFIER_REGULAR_EXPRESSION).value = 'foo' - integration.setting(b.PASSWORD_REGULAR_EXPRESSION).value = 'bar' + integration.setting(b.IDENTIFIER_REGULAR_EXPRESSION).value = "foo" + integration.setting(b.PASSWORD_REGULAR_EXPRESSION).value = "bar" provider = b(self._default_library, integration) @@ -2075,8 +2225,10 @@ def test_server_side_validation(self): integration.setting(b.IDENTIFIER_REGULAR_EXPRESSION).value = None integration.setting(b.PASSWORD_REGULAR_EXPRESSION).value = None provider = b(self._default_library, integration) - assert (b.DEFAULT_IDENTIFIER_REGULAR_EXPRESSION.pattern == - provider.identifier_re.pattern) + assert ( + b.DEFAULT_IDENTIFIER_REGULAR_EXPRESSION.pattern + == provider.identifier_re.pattern + ) assert None == provider.password_re assert True == provider.server_side_validation("food", "barbecue") assert True == provider.server_side_validation("a", None) @@ -2101,9 +2253,7 @@ def test_local_patron_lookup(self): # This patron of another library looks just like the patron # we're about to create, but will never be selected. other_library = self._library() - other_library_patron = self._patron( - "patron1_ext_id", library=other_library - ) + other_library_patron = self._patron("patron1_ext_id", library=other_library) other_library_patron.authorization_identifier = "patron1_auth_id" other_library_patron.username = "patron1" @@ -2122,34 +2272,30 @@ def test_local_patron_lookup(self): # patron1, even though we provided the username associated # with patron2. for patrondata_args in [ - dict(permanent_id=patron1.external_identifier), - dict(authorization_identifier=patron1.authorization_identifier), - dict(username=patron1.username), - dict(permanent_id=PatronData.NO_VALUE, - username=PatronData.NO_VALUE, - authorization_identifier=patron1.authorization_identifier) + dict(permanent_id=patron1.external_identifier), + dict(authorization_identifier=patron1.authorization_identifier), + dict(username=patron1.username), + dict( + permanent_id=PatronData.NO_VALUE, + username=PatronData.NO_VALUE, + authorization_identifier=patron1.authorization_identifier, + ), ]: patrondata = PatronData(**patrondata_args) - assert ( - patron1 == provider.local_patron_lookup( - self._db, patron2.authorization_identifier, patrondata - )) + assert patron1 == provider.local_patron_lookup( + self._db, patron2.authorization_identifier, patrondata + ) # If no PatronData is provided, we can look up patron1 either # by authorization identifier or username, but not by # permanent identifier. - assert ( - patron1 == provider.local_patron_lookup( - self._db, patron1.authorization_identifier, None - )) - assert ( - patron1 == provider.local_patron_lookup( - self._db, patron1.username, None - )) - assert ( - None == provider.local_patron_lookup( - self._db, patron1.external_identifier, None - )) + assert patron1 == provider.local_patron_lookup( + self._db, patron1.authorization_identifier, None + ) + assert patron1 == provider.local_patron_lookup(self._db, patron1.username, None) + assert None == provider.local_patron_lookup( + self._db, patron1.external_identifier, None + ) def test_get_credential_from_header(self): provider = self.mock_basic() @@ -2160,54 +2306,60 @@ def test_get_credential_from_header(self): def test_authentication_flow_document(self): """Test the default authentication provider document.""" provider = self.mock_basic() - provider.identifier_maximum_length=22 - provider.password_maximum_length=7 + provider.identifier_maximum_length = 22 + provider.password_maximum_length = 7 provider.identifier_barcode_format = provider.BARCODE_FORMAT_CODABAR # We're about to call url_for, so we must create an # application context. - os.environ['AUTOINITIALIZE'] = "False" + os.environ["AUTOINITIALIZE"] = "False" from api.app import app + self.app = app - del os.environ['AUTOINITIALIZE'] + del os.environ["AUTOINITIALIZE"] with self.app.test_request_context("/"): doc = provider.authentication_flow_document(self._db) - assert _(provider.DISPLAY_NAME) == doc['description'] - assert provider.FLOW_TYPE == doc['type'] + assert _(provider.DISPLAY_NAME) == doc["description"] + assert provider.FLOW_TYPE == doc["type"] - labels = doc['labels'] - assert provider.identifier_label == labels['login'] - assert provider.password_label == labels['password'] + labels = doc["labels"] + assert provider.identifier_label == labels["login"] + assert provider.password_label == labels["password"] - inputs = doc['inputs'] - assert (provider.identifier_keyboard == - inputs['login']['keyboard']) - assert (provider.password_keyboard == - inputs['password']['keyboard']) + inputs = doc["inputs"] + assert provider.identifier_keyboard == inputs["login"]["keyboard"] + assert provider.password_keyboard == inputs["password"]["keyboard"] - assert provider.BARCODE_FORMAT_CODABAR == inputs['login']['barcode_format'] + assert provider.BARCODE_FORMAT_CODABAR == inputs["login"]["barcode_format"] - assert (provider.identifier_maximum_length == - inputs['login']['maximum_length']) - assert (provider.password_maximum_length == - inputs['password']['maximum_length']) + assert ( + provider.identifier_maximum_length == inputs["login"]["maximum_length"] + ) + assert ( + provider.password_maximum_length == inputs["password"]["maximum_length"] + ) - [logo_link] = doc['links'] + [logo_link] = doc["links"] assert "logo" == logo_link["rel"] - assert "http://localhost/images/" + MockBasic.LOGIN_BUTTON_IMAGE == logo_link["href"] + assert ( + "http://localhost/images/" + MockBasic.LOGIN_BUTTON_IMAGE + == logo_link["href"] + ) def test_remote_patron_lookup(self): - #remote_patron_lookup does the lookup by calling _remote_patron_lookup, - #then calls enforce_library_identifier_restriction to make sure that the patron - #is associated with the correct library + # remote_patron_lookup does the lookup by calling _remote_patron_lookup, + # then calls enforce_library_identifier_restriction to make sure that the patron + # is associated with the correct library class Mock(BasicAuthenticationProvider): def _remote_patron_lookup(self, patron_or_patrondata): self._remote_patron_lookup_called_with = patron_or_patrondata return patron_or_patrondata + def enforce_library_identifier_restriction(self, identifier, patrondata): self.enforce_library_identifier_restriction_called_with = ( - identifier, patrondata + identifier, + patrondata, ) return "Result" @@ -2219,8 +2371,9 @@ def enforce_library_identifier_restriction(self, identifier, patrondata): assert "Result" == provider.remote_patron_lookup(patron) assert provider._remote_patron_lookup_called_with == patron assert provider.enforce_library_identifier_restriction_called_with == ( - patron.authorization_identifier, patron - ) + patron.authorization_identifier, + patron, + ) def test_scrub_credential(self): # Verify that the scrub_credential helper method strips extra whitespace @@ -2241,6 +2394,7 @@ def test_scrub_credential(self): assert "user" == provider.scrub_credential(" \ruser\t ") assert b"user" == provider.scrub_credential(b" user ") + class TestBasicAuthenticationProviderAuthenticate(AuthenticatorTest): """Test the complex BasicAuthenticationProvider.authenticate method.""" @@ -2289,8 +2443,8 @@ def _inactive_patron(self): patrondata = PatronData( permanent_id=patron.external_identifier, username="new username", - authorization_identifier = "new authorization identifier", - complete=True + authorization_identifier="new authorization identifier", + complete=True, ) return patron, patrondata @@ -2305,8 +2459,7 @@ def test_success_but_local_patron_needs_sync(self): # It will respond to a patron lookup request with more detailed # information. minimal_patrondata = PatronData( - permanent_id=patron.external_identifier, - complete=False + permanent_id=patron.external_identifier, complete=False ) provider = self.mock_basic( patrondata=minimal_patrondata, @@ -2323,9 +2476,7 @@ def test_success_but_local_patron_needs_sync(self): # that the patron's details were correct going forward. assert "new username" == patron.username assert "new authorization identifier" == patron.authorization_identifier - assert ( - utc_now()-patron.last_external_sync - ).total_seconds() < 10 + assert (utc_now() - patron.last_external_sync).total_seconds() < 10 def test_success_with_immediate_patron_sync(self): # This patron has not logged on in a really long time. @@ -2335,8 +2486,7 @@ def test_success_with_immediate_patron_sync(self): # set of information. If a remote patron lookup were to happen, # it would explode. provider = self.mock_basic( - patrondata=complete_patrondata, - remote_patron_lookup_patrondata=object() + patrondata=complete_patrondata, remote_patron_lookup_patrondata=object() ) # The patron can be authenticated. @@ -2348,23 +2498,21 @@ def test_success_with_immediate_patron_sync(self): # patron lookup. assert "new username" == patron.username assert "new authorization identifier" == patron.authorization_identifier - assert ( - utc_now()-patron.last_external_sync - ).total_seconds() < 10 + assert (utc_now() - patron.last_external_sync).total_seconds() < 10 def test_failure_when_remote_authentication_returns_problemdetail(self): patron = self._patron() patrondata = PatronData(permanent_id=patron.external_identifier) provider = self.mock_basic(patrondata=UNSUPPORTED_AUTHENTICATION_MECHANISM) - assert (UNSUPPORTED_AUTHENTICATION_MECHANISM == - provider.authenticate(self._db, self.credentials)) + assert UNSUPPORTED_AUTHENTICATION_MECHANISM == provider.authenticate( + self._db, self.credentials + ) def test_failure_when_remote_authentication_returns_none(self): patron = self._patron() patrondata = PatronData(permanent_id=patron.external_identifier) provider = self.mock_basic(patrondata=None) - assert (None == - provider.authenticate(self._db, self.credentials)) + assert None == provider.authenticate(self._db, self.credentials) def test_server_side_validation_runs(self): patron = self._patron() @@ -2372,11 +2520,9 @@ def test_server_side_validation_runs(self): b = MockBasic integration = self._external_integration(self._str) - integration.setting(b.IDENTIFIER_REGULAR_EXPRESSION).value = 'foo' - integration.setting(b.PASSWORD_REGULAR_EXPRESSION).value = 'bar' - provider = b( - self._default_library, integration, patrondata=patrondata - ) + integration.setting(b.IDENTIFIER_REGULAR_EXPRESSION).value = "foo" + integration.setting(b.PASSWORD_REGULAR_EXPRESSION).value = "bar" + provider = b(self._default_library, integration, patrondata=patrondata) # This would succeed, but we don't get to remote_authenticate() # because we fail the regex test. @@ -2384,7 +2530,8 @@ def test_server_side_validation_runs(self): # This succeeds because we pass the regex test. assert patron == provider.authenticate( - self._db, dict(username="food", password="barbecue")) + self._db, dict(username="food", password="barbecue") + ) def test_authentication_succeeds_but_patronlookup_fails(self): """This case should never happen--it indicates a malfunctioning @@ -2399,7 +2546,6 @@ def test_authentication_succeeds_but_patronlookup_fails(self): # this point we give up -- there is no authenticated patron. assert None == provider.authenticate(self._db, self.credentials) - def test_authentication_creates_missing_patron(self): # The authentication provider knows about this patron, # but this is the first we've heard about them. @@ -2413,15 +2559,19 @@ def test_authentication_creates_missing_patron(self): integration = self._external_integration( self._str, ExternalIntegration.PATRON_AUTH_GOAL ) - provider = MockBasic(library, integration, patrondata=patrondata, remote_patron_lookup_patrondata=patrondata) + provider = MockBasic( + library, + integration, + patrondata=patrondata, + remote_patron_lookup_patrondata=patrondata, + ) patron = provider.authenticate(self._db, self.credentials) # A server side Patron was created from the PatronData. assert isinstance(patron, Patron) assert library == patron.library assert patrondata.permanent_id == patron.external_identifier - assert (patrondata.authorization_identifier == - patron.authorization_identifier) + assert patrondata.authorization_identifier == patron.authorization_identifier # Information not relevant to the patron's identity was stored # in the Patron object after it was created. @@ -2491,7 +2641,9 @@ def test_authentication_updates_outdated_patron_on_username_match(self): # new identifiers. assert new_identifier == patron.authorization_identifier - def test_authentication_updates_outdated_patron_on_authorization_identifier_match(self): + def test_authentication_updates_outdated_patron_on_authorization_identifier_match( + self, + ): # This patron has no permanent ID. Their username has # changed but their library card number has not. identifier = "1234" @@ -2526,8 +2678,8 @@ def test_authentication_updates_outdated_patron_on_authorization_identifier_matc # appear no different to us than a patron who has never used the # circulation manager before. -class TestOAuthAuthenticationProvider(AuthenticatorTest): +class TestOAuthAuthenticationProvider(AuthenticatorTest): def test_from_config(self): class ConfigAuthenticationProvider(OAuthAuthenticationProvider): NAME = "Config loading test" @@ -2535,14 +2687,12 @@ class ConfigAuthenticationProvider(OAuthAuthenticationProvider): integration = self._external_integration( self._str, goal=ExternalIntegration.PATRON_AUTH_GOAL ) - integration.username = 'client_id' - integration.password = 'client_secret' + integration.username = "client_id" + integration.password = "client_secret" integration.setting( ConfigAuthenticationProvider.OAUTH_TOKEN_EXPIRATION_DAYS ).value = 20 - provider = ConfigAuthenticationProvider( - self._default_library, integration - ) + provider = ConfigAuthenticationProvider(self._default_library, integration) assert "client_id" == provider.client_id assert "client_secret" == provider.client_secret assert 20 == provider.token_expiration_days @@ -2557,10 +2707,8 @@ def test_get_credential_from_header(self): def test_create_token(self): patron = self._patron() provider = MockOAuth(self._default_library) - in_twenty_days = ( - utc_now() + datetime.timedelta( - days=provider.token_expiration_days - ) + in_twenty_days = utc_now() + datetime.timedelta( + days=provider.token_expiration_days ) data_source = provider.token_data_source(self._db) token, is_new = provider.create_token(self._db, patron, "some token") @@ -2590,9 +2738,7 @@ def test_authenticated_patron_success(self): def test_oauth_callback(self): mock_patrondata = PatronData( - authorization_identifier="1234", - username="user", - personal_name="The User" + authorization_identifier="1234", username="user", personal_name="The User" ) class CallbackImplementation(MockOAuth): @@ -2603,29 +2749,29 @@ def remote_exchange_code_for_access_token(self, _db, access_code): def remote_patron_lookup(self, bearer_token): return mock_patrondata - integration = CallbackImplementation._mock_integration( - self._db, "Mock OAuth" - ) + integration = CallbackImplementation._mock_integration(self._db, "Mock OAuth") ConfigurationSetting.for_library_and_externalintegration( - self._db, CallbackImplementation.LIBRARY_IDENTIFIER_RESTRICTION, - self._default_library, integration + self._db, + CallbackImplementation.LIBRARY_IDENTIFIER_RESTRICTION, + self._default_library, + integration, ).value = "123" ConfigurationSetting.for_library_and_externalintegration( - self._db, CallbackImplementation.LIBRARY_IDENTIFIER_RESTRICTION_TYPE, - self._default_library, integration + self._db, + CallbackImplementation.LIBRARY_IDENTIFIER_RESTRICTION_TYPE, + self._default_library, + integration, ).value = CallbackImplementation.LIBRARY_IDENTIFIER_RESTRICTION_TYPE_PREFIX ConfigurationSetting.for_library_and_externalintegration( - self._db, CallbackImplementation.LIBRARY_IDENTIFIER_FIELD, - self._default_library, integration + self._db, + CallbackImplementation.LIBRARY_IDENTIFIER_FIELD, + self._default_library, + integration, ).value = CallbackImplementation.LIBRARY_IDENTIFIER_RESTRICTION_BARCODE - oauth = CallbackImplementation( - self._default_library, integration=integration - ) - credential, patron, patrondata = oauth.oauth_callback( - self._db, "a code" - ) + oauth = CallbackImplementation(self._default_library, integration=integration) + credential, patron, patrondata = oauth.oauth_callback(self._db, "a code") # remote_exchange_code_for_access_token was called with the # access code. @@ -2654,27 +2800,32 @@ def remote_patron_lookup(self, bearer_token): def test_authentication_flow_document(self): # We're about to call url_for, so we must create an # application context. - os.environ['AUTOINITIALIZE'] = "False" + os.environ["AUTOINITIALIZE"] = "False" from api.app import app + self.app = app - del os.environ['AUTOINITIALIZE'] + del os.environ["AUTOINITIALIZE"] provider = MockOAuth(self._default_library) with self.app.test_request_context("/"): doc = provider.authentication_flow_document(self._db) - assert provider.FLOW_TYPE == doc['type'] - assert provider.NAME == doc['description'] + assert provider.FLOW_TYPE == doc["type"] + assert provider.NAME == doc["description"] # To authenticate with this provider, you must follow the # 'authenticate' link. - [auth_link] = [x for x in doc['links'] if x['rel'] == 'authenticate'] - assert auth_link['href'] == provider._internal_authenticate_url(self._db) + [auth_link] = [x for x in doc["links"] if x["rel"] == "authenticate"] + assert auth_link["href"] == provider._internal_authenticate_url(self._db) - [logo_link] = [x for x in doc['links'] if x['rel'] == 'logo'] - assert "http://localhost/images/" + MockOAuth.LOGIN_BUTTON_IMAGE == logo_link["href"] + [logo_link] = [x for x in doc["links"] if x["rel"] == "logo"] + assert ( + "http://localhost/images/" + MockOAuth.LOGIN_BUTTON_IMAGE + == logo_link["href"] + ) def test_token_data_source_can_create_new_data_source(self): class OAuthWithUnusualDataSource(MockOAuth): TOKEN_DATA_SOURCE_NAME = "Unusual data source" + oauth = OAuthWithUnusualDataSource(self._default_library) source, is_new = oauth.token_data_source(self._db) assert True == is_new @@ -2692,21 +2843,27 @@ def test_external_authenticate_url_parameters(self): # application context. my_api = MockOAuth(self._default_library) my_api.client_id = "clientid" - os.environ['AUTOINITIALIZE'] = "False" + os.environ["AUTOINITIALIZE"] = "False" from api.app import app - del os.environ['AUTOINITIALIZE'] + + del os.environ["AUTOINITIALIZE"] with app.test_request_context("/"): params = my_api.external_authenticate_url_parameters("state", self._db) - assert "state" == params['state'] - assert "clientid" == params['client_id'] - expected_url = url_for("oauth_callback", library_short_name=self._default_library.short_name, _external=True) - assert expected_url == params['oauth_callback_url'] + assert "state" == params["state"] + assert "clientid" == params["client_id"] + expected_url = url_for( + "oauth_callback", + library_short_name=self._default_library.short_name, + _external=True, + ) + assert expected_url == params["oauth_callback_url"] -class TestOAuthController(AuthenticatorTest): +class TestOAuthController(AuthenticatorTest): def setup_method(self): super(TestOAuthController, self).setup_method() + class MockOAuthWithExternalAuthenticateURL(MockOAuth): def __init__(self, library, _db, external_authenticate_url, patron): super(MockOAuthWithExternalAuthenticateURL, self).__init__( @@ -2714,9 +2871,7 @@ def __init__(self, library, _db, external_authenticate_url, patron): ) self.url = external_authenticate_url self.patron = patron - self.token, ignore = self.create_token( - _db, self.patron, "a token" - ) + self.token, ignore = self.create_token(_db, self.patron, "a token") self.patrondata = PatronData(personal_name="Abcd") def external_authenticate_url(self, state, _db): @@ -2741,14 +2896,11 @@ def oauth_callback(self, _db, params): library=self._default_library, basic_auth_provider=self.basic, oauth_providers=[self.oauth1, self.oauth2], - bearer_token_signing_secret="a secret" + bearer_token_signing_secret="a secret", ) self.auth = MockAuthenticator( - self._default_library, - { - self._default_library.short_name : self.library_auth - } + self._default_library, {self._default_library.short_name: self.library_auth} ) self.controller = OAuthController(self.auth) @@ -2759,7 +2911,10 @@ def test_oauth_authentication_redirect(self): params = dict(provider=self.oauth1.NAME) response = self.controller.oauth_authentication_redirect(params, self._db) assert 302 == response.status_code - expected_state = dict(provider=self.oauth1.NAME, redirect_uri="", ) + expected_state = dict( + provider=self.oauth1.NAME, + redirect_uri="", + ) expected_state = urllib.parse.quote(json.dumps(expected_state)) assert "http://oauth1.com/?state=" + expected_state == response.location @@ -2772,16 +2927,15 @@ def test_oauth_authentication_redirect(self): # If we don't recognize the OAuth provider you get sent to # the redirect URI with a fragment containing an encoded # problem detail document. - params = dict(redirect_uri="http://foo.com/", - provider="not an oauth provider") + params = dict(redirect_uri="http://foo.com/", provider="not an oauth provider") response = self.controller.oauth_authentication_redirect(params, self._db) assert 302 == response.status_code assert response.location.startswith("http://foo.com/#") fragments = urllib.parse.parse_qs( urllib.parse.urlparse(response.location).fragment ) - error = json.loads(fragments.get('error')[0]) - assert UNKNOWN_OAUTH_PROVIDER.uri == error.get('type') + error = json.loads(fragments.get("error")[0]) + assert UNKNOWN_OAUTH_PROVIDER.uri == error.get("type") def test_oauth_authentication_callback(self): """Test the controller method that the OAuth provider is supposed @@ -2792,7 +2946,9 @@ def test_oauth_authentication_callback(self): params = dict(code="foo", state=json.dumps(dict(provider=self.oauth1.NAME))) response = self.controller.oauth_authentication_callback(self._db, params) assert 302 == response.status_code - fragments = urllib.parse.parse_qs(urllib.parse.urlparse(response.location).fragment) + fragments = urllib.parse.parse_qs( + urllib.parse.urlparse(response.location).fragment + ) token = fragments.get("access_token")[0] provider_name, provider_token = self.auth.decode_bearer_token(token) assert self.oauth1.NAME == provider_name @@ -2802,7 +2958,9 @@ def test_oauth_authentication_callback(self): params = dict(code="foo", state=json.dumps(dict(provider=self.oauth2.NAME))) response = self.controller.oauth_authentication_callback(self._db, params) assert 302 == response.status_code - fragments = urllib.parse.parse_qs(urllib.parse.urlparse(response.location).fragment) + fragments = urllib.parse.parse_qs( + urllib.parse.urlparse(response.location).fragment + ) token = fragments.get("access_token")[0] provider_name, provider_token = self.auth.decode_bearer_token(token) assert self.oauth2.NAME == provider_name @@ -2820,13 +2978,17 @@ def test_oauth_authentication_callback(self): # In this example we're pretending to be coming in after # authenticating with an OAuth provider that doesn't exist. - params = dict(code="foo", state=json.dumps(dict(provider=("not_an_oauth_provider")))) + params = dict( + code="foo", state=json.dumps(dict(provider=("not_an_oauth_provider"))) + ) response = self.controller.oauth_authentication_callback(self._db, params) assert 302 == response.status_code - fragments = urllib.parse.parse_qs(urllib.parse.urlparse(response.location).fragment) - assert None == fragments.get('access_token') - error = json.loads(fragments.get('error')[0]) - assert UNKNOWN_OAUTH_PROVIDER.uri == error.get('type') + fragments = urllib.parse.parse_qs( + urllib.parse.urlparse(response.location).fragment + ) + assert None == fragments.get("access_token") + error = json.loads(fragments.get("error")[0]) + assert UNKNOWN_OAUTH_PROVIDER.uri == error.get("type") def test_oauth_authentication_invalid_token(self): """If an invalid bearer token is provided, an appropriate problem diff --git a/tests/test_axis.py b/tests/test_axis.py index ceafac1ee3..140bf6464b 100644 --- a/tests/test_axis.py +++ b/tests/test_axis.py @@ -68,45 +68,39 @@ class Axis360Test(DatabaseTest): - def setup_method(self): - super(Axis360Test,self).setup_method() + super(Axis360Test, self).setup_method() self.collection = MockAxis360API.mock_collection(self._db) self.api = MockAxis360API(self._db, self.collection) @classmethod def sample_data(cls, filename): - return sample_data(filename, 'axis') + return sample_data(filename, "axis") # Sample bibliographic and availability data you can use in a test # without having to parse it from an XML file. BIBLIOGRAPHIC_DATA = Metadata( DataSource.AXIS_360, - publisher='Random House Inc', - language='eng', - title='Faith of My Fathers : A Family Memoir', - imprint='Random House Inc2', + publisher="Random House Inc", + language="eng", + title="Faith of My Fathers : A Family Memoir", + imprint="Random House Inc2", published=datetime_utc(2000, 3, 7, 0, 0), primary_identifier=IdentifierData( - type=Identifier.AXIS_360_ID, - identifier='0003642860' + type=Identifier.AXIS_360_ID, identifier="0003642860" ), - identifiers = [ - IdentifierData(type=Identifier.ISBN, identifier='9780375504587') - ], - contributors = [ - ContributorData(sort_name="McCain, John", - roles=[Contributor.PRIMARY_AUTHOR_ROLE] - ), - ContributorData(sort_name="Salter, Mark", - roles=[Contributor.AUTHOR_ROLE] - ), + identifiers=[IdentifierData(type=Identifier.ISBN, identifier="9780375504587")], + contributors=[ + ContributorData( + sort_name="McCain, John", roles=[Contributor.PRIMARY_AUTHOR_ROLE] + ), + ContributorData(sort_name="Salter, Mark", roles=[Contributor.AUTHOR_ROLE]), ], - subjects = [ - SubjectData(type=Subject.BISAC, - identifier='BIOGRAPHY & AUTOBIOGRAPHY / Political'), - SubjectData(type=Subject.FREEFORM_AUDIENCE, - identifier='Adult'), + subjects=[ + SubjectData( + type=Subject.BISAC, identifier="BIOGRAPHY & AUTOBIOGRAPHY / Political" + ), + SubjectData(type=Subject.FREEFORM_AUDIENCE, identifier="Adult"), ], ) @@ -122,11 +116,10 @@ def sample_data(cls, filename): class TestAxis360API(Axis360Test): - def test_external_integration(self): - assert ( - self.collection.external_integration == - self.api.external_integration(object())) + assert self.collection.external_integration == self.api.external_integration( + object() + ) def test__run_self_tests(self): # Verify that Axis360API._run_self_tests() calls the right @@ -143,7 +136,7 @@ def refresh_bearer_token(self): # give minutes. def recent_activity(self, since): self.recent_activity_called_with = since - return [(1,"a"),(2, "b"), (3, "c")] + return [(1, "a"), (2, "b"), (3, "c")] # Then we will count the loans and holds for the default # patron. @@ -161,7 +154,7 @@ def patron_activity(self, patron, pin): integration = self._external_integration( "api.simple_authentication", ExternalIntegration.PATRON_AUTH_GOAL, - libraries=[with_default_patron] + libraries=[with_default_patron], ) p = BasicAuthenticationProvider integration.setting(p.TEST_IDENTIFIER).value = "username1" @@ -170,42 +163,56 @@ def patron_activity(self, patron, pin): # Now that everything is set up, run the self-test. api = Mock(self._db, self.collection) now = utc_now() - [no_patron_credential, recent_circulation_events, patron_activity, - pools_without_delivery, refresh_bearer_token] = sorted( - api._run_self_tests(self._db), key=lambda x: x.name - ) + [ + no_patron_credential, + recent_circulation_events, + patron_activity, + pools_without_delivery, + refresh_bearer_token, + ] = sorted(api._run_self_tests(self._db), key=lambda x: x.name) assert "Refreshing bearer token" == refresh_bearer_token.name assert True == refresh_bearer_token.success assert "the new token" == refresh_bearer_token.result assert ( - "Acquiring test patron credentials for library %s" % no_default_patron.name == - no_patron_credential.name) + "Acquiring test patron credentials for library %s" % no_default_patron.name + == no_patron_credential.name + ) assert False == no_patron_credential.success - assert ("Library has no test patron configured." == - str(no_patron_credential.exception)) + assert "Library has no test patron configured." == str( + no_patron_credential.exception + ) - assert ("Asking for circulation events for the last five minutes" == - recent_circulation_events.name) + assert ( + "Asking for circulation events for the last five minutes" + == recent_circulation_events.name + ) assert True == recent_circulation_events.success assert "Found 3 event(s)" == recent_circulation_events.result since = api.recent_activity_called_with five_minutes_ago = utc_now() - datetime.timedelta(minutes=5) - assert (five_minutes_ago-since).total_seconds() < 5 + assert (five_minutes_ago - since).total_seconds() < 5 - assert ("Checking activity for test patron for library %s" % with_default_patron.name == - patron_activity.name) + assert ( + "Checking activity for test patron for library %s" + % with_default_patron.name + == patron_activity.name + ) assert True == patron_activity.success assert "Found 2 loans/holds" == patron_activity.result patron, pin = api.patron_activity_called_with assert "username1" == patron.authorization_identifier assert "password1" == pin - assert ("Checking for titles that have no delivery mechanisms." == - pools_without_delivery.name) + assert ( + "Checking for titles that have no delivery mechanisms." + == pools_without_delivery.name + ) assert True == pools_without_delivery.success - assert ("All titles in this collection have delivery mechanisms." == - pools_without_delivery.result) + assert ( + "All titles in this collection have delivery mechanisms." + == pools_without_delivery.result + ) def test__run_self_tests_short_circuit(self): # If we can't refresh the bearer token, the rest of the @@ -236,7 +243,7 @@ def test_availability_no_timeout(self): self.api.availability() request = self.api.requests.pop() kwargs = request[-1] - assert None == kwargs['timeout'] + assert None == kwargs["timeout"] def test_availability_exception(self): @@ -244,16 +251,17 @@ def test_availability_exception(self): with pytest.raises(RemoteIntegrationException) as excinfo: self.api.availability() - assert "Bad response from http://axis.test/availability/v2: Got status code 500 from external server, cannot continue." in str(excinfo.value) + assert ( + "Bad response from http://axis.test/availability/v2: Got status code 500 from external server, cannot continue." + in str(excinfo.value) + ) def test_refresh_bearer_token_after_401(self): # If we get a 401, we will fetch a new bearer token and try the # request again. self.api.queue_response(401) - self.api.queue_response( - 200, content=json.dumps(dict(access_token="foo")) - ) + self.api.queue_response(200, content=json.dumps(dict(access_token="foo"))) self.api.queue_response(200, content="The data") response = self.api.request("http://url/") assert b"The data" == response.content @@ -266,23 +274,26 @@ def test_refresh_bearer_token_error(self): api.queue_response(412) with pytest.raises(RemoteIntegrationException) as excinfo: api.refresh_bearer_token() - assert "Bad response from http://axis.test/accesstoken: Got status code 412 from external server, but can only continue on: 200." in str(excinfo.value) + assert ( + "Bad response from http://axis.test/accesstoken: Got status code 412 from external server, but can only continue on: 200." + in str(excinfo.value) + ) def test_exception_after_401_with_fresh_token(self): # If we get a 401 immediately after refreshing the token, we will # raise an exception. self.api.queue_response(401) - self.api.queue_response( - 200, content=json.dumps(dict(access_token="foo")) - ) + self.api.queue_response(200, content=json.dumps(dict(access_token="foo"))) self.api.queue_response(401) self.api.queue_response(301) with pytest.raises(RemoteIntegrationException) as excinfo: self.api.request("http://url/") - assert "Got status code 401 from external server, cannot continue." in str(excinfo.value) + assert "Got status code 401 from external server, cannot continue." in str( + excinfo.value + ) # The fourth request never got made. assert [301] == [x.status_code for x in self.api.responses] @@ -296,7 +307,7 @@ def test_update_availability(self): identifier_type=Identifier.AXIS_360_ID, data_source_name=DataSource.AXIS_360, with_license_pool=True, - collection=self.collection + collection=self.collection, ) # We have never checked the circulation information for this @@ -332,25 +343,24 @@ def test_checkin_success(self): edition, pool = self._edition( identifier_type=Identifier.AXIS_360_ID, data_source_name=DataSource.AXIS_360, - with_license_pool=True + with_license_pool=True, ) data = self.sample_data("checkin_success.xml") self.api.queue_response(200, content=data) patron = self._patron() barcode = self._str patron.authorization_identifier = barcode - response = self.api.checkin(patron, 'pin', pool) + response = self.api.checkin(patron, "pin", pool) assert response == True # Verify the format of the HTTP request that was made. [request] = self.api.requests [url, args, kwargs] = request - data = kwargs.pop('data') - assert kwargs['method'] == 'GET' - expect = ( - '/EarlyCheckInTitle/v3?itemID=%s&patronID=%s' % ( - pool.identifier.identifier, barcode - ) + data = kwargs.pop("data") + assert kwargs["method"] == "GET" + expect = "/EarlyCheckInTitle/v3?itemID=%s&patronID=%s" % ( + pool.identifier.identifier, + barcode, ) assert expect in url @@ -360,43 +370,41 @@ def test_checkin_failure(self): edition, pool = self._edition( identifier_type=Identifier.AXIS_360_ID, data_source_name=DataSource.AXIS_360, - with_license_pool=True + with_license_pool=True, ) data = self.sample_data("checkin_failure.xml") self.api.queue_response(200, content=data) patron = self._patron() patron.authorization_identifier = self._str - pytest.raises( - NotFoundOnRemote, self.api.checkin, patron, 'pin', pool - ) + pytest.raises(NotFoundOnRemote, self.api.checkin, patron, "pin", pool) def test_place_hold(self): edition, pool = self._edition( identifier_type=Identifier.AXIS_360_ID, data_source_name=DataSource.AXIS_360, - with_license_pool=True + with_license_pool=True, ) data = self.sample_data("place_hold_success.xml") self.api.queue_response(200, content=data) patron = self._patron() ConfigurationSetting.for_library( - Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, - self._default_library).value = "notifications@example.com" - response = self.api.place_hold(patron, 'pin', pool, None) + Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, self._default_library + ).value = "notifications@example.com" + response = self.api.place_hold(patron, "pin", pool, None) assert 1 == response.hold_position assert response.identifier_type == pool.identifier.type assert response.identifier == pool.identifier.identifier [request] = self.api.requests - params = request[-1]['params'] - assert 'notifications@example.com' == params['email'] + params = request[-1]["params"] + assert "notifications@example.com" == params["email"] def test_fulfill(self): # Test our ability to fulfill an Axis 360 title. edition, pool = self._edition( identifier_type=Identifier.AXIS_360_ID, - identifier_id='0015176429', + identifier_id="0015176429", data_source_name=DataSource.AXIS_360, - with_license_pool=True + with_license_pool=True, ) patron = self._patron() @@ -404,8 +412,7 @@ def test_fulfill(self): def fulfill(internal_format="not AxisNow"): return self.api.fulfill( - patron, "pin", licensepool=pool, - internal_format=internal_format + patron, "pin", licensepool=pool, internal_format=internal_format ) # If Axis 360 says a patron does not have a title checked out, @@ -440,7 +447,7 @@ def fulfill(internal_format="not AxisNow"): # If the title is checked out but Axis provides no fulfillment # info, the exception is CannotFulfill. - pool.identifier.identifier = '0015176429' + pool.identifier.identifier = "0015176429" data = self.sample_data("availability_without_fulfillment.xml") self.api.queue_response(200, content=data) pytest.raises(CannotFulfill, fulfill) @@ -471,22 +478,20 @@ def test_patron_activity(self): # We made a request that included the authorization identifier # of the patron in question. [url, args, kwargs] = self.api.requests.pop() - assert patron.authorization_identifier == kwargs['params']['patronId'] + assert patron.authorization_identifier == kwargs["params"]["patronId"] # We got three results -- two holds and one loan. - [hold1, loan, hold2] = sorted( - results, key=lambda x: x.identifier - ) + [hold1, loan, hold2] = sorted(results, key=lambda x: x.identifier) assert isinstance(hold1, HoldInfo) assert isinstance(hold2, HoldInfo) assert isinstance(loan, LoanInfo) def test_update_licensepools_for_identifiers(self): - class Mock(MockAxis360API): """Simulates an Axis 360 API that knows about some books but not others. """ + updated = [] reaped = [] @@ -495,18 +500,17 @@ def _fetch_remote_availability(self, identifiers): # The first identifer in the list is still # available. identifier_data = IdentifierData( - type=identifier.type, - identifier=identifier.identifier + type=identifier.type, identifier=identifier.identifier ) metadata = Metadata( data_source=DataSource.AXIS_360, - primary_identifier=identifier_data + primary_identifier=identifier_data, ) availability = CirculationData( data_source=DataSource.AXIS_360, primary_identifier=identifier_data, licenses_owned=7, - licenses_available=6 + licenses_available=6, ) yield metadata, availability @@ -517,9 +521,7 @@ def _reap(self, identifier): self.reaped.append(identifier) api = Mock(self._db, self.collection) - still_in_collection = self._identifier( - identifier_type=Identifier.AXIS_360_ID - ) + still_in_collection = self._identifier(identifier_type=Identifier.AXIS_360_ID) no_longer_in_collection = self._identifier( identifier_type=Identifier.AXIS_360_ID ) @@ -551,12 +553,15 @@ def test_fetch_remote_availability(self): # We asked for information on two identifiers. [request] = self.api.requests kwargs = request[-1] - assert {'titleIds': '2001,2002'} == kwargs['params'] + assert {"titleIds": "2001,2002"} == kwargs["params"] # We got information on only one. [(metadata, circulation)] = results assert (id1, False) == metadata.primary_identifier.load(self._db) - assert 'El caso de la gracia : Un periodista explora las evidencias de unas vidas transformadas' == metadata.title + assert ( + "El caso de la gracia : Un periodista explora las evidencias de unas vidas transformadas" + == metadata.title + ) assert 2 == circulation.licenses_owned def test_reap(self): @@ -574,8 +579,10 @@ def test_reap(self): # it's already been reaped, so nothing happens. edition, pool, = self._edition( data_source_name=DataSource.AXIS_360, - identifier_type=id1.type, identifier_id=id1.identifier, - with_license_pool=True, collection=self.collection + identifier_type=id1.type, + identifier_id=id1.identifier, + with_license_pool=True, + collection=self.collection, ) # This LicensePool has licenses, but it's not in a different @@ -584,8 +591,10 @@ def test_reap(self): collection2 = self._collection() edition2, pool2, = self._edition( data_source_name=DataSource.AXIS_360, - identifier_type=id1.type, identifier_id=id1.identifier, - with_license_pool=True, collection=collection2 + identifier_type=id1.type, + identifier_id=id1.identifier, + with_license_pool=True, + collection=collection2, ) pool.licenses_owned = 0 @@ -626,8 +635,8 @@ def test_get_fulfillment_info(self): # with the right keyword arguments and the right HTTP method. url, args, kwargs = api.requests.pop() assert url.endswith(api.fulfillment_endpoint) - assert 'POST' == kwargs['method'] - assert 'transaction ID' == kwargs['params']['TransactionID'] + assert "POST" == kwargs["method"] + assert "transaction ID" == kwargs["params"]["TransactionID"] def test_get_audiobook_metadata(self): # Test the get_audiobook_metadata method, which makes an API request. @@ -643,8 +652,8 @@ def test_get_audiobook_metadata(self): # with the right keyword arguments and the right HTTP method. url, args, kwargs = api.requests.pop() assert url.endswith(api.audiobook_metadata_endpoint) - assert 'POST' == kwargs['method'] - assert 'Findaway content ID' == kwargs['params']['fndcontentid'] + assert "POST" == kwargs["method"] + assert "Findaway content ID" == kwargs["params"]["fndcontentid"] def test_update_book(self): # Verify that the update_book method takes a Metadata and a @@ -654,8 +663,7 @@ def test_update_book(self): analytics = MockAnalyticsProvider() api = MockAxis360API(self._db, self.collection) e, e_new, lp, lp_new = api.update_book( - self.BIBLIOGRAPHIC_DATA, self.AVAILABILITY_DATA, - analytics=analytics + self.BIBLIOGRAPHIC_DATA, self.AVAILABILITY_DATA, analytics=analytics ) # A new LicensePool and Edition were created. assert True == lp_new @@ -670,7 +678,7 @@ def test_update_book(self): assert e == lp.work.presentation_edition # The Edition reflects what it said in BIBLIOGRAPHIC_DATA - assert 'Faith of My Fathers : A Family Memoir' == e.title + assert "Faith of My Fathers : A Family Memoir" == e.title # Three analytics events were sent out. # @@ -689,8 +697,7 @@ def test_update_book(self): ) e2, e_new, lp2, lp_new = api.update_book( - self.BIBLIOGRAPHIC_DATA, new_circulation, - analytics=analytics + self.BIBLIOGRAPHIC_DATA, new_circulation, analytics=analytics ) # The same LicensePool and Edition are returned -- no new ones @@ -716,9 +723,11 @@ def test_update_book(self): (Axis360API.VERIFY_SSL, None, "verify_certificate", True), (Axis360API.VERIFY_SSL, "True", "verify_certificate", True), (Axis360API.VERIFY_SSL, "False", "verify_certificate", False), - ] + ], ) - def test_integration_settings(self, setting, setting_value, attribute, attribute_value): + def test_integration_settings( + self, setting, setting_value, attribute, attribute_value + ): external_integration = self.collection.external_integration if setting_value is not None: external_integration.setting(setting).value = setting_value @@ -727,11 +736,11 @@ def test_integration_settings(self, setting, setting_value, attribute, attribute class TestCirculationMonitor(Axis360Test): - def test_run(self): class Mock(Axis360CirculationMonitor): def catch_up_from(self, start, cutoff, progress): self.called_with = (start, cutoff, progress) + monitor = Mock(self._db, self.collection, api_class=MockAxis360API) # The first time run() is called, catch_up_from() is asked to @@ -756,10 +765,11 @@ def test_catch_up_from(self): class MockAPI(MockAxis360API): def recent_activity(self, since): self.recent_activity_called_with = since - return [(1,"a"),(2, "b")] + return [(1, "a"), (2, "b")] class MockMonitor(Axis360CirculationMonitor): processed = [] + def process_book(self, bibliographic, circulation): self.processed.append((bibliographic, circulation)) @@ -773,7 +783,7 @@ def process_book(self, bibliographic, circulation): assert "start" == monitor.api.recent_activity_called_with # process_book was called on each item returned by recent_activity. - assert [(1,"a"),(2, "b")] == monitor.processed + assert [(1, "a"), (2, "b")] == monitor.processed # The number of books processed was stored in # TimestampData.achievements. @@ -781,38 +791,48 @@ def process_book(self, bibliographic, circulation): def test_process_book(self): integration, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, goal=ExternalIntegration.ANALYTICS_GOAL, protocol="core.local_analytics_provider", ) monitor = Axis360CirculationMonitor( - self._db, self.collection, api_class=MockAxis360API, + self._db, + self.collection, + api_class=MockAxis360API, ) edition, license_pool = monitor.process_book( - self.BIBLIOGRAPHIC_DATA, self.AVAILABILITY_DATA) - assert 'Faith of My Fathers : A Family Memoir' == edition.title - assert 'eng' == edition.language - assert 'Random House Inc' == edition.publisher - assert 'Random House Inc2' == edition.imprint + self.BIBLIOGRAPHIC_DATA, self.AVAILABILITY_DATA + ) + assert "Faith of My Fathers : A Family Memoir" == edition.title + assert "eng" == edition.language + assert "Random House Inc" == edition.publisher + assert "Random House Inc2" == edition.imprint assert Identifier.AXIS_360_ID == edition.primary_identifier.type - assert '0003642860' == edition.primary_identifier.identifier + assert "0003642860" == edition.primary_identifier.identifier - [isbn] = [x for x in edition.equivalent_identifiers() - if x is not edition.primary_identifier] + [isbn] = [ + x + for x in edition.equivalent_identifiers() + if x is not edition.primary_identifier + ] assert Identifier.ISBN == isbn.type - assert '9780375504587' == isbn.identifier + assert "9780375504587" == isbn.identifier - assert (["McCain, John", "Salter, Mark"] == - sorted([x.sort_name for x in edition.contributors])) + assert ["McCain, John", "Salter, Mark"] == sorted( + [x.sort_name for x in edition.contributors] + ) subs = sorted( (x.subject.type, x.subject.identifier) for x in edition.primary_identifier.classifications ) - assert [(Subject.BISAC, 'BIOGRAPHY & AUTOBIOGRAPHY / Political'), - (Subject.FREEFORM_AUDIENCE, 'Adult')] == subs + assert [ + (Subject.BISAC, "BIOGRAPHY & AUTOBIOGRAPHY / Political"), + (Subject.FREEFORM_AUDIENCE, "Adult"), + ] == subs assert 9 == license_pool.licenses_owned assert 8 == license_pool.licenses_available @@ -822,8 +842,11 @@ def test_process_book(self): # Three circulation events were created, backdated to the # last_checked date of the license pool. events = license_pool.circulation_events - assert (['distributor_title_add', 'distributor_check_in', 'distributor_license_add'] == - [x.type for x in events]) + assert [ + "distributor_title_add", + "distributor_check_in", + "distributor_license_add", + ] == [x.type for x in events] for e in events: assert e.start == license_pool.last_checked @@ -836,18 +859,23 @@ def test_process_book(self): # 360 bibliographic coverage provider, so that in the future # it doesn't have to make a separate API request to ask about # this book. - records = [x for x in license_pool.identifier.coverage_records - if x.data_source.name == DataSource.AXIS_360 - and x.operation is None] + records = [ + x + for x in license_pool.identifier.coverage_records + if x.data_source.name == DataSource.AXIS_360 and x.operation is None + ] assert 1 == len(records) # Now, another collection with the same book shows up. collection2 = MockAxis360API.mock_collection(self._db, "coll2") monitor = Axis360CirculationMonitor( - self._db, collection2, api_class=MockAxis360API, + self._db, + collection2, + api_class=MockAxis360API, ) edition2, license_pool2 = monitor.process_book( - self.BIBLIOGRAPHIC_DATA, self.AVAILABILITY_DATA) + self.BIBLIOGRAPHIC_DATA, self.AVAILABILITY_DATA + ) # Both license pools have the same Work and the same presentation # edition. @@ -859,8 +887,9 @@ def test_process_book_updates_old_licensepool(self): updates it. """ edition, licensepool = self._edition( - with_license_pool=True, identifier_type=Identifier.AXIS_360_ID, - identifier_id='0003642860' + with_license_pool=True, + identifier_type=Identifier.AXIS_360_ID, + identifier_id="0003642860", ) # We start off with availability information based on the # default for test data. @@ -868,50 +897,47 @@ def test_process_book_updates_old_licensepool(self): identifier = IdentifierData( type=licensepool.identifier.type, - identifier=licensepool.identifier.identifier + identifier=licensepool.identifier.identifier, ) metadata = Metadata(DataSource.AXIS_360, primary_identifier=identifier) monitor = Axis360CirculationMonitor( - self._db, self.collection, api_class=MockAxis360API, - ) - edition, licensepool = monitor.process_book( - metadata, self.AVAILABILITY_DATA + self._db, + self.collection, + api_class=MockAxis360API, ) + edition, licensepool = monitor.process_book(metadata, self.AVAILABILITY_DATA) # Now we have information based on the CirculationData. assert 9 == licensepool.licenses_owned class TestReaper(Axis360Test): - def test_instantiate(self): # Validate the standard CollectionMonitor interface. monitor = AxisCollectionReaper( - self._db, self.collection, - api_class=MockAxis360API + self._db, self.collection, api_class=MockAxis360API ) -class TestParsers(Axis360Test): +class TestParsers(Axis360Test): def test_bibliographic_parser(self): # Make sure the bibliographic information gets properly # collated in preparation for creating Edition objects. data = self.sample_data("tiny_collection.xml") - [bib1, av1], [bib2, av2] = BibliographicParser( - False, True).process_all(data) + [bib1, av1], [bib2, av2] = BibliographicParser(False, True).process_all(data) # We didn't ask for availability information, so none was provided. assert None == av1 assert None == av2 - assert 'Faith of My Fathers : A Family Memoir' == bib1.title - assert 'eng' == bib1.language + assert "Faith of My Fathers : A Family Memoir" == bib1.title + assert "eng" == bib1.language assert datetime_utc(2000, 3, 7, 0, 0) == bib1.published - assert 'Simon & Schuster' == bib2.publisher - assert 'Pocket Books' == bib2.imprint + assert "Simon & Schuster" == bib2.publisher + assert "Pocket Books" == bib2.imprint assert Edition.BOOK_MEDIUM == bib1.medium @@ -923,25 +949,26 @@ def test_bibliographic_parser(self): [description, cover] = bib1.links assert Hyperlink.DESCRIPTION == description.rel assert Representation.TEXT_PLAIN == description.media_type - assert description.content.startswith( - "John McCain's deeply moving memoir" - ) + assert description.content.startswith("John McCain's deeply moving memoir") # The cover image simulates the current state of the B&T cover # service, where we get a thumbnail-sized image URL in the # Axis 360 API response and we can hack the URL to get the # full-sized image URL. assert LinkRelations.IMAGE == cover.rel - assert ("http://contentcafecloud.baker-taylor.com/Jacket.svc/D65D0665-050A-487B-9908-16E6D8FF5C3E/9780375504587/Large/Empty" == - cover.href) + assert ( + "http://contentcafecloud.baker-taylor.com/Jacket.svc/D65D0665-050A-487B-9908-16E6D8FF5C3E/9780375504587/Large/Empty" + == cover.href + ) assert MediaTypes.JPEG_MEDIA_TYPE == cover.media_type assert LinkRelations.THUMBNAIL_IMAGE == cover.thumbnail.rel - assert ("http://contentcafecloud.baker-taylor.com/Jacket.svc/D65D0665-050A-487B-9908-16E6D8FF5C3E/9780375504587/Medium/Empty" == - cover.thumbnail.href) + assert ( + "http://contentcafecloud.baker-taylor.com/Jacket.svc/D65D0665-050A-487B-9908-16E6D8FF5C3E/9780375504587/Medium/Empty" + == cover.thumbnail.href + ) assert MediaTypes.JPEG_MEDIA_TYPE == cover.thumbnail.media_type - # Book #1 has a primary author, another author and a narrator. # # TODO: The narrator data is simulated. we haven't actually @@ -963,23 +990,29 @@ def test_bibliographic_parser(self): assert [Contributor.PRIMARY_AUTHOR_ROLE] == cont.roles axis_id, isbn = sorted(bib1.identifiers, key=lambda x: x.identifier) - assert '0003642860' == axis_id.identifier - assert '9780375504587' == isbn.identifier + assert "0003642860" == axis_id.identifier + assert "9780375504587" == isbn.identifier # Check the subjects for #2 because it includes an audience, # unlike #1. - subjects = sorted(bib2.subjects, key = lambda x: x.identifier or "") - assert [Subject.BISAC, Subject.BISAC, Subject.BISAC, - Subject.AXIS_360_AUDIENCE] == [x.type for x in subjects] - general_fiction, women_sleuths, romantic_suspense = sorted([ - x.name for x in subjects if x.type==Subject.BISAC]) - assert 'FICTION / General' == general_fiction - assert 'FICTION / Mystery & Detective / Women Sleuths' == women_sleuths - assert 'FICTION / Romance / Suspense' == romantic_suspense - - [adult] = [x.identifier for x in subjects - if x.type==Subject.AXIS_360_AUDIENCE] - assert 'General Adult' == adult + subjects = sorted(bib2.subjects, key=lambda x: x.identifier or "") + assert [ + Subject.BISAC, + Subject.BISAC, + Subject.BISAC, + Subject.AXIS_360_AUDIENCE, + ] == [x.type for x in subjects] + general_fiction, women_sleuths, romantic_suspense = sorted( + [x.name for x in subjects if x.type == Subject.BISAC] + ) + assert "FICTION / General" == general_fiction + assert "FICTION / Mystery & Detective / Women Sleuths" == women_sleuths + assert "FICTION / Romance / Suspense" == romantic_suspense + + [adult] = [ + x.identifier for x in subjects if x.type == Subject.AXIS_360_AUDIENCE + ] + assert "General Adult" == adult # The second book has a cover image simulating some possible # future case, where B&T change their cover service so that @@ -1028,13 +1061,13 @@ def test_bibliographic_parser_audiobook(self): # Although the audiobook is also available in the "AxisNow" # format, no second delivery mechanism was created for it, the # way it would have been for an ebook. - assert b'AxisNow' in data + assert b"AxisNow" in data def test_bibliographic_parser_blio_format(self): # This book is available as 'Blio' but not 'AxisNow'. data = self.sample_data("availability_with_audiobook_fulfillment.xml") - data = data.replace(b'Acoustik', b'Blio') - data = data.replace(b'AxisNow', b'No Such Format') + data = data.replace(b"Acoustik", b"Blio") + data = data.replace(b"AxisNow", b"No Such Format") [[bib, av]] = BibliographicParser(False, True).process_all(data) @@ -1047,7 +1080,7 @@ def test_bibliographic_parser_blio_format(self): def test_bibliographic_parser_blio_and_axisnow_format(self): # This book is available as both 'Blio' and 'AxisNow'. data = self.sample_data("availability_with_audiobook_fulfillment.xml") - data = data.replace(b'Acoustik', b'Blio') + data = data.replace(b"Acoustik", b"Blio") [[bib, av]] = BibliographicParser(False, True).process_all(data) @@ -1059,8 +1092,8 @@ def test_bibliographic_parser_blio_and_axisnow_format(self): def test_bibliographic_parser_unsupported_format(self): data = self.sample_data("availability_with_audiobook_fulfillment.xml") - data = data.replace(b'Acoustik', b'No Such Format 1') - data = data.replace(b'AxisNow', b'No Such Format 2') + data = data.replace(b"Acoustik", b"No Such Format 1") + data = data.replace(b"AxisNow", b"No Such Format 2") [[bib, av]] = BibliographicParser(False, True).process_all(data) @@ -1094,11 +1127,11 @@ def test_parse_author_role(self): # force_role overwrites whatever other role might be # assigned. author = "Bob, Inc. (COR)" - c = parse(author, primary_author_found=False, - force_role=Contributor.NARRATOR_ROLE) + c = parse( + author, primary_author_found=False, force_role=Contributor.NARRATOR_ROLE + ) assert [Contributor.NARRATOR_ROLE] == c.roles - def test_availability_parser(self): """Make sure the availability information gets properly collated in preparation for updating a LicensePool. @@ -1106,8 +1139,7 @@ def test_availability_parser(self): data = self.sample_data("tiny_collection.xml") - [bib1, av1], [bib2, av2] = BibliographicParser( - True, False).process_all(data) + [bib1, av1], [bib2, av2] = BibliographicParser(True, False).process_all(data) # We didn't ask for bibliographic information, so none was provided. assert None == bib1 @@ -1120,14 +1152,12 @@ def test_availability_parser(self): class BaseParserTest(object): - @classmethod def sample_data(cls, filename): - return sample_data(filename, 'axis') + return sample_data(filename, "axis") class TestResponseParser(BaseParserTest): - def setup_method(self): # We don't need an actual Collection object to test most of # these classes, but we do need to test that whatever object @@ -1135,11 +1165,12 @@ def setup_method(self): # right spot of HoldInfo and LoanInfo objects. class MockCollection(object): pass + self._default_collection = MockCollection() self._default_collection.id = object() -class TestRaiseExceptionOnError(TestResponseParser): +class TestRaiseExceptionOnError(TestResponseParser): def test_internal_server_error(self): data = self.sample_data("internal_server_error.xml") parser = HoldReleaseResponseParser(None) @@ -1147,18 +1178,17 @@ def test_internal_server_error(self): parser.process_all(data) assert "Internal Server Error" in str(excinfo.value) - def test_ignore_error_codes(self): # A parser subclass can decide not to raise exceptions # when encountering specific error codes. data = self.sample_data("internal_server_error.xml") retval = object() + class IgnoreISE(HoldReleaseResponseParser): def process_one(self, e, namespaces): - self.raise_exception_on_error( - e, namespaces, ignore_error_codes = [5000] - ) + self.raise_exception_on_error(e, namespaces, ignore_error_codes=[5000]) return retval + # Unlike in test_internal_server_error, no exception is # raised, because we told the parser to ignore this particular # error code. @@ -1181,16 +1211,12 @@ def test_missing_error_code(self): class TestCheckinResponseParser(TestResponseParser): - def test_parse_checkin_success(self): # The response parser raises an exception if there's a problem, # and returne True otherwise. # # "Book is not on loan" is not treated as a problem. - for filename in ( - "checkin_success.xml", - "checkin_not_checked_out.xml" - ): + for filename in ("checkin_success.xml", "checkin_not_checked_out.xml"): data = self.sample_data(filename) parser = CheckinResponseParser(self._default_collection) parsed = parser.process_all(data) @@ -1203,7 +1229,6 @@ def test_parse_checkin_failure(self): class TestCheckoutResponseParser(TestResponseParser): - def test_parse_checkout_success(self): data = self.sample_data("checkout_success.xml") parser = CheckoutResponseParser(self._default_collection) @@ -1212,8 +1237,7 @@ def test_parse_checkout_success(self): assert self._default_collection.id == parsed.collection_id assert DataSource.AXIS_360 == parsed.data_source_name assert Identifier.AXIS_360_ID == parsed.identifier_type - assert (datetime_utc(2015, 8, 11, 18, 57, 42) == - parsed.end_date) + assert datetime_utc(2015, 8, 11, 18, 57, 42) == parsed.end_date # There is no FulfillmentInfo associated with the LoanInfo, # because we don't need it (checkout and fulfillment are @@ -1230,8 +1254,8 @@ def test_parse_not_found_on_remote(self): parser = CheckoutResponseParser(None) pytest.raises(NotFoundOnRemote, parser.process_all, data) -class TestHoldResponseParser(TestResponseParser): +class TestHoldResponseParser(TestResponseParser): def test_parse_hold_success(self): data = self.sample_data("place_hold_success.xml") parser = HoldResponseParser(self._default_collection) @@ -1248,8 +1272,8 @@ def test_parse_already_on_hold(self): parser = HoldResponseParser(None) pytest.raises(AlreadyOnHold, parser.process_all, data) -class TestHoldReleaseResponseParser(TestResponseParser): +class TestHoldReleaseResponseParser(TestResponseParser): def test_success(self): data = self.sample_data("release_hold_success.xml") parser = HoldReleaseResponseParser(None) @@ -1260,6 +1284,7 @@ def test_failure(self): parser = HoldReleaseResponseParser(None) pytest.raises(NotOnHold, parser.process_all, data) + class TestAvailabilityResponseParser(Axis360Test, BaseParserTest): """Unlike other response parser tests, this one needs access to a real database session, because it needs a real Collection @@ -1328,7 +1353,10 @@ def test_parse_ebook_availability(self): # as its content_link. assert isinstance(fulfillment, FulfillmentInfo) assert not isinstance(fulfillment, Axis360FulfillmentInfo) - assert "http://adobe.acsm/?src=library&transactionId=2a34598b-12af-41e4-a926-af5e42da7fe5&isbn=9780763654573&format=F2" == fulfillment.content_link + assert ( + "http://adobe.acsm/?src=library&transactionId=2a34598b-12af-41e4-a926-af5e42da7fe5&isbn=9780763654573&format=F2" + == fulfillment.content_link + ) # Next ask for AxisNow -- this will be more like # test_parse_audiobook_availability, since it requires an @@ -1346,7 +1374,6 @@ def test_parse_ebook_availability(self): class TestJSONResponseParser(object): - def test__required_key(self): m = JSONResponseParser._required_key parsed = dict(key="value") @@ -1357,7 +1384,10 @@ def test__required_key(self): # If not, it raises a RemoteInitiatedServerError. with pytest.raises(RemoteInitiatedServerError) as excinfo: m("absent", parsed) - assert "Required key absent not present in Axis 360 fulfillment document: {'key': 'value'}" in str(excinfo.value) + assert ( + "Required key absent not present in Axis 360 fulfillment document: {'key': 'value'}" + in str(excinfo.value) + ) def test_verify_status_code(self): success = dict(Status=dict(Code=0000)) @@ -1379,12 +1409,13 @@ def test_verify_status_code(self): # raised. with pytest.raises(RemoteInitiatedServerError) as excinfo: m(missing) - assert "Required key Status not present in Axis 360 fulfillment document" in str(excinfo.value) + assert ( + "Required key Status not present in Axis 360 fulfillment document" + in str(excinfo.value) + ) def test_parse(self): - class Mock(JSONResponseParser): - def _parse(self, parsed, *args, **kwargs): self.called_with = parsed, args, kwargs return "success" @@ -1398,24 +1429,22 @@ def _parse(self, parsed, *args, **kwargs): # arguments to parse() will be passed through to _parse(). result = parser.parse(json.dumps(doc), "value1", arg2="value2") assert "success" == result - assert ( - (doc, ("value1",), dict(arg2="value2")) == - parser.called_with) + assert (doc, ("value1",), dict(arg2="value2")) == parser.called_with # It also works if the JSON was already parsed. result = parser.parse(doc, "new_value") - assert ( - (doc, ("new_value",), {}) == parser.called_with) + assert (doc, ("new_value",), {}) == parser.called_with # Non-JSON input causes an error. with pytest.raises(RemoteInitiatedServerError) as excinfo: parser.parse("I'm not JSON") - assert "Invalid response from Axis 360 (was expecting JSON): I'm not JSON" in str(excinfo.value) - + assert ( + "Invalid response from Axis 360 (was expecting JSON): I'm not JSON" + in str(excinfo.value) + ) class TestAxis360FulfillmentInfoResponseParser(Axis360Test): - def test__parse_findaway(self): # _parse will create a valid FindawayManifest given a # complete document. @@ -1424,12 +1453,11 @@ def test__parse_findaway(self): m = parser._parse edition, pool = self._edition(with_license_pool=True) + def get_data(): # We'll be modifying this document to simulate failures, # so make it easy to load a fresh copy. - return json.loads( - self.sample_data("audiobook_fulfillment_info.json") - ) + return json.loads(self.sample_data("audiobook_fulfillment_info.json")) # This is the data we just got from a call to Axis 360's # getfulfillmentInfo endpoint. @@ -1449,21 +1477,21 @@ def get_data(): # The manifest contains information from the LicensePool's presentation # edition - assert edition.title == metadata['title'] + assert edition.title == metadata["title"] # It contains DRM licensing information from Findaway via the # Axis 360 API. - encrypted = metadata['encrypted'] + encrypted = metadata["encrypted"] assert ( - '0f547af1-38c1-4b1c-8a1a-169d353065d0' == - encrypted['findaway:sessionKey']) - assert '5babb89b16a4ed7d8238f498' == encrypted['findaway:checkoutId'] - assert '04960' == encrypted['findaway:fulfillmentId'] - assert '58ee81c6d3d8eb3b05597cdc' == encrypted['findaway:licenseId'] + "0f547af1-38c1-4b1c-8a1a-169d353065d0" == encrypted["findaway:sessionKey"] + ) + assert "5babb89b16a4ed7d8238f498" == encrypted["findaway:checkoutId"] + assert "04960" == encrypted["findaway:fulfillmentId"] + assert "58ee81c6d3d8eb3b05597cdc" == encrypted["findaway:licenseId"] # The spine items and duration have been filled in by the call to # the getaudiobookmetadata endpoint. - assert 8150.87 == metadata['duration'] + assert 8150.87 == metadata["duration"] assert 5 == len(manifest.readingOrder) # We also know when the licensing document expires. @@ -1473,8 +1501,10 @@ def get_data(): # document and verify that extraction fails. # for field in ( - 'FNDContentID', 'FNDLicenseID', 'FNDSessionKey', - 'ExpirationDate', + "FNDContentID", + "FNDLicenseID", + "FNDSessionKey", + "ExpirationDate", ): missing_field = get_data() del missing_field[field] @@ -1484,7 +1514,7 @@ def get_data(): # Try with a bad expiration date. bad_date = get_data() - bad_date['ExpirationDate'] = 'not-a-date' + bad_date["ExpirationDate"] = "not-a-date" with pytest.raises(RemoteInitiatedServerError) as excinfo: m(bad_date, pool) assert "Could not parse expiration date: not-a-date" in str(excinfo.value) @@ -1497,12 +1527,11 @@ def test__parse_axisnow(self): m = parser._parse edition, pool = self._edition(with_license_pool=True) + def get_data(): # We'll be modifying this document to simulate failures, # so make it easy to load a fresh copy. - return json.loads( - self.sample_data("ebook_fulfillment_info.json") - ) + return json.loads(self.sample_data("ebook_fulfillment_info.json")) # This is the data we just got from a call to Axis 360's # getfulfillmentInfo endpoint. @@ -1514,19 +1543,20 @@ def get_data(): manifest, expires = m(data, pool) assert isinstance(manifest, AxisNowManifest) - assert ({"book_vault_uuid": "1c11c31f-81c2-41bb-9179-491114c3f121", "isbn": "9780547351551"} == - json.loads(str(manifest))) + assert { + "book_vault_uuid": "1c11c31f-81c2-41bb-9179-491114c3f121", + "isbn": "9780547351551", + } == json.loads(str(manifest)) # Try with a bad expiration date. bad_date = get_data() - bad_date['ExpirationDate'] = 'not-a-date' + bad_date["ExpirationDate"] = "not-a-date" with pytest.raises(RemoteInitiatedServerError) as excinfo: m(bad_date, pool) assert "Could not parse expiration date: not-a-date" in str(excinfo.value) class TestAudiobookMetadataParser(Axis360Test): - def test__parse(self): # _parse will find the Findaway account ID and # the spine items. @@ -1536,15 +1566,12 @@ def _extract_spine_item(cls, part): return part + " (extracted)" metadata = dict( - fndaccountid="An account ID", - readingOrder=["Spine item 1", "Spine item 2"] + fndaccountid="An account ID", readingOrder=["Spine item 1", "Spine item 2"] ) account_id, spine_items = Mock(None)._parse(metadata) assert "An account ID" == account_id - assert (["Spine item 1 (extracted)", - "Spine item 2 (extracted)"] == - spine_items) + assert ["Spine item 1 (extracted)", "Spine item 2 (extracted)"] == spine_items # No data? Nothing will be parsed. account_id, spine_items = Mock(None)._parse({}) @@ -1556,9 +1583,7 @@ def test__extract_spine_item(self): # a SpineItem object. m = AudiobookMetadataParser._extract_spine_item item = m( - dict(duration=100.4, fndpart=2, fndsequence=3, - title="The Gathering Storm" - ) + dict(duration=100.4, fndpart=2, fndsequence=3, title="The Gathering Storm") ) assert isinstance(item, SpineItem) assert "The Gathering Storm" == item.title @@ -1598,8 +1623,11 @@ def test_fetch_audiobook(self): edition, pool = self._edition(with_license_pool=True) identifier = pool.identifier fulfillment = Axis360FulfillmentInfo( - self.api, pool.data_source.name, - identifier.type, identifier.identifier, 'transaction_id' + self.api, + pool.data_source.name, + identifier.type, + identifier.identifier, + "transaction_id", ) assert None == fulfillment._content_type @@ -1621,8 +1649,7 @@ def test_fetch_audiobook(self): # The content expiration date also comes from the fulfillment # document. - assert ( - datetime_utc(2018, 9, 29, 18, 34) == fulfillment.content_expires) + assert datetime_utc(2018, 9, 29, 18, 34) == fulfillment.content_expires def test_fetch_ebook(self): # When no Findaway information is present in the response from @@ -1636,8 +1663,11 @@ def test_fetch_ebook(self): edition, pool = self._edition(with_license_pool=True) identifier = pool.identifier fulfillment = Axis360FulfillmentInfo( - self.api, pool.data_source.name, - identifier.type, identifier.identifier, 'transaction_id' + self.api, + pool.data_source.name, + identifier.type, + identifier.identifier, + "transaction_id", ) assert None == fulfillment._content_type @@ -1648,13 +1678,13 @@ def test_fetch_ebook(self): # document derived from the fulfillment document. assert DeliveryMechanism.AXISNOW_DRM == fulfillment.content_type assert ( - '{"book_vault_uuid": "1c11c31f-81c2-41bb-9179-491114c3f121", "isbn": "9780547351551"}' == - fulfillment.content) + '{"book_vault_uuid": "1c11c31f-81c2-41bb-9179-491114c3f121", "isbn": "9780547351551"}' + == fulfillment.content + ) # The content expiration date also comes from the fulfillment # document. - assert ( - datetime_utc(2018, 9, 29, 18, 34) == fulfillment.content_expires) + assert datetime_utc(2018, 9, 29, 18, 34) == fulfillment.content_expires class TestAxisNowManifest(object): @@ -1663,9 +1693,7 @@ class TestAxisNowManifest(object): def test_unicode(self): manifest = AxisNowManifest("A UUID", "An ISBN") - assert ( - '{"book_vault_uuid": "A UUID", "isbn": "An ISBN"}' == - str(manifest)) + assert '{"book_vault_uuid": "A UUID", "isbn": "An ISBN"}' == str(manifest) assert DeliveryMechanism.AXISNOW_DRM == manifest.MEDIA_TYPE @@ -1684,8 +1712,7 @@ def test_script_instantiation(self): the coverage provider. """ script = RunCollectionCoverageProviderScript( - Axis360BibliographicCoverageProvider, self._db, - api_class=MockAxis360API + Axis360BibliographicCoverageProvider, self._db, api_class=MockAxis360API ) [provider] = script.providers assert isinstance(provider, Axis360BibliographicCoverageProvider) @@ -1700,7 +1727,7 @@ def test_process_item_creates_presentation_ready_work(self): # Here's the book mentioned in single_item.xml. identifier = self._identifier(identifier_type=Identifier.AXIS_360_ID) - identifier.identifier = '0003642860' + identifier.identifier = "0003642860" # This book has no LicensePool. assert [] == identifier.licensed_through @@ -1714,11 +1741,13 @@ def test_process_item_creates_presentation_ready_work(self): [pool] = identifier.licensed_through assert 9 == pool.licenses_owned [lpdm] = pool.delivery_mechanisms - assert ('application/epub+zip (application/vnd.adobe.adept+xml)' == - lpdm.delivery_mechanism.name) + assert ( + "application/epub+zip (application/vnd.adobe.adept+xml)" + == lpdm.delivery_mechanism.name + ) # A Work was created and made presentation ready. - assert 'Faith of My Fathers : A Family Memoir' == pool.work.title + assert "Faith of My Fathers : A Family Memoir" == pool.work.title assert True == pool.work.presentation_ready def test_transient_failure_if_requested_book_not_mentioned(self): @@ -1727,7 +1756,7 @@ def test_transient_failure_if_requested_book_not_mentioned(self): """ # We're going to ask about abcdef identifier = self._identifier(identifier_type=Identifier.AXIS_360_ID) - identifier.identifier = 'abcdef' + identifier.identifier = "abcdef" # But we're going to get told about 0003642860. data = self.sample_data("single_item.xml") @@ -1772,7 +1801,7 @@ def fulfillment_info(self): "identifier_type": None, "identifier": None, "verify": None, - "content_link": "https://fake.url" + "content_link": "https://fake.url", } return partial(Axis360AcsFulfillmentInfo, **params) @@ -1782,19 +1811,26 @@ def patch_urllib_urlopen(self, monkeypatch): # this function. def patch_urlopen(mock): monkeypatch.setattr(urllib.request, "urlopen", mock) + return patch_urlopen - def test_url_encoding_not_capitalized(self, patch_urllib_urlopen, mock_request, fulfillment_info): + def test_url_encoding_not_capitalized( + self, patch_urllib_urlopen, mock_request, fulfillment_info + ): # Mock the urllopen function to make sure that the URL is not actually requested # then make sure that when the request is built the %3a character encoded in the # string is not uppercased to be %3A. called_url = None + def mock_urlopen(url, **kwargs): nonlocal called_url called_url = url return mock_request + patch_urllib_urlopen(mock_urlopen) - fulfillment = fulfillment_info(content_link="https://test.com/?param=%3atest123") + fulfillment = fulfillment_info( + content_link="https://test.com/?param=%3atest123" + ) response = fulfillment.as_response assert called_url is not None assert called_url.selector == "/?param=%3atest123" @@ -1802,8 +1838,8 @@ def mock_urlopen(url, **kwargs): assert type(response) == Response mock_request.__enter__.assert_called() mock_request.__enter__.return_value.read.assert_called() - assert 'status' in dir(mock_request.__enter__.return_value) - assert 'headers' in dir(mock_request.__enter__.return_value) + assert "status" in dir(mock_request.__enter__.return_value) + assert "headers" in dir(mock_request.__enter__.return_value) mock_request.__exit__.assert_called() @pytest.mark.parametrize( @@ -1812,11 +1848,13 @@ def mock_urlopen(url, **kwargs): urllib.error.HTTPError(url="", code=301, msg="", hdrs={}, fp=Mock()), socket.timeout(), urllib.error.URLError(reason=""), - ssl.SSLError() + ssl.SSLError(), ], - ids=lambda val: val.__class__.__name__ + ids=lambda val: val.__class__.__name__, ) - def test_exception_returns_problem_detail(self, patch_urllib_urlopen, fulfillment_info, exception): + def test_exception_returns_problem_detail( + self, patch_urllib_urlopen, fulfillment_info, exception + ): # Check that when the urlopen function throws an exception, we catch the exception and # we turn it into a problem detail to be returned to the client. This mimics the behavior # of the http utils function that we are bypassing with this fulfillment method. @@ -1827,12 +1865,16 @@ def test_exception_returns_problem_detail(self, patch_urllib_urlopen, fulfillmen @pytest.mark.parametrize( ("verify", "verify_mode", "check_hostname"), - [ - (True, ssl.CERT_REQUIRED, True), - (False, ssl.CERT_NONE, False) - ] + [(True, ssl.CERT_REQUIRED, True), (False, ssl.CERT_NONE, False)], ) - def test_verify_ssl(self, patch_urllib_urlopen, fulfillment_info, verify, verify_mode, check_hostname): + def test_verify_ssl( + self, + patch_urllib_urlopen, + fulfillment_info, + verify, + verify_mode, + check_hostname, + ): # Make sure that when the verify parameter of the fulfillment method is set we use the # correct SSL context to either verify or not verify the ssl certificate for the # URL we are fetching. @@ -1841,7 +1883,7 @@ def test_verify_ssl(self, patch_urllib_urlopen, fulfillment_info, verify, verify response = fulfillment.as_response mock = urllib.request.urlopen mock.assert_called() - assert 'context' in mock.call_args[1] - context = mock.call_args[1]['context'] + assert "context" in mock.call_args[1] + context = mock.call_args[1]["context"] assert context.verify_mode == verify_mode assert context.check_hostname == check_hostname diff --git a/tests/test_bibliotheca.py b/tests/test_bibliotheca.py index 6c67938f31..d68d11a9b0 100644 --- a/tests/test_bibliotheca.py +++ b/tests/test_bibliotheca.py @@ -1,29 +1,37 @@ # encoding: utf-8 -import pytest - -from datetime import datetime, timedelta import json import os import pkgutil -import mock -from mock import MagicMock import random -from io import ( - BytesIO, - StringIO, -) +from datetime import datetime, timedelta +from io import BytesIO, StringIO +import mock +import pytest +import pytz +from mock import MagicMock from pymarc import parse_xml_to_array from pymarc.record import Record -import pytz -from core.testing import DatabaseTest -from . import sample_data - -from core.metadata_layer import ( - ReplacementPolicy, - TimestampData, +from api.authenticator import BasicAuthenticationProvider +from api.bibliotheca import ( + BibliothecaAPI, + BibliothecaBibliographicCoverageProvider, + BibliothecaCirculationSweep, + BibliothecaEventMonitor, + BibliothecaParser, + BibliothecaPurchaseMonitor, + CheckoutResponseParser, + ErrorParser, + EventParser, + ItemListParser, + MockBibliothecaAPI, + PatronCirculationParser, ) +from api.circulation import CirculationAPI, FulfillmentInfo, HoldInfo, LoanInfo +from api.circulation_exceptions import * +from api.web_publication_manifest import FindawayManifest +from core.metadata_layer import ReplacementPolicy, TimestampData from core.mock_analytics_provider import MockAnalyticsProvider from core.model import ( CirculationEvent, @@ -39,53 +47,26 @@ LicensePool, Loan, Measurement, - Resource, Representation, + Resource, Subject, Timestamp, Work, WorkCoverageRecord, create, ) -from core.util.datetime_helpers import ( - datetime_utc, - utc_now, -) -from core.util.http import ( - BadResponseException, -) -from core.util.web_publication_manifest import AudiobookManifest from core.scripts import RunCollectionCoverageProviderScript +from core.testing import DatabaseTest +from core.util.datetime_helpers import datetime_utc, utc_now +from core.util.http import BadResponseException +from core.util.web_publication_manifest import AudiobookManifest -from api.authenticator import BasicAuthenticationProvider -from api.circulation import ( - CirculationAPI, - FulfillmentInfo, - HoldInfo, - LoanInfo, -) -from api.circulation_exceptions import * -from api.bibliotheca import ( - BibliothecaCirculationSweep, - CheckoutResponseParser, - ErrorParser, - EventParser, - MockBibliothecaAPI, - PatronCirculationParser, - BibliothecaAPI, - BibliothecaEventMonitor, - BibliothecaParser, - BibliothecaPurchaseMonitor, - ItemListParser, - BibliothecaBibliographicCoverageProvider, -) -from api.web_publication_manifest import FindawayManifest +from . import sample_data class BibliothecaAPITest(DatabaseTest): - def setup_method(self): - super(BibliothecaAPITest,self).setup_method() + super(BibliothecaAPITest, self).setup_method() self.collection = MockBibliothecaAPI.mock_collection(self._db) self.api = MockBibliothecaAPI(self._db, self.collection) @@ -94,20 +75,19 @@ def setup_method(self): @classmethod def sample_data(self, filename): - return sample_data(filename, 'bibliotheca') + return sample_data(filename, "bibliotheca") -class TestBibliothecaAPI(BibliothecaAPITest): +class TestBibliothecaAPI(BibliothecaAPITest): def setup_method(self): super(TestBibliothecaAPI, self).setup_method() self.collection = MockBibliothecaAPI.mock_collection(self._db) self.api = MockBibliothecaAPI(self._db, self.collection) - def test_external_integration(self): - assert ( - self.collection.external_integration == - self.api.external_integration(object())) + assert self.collection.external_integration == self.api.external_integration( + object() + ) def test__run_self_tests(self): # Verify that BibliothecaAPI._run_self_tests() calls the right @@ -120,7 +100,7 @@ class Mock(MockBibliothecaAPI): # last five minutes. def get_events_between(self, start, finish): self.get_events_between_called_with = (start, finish) - return [1,2,3] + return [1, 2, 3] # Then we will count the loans and holds for the default # patron. @@ -138,7 +118,7 @@ def patron_activity(self, patron, pin): integration = self._external_integration( "api.simple_authentication", ExternalIntegration.PATRON_AUTH_GOAL, - libraries=[with_default_patron] + libraries=[with_default_patron], ) p = BasicAuthenticationProvider integration.setting(p.TEST_IDENTIFIER).value = "username1" @@ -152,22 +132,29 @@ def patron_activity(self, patron, pin): ) assert ( - "Acquiring test patron credentials for library %s" % no_default_patron.name == - no_patron_credential.name) + "Acquiring test patron credentials for library %s" % no_default_patron.name + == no_patron_credential.name + ) assert False == no_patron_credential.success - assert ("Library has no test patron configured." == - str(no_patron_credential.exception)) + assert "Library has no test patron configured." == str( + no_patron_credential.exception + ) - assert ("Asking for circulation events for the last five minutes" == - recent_circulation_events.name) + assert ( + "Asking for circulation events for the last five minutes" + == recent_circulation_events.name + ) assert True == recent_circulation_events.success assert "Found 3 event(s)" == recent_circulation_events.result start, end = api.get_events_between_called_with - assert 5*60 == (end-start).total_seconds() - assert (end-now).total_seconds() < 2 + assert 5 * 60 == (end - start).total_seconds() + assert (end - now).total_seconds() < 2 - assert ("Checking activity for test patron for library %s" % with_default_patron.name == - patron_activity.name) + assert ( + "Checking activity for test patron for library %s" + % with_default_patron.name + == patron_activity.name + ) assert "Found 2 loans/holds" == patron_activity.result patron, pin = api.patron_activity_called_with assert "username1" == patron.authorization_identifier @@ -177,15 +164,20 @@ def test_full_path(self): id = self.api.library_id assert "/cirrus/library/%s/foo" % id == self.api.full_path("foo") assert "/cirrus/library/%s/foo" % id == self.api.full_path("/foo") - assert ("/cirrus/library/%s/foo" % id == - self.api.full_path("/cirrus/library/%s/foo" % id)) + assert "/cirrus/library/%s/foo" % id == self.api.full_path( + "/cirrus/library/%s/foo" % id + ) def test_full_url(self): id = self.api.library_id - assert ("http://bibliotheca.test/cirrus/library/%s/foo" % id == - self.api.full_url("foo")) - assert ("http://bibliotheca.test/cirrus/library/%s/foo" % id == - self.api.full_url("/foo")) + assert ( + "http://bibliotheca.test/cirrus/library/%s/foo" % id + == self.api.full_url("foo") + ) + assert ( + "http://bibliotheca.test/cirrus/library/%s/foo" % id + == self.api.full_url("/foo") + ) def test_request_signing(self): # Confirm a known correct result for the Bibliotheca request signing @@ -194,11 +186,11 @@ def test_request_signing(self): self.api.queue_response(200) response = self.api.request("some_url") [request] = self.api.requests - headers = request[-1]['headers'] - assert 'Fri, 01 Jan 2016 00:00:00 GMT' == headers['3mcl-Datetime'] - assert '2.0' == headers['3mcl-Version'] - expect = '3MCLAUTH a:HZHNGfn6WVceakGrwXaJQ9zIY0Ai5opGct38j9/bHrE=' - assert expect == headers['3mcl-Authorization'] + headers = request[-1]["headers"] + assert "Fri, 01 Jan 2016 00:00:00 GMT" == headers["3mcl-Datetime"] + assert "2.0" == headers["3mcl-Version"] + expect = "3MCLAUTH a:HZHNGfn6WVceakGrwXaJQ9zIY0Ai5opGct38j9/bHrE=" + assert expect == headers["3mcl-Authorization"] # Tweak one of the variables that go into the signature, and # the signature changes. @@ -206,8 +198,8 @@ def test_request_signing(self): self.api.queue_response(200) response = self.api.request("some_url") request = self.api.requests[-1] - headers = request[-1]['headers'] - assert headers['3mcl-Authorization'] != expect + headers = request[-1]["headers"] + assert headers["3mcl-Authorization"] != expect def test_replacement_policy(self): mock_analytics = object() @@ -228,7 +220,6 @@ def test_bibliographic_lookup_request(self): assert b"some data" == response def test_bibliographic_lookup(self): - class MockItemListParser(object): def parse(self, data): self.parse_called_with = data @@ -237,20 +228,22 @@ def parse(self, data): class Mock(MockBibliothecaAPI): """Mock the functionality used by bibliographic_lookup_request.""" + def __init__(self): self.item_list_parser = MockItemListParser() def bibliographic_lookup_request(self, identifier_strings): self.bibliographic_lookup_request_called_with = identifier_strings return "parse me" + api = Mock() identifier = self._identifier() # We can pass in a list of identifier strings, a list of # Identifier objects, or a single example of each: for identifier, identifier_string in ( - ("id1", "id1"), - (identifier, identifier.identifier) + ("id1", "id1"), + (identifier, identifier.identifier), ): for identifier_list in ([identifier], identifier): api.item_list_parser.parse_called_with = None @@ -259,9 +252,9 @@ def bibliographic_lookup_request(self, identifier_strings): # A list of identifier strings is passed into # bibliographic_lookup_request(). - assert ( - [identifier_string] == - api.bibliographic_lookup_request_called_with) + assert [ + identifier_string + ] == api.bibliographic_lookup_request_called_with # The response content is passed into parse() assert "parse me" == api.item_list_parser.parse_called_with @@ -282,14 +275,14 @@ def test_put_request(self): # manager, which actually uses this functionality. self.api.queue_response(200, content="ok, you put something") - response = self.api.request('checkout', "put this!", method="PUT") + response = self.api.request("checkout", "put this!", method="PUT") # The PUT request went through to the correct URL and the right # payload was sent. [[method, url, args, kwargs]] = self.api.requests assert "PUT" == method assert self.api.full_url("checkout") == url - assert 'put this!' == kwargs['data'] + assert "put this!" == kwargs["data"] # The response is what we'd expect. assert 200 == response.status_code @@ -302,15 +295,14 @@ def test_get_events_between_success(self): an_hour_ago = now - timedelta(minutes=3600) response = self.api.get_events_between(an_hour_ago, now) [event] = list(response) - assert 'd5rf89' == event[0] + assert "d5rf89" == event[0] def test_get_events_between_failure(self): self.api.queue_response(500) now = utc_now() an_hour_ago = now - timedelta(minutes=3600) pytest.raises( - BadResponseException, - self.api.get_events_between, an_hour_ago, now + BadResponseException, self.api.get_events_between, an_hour_ago, now ) def test_update_availability(self): @@ -320,7 +312,8 @@ def test_update_availability(self): # Create an analytics integration so we can make sure # events are tracked. integration, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, goal=ExternalIntegration.ANALYTICS_GOAL, protocol="core.local_analytics_provider", ) @@ -330,7 +323,7 @@ def test_update_availability(self): identifier_type=Identifier.THREEM_ID, data_source_name=DataSource.THREEM, with_license_pool=True, - collection=self.collection + collection=self.collection, ) # We have never checked the circulation information for this @@ -345,8 +338,9 @@ def test_update_availability(self): # change for it. work, is_new = pool.calculate_work() assert any( - x for x in work.coverage_records - if x.operation==WorkCoverageRecord.CLASSIFY_OPERATION + x + for x in work.coverage_records + if x.operation == WorkCoverageRecord.CLASSIFY_OPERATION ) # Prepare availability information. @@ -365,13 +359,23 @@ def test_update_availability(self): assert 1 == pool.licenses_available assert 0 == pool.patrons_in_hold_queue - circulation_events = self._db.query(CirculationEvent).join(LicensePool).filter(LicensePool.id==pool.id) + circulation_events = ( + self._db.query(CirculationEvent) + .join(LicensePool) + .filter(LicensePool.id == pool.id) + ) assert 3 == circulation_events.count() types = [e.type for e in circulation_events] - assert (sorted([CirculationEvent.DISTRIBUTOR_LICENSE_REMOVE, + assert ( + sorted( + [ + CirculationEvent.DISTRIBUTOR_LICENSE_REMOVE, CirculationEvent.DISTRIBUTOR_CHECKOUT, - CirculationEvent.DISTRIBUTOR_HOLD_RELEASE]) == - sorted(types)) + CirculationEvent.DISTRIBUTOR_HOLD_RELEASE, + ] + ) + == sorted(types) + ) old_last_checked = pool.last_checked assert old_last_checked is not None @@ -380,8 +384,9 @@ def test_update_availability(self): # removed. In the near future its coverage will be # recalculated to accommodate the new metadata. assert any( - x for x in work.coverage_records - if x.operation==WorkCoverageRecord.CLASSIFY_OPERATION + x + for x in work.coverage_records + if x.operation == WorkCoverageRecord.CLASSIFY_OPERATION ) # Now let's try update_availability again, with a file that @@ -398,7 +403,11 @@ def test_update_availability(self): assert pool.last_checked is not old_last_checked - circulation_events = self._db.query(CirculationEvent).join(LicensePool).filter(LicensePool.id==pool.id) + circulation_events = ( + self._db.query(CirculationEvent) + .join(LicensePool) + .filter(LicensePool.id == pool.id) + ) assert 5 == circulation_events.count() def test_marc_request(self): @@ -406,41 +415,38 @@ def test_marc_request(self): # call and yields a sequence of pymarc Record objects. start = datetime_utc(2012, 1, 2, 3, 4, 5) end = datetime_utc(2014, 5, 6, 7, 8, 9) - self.api.queue_response( - 200, content=self.sample_data("marc_records_two.xml") - ) + self.api.queue_response(200, content=self.sample_data("marc_records_two.xml")) records = [x for x in self.api.marc_request(start, end, 10, 20)] [(method, url, body, headers)] = self.api.requests # A GET request was sent to the expected endpoint assert method == "GET" for expect in ( - "/data/marc?" - "startdate=2012-01-02T03:04:05", - "enddate=2014-05-06T07:08:09", - "offset=10", - "limit=20" + "/data/marc?" "startdate=2012-01-02T03:04:05", + "enddate=2014-05-06T07:08:09", + "offset=10", + "limit=20", ): assert expect in url # The queued response was converted into pymarc Record objects. assert all(isinstance(x, Record) for x in records) - assert ['Siege and Storm', 'Red Island House A Novel/'] == [ + assert ["Siege and Storm", "Red Island House A Novel/"] == [ x.title() for x in records ] # If the API returns an error, an appropriate exception is raised. - self.api.queue_response( - 404, content=self.sample_data("error_unknown.xml") - ) + self.api.queue_response(404, content=self.sample_data("error_unknown.xml")) with pytest.raises(RemoteInitiatedServerError) as excinfo: [x for x in self.api.marc_request(start, end, 10, 20)] def test_sync_bookshelf(self): patron = self._patron() - circulation = CirculationAPI(self._db, self._default_library, api_map={ - self.collection.protocol : MockBibliothecaAPI - }) + circulation = CirculationAPI( + self._db, + self._default_library, + api_map={self.collection.protocol: MockBibliothecaAPI}, + ) api = circulation.api_for_collection[self.collection.id] api.queue_response(200, content=self.sample_data("checkouts.xml")) @@ -472,16 +478,17 @@ def test_place_hold(self): patron = self._patron() edition, pool = self._edition(with_license_pool=True) self.api.queue_response(200, content=self.sample_data("successful_hold.xml")) - response = self.api.place_hold(patron, 'pin', pool) + response = self.api.place_hold(patron, "pin", pool) assert pool.identifier.type == response.identifier_type assert pool.identifier.identifier == response.identifier def test_place_hold_fails_if_exceeded_hold_limit(self): patron = self._patron() edition, pool = self._edition(with_license_pool=True) - self.api.queue_response(400, content=self.sample_data("error_exceeded_hold_limit.xml")) - pytest.raises(PatronHoldLimitReached, self.api.place_hold, - patron, 'pin', pool) + self.api.queue_response( + 400, content=self.sample_data("error_exceeded_hold_limit.xml") + ) + pytest.raises(PatronHoldLimitReached, self.api.place_hold, patron, "pin", pool) def test_get_audio_fulfillment_file(self): """Verify that get_audio_fulfillment_file sends the @@ -492,8 +499,11 @@ def test_get_audio_fulfillment_file(self): [[method, url, args, kwargs]] = self.api.requests assert "POST" == method - assert url.endswith('GetItemAudioFulfillment') - assert 'bib idpatron id' == kwargs['data'] + assert url.endswith("GetItemAudioFulfillment") + assert ( + "bib idpatron id" + == kwargs["data"] + ) assert 200 == response.status_code assert b"A license" == response.content @@ -510,12 +520,11 @@ def test_fulfill(self): # Let's fulfill the EPUB first. self.api.queue_response( - 200, headers={"Content-Type": "presumably/an-acsm"}, - content="this is an ACSM" - ) - fulfillment = self.api.fulfill( - patron, 'password', pool, internal_format='ePub' + 200, + headers={"Content-Type": "presumably/an-acsm"}, + content="this is an ACSM", ) + fulfillment = self.api.fulfill(patron, "password", pool, internal_format="ePub") assert isinstance(fulfillment, FulfillmentInfo) assert b"this is an ACSM" == fulfillment.content assert pool.identifier.identifier == fulfillment.identifier @@ -528,12 +537,9 @@ def test_fulfill(self): # Now let's try the audio version. license = self.sample_data("sample_findaway_audiobook_license.json") self.api.queue_response( - 200, headers={"Content-Type": "application/json"}, - content=license - ) - fulfillment = self.api.fulfill( - patron, 'password', pool, internal_format='MP3' + 200, headers={"Content-Type": "application/json"}, content=license ) + fulfillment = self.api.fulfill(patron, "password", pool, internal_format="MP3") assert isinstance(fulfillment, FulfillmentInfo) # Here, the media type reported by the server is not passed @@ -548,21 +554,20 @@ def test_fulfill(self): # test_findaway_license_to_webpub_manifest. This just verifies # that the manifest contains information from the 'Findaway' # document as well as information from the Work. - metadata = manifest['metadata'] - assert 'abcdef01234789abcdef0123' == metadata['encrypted']['findaway:checkoutId'] - assert work.title == metadata['title'] + metadata = manifest["metadata"] + assert ( + "abcdef01234789abcdef0123" == metadata["encrypted"]["findaway:checkoutId"] + ) + assert work.title == metadata["title"] # Now let's see what happens to fulfillment when 'Findaway' or # 'Bibliotheca' sends bad information. bad_media_type = "application/error+json" bad_content = b"This is not my beautiful license document!" self.api.queue_response( - 200, headers={"Content-Type": bad_media_type}, - content=bad_content - ) - fulfillment = self.api.fulfill( - patron, 'password', pool, internal_format='MP3' + 200, headers={"Content-Type": bad_media_type}, content=bad_content ) + fulfillment = self.api.fulfill(patron, "password", pool, internal_format="MP3") assert isinstance(fulfillment, FulfillmentInfo) # The (apparently) bad document is just passed on to the @@ -579,7 +584,7 @@ def test_findaway_license_to_webpub_manifest(self): # Randomly scramble the Findaway manifest to make sure it gets # properly sorted when converted to a Webpub-like manifest. document = json.loads(document) - document['items'].sort(key=lambda x: random.random()) + document["items"].sort(key=lambda x: random.random()) document = json.dumps(document) m = BibliothecaAPI.findaway_license_to_webpub_manifest @@ -591,43 +596,44 @@ def test_findaway_license_to_webpub_manifest(self): # files, but we also define an extension context called # 'findaway', which lets us include terms coined by Findaway # in a normal Web Publication Manifest document. - context = manifest['@context'] + context = manifest["@context"] default, findaway = context assert AudiobookManifest.DEFAULT_CONTEXT == default - assert ({"findaway" : FindawayManifest.FINDAWAY_EXTENSION_CONTEXT} == - findaway) + assert {"findaway": FindawayManifest.FINDAWAY_EXTENSION_CONTEXT} == findaway - metadata = manifest['metadata'] + metadata = manifest["metadata"] # Information about the book has been added to metadata. # (This is tested more fully in # core/tests/util/test_util_web_publication_manifest.py.) - assert work.title == metadata['title'] - assert pool.identifier.urn == metadata['identifier'] - assert 'en' == metadata['language'] + assert work.title == metadata["title"] + assert pool.identifier.urn == metadata["identifier"] + assert "en" == metadata["language"] # Information about the license has been added to an 'encrypted' # object within metadata. - encrypted = metadata['encrypted'] - assert ('http://librarysimplified.org/terms/drm/scheme/FAE' == - encrypted['scheme']) - assert 'abcdef01234789abcdef0123' == encrypted['findaway:checkoutId'] - assert '1234567890987654321ababa' == encrypted['findaway:licenseId'] - assert '3M' == encrypted['findaway:accountId'] - assert '123456' == encrypted['findaway:fulfillmentId'] - assert ('aaaaaaaa-4444-cccc-dddd-666666666666' == - encrypted['findaway:sessionKey']) + encrypted = metadata["encrypted"] + assert ( + "http://librarysimplified.org/terms/drm/scheme/FAE" == encrypted["scheme"] + ) + assert "abcdef01234789abcdef0123" == encrypted["findaway:checkoutId"] + assert "1234567890987654321ababa" == encrypted["findaway:licenseId"] + assert "3M" == encrypted["findaway:accountId"] + assert "123456" == encrypted["findaway:fulfillmentId"] + assert ( + "aaaaaaaa-4444-cccc-dddd-666666666666" == encrypted["findaway:sessionKey"] + ) # Every entry in the license document's 'items' list has # become a readingOrder item in the manifest. - reading_order = manifest['readingOrder'] + reading_order = manifest["readingOrder"] assert 79 == len(reading_order) # The duration of each readingOrder item has been converted to # seconds. first = reading_order[0] - assert 16.201 == first['duration'] - assert "Track 1" == first['title'] + assert 16.201 == first["duration"] + assert "Track 1" == first["title"] # There is no 'href' value for the readingOrder items because the # files must be obtained through the Findaway SDK rather than @@ -637,17 +643,16 @@ def test_findaway_license_to_webpub_manifest(self): # part #0. Within that part, the items have been sorted by # their sequence. for i, item in enumerate(reading_order): - assert None == item.get('href', None) - assert Representation.MP3_MEDIA_TYPE == item['type'] - assert 0 == item['findaway:part'] - assert i+1 == item['findaway:sequence'] + assert None == item.get("href", None) + assert Representation.MP3_MEDIA_TYPE == item["type"] + assert 0 == item["findaway:part"] + assert i + 1 == item["findaway:sequence"] # The total duration, in seconds, has been added to metadata. - assert 28371 == int(metadata['duration']) + assert 28371 == int(metadata["duration"]) class TestBibliothecaCirculationSweep(BibliothecaAPITest): - def test_circulation_sweep_discovers_work(self): # Test what happens when BibliothecaCirculationSweep discovers a new # work. @@ -655,7 +660,8 @@ def test_circulation_sweep_discovers_work(self): # Create an analytics integration so we can make sure # events are tracked. integration, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, goal=ExternalIntegration.ANALYTICS_GOAL, protocol="core.local_analytics_provider", ) @@ -690,21 +696,30 @@ def test_circulation_sweep_discovers_work(self): # Three circulation events were created for this license pool, # marking the creation of the license pool, the addition of # licenses owned, and the making of those licenses available. - circulation_events = self._db.query(CirculationEvent).join(LicensePool).filter(LicensePool.id==pool.id) + circulation_events = ( + self._db.query(CirculationEvent) + .join(LicensePool) + .filter(LicensePool.id == pool.id) + ) assert 3 == circulation_events.count() types = [e.type for e in circulation_events] - assert (sorted([CirculationEvent.DISTRIBUTOR_LICENSE_ADD, + assert ( + sorted( + [ + CirculationEvent.DISTRIBUTOR_LICENSE_ADD, CirculationEvent.DISTRIBUTOR_TITLE_ADD, - CirculationEvent.DISTRIBUTOR_CHECKIN - ]) == - sorted(types)) + CirculationEvent.DISTRIBUTOR_CHECKIN, + ] + ) + == sorted(types) + ) # Tests of the various parser classes. # -class TestBibliothecaParser(BibliothecaAPITest): +class TestBibliothecaParser(BibliothecaAPITest): def test_parse_date(self): parser = BibliothecaParser() v = parser.parse_date("2016-01-02T12:34:56") @@ -715,7 +730,6 @@ def test_parse_date(self): class TestEventParser(BibliothecaAPITest): - def test_parse_empty_list(self): data = self.sample_data("empty_event_batch.xml") @@ -729,23 +743,24 @@ def test_parse_empty_list(self): no_events_error = True with pytest.raises(RemoteInitiatedServerError) as excinfo: list(EventParser().process_all(data, no_events_error)) - assert "No events returned from server. This may not be an error, but treating it as one to be safe." in str(excinfo.value) + assert ( + "No events returned from server. This may not be an error, but treating it as one to be safe." + in str(excinfo.value) + ) def test_parse_empty_end_date_event(self): data = self.sample_data("empty_end_date_event.xml") [event] = list(EventParser().process_all(data)) - (threem_id, isbn, patron_id, start_time, end_time, - internal_event_type) = event - assert 'd5rf89' == threem_id - assert '9781101190623' == isbn + (threem_id, isbn, patron_id, start_time, end_time, internal_event_type) = event + assert "d5rf89" == threem_id + assert "9781101190623" == isbn assert None == patron_id assert datetime_utc(2016, 4, 28, 11, 4, 6) == start_time assert None == end_time - assert 'distributor_license_add' == internal_event_type + assert "distributor_license_add" == internal_event_type class TestPatronCirculationParser(BibliothecaAPITest): - def test_parse(self): data = self.sample_data("checkouts.xml") collection = self.collection @@ -793,33 +808,32 @@ def test_parse(self): class TestErrorParser(BibliothecaAPITest): - def test_exceeded_limit(self): """The normal case--we get a helpful error message which we turn into an appropriate circulation exception. """ - msg=self.sample_data("error_exceeded_limit.xml") + msg = self.sample_data("error_exceeded_limit.xml") error = ErrorParser().process_all(msg) assert isinstance(error, PatronLoanLimitReached) - assert 'Patron cannot loan more than 12 documents' == error.message + assert "Patron cannot loan more than 12 documents" == error.message def test_exceeded_hold_limit(self): - msg=self.sample_data("error_exceeded_hold_limit.xml") + msg = self.sample_data("error_exceeded_hold_limit.xml") error = ErrorParser().process_all(msg) assert isinstance(error, PatronHoldLimitReached) - assert 'Patron cannot have more than 15 holds' == error.message + assert "Patron cannot have more than 15 holds" == error.message def test_wrong_status(self): - msg=self.sample_data("error_no_licenses.xml") + msg = self.sample_data("error_no_licenses.xml") error = ErrorParser().process_all(msg) assert isinstance(error, NoLicenses) assert ( - 'the patron document status was CAN_WISH and not one of CAN_LOAN,RESERVATION' == - error.message) + "the patron document status was CAN_WISH and not one of CAN_LOAN,RESERVATION" + == error.message + ) problem = error.as_problem_detail_document() - assert ("The library currently has no licenses for this book." == - problem.detail) + assert "The library currently has no licenses for this book." == problem.detail assert 404 == problem.status_code def test_internal_server_error_beomces_remote_initiated_server_error(self): @@ -836,7 +850,7 @@ def test_internal_server_error_beomces_remote_initiated_server_error(self): def test_unknown_error_becomes_remote_initiated_server_error(self): """Simulate the message we get when the server gives a vague error.""" - msg=self.sample_data("error_unknown.xml") + msg = self.sample_data("error_unknown.xml") error = ErrorParser().process_all(msg) assert isinstance(error, RemoteInitiatedServerError) assert BibliothecaAPI.SERVICE_NAME == error.service_name @@ -847,7 +861,7 @@ def test_remote_authentication_failed_becomes_remote_initiated_server_error(self 'Authentication failed' but our authentication information is set up correctly. """ - msg=self.sample_data("error_authentication_failed.xml") + msg = self.sample_data("error_authentication_failed.xml") error = ErrorParser().process_all(msg) assert isinstance(error, RemoteInitiatedServerError) assert BibliothecaAPI.SERVICE_NAME == error.service_name @@ -867,6 +881,7 @@ def test_blank_error_message_becomes_remote_initiated_server_error(self): assert BibliothecaAPI.SERVICE_NAME == error.service_name assert "Unknown error" == error.message + class TestBibliothecaEventParser(object): # Sample event feed to test out the parser. @@ -905,8 +920,7 @@ def test_parse_event_batch(self): # Parsing the XML gives us two events. event1, event2 = EventParser().process_all(self.TWO_EVENTS) - (threem_id, isbn, patron_id, start_time, end_time, - internal_event_type) = event1 + (threem_id, isbn, patron_id, start_time, end_time, internal_event_type) = event1 assert "theitem1" == threem_id assert "900isbn1" == isbn @@ -914,8 +928,7 @@ def test_parse_event_batch(self): assert CirculationEvent.DISTRIBUTOR_CHECKIN == internal_event_type assert start_time == end_time - (threem_id, isbn, patron_id, start_time, end_time, - internal_event_type) = event2 + (threem_id, isbn, patron_id, start_time, end_time, internal_event_type) = event2 assert "theitem2" == threem_id assert "900isbn2" == isbn assert "patronid2" == patron_id @@ -977,44 +990,78 @@ def test_exception(self): error = parser.process_all(self.TRIED_TO_HOLD_BOOK_ON_LOAN) assert isinstance(error, CannotHold) -class TestBibliothecaPurchaseMonitor(BibliothecaAPITest): +class TestBibliothecaPurchaseMonitor(BibliothecaAPITest): @pytest.fixture() def default_monitor(self): return BibliothecaPurchaseMonitor( - self._db, self.collection, api_class=MockBibliothecaAPI, - analytics=MockAnalyticsProvider() + self._db, + self.collection, + api_class=MockBibliothecaAPI, + analytics=MockAnalyticsProvider(), ) @pytest.fixture() def initialized_monitor(self): - collection = MockBibliothecaAPI.mock_collection(self._db, name='Initialized Purchase Monitor Collection') + collection = MockBibliothecaAPI.mock_collection( + self._db, name="Initialized Purchase Monitor Collection" + ) monitor = BibliothecaPurchaseMonitor( self._db, collection, api_class=MockBibliothecaAPI ) Timestamp.stamp( - self._db, service=monitor.service_name, - service_type=Timestamp.MONITOR_TYPE, collection=collection + self._db, + service=monitor.service_name, + service_type=Timestamp.MONITOR_TYPE, + collection=collection, ) return monitor - @pytest.mark.parametrize('specified_default_start, expected_default_start', [ - ('2011', datetime_utc(year=2011, month=1, day=1)), - ('2011-10', datetime_utc(year=2011, month=10, day=1)), - ('2011-10-05', datetime_utc(year=2011, month=10, day=5)), - ('2011-10-05T15', datetime_utc(year=2011, month=10, day=5, hour=15)), - ('2011-10-05T15:27', datetime_utc(year=2011, month=10, day=5, hour=15, minute=27)), - ('2011-10-05T15:27:33', datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33)), - ('2011-10-05 15:27:33', datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33)), - ('2011-10-05T15:27:33.123456', - datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33, microsecond=123456)), - (datetime_utc(year=2011, month=10, day=5, hour=15, minute=27), - datetime_utc(year=2011, month=10, day=5, hour=15, minute=27)), - (None, None), - ]) - def test_optional_iso_date_valid_dates(self, specified_default_start, expected_default_start, default_monitor): + @pytest.mark.parametrize( + "specified_default_start, expected_default_start", + [ + ("2011", datetime_utc(year=2011, month=1, day=1)), + ("2011-10", datetime_utc(year=2011, month=10, day=1)), + ("2011-10-05", datetime_utc(year=2011, month=10, day=5)), + ("2011-10-05T15", datetime_utc(year=2011, month=10, day=5, hour=15)), + ( + "2011-10-05T15:27", + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27), + ), + ( + "2011-10-05T15:27:33", + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33), + ), + ( + "2011-10-05 15:27:33", + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33), + ), + ( + "2011-10-05T15:27:33.123456", + datetime_utc( + year=2011, + month=10, + day=5, + hour=15, + minute=27, + second=33, + microsecond=123456, + ), + ), + ( + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27), + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27), + ), + (None, None), + ], + ) + def test_optional_iso_date_valid_dates( + self, specified_default_start, expected_default_start, default_monitor + ): # ISO 8601 strings, `datetime`s, or None are valid. - actual_default_start = default_monitor._optional_iso_date(specified_default_start) + actual_default_start = default_monitor._optional_iso_date( + specified_default_start + ) if expected_default_start is not None: assert isinstance(actual_default_start, datetime) assert actual_default_start == expected_default_start @@ -1030,21 +1077,44 @@ def test_monitor_intrinsic_start_time(self, default_monitor, initialized_monitor assert intrinsic_start == expected_intrinsic_start assert intrinsic_start == monitor.default_start_time - @pytest.mark.parametrize('specified_default_start, override_timestamp, expected_start', [ - ('2011-10-05T15:27', False, datetime_utc(year=2011, month=10, day=5, hour=15, minute=27)), - ('2011-10-05T15:27:33', False, datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33)), - (None, False, None), - (None, True, None), - ('2011-10-05T15:27', True, datetime_utc(year=2011, month=10, day=5, hour=15, minute=27)), - ('2011-10-05T15:27:33', True, datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33)), - ]) - def test_specified_start_trumps_intrinsic_default_start(self, specified_default_start, - override_timestamp, expected_start): + @pytest.mark.parametrize( + "specified_default_start, override_timestamp, expected_start", + [ + ( + "2011-10-05T15:27", + False, + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27), + ), + ( + "2011-10-05T15:27:33", + False, + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33), + ), + (None, False, None), + (None, True, None), + ( + "2011-10-05T15:27", + True, + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27), + ), + ( + "2011-10-05T15:27:33", + True, + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33), + ), + ], + ) + def test_specified_start_trumps_intrinsic_default_start( + self, specified_default_start, override_timestamp, expected_start + ): # When a valid `default_start` parameter is specified, it -- not the monitor's # intrinsic default -- will always become the monitor's `default_start_time`. monitor = BibliothecaPurchaseMonitor( - self._db, self.collection, api_class=MockBibliothecaAPI, - default_start=specified_default_start, override_timestamp=override_timestamp, + self._db, + self.collection, + api_class=MockBibliothecaAPI, + default_start=specified_default_start, + override_timestamp=override_timestamp, ) monitor_intrinsic_default = monitor._intrinsic_start_time(self._db) assert isinstance(monitor.default_start_time, datetime) @@ -1052,7 +1122,14 @@ def test_specified_start_trumps_intrinsic_default_start(self, specified_default_ if specified_default_start: assert monitor.default_start_time == expected_start else: - assert abs((monitor_intrinsic_default - monitor.default_start_time).total_seconds()) <= 1 + assert ( + abs( + ( + monitor_intrinsic_default - monitor.default_start_time + ).total_seconds() + ) + <= 1 + ) # If no `default_date` specified, then `override_timestamp` must be false. if not specified_default_start: @@ -1063,7 +1140,9 @@ def test_specified_start_trumps_intrinsic_default_start(self, specified_default_ # will be the actual start time. The cut-off will be roughly the current time, in # either case. expected_cutoff = utc_now() - with mock.patch.object(monitor, 'catch_up_from', return_value=None) as catch_up_from: + with mock.patch.object( + monitor, "catch_up_from", return_value=None + ) as catch_up_from: monitor.run() actual_start, actual_cutoff, progress = catch_up_from.call_args[0] assert abs((expected_cutoff - actual_cutoff).total_seconds()) <= 1 @@ -1071,31 +1150,60 @@ def test_specified_start_trumps_intrinsic_default_start(self, specified_default_ assert actual_start == monitor.default_start_time assert progress.start == monitor.default_start_time - @pytest.mark.parametrize('specified_default_start, override_timestamp, expected_start', [ - ('2011-10-05T15:27', False, datetime_utc(year=2011, month=10, day=5, hour=15, minute=27)), - ('2011-10-05T15:27:33', False, datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33)), - (None, False, None), - (None, True, None), - ('2011-10-05T15:27', True, datetime_utc(year=2011, month=10, day=5, hour=15, minute=27)), - ('2011-10-05T15:27:33', True, datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33)), - ]) - def test_specified_start_can_override_timestamp(self, specified_default_start, - override_timestamp, expected_start): + @pytest.mark.parametrize( + "specified_default_start, override_timestamp, expected_start", + [ + ( + "2011-10-05T15:27", + False, + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27), + ), + ( + "2011-10-05T15:27:33", + False, + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33), + ), + (None, False, None), + (None, True, None), + ( + "2011-10-05T15:27", + True, + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27), + ), + ( + "2011-10-05T15:27:33", + True, + datetime_utc(year=2011, month=10, day=5, hour=15, minute=27, second=33), + ), + ], + ) + def test_specified_start_can_override_timestamp( + self, specified_default_start, override_timestamp, expected_start + ): monitor = BibliothecaPurchaseMonitor( - self._db, self.collection, api_class=MockBibliothecaAPI, - default_start=specified_default_start, override_timestamp=override_timestamp, + self._db, + self.collection, + api_class=MockBibliothecaAPI, + default_start=specified_default_start, + override_timestamp=override_timestamp, ) # For an initialized monitor, the `default_start_time` will be derived from # `timestamp.finish`, unless overridden by a specified `default_start` when # `override_timestamp` is specified as True. ts = Timestamp.stamp( - self._db, service=monitor.service_name, - service_type=Timestamp.MONITOR_TYPE, collection=monitor.collection + self._db, + service=monitor.service_name, + service_type=Timestamp.MONITOR_TYPE, + collection=monitor.collection, ) start_time_from_ts = ts.finish - BibliothecaPurchaseMonitor.OVERLAP - expected_actual_start_time = expected_start if monitor.override_timestamp else start_time_from_ts + expected_actual_start_time = ( + expected_start if monitor.override_timestamp else start_time_from_ts + ) expected_cutoff = utc_now() - with mock.patch.object(monitor, 'catch_up_from', return_value=None) as catch_up_from: + with mock.patch.object( + monitor, "catch_up_from", return_value=None + ) as catch_up_from: monitor.run() actual_start, actual_cutoff, progress = catch_up_from.call_args[0] assert abs((expected_cutoff - actual_cutoff).total_seconds()) <= 1 @@ -1103,9 +1211,7 @@ def test_specified_start_can_override_timestamp(self, specified_default_start, assert actual_start == expected_actual_start_time assert progress.start == expected_actual_start_time - @pytest.mark.parametrize('input', [ - ('invalid'), ('2020/10'), (['2020-10-05']) - ]) + @pytest.mark.parametrize("input", [("invalid"), ("2020/10"), (["2020-10-05"])]) def test_optional_iso_date_invalid_dates(self, input, default_monitor): with pytest.raises(ValueError) as excinfo: default_monitor._optional_iso_date(input) @@ -1119,41 +1225,23 @@ def test_catch_up_from(self, default_monitor): # _checkpoint() will be called after processing this slice # because it's a full slice that ends before today. - full_slice = [ - datetime_utc(2014, 1, 1), - datetime_utc(2014, 1, 2), - True - ] + full_slice = [datetime_utc(2014, 1, 1), datetime_utc(2014, 1, 2), True] # _checkpoint() is not called after processing this slice # because it's not a full slice. - incomplete_slice = [ - datetime_utc(2015, 1, 1), - datetime_utc(2015, 1, 2), - False - ] + incomplete_slice = [datetime_utc(2015, 1, 1), datetime_utc(2015, 1, 2), False] # _checkpoint() is not called after processing this slice, # even though it's supposedly complete, because today isn't # over yet. - today_slice = [ - today - timedelta(days=1), - today, - True - ] + today_slice = [today - timedelta(days=1), today, True] # _checkpoint() is not called after processing this slice # because it doesn't end in the past. - future_slice = [ - today + timedelta(days=1), - today + timedelta(days=2), - True - ] + future_slice = [today + timedelta(days=1), today + timedelta(days=2), True] default_monitor.slice_timespan = MagicMock( - return_value = [ - full_slice, incomplete_slice, today_slice, future_slice - ] + return_value=[full_slice, incomplete_slice, today_slice, future_slice] ) default_monitor.purchases = MagicMock(return_value=["A record"]) default_monitor.process_record = MagicMock() @@ -1172,8 +1260,9 @@ def test_catch_up_from(self, default_monitor): # purchases() was called on each slice it returned. default_monitor.purchases.assert_has_calls( - [mock.call(*x[:2]) for x in ( - full_slice, incomplete_slice, today_slice, future_slice) + [ + mock.call(*x[:2]) + for x in (full_slice, incomplete_slice, today_slice, future_slice) ] ) @@ -1181,8 +1270,10 @@ def test_catch_up_from(self, default_monitor): # passed into process_record along with the start date of the # current slice. default_monitor.process_record.assert_has_calls( - [mock.call("A record", x[0]) for x in - [full_slice, incomplete_slice, today_slice, future_slice]] + [ + mock.call("A record", x[0]) + for x in [full_slice, incomplete_slice, today_slice, future_slice] + ] ) # TimestampData.achievements was set to the total number of @@ -1263,12 +1354,14 @@ def test_process_record(self, default_monitor, caplog): multiple_control_numbers = b"""01034nam a22002413a 4500ehasb89abcde""" no_control_number = b"""01034nam a22002413a 4500""" for bad_record, expect_error in ( - (multiple_control_numbers, - "Ignoring MARC record with multiple Bibliotheca control numbers." - ), - (no_control_number, - "Ignoring MARC record with no Bibliotheca control number." - ) + ( + multiple_control_numbers, + "Ignoring MARC record with multiple Bibliotheca control numbers.", + ), + ( + no_control_number, + "Ignoring MARC record with no Bibliotheca control number.", + ), ): [marc] = parse_xml_to_array(BytesIO(bad_record)) assert default_monitor.process_record(marc, purchase_time) is None @@ -1287,9 +1380,7 @@ def test_process_record(self, default_monitor, caplog): assert pool.identifier.type == Identifier.BIBLIOTHECA_ID assert pool.data_source.name == DataSource.BIBLIOTHECA assert self.collection == pool.collection - ensure_coverage.assert_called_once_with( - pool.identifier, force=True - ) + ensure_coverage.assert_called_once_with(pool.identifier, force=True) # An analytics event is issued to mark the time at which the # book was first purchased. @@ -1300,12 +1391,15 @@ def test_process_record(self, default_monitor, caplog): # If the book is already in this collection, ensure_coverage # is not called. pool, ignore = LicensePool.for_foreign_id( - self._db, DataSource.BIBLIOTHECA, Identifier.BIBLIOTHECA_ID, - "3oock89", collection=self.collection + self._db, + DataSource.BIBLIOTHECA, + Identifier.BIBLIOTHECA_ID, + "3oock89", + collection=self.collection, ) pool2 = default_monitor.process_record(oock89, purchase_time) assert pool == pool2 - assert ensure_coverage.call_count == 1 # i.e. was not called again. + assert ensure_coverage.call_count == 1 # i.e. was not called again. # But an analytics event is still issued to mark the purchase. assert analytics.count == 2 @@ -1327,12 +1421,8 @@ def test_end_to_end(self, default_monitor): # book, and one to the metadata endpoint for information about # that book. api = default_monitor.api - api.queue_response( - 200, content=self.sample_data("marc_records_one.xml") - ) - api.queue_response( - 200, content=self.sample_data("item_metadata_single.xml") - ) + api.queue_response(200, content=self.sample_data("marc_records_one.xml")) + api.queue_response(200, content=self.sample_data("item_metadata_single.xml")) default_monitor.run() # One book was created. @@ -1346,14 +1436,14 @@ def test_end_to_end(self, default_monitor): # Licensing information was also taken from the coverage # provider. [lp] = work.license_pools - assert lp.identifier.identifier == 'ddf4gr9' + assert lp.identifier.identifier == "ddf4gr9" assert default_monitor.collection == lp.collection assert lp.licenses_owned == 1 assert lp.licenses_available == 1 # An analytics event was issued to commemorate the addition of # the book to the collection. - assert default_monitor.analytics.event_type == 'distributor_title_add' + assert default_monitor.analytics.event_type == "distributor_title_add" # The timestamp has been updated; the next time the monitor # runs it will ask for purchases that haven't happened yet. @@ -1362,8 +1452,8 @@ def test_end_to_end(self, default_monitor): assert timestamp.achievements == "MARC records processed: 1" assert timestamp.finish > start_time -class TestBibliothecaEventMonitor(BibliothecaAPITest): +class TestBibliothecaEventMonitor(BibliothecaAPITest): @pytest.fixture() def default_monitor(self): return BibliothecaEventMonitor( @@ -1372,17 +1462,20 @@ def default_monitor(self): @pytest.fixture() def initialized_monitor(self): - collection = MockBibliothecaAPI.mock_collection(self._db, name='Initialized Monitor Collection') + collection = MockBibliothecaAPI.mock_collection( + self._db, name="Initialized Monitor Collection" + ) monitor = BibliothecaEventMonitor( self._db, collection, api_class=MockBibliothecaAPI ) Timestamp.stamp( - self._db, service=monitor.service_name, - service_type=Timestamp.MONITOR_TYPE, collection=collection + self._db, + service=monitor.service_name, + service_type=Timestamp.MONITOR_TYPE, + collection=collection, ) return monitor - def test_run_once(self): # run_once() slices the time between its start date # and the current time into five-minute intervals, and asks for @@ -1393,23 +1486,17 @@ def test_run_once(self): two_hours_ago = now - timedelta(hours=2) # Simulate that this script last ran 24 hours ago - before_timestamp = TimestampData( - start=two_hours_ago, finish=one_hour_ago - ) + before_timestamp = TimestampData(start=two_hours_ago, finish=one_hour_ago) api = MockBibliothecaAPI(self._db, self.collection) - api.queue_response( - 200, content=self.sample_data("item_metadata_single.xml") - ) + api.queue_response(200, content=self.sample_data("item_metadata_single.xml")) # Setting up making requests in 5-minute intervals in the hour slice. for i in range(1, 15): api.queue_response( 200, content=self.sample_data("empty_end_date_event.xml") ) - monitor = BibliothecaEventMonitor( - self._db, self.collection, api_class=api - ) + monitor = BibliothecaEventMonitor(self._db, self.collection, api_class=api) after_timestamp = monitor.run_once(before_timestamp) # Fifteen requests were made to the API: @@ -1442,7 +1529,7 @@ def test_run_once(self): # # The events we found were both from 2016, but that's not # considered when setting the timestamp. - assert one_hour_ago-monitor.OVERLAP == after_timestamp.start + assert one_hour_ago - monitor.OVERLAP == after_timestamp.start self.time_eq(after_timestamp.finish, now) # The timestamp's achivements have been updated. assert "Events handled: 13." == after_timestamp.achievements @@ -1462,12 +1549,8 @@ def test_run_once(self): # # This is going to result in two more API calls, one for the # "5 minutes" and one for the "little bit". - api.queue_response( - 200, content=self.sample_data("empty_event_batch.xml") - ) - api.queue_response( - 200, content=self.sample_data("empty_event_batch.xml") - ) + api.queue_response(200, content=self.sample_data("empty_event_batch.xml")) + api.queue_response(200, content=self.sample_data("empty_event_batch.xml")) monitor.run_once(after_timestamp) # Two more requests were made, but no events were found for the @@ -1480,18 +1563,21 @@ def test_run_once(self): def test_handle_event(self): api = MockBibliothecaAPI(self._db, self.collection) - api.queue_response( - 200, content=self.sample_data("item_metadata_single.xml") - ) + api.queue_response(200, content=self.sample_data("item_metadata_single.xml")) analytics = MockAnalyticsProvider() monitor = BibliothecaEventMonitor( - self._db, self.collection, api_class=api, - analytics=analytics + self._db, self.collection, api_class=api, analytics=analytics ) now = utc_now() - monitor.handle_event("ddf4gr9", "9781250015280", None, now, None, - CirculationEvent.DISTRIBUTOR_LICENSE_ADD) + monitor.handle_event( + "ddf4gr9", + "9781250015280", + None, + now, + None, + CirculationEvent.DISTRIBUTOR_LICENSE_ADD, + ) # The collection now has a LicensePool corresponding to the book # we just loaded. @@ -1523,23 +1609,25 @@ def test_handle_event(self): class TestBibliothecaPurchaseMonitorWhenMultipleCollections(BibliothecaAPITest): - def test_multiple_service_type_timestamps_with_start_date(self): # Start with multiple collections that have timestamps # because they've run before. collections = [ - MockBibliothecaAPI.mock_collection(self._db, name='Collection 1'), - MockBibliothecaAPI.mock_collection(self._db, name='Collection 2'), + MockBibliothecaAPI.mock_collection(self._db, name="Collection 1"), + MockBibliothecaAPI.mock_collection(self._db, name="Collection 2"), ] for c in collections: Timestamp.stamp( - self._db, service=BibliothecaPurchaseMonitor.SERVICE_NAME, - service_type=Timestamp.MONITOR_TYPE, collection=c + self._db, + service=BibliothecaPurchaseMonitor.SERVICE_NAME, + service_type=Timestamp.MONITOR_TYPE, + collection=c, ) # Instantiate the associated monitors with a start date. monitors = [ - BibliothecaPurchaseMonitor(self._db, c, api_class=BibliothecaAPI, - default_start='2011-02-03') + BibliothecaPurchaseMonitor( + self._db, c, api_class=BibliothecaAPI, default_start="2011-02-03" + ) for c in collections ] assert len(monitors) == len(collections) @@ -1549,41 +1637,51 @@ def test_multiple_service_type_timestamps_with_start_date(self): class TestItemListParser(BibliothecaAPITest): - def test_contributors_for_string(cls): - authors = list(ItemListParser.contributors_from_string( - "Walsh, Jill Paton; Sayers, Dorothy L.")) - assert ([x.sort_name for x in authors] == - ["Walsh, Jill Paton", "Sayers, Dorothy L."]) - assert ([x.roles for x in authors] == - [[Contributor.AUTHOR_ROLE], [Contributor.AUTHOR_ROLE]]) + authors = list( + ItemListParser.contributors_from_string( + "Walsh, Jill Paton; Sayers, Dorothy L." + ) + ) + assert [x.sort_name for x in authors] == [ + "Walsh, Jill Paton", + "Sayers, Dorothy L.", + ] + assert [x.roles for x in authors] == [ + [Contributor.AUTHOR_ROLE], + [Contributor.AUTHOR_ROLE], + ] # Parentheticals are stripped. [author] = ItemListParser.contributors_from_string( - "Baum, Frank L. (Frank Lyell)") + "Baum, Frank L. (Frank Lyell)" + ) assert "Baum, Frank L." == author.sort_name # Contributors may have two levels of entity reference escaping, # one of which will have already been handled by the initial parse. # So, we'll test zero and one escapings here. - authors = list(ItemListParser.contributors_from_string( - u'Raji Codell, Esmé; Raji Codell, Esmé')) + authors = list( + ItemListParser.contributors_from_string( + u"Raji Codell, Esmé; Raji Codell, Esmé" + ) + ) author_names = [a.sort_name for a in authors] assert len(authors) == 2 assert len(set(author_names)) == 1 - assert all(u'Raji Codell, Esmé' == name for name in author_names) + assert all(u"Raji Codell, Esmé" == name for name in author_names) # It's possible to specify some role other than AUTHOR_ROLE. narrators = list( ItemListParser.contributors_from_string( - "Callow, Simon; Mann, Bruce; Hagon, Garrick", - Contributor.NARRATOR_ROLE + "Callow, Simon; Mann, Bruce; Hagon, Garrick", Contributor.NARRATOR_ROLE ) ) for narrator in narrators: assert [Contributor.NARRATOR_ROLE] == narrator.roles - assert (["Callow, Simon", "Mann, Bruce", "Hagon, Garrick"] == - [narrator.sort_name for narrator in narrators]) + assert ["Callow, Simon", "Mann, Bruce", "Hagon, Garrick"] == [ + narrator.sort_name for narrator in narrators + ] def test_parse_genre_string(self): def f(genre_string): @@ -1591,12 +1689,17 @@ def f(genre_string): assert all([x.type == Subject.BISAC for x in genres]) return [x.name for x in genres] - assert (["Children's Health", "Health"] == - f("Children&#39;s Health,Health,")) + assert ["Children's Health", "Health"] == f("Children&#39;s Health,Health,") - assert (["Action & Adventure", "Science Fiction", "Fantasy", "Magic", - "Renaissance"] == - f("Action &amp; Adventure,Science Fiction, Fantasy, Magic,Renaissance,")) + assert [ + "Action & Adventure", + "Science Fiction", + "Fantasy", + "Magic", + "Renaissance", + ] == f( + "Action &amp; Adventure,Science Fiction, Fantasy, Magic,Renaissance," + ) def test_item_list(cls): data = cls.sample_data("item_metadata_list_mini.xml") @@ -1612,18 +1715,16 @@ def test_item_list(cls): assert Edition.BOOK_MEDIUM == cooked.medium assert "eng" == cooked.language assert "St. Martin's Press" == cooked.publisher - assert (datetime_utc(year=2012, month=9, day=17) == - cooked.published) + assert datetime_utc(year=2012, month=9, day=17) == cooked.published primary = cooked.primary_identifier assert "ddf4gr9" == primary.identifier assert Identifier.THREEM_ID == primary.type - identifiers = sorted( - cooked.identifiers, key=lambda x: x.identifier - ) - assert (['9781250015280', '9781250031112', 'ddf4gr9'] == - [x.identifier for x in identifiers]) + identifiers = sorted(cooked.identifiers, key=lambda x: x.identifier) + assert ["9781250015280", "9781250031112", "ddf4gr9"] == [ + x.identifier for x in identifiers + ] [author] = cooked.contributors assert "Rowland, Laura Joh" == author.sort_name @@ -1636,8 +1737,7 @@ def test_item_list(cls): assert Measurement.PAGE_COUNT == pages.quantity_measured assert 304 == pages.value - [alternate, image, description] = sorted( - cooked.links, key = lambda x: x.rel) + [alternate, image, description] = sorted(cooked.links, key=lambda x: x.rel) assert "alternate" == alternate.rel assert alternate.href.startswith("http://ebook.3m.com/library") @@ -1645,8 +1745,8 @@ def test_item_list(cls): assert Hyperlink.IMAGE == image.rel assert Representation.JPEG_MEDIA_TYPE == image.media_type assert image.href.startswith("http://ebook.3m.com/delivery") - assert 'documentID=ddf4gr9' in image.href - assert '&size=NORMAL' not in image.href + assert "documentID=ddf4gr9" in image.href + assert "&size=NORMAL" not in image.href # ... and a thumbnail, which we obtained by adding an argument # to the main image URL. @@ -1669,11 +1769,17 @@ def test_multiple_contributor_roles(self): # We found one author and three narrators. assert ( - sorted([('Riggs, Ransom', 'Author'), - ('Callow, Simon', 'Narrator'), - ('Mann, Bruce', 'Narrator'), - ('Hagon, Garrick', 'Narrator')]) == - sorted(names_and_roles)) + sorted( + [ + ("Riggs, Ransom", "Author"), + ("Callow, Simon", "Narrator"), + ("Mann, Bruce", "Narrator"), + ("Hagon, Garrick", "Narrator"), + ] + ) + == sorted(names_and_roles) + ) + class TestBibliographicCoverageProvider(TestBibliothecaAPI): @@ -1684,19 +1790,19 @@ def test_script_instantiation(self): this coverage provider. """ script = RunCollectionCoverageProviderScript( - BibliothecaBibliographicCoverageProvider, self._db, - api_class=MockBibliothecaAPI + BibliothecaBibliographicCoverageProvider, + self._db, + api_class=MockBibliothecaAPI, ) [provider] = script.providers - assert isinstance(provider, - BibliothecaBibliographicCoverageProvider) + assert isinstance(provider, BibliothecaBibliographicCoverageProvider) assert isinstance(provider.api, MockBibliothecaAPI) def test_process_item_creates_presentation_ready_work(self): # Test the normal workflow where we ask Bibliotheca for data, # Bibliotheca provides it, and we create a presentation-ready work. identifier = self._identifier(identifier_type=Identifier.BIBLIOTHECA_ID) - identifier.identifier = 'ddf4gr9' + identifier.identifier = "ddf4gr9" # This book has no LicensePools. assert [] == identifier.licensed_through @@ -1721,8 +1827,9 @@ def test_process_item_creates_presentation_ready_work(self): assert 1 == pool.licenses_available [lpdm] = pool.delivery_mechanisms assert ( - 'application/epub+zip (application/vnd.adobe.adept+xml)' == - lpdm.delivery_mechanism.name) + "application/epub+zip (application/vnd.adobe.adept+xml)" + == lpdm.delivery_mechanism.name + ) # A Work was created and made presentation ready. assert "The Incense Game" == pool.work.title @@ -1731,6 +1838,7 @@ def test_process_item_creates_presentation_ready_work(self): def test_internal_formats(self): m = ItemListParser.internal_formats + def _check_format(input, expect_medium, expect_format, expect_drm): medium, formats = m(input) assert medium == expect_medium diff --git a/tests/test_circulation_exceptions.py b/tests/test_circulation_exceptions.py index 6749a81dba..4d2aa4ad8f 100644 --- a/tests/test_circulation_exceptions.py +++ b/tests/test_circulation_exceptions.py @@ -1,9 +1,10 @@ from flask_babel import lazy_gettext as _ -from core.util.problem_detail import ProblemDetail -from api.config import Configuration + from api.circulation_exceptions import * +from api.config import Configuration from api.problem_details import * from core.testing import DatabaseTest +from core.util.problem_detail import ProblemDetail class TestCirculationExceptions(object): @@ -35,13 +36,10 @@ class TestLimitReached(DatabaseTest): """ def test_as_problem_detail_document(self): - generic_message = _("You exceeded the limit, but I don't know what the limit was.") - pd = ProblemDetail( - "http://uri/", - 403, - _("Limit exceeded."), - generic_message + generic_message = _( + "You exceeded the limit, but I don't know what the limit was." ) + pd = ProblemDetail("http://uri/", 403, _("Limit exceeded."), generic_message) setting = "some setting" class Mock(LimitReached): @@ -75,10 +73,14 @@ def test_subclasses(self): library.setting(Configuration.LOAN_LIMIT).value = 2 pd = PatronLoanLimitReached(library=library).as_problem_detail_document() - assert ("You have reached your loan limit of 2. You cannot borrow anything further until you return something." == - pd.detail) + assert ( + "You have reached your loan limit of 2. You cannot borrow anything further until you return something." + == pd.detail + ) library.setting(Configuration.HOLD_LIMIT).value = 3 pd = PatronHoldLimitReached(library=library).as_problem_detail_document() - assert ("You have reached your hold limit of 3. You cannot place another item on hold until you borrow something or remove a hold." == - pd.detail) + assert ( + "You have reached your hold limit of 3. You cannot place another item on hold until you borrow something or remove a hold." + == pd.detail + ) diff --git a/tests/test_circulationapi.py b/tests/test_circulationapi.py index e48a61bd4f..0fe6f33d20 100644 --- a/tests/test_circulationapi.py +++ b/tests/test_circulationapi.py @@ -39,6 +39,7 @@ ) from core.testing import DatabaseTest from core.util.datetime_helpers import utc_now + from . import sample_data @@ -55,7 +56,8 @@ def setup_method(self): edition, self.pool = self._edition( data_source_name=DataSource.BIBLIOTHECA, identifier_type=Identifier.BIBLIOTHECA_ID, - with_license_pool=True, collection=self.collection + with_license_pool=True, + collection=self.collection, ) self.pool.open_access = False self.identifier = self.pool.identifier @@ -63,21 +65,20 @@ def setup_method(self): self.patron = self._patron() self.analytics = MockAnalyticsProvider() self.circulation = MockCirculationAPI( - self._db, self._default_library, analytics=self.analytics, api_map = { - ExternalIntegration.BIBLIOTHECA : MockBibliothecaAPI - } + self._db, + self._default_library, + analytics=self.analytics, + api_map={ExternalIntegration.BIBLIOTHECA: MockBibliothecaAPI}, ) self.remote = self.circulation.api_for_license_pool(self.pool) def borrow(self): return self.circulation.borrow( - self.patron, '1234', self.pool, self.delivery_mechanism + self.patron, "1234", self.pool, self.delivery_mechanism ) def sync_bookshelf(self): - return self.circulation.sync_bookshelf( - self.patron, '1234' - ) + return self.circulation.sync_bookshelf(self.patron, "1234") def test_circulationinfo_collection_id(self): # It's possible to instantiate CirculationInfo (the superclass of all @@ -95,10 +96,12 @@ def test_circulationinfo_collection_id(self): def test_borrow_sends_analytics_event(self): now = utc_now() loaninfo = LoanInfo( - self.pool.collection, self.pool.data_source, + self.pool.collection, + self.pool.data_source, self.pool.identifier.type, self.pool.identifier.identifier, - now, now + timedelta(seconds=3600), + now, + now + timedelta(seconds=3600), external_identifier=self._str, ) self.remote.queue_checkout(loaninfo) @@ -115,8 +118,7 @@ def test_borrow_sends_analytics_event(self): # An analytics event was created. assert 1 == self.analytics.count - assert (CirculationEvent.CM_CHECKOUT == - self.analytics.event_type) + assert CirculationEvent.CM_CHECKOUT == self.analytics.event_type # Try to 'borrow' the same book again. self.remote.queue_checkout(AlreadyCheckedOut()) @@ -159,8 +161,8 @@ def test_borrowing_of_self_hosted_book_succeeds(self): def test_borrowing_of_unlimited_access_book_succeeds(self): """Ensure that unlimited access books that don't belong to collections - having a custom CirculationAPI implementation (e.g., OPDS 1.x, OPDS 2.x collections) - are checked out in the same way as OA and self-hosted books.""" + having a custom CirculationAPI implementation (e.g., OPDS 1.x, OPDS 2.x collections) + are checked out in the same way as OA and self-hosted books.""" # Arrange # Reset the API map, this book belongs to the "basic" collection, @@ -186,8 +188,12 @@ def test_attempt_borrow_with_existing_remote_loan(self): """ # Remote loan. self.circulation.add_remote_loan( - self.pool.collection, self.pool.data_source, self.identifier.type, - self.identifier.identifier, self.YESTERDAY, self.IN_TWO_WEEKS + self.pool.collection, + self.pool.data_source, + self.identifier.type, + self.identifier.identifier, + self.YESTERDAY, + self.IN_TWO_WEEKS, ) self.remote.queue_checkout(AlreadyCheckedOut()) @@ -204,8 +210,8 @@ def test_attempt_borrow_with_existing_remote_loan(self): # but didn't give us any useful information on when that loan # was created. We've faked it with values that should be okay # until the next sync. - assert abs((loan.start-now).seconds) < 2 - assert 3600 == (loan.end-loan.start).seconds + assert abs((loan.start - now).seconds) < 2 + assert 3600 == (loan.end - loan.start).seconds def test_attempt_borrow_with_existing_remote_hold(self): """The patron has a remote hold that the circ manager doesn't know @@ -214,9 +220,13 @@ def test_attempt_borrow_with_existing_remote_hold(self): """ # Remote hold. self.circulation.add_remote_hold( - self.pool.collection, self.pool.data_source, - self.identifier.type, self.identifier.identifier, - self.YESTERDAY, self.IN_TWO_WEEKS, 10 + self.pool.collection, + self.pool.data_source, + self.identifier.type, + self.identifier.identifier, + self.YESTERDAY, + self.IN_TWO_WEEKS, + 10, ) self.remote.queue_checkout(AlreadyOnHold()) @@ -234,7 +244,7 @@ def test_attempt_borrow_with_existing_remote_hold(self): # created. We've set the hold start time to the time we found # out about it. We'll get the real information the next time # we do a sync. - assert abs((hold.start-now).seconds) < 2 + assert abs((hold.start - now).seconds) < 2 assert None == hold.end assert None == hold.position @@ -247,9 +257,12 @@ def test_attempt_premature_renew_with_local_loan(self): # Remote loan. self.circulation.add_remote_loan( - self.pool.collection, self.pool.data_source, - self.identifier.type, self.identifier.identifier, - self.YESTERDAY, self.IN_TWO_WEEKS + self.pool.collection, + self.pool.data_source, + self.identifier.type, + self.identifier.identifier, + self.YESTERDAY, + self.IN_TWO_WEEKS, ) # This is the expected behavior in most cases--you tried to @@ -257,7 +270,7 @@ def test_attempt_premature_renew_with_local_loan(self): self.remote.queue_checkout(CannotRenew()) with pytest.raises(CannotRenew) as excinfo: self.borrow() - assert 'CannotRenew' in str(excinfo.value) + assert "CannotRenew" in str(excinfo.value) def test_attempt_renew_with_local_loan_and_no_available_copies(self): """We have a local loan and a remote loan but the patron tried to @@ -268,9 +281,12 @@ def test_attempt_renew_with_local_loan_and_no_available_copies(self): # Remote loan. self.circulation.add_remote_loan( - self.pool.collection, self.pool.data_source, - self.identifier.type, self.identifier.identifier, - self.YESTERDAY, self.IN_TWO_WEEKS + self.pool.collection, + self.pool.data_source, + self.identifier.type, + self.identifier.identifier, + self.YESTERDAY, + self.IN_TWO_WEEKS, ) # NoAvailableCopies can happen if there are already people @@ -282,15 +298,21 @@ def test_attempt_renew_with_local_loan_and_no_available_copies(self): self.remote.queue_checkout(NoAvailableCopies()) with pytest.raises(CannotRenew) as excinfo: self.borrow() - assert "You cannot renew a loan if other patrons have the work on hold." in str(excinfo.value) + assert "You cannot renew a loan if other patrons have the work on hold." in str( + excinfo.value + ) def test_loan_becomes_hold_if_no_available_copies(self): # We want to borrow this book but there are no copies. self.remote.queue_checkout(NoAvailableCopies()) holdinfo = HoldInfo( - self.pool.collection, self.pool.data_source, - self.identifier.type, self.identifier.identifier, - None, None, 10 + self.pool.collection, + self.pool.data_source, + self.identifier.type, + self.identifier.identifier, + None, + None, + 10, ) self.remote.queue_hold(holdinfo) @@ -307,9 +329,13 @@ def test_borrow_creates_hold_if_api_returns_hold_info(self): # places a hold for us right away instead of raising # an error. holdinfo = HoldInfo( - self.pool.collection, self.pool.data_source, - self.identifier.type, self.identifier.identifier, - None, None, 10 + self.pool.collection, + self.pool.data_source, + self.identifier.type, + self.identifier.identifier, + None, + None, + 10, ) self.remote.queue_checkout(holdinfo) @@ -329,9 +355,13 @@ def test_vendor_side_loan_limit_allows_for_hold_placement(self): # But the point is moot because the book isn't even available. # Attempting to place a hold will succeed. holdinfo = HoldInfo( - self.pool.collection, self.pool.data_source, - self.identifier.type, self.identifier.identifier, - None, None, 10 + self.pool.collection, + self.pool.data_source, + self.identifier.type, + self.identifier.identifier, + None, + None, + 10, ) self.remote.queue_hold(holdinfo) @@ -355,8 +385,8 @@ def test_loan_exception_reraised_if_hold_placement_fails(self): # in the first place. self.remote.queue_hold(CurrentlyAvailable()) - assert len(self.remote.responses['checkout']) == 1 - assert len(self.remote.responses['hold']) == 1 + assert len(self.remote.responses["checkout"]) == 1 + assert len(self.remote.responses["hold"]) == 1 # The exception raised is PatronLoanLimitReached, the first # one we encountered... @@ -364,15 +394,19 @@ def test_loan_exception_reraised_if_hold_placement_fails(self): # ...but we made both requests and have no more responses # queued. - assert not self.remote.responses['checkout'] - assert not self.remote.responses['hold'] + assert not self.remote.responses["checkout"] + assert not self.remote.responses["hold"] def test_hold_sends_analytics_event(self): self.remote.queue_checkout(NoAvailableCopies()) holdinfo = HoldInfo( - self.pool.collection, self.pool.data_source, - self.identifier.type, self.identifier.identifier, - None, None, 10 + self.pool.collection, + self.pool.data_source, + self.identifier.type, + self.identifier.identifier, + None, + None, + 10, ) self.remote.queue_hold(holdinfo) @@ -386,8 +420,7 @@ def test_hold_sends_analytics_event(self): # An analytics event was created. assert 1 == self.analytics.count - assert (CirculationEvent.CM_HOLD_PLACE == - self.analytics.event_type) + assert CirculationEvent.CM_HOLD_PLACE == self.analytics.event_type # Try to 'borrow' the same book again. self.remote.queue_checkout(AlreadyOnHold()) @@ -407,9 +440,13 @@ def test_loan_becomes_hold_if_no_available_copies_and_preexisting_loan(self): # copies! self.remote.queue_checkout(NoAvailableCopies()) holdinfo = HoldInfo( - self.pool.collection, self.pool.data_source, - self.identifier.type, self.identifier.identifier, - None, None, 10 + self.pool.collection, + self.pool.data_source, + self.identifier.type, + self.identifier.identifier, + None, + None, + 10, ) self.remote.queue_hold(holdinfo) @@ -433,10 +470,12 @@ def test_borrow_with_expired_card_fails(self): # This checkout would succeed... now = utc_now() loaninfo = LoanInfo( - self.pool.collection, self.pool.data_source, + self.pool.collection, + self.pool.data_source, self.pool.identifier.type, self.pool.identifier.identifier, - now, now + timedelta(seconds=3600), + now, + now + timedelta(seconds=3600), ) self.remote.queue_checkout(loaninfo) @@ -452,10 +491,12 @@ def test_borrow_with_outstanding_fines(self): # This checkout would succeed... now = utc_now() loaninfo = LoanInfo( - self.pool.collection, self.pool.data_source, + self.pool.collection, + self.pool.data_source, self.pool.identifier.type, self.pool.identifier.identifier, - now, now + timedelta(seconds=3600), + now, + now + timedelta(seconds=3600), ) self.remote.queue_checkout(loaninfo) @@ -463,8 +504,7 @@ def test_borrow_with_outstanding_fines(self): old_fines = self.patron.fines self.patron.fines = 1000 setting = ConfigurationSetting.for_library( - Configuration.MAX_OUTSTANDING_FINES, - self._default_library + Configuration.MAX_OUTSTANDING_FINES, self._default_library ) setting.value = "$0.50" @@ -474,7 +514,6 @@ def test_borrow_with_outstanding_fines(self): setting.value = "$0" pytest.raises(OutstandingFines, self.borrow) - # Remove the fine policy, and borrow succeeds. setting.value = None loan, i1, i2 = self.borrow() @@ -486,10 +525,12 @@ def test_borrow_with_block_fails(self): # This checkout would succeed... now = utc_now() loaninfo = LoanInfo( - self.pool.collection, self.pool.data_source, + self.pool.collection, + self.pool.data_source, self.pool.identifier.type, self.pool.identifier.identifier, - now, now + timedelta(seconds=3600), + now, + now + timedelta(seconds=3600), ) self.remote.queue_checkout(loaninfo) @@ -526,6 +567,7 @@ def internal_format(self, *args, **kwargs): def checkout(self, *args, **kwargs): raise NotImplementedError() + api = MockVendorAPI() class MockCirculationAPI(CirculationAPI): @@ -539,6 +581,7 @@ def enforce_limits(self, patron, licensepool): def api_for_license_pool(self, pool): # Always return the same mock MockVendorAPI. return api + self.circulation = MockCirculationAPI(self._db, self._default_library) # checkout() raised the expected NotImplementedError @@ -723,6 +766,7 @@ def assert_enforce_limits_raises(expected_exception): # error message when the exception is converted to a # problem detail document. assert 12 == e.limit + assert_enforce_limits_raises(PatronLoanLimitReached) # We were able to deduce that the patron can't do anything @@ -809,10 +853,13 @@ def test_borrow_hold_limit_reached(self): self.remote.queue_checkout(NoAvailableCopies()) now = utc_now() holdinfo = HoldInfo( - self.pool.collection, self.pool.data_source, + self.pool.collection, + self.pool.data_source, self.pool.identifier.type, self.pool.identifier.identifier, - now, now + timedelta(seconds=3600), 10 + now, + now + timedelta(seconds=3600), + 10, ) self.remote.queue_hold(holdinfo) loan, hold, is_new = self.borrow() @@ -833,20 +880,27 @@ def test_fulfill_open_access(self): # fulfill_open_access() and fulfill() will both raise # FormatNotAvailable. - pytest.raises(FormatNotAvailable, self.circulation.fulfill_open_access, - self.pool, i_want_an_epub) + pytest.raises( + FormatNotAvailable, + self.circulation.fulfill_open_access, + self.pool, + i_want_an_epub, + ) - pytest.raises(FormatNotAvailable, self.circulation.fulfill, - self.patron, '1234', self.pool, - broken_lpdm, - sync_on_failure=False + pytest.raises( + FormatNotAvailable, + self.circulation.fulfill, + self.patron, + "1234", + self.pool, + broken_lpdm, + sync_on_failure=False, ) # Let's add a second LicensePoolDeliveryMechanism of the same # type which has an associated Resource. link, new = self.pool.identifier.add_link( - Hyperlink.OPEN_ACCESS_DOWNLOAD, self._url, - self.pool.data_source + Hyperlink.OPEN_ACCESS_DOWNLOAD, self._url, self.pool.data_source ) working_lpdm = self.pool.set_delivery_mechanism( @@ -859,20 +913,24 @@ def test_fulfill_open_access(self): # It's still not going to work because the Resource has no # Representation. assert None == link.resource.representation - pytest.raises(FormatNotAvailable, self.circulation.fulfill_open_access, - self.pool, i_want_an_epub) + pytest.raises( + FormatNotAvailable, + self.circulation.fulfill_open_access, + self.pool, + i_want_an_epub, + ) # Let's add a Representation to the Resource. representation, is_new = self._representation( - link.resource.url, i_want_an_epub.content_type, - "Dummy content", mirrored=True + link.resource.url, + i_want_an_epub.content_type, + "Dummy content", + mirrored=True, ) link.resource.representation = representation # We can finally fulfill a loan. - result = self.circulation.fulfill_open_access( - self.pool, broken_lpdm - ) + result = self.circulation.fulfill_open_access(self.pool, broken_lpdm) assert isinstance(result, FulfillmentInfo) assert result.content_link == link.resource.representation.public_url assert result.content_type == i_want_an_epub.content_type @@ -880,9 +938,7 @@ def test_fulfill_open_access(self): # Now, if we try to call fulfill() with the broken # LicensePoolDeliveryMechanism we get a result from the # working DeliveryMechanism with the same format. - result = self.circulation.fulfill( - self.patron, '1234', self.pool, broken_lpdm - ) + result = self.circulation.fulfill(self.patron, "1234", self.pool, broken_lpdm) assert isinstance(result, FulfillmentInfo) assert result.content_link == link.resource.representation.public_url assert result.content_type == i_want_an_epub.content_type @@ -891,9 +947,7 @@ def test_fulfill_open_access(self): # fulfill_open_access() is incorrectly written and passes in # the broken LicensePoolDeliveryMechanism (as opposed to its # generic DeliveryMechanism). - result = self.circulation.fulfill_open_access( - self.pool, broken_lpdm - ) + result = self.circulation.fulfill_open_access(self.pool, broken_lpdm) assert isinstance(result, FulfillmentInfo) assert result.content_link == link.resource.representation.public_url assert result.content_type == i_want_an_epub.content_type @@ -902,17 +956,20 @@ def test_fulfill_open_access(self): # media type than the one we're asking for, we're back to # FormatNotAvailable errors. irrelevant_delivery_mechanism, ignore = DeliveryMechanism.lookup( - self._db, "application/some-other-type", - DeliveryMechanism.NO_DRM + self._db, "application/some-other-type", DeliveryMechanism.NO_DRM ) working_lpdm.delivery_mechanism = irrelevant_delivery_mechanism - pytest.raises(FormatNotAvailable, self.circulation.fulfill_open_access, - self.pool, i_want_an_epub) + pytest.raises( + FormatNotAvailable, + self.circulation.fulfill_open_access, + self.pool, + i_want_an_epub, + ) def test_fulfilment_of_unlimited_access_book_succeeds(self): """Ensure that unlimited access books that don't belong to collections - having a custom CirculationAPI implementation (e.g., OPDS 1.x, OPDS 2.x collections) - are fulfilled in the same way as OA and self-hosted books.""" + having a custom CirculationAPI implementation (e.g., OPDS 1.x, OPDS 2.x collections) + are fulfilled in the same way as OA and self-hosted books.""" # Reset the API map, this book belongs to the "basic" collection, # i.e. collection without a custom CirculationAPI implementation. self.circulation.api_for_license_pool = MagicMock(return_value=None) @@ -924,9 +981,7 @@ def test_fulfilment_of_unlimited_access_book_succeeds(self): # Create a borrow link. link, _ = self.pool.identifier.add_link( - Hyperlink.BORROW, - self._url, - self.pool.data_source + Hyperlink.BORROW, self._url, self.pool.data_source ) # Create a license pool delivery mechanism. @@ -939,10 +994,7 @@ def test_fulfilment_of_unlimited_access_book_succeeds(self): # Create a representation. representation, _ = self._representation( - link.resource.url, - media_type, - "Dummy content", - mirrored=True + link.resource.url, media_type, "Dummy content", mirrored=True ) link.resource.representation = representation @@ -950,10 +1002,7 @@ def test_fulfilment_of_unlimited_access_book_succeeds(self): self.pool.loan_to(self.patron) result = self.circulation.fulfill( - self.patron, - '1234', - self.pool, - self.pool.delivery_mechanisms[0] + self.patron, "1234", self.pool, self.pool.delivery_mechanisms[0] ) # The fulfillment looks good. @@ -969,16 +1018,16 @@ def test_fulfill(self): fulfillment.content_link = None self.remote.queue_fulfill(fulfillment) - result = self.circulation.fulfill(self.patron, '1234', self.pool, - self.pool.delivery_mechanisms[0]) + result = self.circulation.fulfill( + self.patron, "1234", self.pool, self.pool.delivery_mechanisms[0] + ) # The fulfillment looks good. assert fulfillment == result # An analytics event was created. assert 1 == self.analytics.count - assert (CirculationEvent.CM_FULFILL == - self.analytics.event_type) + assert CirculationEvent.CM_FULFILL == self.analytics.event_type def test_fulfill_without_loan(self): # By default, a title cannot be fulfilled unless there is an active @@ -991,7 +1040,7 @@ def test_fulfill_without_loan(self): def try_to_fulfill(): # Note that we're passing None for `patron`. return self.circulation.fulfill( - None, '1234', self.pool, self.pool.delivery_mechanisms[0] + None, "1234", self.pool, self.pool.delivery_mechanisms[0] ) pytest.raises(NoActiveLoan, try_to_fulfill) @@ -1000,15 +1049,18 @@ def try_to_fulfill(): # okay, the title will be fulfilled anyway. def yes_we_can(*args, **kwargs): return True + self.circulation.can_fulfill_without_loan = yes_we_can result = try_to_fulfill() assert fulfillment == result - @parameterized.expand([ - ('open_access', True, False), - ('self_hosted', False, True), - ('neither', False, False), - ]) + @parameterized.expand( + [ + ("open_access", True, False), + ("self_hosted", False, True), + ("neither", False, False), + ] + ) def test_revoke_loan(self, _, open_access=False, self_hosted=False): self.pool.open_access = open_access self.pool.self_hosted = self_hosted @@ -1017,7 +1069,7 @@ def test_revoke_loan(self, _, open_access=False, self_hosted=False): self.pool.loan_to(self.patron) self.remote.queue_checkin(True) - result = self.circulation.revoke_loan(self.patron, '1234', self.pool) + result = self.circulation.revoke_loan(self.patron, "1234", self.pool) assert True == result # The patron's loan activity is now out of sync. @@ -1025,13 +1077,9 @@ def test_revoke_loan(self, _, open_access=False, self_hosted=False): # An analytics event was created. assert 1 == self.analytics.count - assert (CirculationEvent.CM_CHECKIN == - self.analytics.event_type) + assert CirculationEvent.CM_CHECKIN == self.analytics.event_type - @parameterized.expand([ - ('open_access', True, False), - ('self_hosted', False, True) - ]) + @parameterized.expand([("open_access", True, False), ("self_hosted", False, True)]) def test_release_hold(self, _, open_access=False, self_hosted=False): self.pool.open_access = open_access self.pool.self_hosted = self_hosted @@ -1040,7 +1088,7 @@ def test_release_hold(self, _, open_access=False, self_hosted=False): self.pool.on_hold_to(self.patron) self.remote.queue_release_hold(True) - result = self.circulation.release_hold(self.patron, '1234', self.pool) + result = self.circulation.release_hold(self.patron, "1234", self.pool) assert True == result # The patron's loan activity is now out of sync. @@ -1048,8 +1096,7 @@ def test_release_hold(self, _, open_access=False, self_hosted=False): # An analytics event was created. assert 1 == self.analytics.count - assert (CirculationEvent.CM_HOLD_RELEASE == - self.analytics.event_type) + assert CirculationEvent.CM_HOLD_RELEASE == self.analytics.event_type def test__collect_event(self): # Test the _collect_event method, which gathers information @@ -1060,9 +1107,7 @@ def __init__(self): self.events = [] def collect_event(self, library, licensepool, name, neighborhood): - self.events.append( - (library, licensepool, name, neighborhood) - ) + self.events.append((library, licensepool, name, neighborhood)) return True analytics = MockAnalytics() @@ -1094,25 +1139,16 @@ def assert_event(inp, outp): # Worst case scenario -- the only information we can find is # the Library associated with the CirculationAPI object itself. - assert_event( - (None, None, 'event'), - (l1, None, 'event', None) - ) + assert_event((None, None, "event"), (l1, None, "event", None)) # If a LicensePool is provided, it's passed right through # to Analytics.collect_event. - assert_event( - (None, lp2, 'event'), - (l1, lp2, 'event', None) - ) + assert_event((None, lp2, "event"), (l1, lp2, "event", None)) # If a Patron is provided, their Library takes precedence over # the Library associated with the CirculationAPI (though this # shouldn't happen). - assert_event( - (p2, None, 'event'), - (l2, None, 'event', None) - ) + assert_event((p2, None, "event"), (l2, None, "event", None)) # We must run the rest of the tests in a simulated Flask request # context. @@ -1122,20 +1158,14 @@ def assert_event(inp, outp): # associated with the CirculationAPI (though this # shouldn't happen). flask.request.library = l2 - assert_event( - (None, None, 'event'), - (l2, None, 'event', None) - ) + assert_event((None, None, "event"), (l2, None, "event", None)) with app.test_request_context(): # The library of the request patron also takes precedence # over both (though again, this shouldn't happen). flask.request.library = l1 flask.request.patron = p2 - assert_event( - (None, None, 'event'), - (l2, None, 'event', None) - ) + assert_event((None, None, "event"), (l2, None, "event", None)) # Now let's check neighborhood gathering. p2.neighborhood = "Compton" @@ -1143,39 +1173,24 @@ def assert_event(inp, outp): # Neighborhood is only gathered if we explicitly ask for # it. flask.request.patron = p2 - assert_event( - (p2, None, 'event'), - (l2, None, 'event', None) - ) - assert_event( - (p2, None, 'event', False), - (l2, None, 'event', None) - ) - assert_event( - (p2, None, 'event', True), - (l2, None, 'event', "Compton") - ) + assert_event((p2, None, "event"), (l2, None, "event", None)) + assert_event((p2, None, "event", False), (l2, None, "event", None)) + assert_event((p2, None, "event", True), (l2, None, "event", "Compton")) # Neighborhood is not gathered if the request's active # patron is not the patron who triggered the event. - assert_event( - (p1, None, 'event', True), - (l1, None, 'event', None) - ) + assert_event((p1, None, "event", True), (l1, None, "event", None)) with app.test_request_context(): # Even if we ask for it, neighborhood is not gathered if # the data isn't available. flask.request.patron = p1 - assert_event( - (p1, None, 'event', True), - (l1, None, 'event', None) - ) + assert_event((p1, None, "event", True), (l1, None, "event", None)) # Finally, remove the mock Analytics object entirely and # verify that calling _collect_event doesn't cause a crash. api.analytics = None - api._collect_event(p1, None, 'event') + api._collect_event(p1, None, "event") def test_sync_bookshelf_ignores_local_loan_with_no_identifier(self): loan, ignore = self.pool.loan_to(self.patron) @@ -1212,7 +1227,9 @@ def test_sync_bookshelf_ignores_local_hold_with_no_identifier(self): # But we can still sync without crashing. self.sync_bookshelf() - def test_sync_bookshelf_with_old_local_loan_and_no_remote_loan_deletes_local_loan(self): + def test_sync_bookshelf_with_old_local_loan_and_no_remote_loan_deletes_local_loan( + self, + ): # Local loan that was created yesterday. loan, ignore = self.pool.loan_to(self.patron) loan.start = self.YESTERDAY @@ -1227,7 +1244,9 @@ def test_sync_bookshelf_with_old_local_loan_and_no_remote_loan_deletes_local_loa loans = self._db.query(Loan).all() assert [] == loans - def test_sync_bookshelf_with_new_local_loan_and_no_remote_loan_keeps_local_loan(self): + def test_sync_bookshelf_with_new_local_loan_and_no_remote_loan_keeps_local_loan( + self, + ): # Local loan that was just created. loan, ignore = self.pool.loan_to(self.patron) loan.start = utc_now() @@ -1254,8 +1273,10 @@ def patron_activity(self, patron, pin): return [], [], False circulation = IncompleteCirculationAPI( - self._db, self._default_library, - api_map={ExternalIntegration.BIBLIOTHECA : MockBibliothecaAPI}) + self._db, + self._default_library, + api_map={ExternalIntegration.BIBLIOTHECA: MockBibliothecaAPI}, + ) circulation.sync_bookshelf(self.patron, "1234") # The loan is still in the db, since there was an @@ -1276,8 +1297,10 @@ def patron_activity(self, patron, pin): return [], [], True circulation = CompleteCirculationAPI( - self._db, self._default_library, - api_map={ExternalIntegration.BIBLIOTHECA : MockBibliothecaAPI}) + self._db, + self._default_library, + api_map={ExternalIntegration.BIBLIOTHECA: MockBibliothecaAPI}, + ) circulation.sync_bookshelf(self.patron, "1234") # Now the loan is gone. @@ -1287,7 +1310,7 @@ def patron_activity(self, patron, pin): # Since we know our picture of the patron's bookshelf is up-to-date, # patron.last_loan_activity_sync has been set to the current time. now = utc_now() - assert (now-self.patron.last_loan_activity_sync).total_seconds() < 2 + assert (now - self.patron.last_loan_activity_sync).total_seconds() < 2 def test_sync_bookshelf_updates_local_loan_and_hold_with_modified_timestamps(self): # We have a local loan that supposedly runs from yesterday @@ -1299,15 +1322,20 @@ def test_sync_bookshelf_updates_local_loan_and_hold_with_modified_timestamps(sel # But the remote thinks the loan runs from today until two # weeks from today. self.circulation.add_remote_loan( - self.pool.collection, self.pool.data_source, self.identifier.type, - self.identifier.identifier, self.TODAY, self.IN_TWO_WEEKS + self.pool.collection, + self.pool.data_source, + self.identifier.type, + self.identifier.identifier, + self.TODAY, + self.IN_TWO_WEEKS, ) # Similar situation for this hold on a different LicensePool. edition, pool2 = self._edition( data_source_name=DataSource.BIBLIOTHECA, identifier_type=Identifier.BIBLIOTHECA_ID, - with_license_pool=True, collection=self.collection + with_license_pool=True, + collection=self.collection, ) hold, ignore = pool2.on_hold_to(self.patron) @@ -1316,9 +1344,13 @@ def test_sync_bookshelf_updates_local_loan_and_hold_with_modified_timestamps(sel hold.position = 10 self.circulation.add_remote_hold( - pool2.collection, pool2.data_source, pool2.identifier.type, - pool2.identifier.identifier, self.TODAY, self.IN_TWO_WEEKS, - 0 + pool2.collection, + pool2.data_source, + pool2.identifier.type, + pool2.identifier.identifier, + self.TODAY, + self.IN_TWO_WEEKS, + 0, ) self.circulation.sync_bookshelf(self.patron, "1234") @@ -1340,12 +1372,13 @@ def test_sync_bookshelf_applies_locked_delivery_mechanism_to_loan(self): ) pool = self._licensepool(None) self.circulation.add_remote_loan( - pool.collection, pool.data_source.name, + pool.collection, + pool.data_source.name, pool.identifier.type, pool.identifier.identifier, utc_now(), None, - locked_to=mechanism + locked_to=mechanism, ) self.circulation.sync_bookshelf(self.patron, "1234") @@ -1368,9 +1401,12 @@ def test_sync_bookshelf_respects_last_loan_activity_sync(self): # Little do we know that they just used a vendor website to # create a loan. self.circulation.add_remote_loan( - self.pool.collection, self.pool.data_source, - self.identifier.type, self.identifier.identifier, - self.YESTERDAY, self.IN_TWO_WEEKS + self.pool.collection, + self.pool.data_source, + self.identifier.type, + self.identifier.identifier, + self.YESTERDAY, + self.IN_TWO_WEEKS, ) # Syncing our loans with the remote won't actually do anything. @@ -1389,7 +1425,7 @@ def test_sync_bookshelf_respects_last_loan_activity_sync(self): # Once that happens, patron.last_loan_activity_sync is updated to # the current time. updated = self.patron.last_loan_activity_sync - assert (updated-now).total_seconds() < 2 + assert (updated - now).total_seconds() < 2 # It's also possible to force a sync even when one wouldn't # normally happen, by passing force=True into sync_bookshelf. @@ -1408,9 +1444,10 @@ def test_sync_bookshelf_respects_last_loan_activity_sync(self): def test_patron_activity(self): # Get a CirculationAPI that doesn't mock out its API's patron activity. circulation = CirculationAPI( - self._db, self._default_library, api_map={ - ExternalIntegration.BIBLIOTHECA : MockBibliothecaAPI - }) + self._db, + self._default_library, + api_map={ExternalIntegration.BIBLIOTHECA: MockBibliothecaAPI}, + ) mock_bibliotheca = circulation.api_for_collection[self.collection.id] data = sample_data("checkouts.xml", "bibliotheca") mock_bibliotheca.queue_response(200, content=data) @@ -1431,6 +1468,7 @@ def test_can_fulfill_without_loan(self): """Can a title can be fulfilled without an active loan? It depends on the BaseCirculationAPI implementation for that title's colelction. """ + class Mock(BaseCirculationAPI): def can_fulfill_without_loan(self, patron, pool, lpdm): return "yep" @@ -1438,9 +1476,7 @@ def can_fulfill_without_loan(self, patron, pool, lpdm): pool = self._licensepool(None) circulation = CirculationAPI(self._db, self._default_library) circulation.api_for_collection[pool.collection.id] = Mock() - assert ( - "yep" == - circulation.can_fulfill_without_loan(None, pool, object())) + assert "yep" == circulation.can_fulfill_without_loan(None, pool, object()) # If format data is missing or the BaseCirculationAPI cannot # be found, we assume the title cannot be fulfilled. @@ -1456,14 +1492,12 @@ def can_fulfill_without_loan(self, patron, pool, lpdm): class TestBaseCirculationAPI(DatabaseTest): - def test_default_notification_email_address(self): # Test the ability to get the default notification email address # for a patron or a library. self._default_library.setting( - Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS).value = ( - "help@library" - ) + Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS + ).value = "help@library" m = BaseCirculationAPI.default_notification_email_address assert "help@library" == m(self._default_library, None) assert "help@library" == m(self._patron(), None) @@ -1501,9 +1535,7 @@ def _library_authenticator(self, library): # capable of providing an address. assert None == api.patron_email_address(patron) assert patron.library == api._library_authenticator_called_with - assert isinstance( - api._library_authenticator_returned, LibraryAuthenticator - ) + assert isinstance(api._library_authenticator_returned, LibraryAuthenticator) # Now we're going to pass in our own LibraryAuthenticator, # which we've populated with mock authentication providers, @@ -1516,16 +1548,16 @@ def _library_authenticator(self, library): # so it's no help. class MockOAuth(object): NAME = "mock oauth" + def remote_patron_lookup(self, patron): self.called_with = patron raise NotImplementedError() + mock_oauth = MockOAuth() authenticator.register_oauth_provider(mock_oauth) - assert ( - None == - api.patron_email_address( - patron, library_authenticator=authenticator - )) + assert None == api.patron_email_address( + patron, library_authenticator=authenticator + ) # But we can verify that remote_patron_lookup was in fact # called. assert patron == mock_oauth.called_with @@ -1537,13 +1569,12 @@ class MockBasic(object): def remote_patron_lookup(self, patron): self.called_with = patron return PatronData(authorization_identifier="patron") + basic = MockBasic() authenticator.register_basic_auth_provider(basic) - assert ( - None == - api.patron_email_address( - patron, library_authenticator=authenticator - )) + assert None == api.patron_email_address( + patron, library_authenticator=authenticator + ) assert patron == basic.called_with # This basic authentication provider gives us the information @@ -1552,26 +1583,22 @@ class MockBasic(object): def remote_patron_lookup(self, patron): self.called_with = patron return PatronData(email_address="me@email") + basic = MockBasic() authenticator.basic_auth_provider = basic - assert ( - "me@email" == - api.patron_email_address( - patron, library_authenticator=authenticator - )) + assert "me@email" == api.patron_email_address( + patron, library_authenticator=authenticator + ) assert patron == basic.called_with class TestDeliveryMechanismInfo(DatabaseTest): - def test_apply(self): # Here's a LicensePool with one non-open-access delivery mechanism. pool = self._licensepool(None) assert False == pool.open_access - [mechanism] = [ - lpdm.delivery_mechanism for lpdm in pool.delivery_mechanisms - ] + [mechanism] = [lpdm.delivery_mechanism for lpdm in pool.delivery_mechanisms] assert Representation.EPUB_MEDIA_TYPE == mechanism.content_type assert DeliveryMechanism.ADOBE_DRM == mechanism.drm_scheme @@ -1591,7 +1618,8 @@ def test_apply(self): # This results in the addition of a new delivery mechanism to # the LicensePool. [new_mechanism] = [ - lpdm.delivery_mechanism for lpdm in pool.delivery_mechanisms + lpdm.delivery_mechanism + for lpdm in pool.delivery_mechanisms if lpdm.delivery_mechanism != mechanism ] assert Representation.PDF_MEDIA_TYPE == new_mechanism.content_type @@ -1607,20 +1635,25 @@ def test_apply(self): # real life, it's possible for this operation to reveal a new # *open-access* delivery mechanism for a LicensePool. link, new = pool.identifier.add_link( - Hyperlink.OPEN_ACCESS_DOWNLOAD, self._url, - pool.data_source, Representation.EPUB_MEDIA_TYPE + Hyperlink.OPEN_ACCESS_DOWNLOAD, + self._url, + pool.data_source, + Representation.EPUB_MEDIA_TYPE, ) info = DeliveryMechanismInfo( - Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM, - RightsStatus.CC0, link.resource + Representation.EPUB_MEDIA_TYPE, + DeliveryMechanism.NO_DRM, + RightsStatus.CC0, + link.resource, ) # Calling apply() on the loan we were using before will update # its associated LicensePoolDeliveryMechanism. info.apply(loan) [oa_lpdm] = [ - lpdm for lpdm in pool.delivery_mechanisms + lpdm + for lpdm in pool.delivery_mechanisms if lpdm.delivery_mechanism not in (mechanism, new_mechanism) ] assert oa_lpdm == loan.fulfillment @@ -1636,9 +1669,7 @@ def test_apply(self): class TestConfigurationFailures(DatabaseTest): - class MisconfiguredAPI(object): - def __init__(self, _db, collection): raise CannotLoadConfiguration("doomed!") @@ -1647,10 +1678,8 @@ def test_configuration_exception_is_stored(self): # CannotLoadConfiguration, the exception is stored with the # CirculationAPI rather than being propagated. - api_map = {self._default_collection.protocol : self.MisconfiguredAPI} - circulation = CirculationAPI( - self._db, self._default_library, api_map=api_map - ) + api_map = {self._default_collection.protocol: self.MisconfiguredAPI} + circulation = CirculationAPI(self._db, self._default_library, api_map=api_map) # Although the CirculationAPI was created, it has no functioning # APIs. @@ -1664,14 +1693,12 @@ def test_configuration_exception_is_stored(self): class TestFulfillmentInfo(DatabaseTest): - def test_as_response(self): # The default behavior of as_response is to do nothing # and let controller code turn the FulfillmentInfo # into a Flask Response. info = FulfillmentInfo( - self._default_collection, None, - None, None, None, None, None, None + self._default_collection, None, None, None, None, None, None, None ) assert None == info.as_response @@ -1686,6 +1713,7 @@ class MockAPIAwareFulfillmentInfo(APIAwareFulfillmentInfo): """An APIAwareFulfillmentInfo that implements do_fetch() by delegating to its API object. """ + def do_fetch(self): return self.api.do_fetch() @@ -1693,6 +1721,7 @@ class MockAPI(object): """An API class that sets a flag when do_fetch() is called. """ + def __init__(self, collection): self.collection = collection self.fetch_happened = False @@ -1715,8 +1744,11 @@ def make_info(self, api=None): # Create a MockAPIAwareFulfillmentInfo with # well-known mock values for its properties. return self.MockAPIAwareFulfillmentInfo( - api, self.mock_data_source_name, self.mock_identifier_type, - self.mock_identifier, self.mock_key + api, + self.mock_data_source_name, + self.mock_identifier_type, + self.mock_identifier, + self.mock_key, ) def test_constructor(self): diff --git a/tests/test_config.py b/tests/test_config.py index 136aadd6ec..fc356ce3ae 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,38 +1,34 @@ -from Crypto.PublicKey import RSA -from Crypto.Cipher import PKCS1_OAEP - -from collections import Counter import json +from collections import Counter +from Crypto.Cipher import PKCS1_OAEP +from Crypto.PublicKey import RSA + +from api.config import Configuration from core.config import Configuration as CoreConfiguration -from core.model import ( - ConfigurationSetting -) +from core.model import ConfigurationSetting from core.testing import DatabaseTest -from api.config import Configuration -class TestConfiguration(DatabaseTest): +class TestConfiguration(DatabaseTest): def test_key_pair(self): # Test the ability to create, replace, or look up a # public/private key pair in a ConfigurationSetting. - setting = ConfigurationSetting.sitewide( - self._db, Configuration.KEY_PAIR - ) + setting = ConfigurationSetting.sitewide(self._db, Configuration.KEY_PAIR) setting.value = "nonsense" # If you pass in a ConfigurationSetting that is missing its # value, or whose value is not a public key pair, a new key # pair is created. public_key, private_key = Configuration.key_pair(setting) - assert 'BEGIN PUBLIC KEY' in public_key - assert 'BEGIN RSA PRIVATE KEY' in private_key + assert "BEGIN PUBLIC KEY" in public_key + assert "BEGIN RSA PRIVATE KEY" in private_key assert [public_key, private_key] == setting.json_value setting.value = None public_key, private_key = Configuration.key_pair(setting) - assert 'BEGIN PUBLIC KEY' in public_key - assert 'BEGIN RSA PRIVATE KEY' in private_key + assert "BEGIN PUBLIC KEY" in public_key + assert "BEGIN RSA PRIVATE KEY" in private_key assert [public_key, private_key] == setting.json_value # If the setting has a good value already, the key pair is @@ -66,9 +62,11 @@ def test_collection_language_method_performs_estimate(self): library = self._default_library # We haven't set any of these values. - for key in [C.LARGE_COLLECTION_LANGUAGES, - C.SMALL_COLLECTION_LANGUAGES, - C.TINY_COLLECTION_LANGUAGES]: + for key in [ + C.LARGE_COLLECTION_LANGUAGES, + C.SMALL_COLLECTION_LANGUAGES, + C.TINY_COLLECTION_LANGUAGES, + ]: assert None == ConfigurationSetting.for_library(key, library).value # So how does this happen? @@ -84,10 +82,18 @@ def test_collection_language_method_performs_estimate(self): C.LARGE_COLLECTION_LANGUAGES, library ) assert ["eng"] == large_setting.json_value - assert [] == ConfigurationSetting.for_library( - C.SMALL_COLLECTION_LANGUAGES, library).json_value - assert [] == ConfigurationSetting.for_library( - C.TINY_COLLECTION_LANGUAGES, library).json_value + assert ( + [] + == ConfigurationSetting.for_library( + C.SMALL_COLLECTION_LANGUAGES, library + ).json_value + ) + assert ( + [] + == ConfigurationSetting.for_library( + C.TINY_COLLECTION_LANGUAGES, library + ).json_value + ) # We can change these values. large_setting.value = json.dumps(["spa", "jpn"]) @@ -107,14 +113,13 @@ def test_estimate_language_collection_for_library(self): # We thought we'd have big collections. old_settings = { - Configuration.LARGE_COLLECTION_LANGUAGES : ["spa", "fre"], - Configuration.SMALL_COLLECTION_LANGUAGES : ["chi"], - Configuration.TINY_COLLECTION_LANGUAGES : ["rus"], + Configuration.LARGE_COLLECTION_LANGUAGES: ["spa", "fre"], + Configuration.SMALL_COLLECTION_LANGUAGES: ["chi"], + Configuration.TINY_COLLECTION_LANGUAGES: ["rus"], } for key, value in list(old_settings.items()): - ConfigurationSetting.for_library( - key, library).value = json.dumps(value) + ConfigurationSetting.for_library(key, library).value = json.dumps(value) # But there's nothing in our database, so when we call # Configuration.estimate_language_collections_for_library... @@ -125,13 +130,19 @@ def test_estimate_language_collection_for_library(self): Configuration.LARGE_COLLECTION_LANGUAGES, library ).json_value - assert [] == ConfigurationSetting.for_library( - Configuration.SMALL_COLLECTION_LANGUAGES, library - ).json_value + assert ( + [] + == ConfigurationSetting.for_library( + Configuration.SMALL_COLLECTION_LANGUAGES, library + ).json_value + ) - assert [] == ConfigurationSetting.for_library( - Configuration.TINY_COLLECTION_LANGUAGES, library - ).json_value + assert ( + [] + == ConfigurationSetting.for_library( + Configuration.TINY_COLLECTION_LANGUAGES, library + ).json_value + ) def test_classify_holdings(self): @@ -149,10 +160,10 @@ def test_classify_holdings(self): # Otherwise, the classification of a collection depends on the # sheer number of items in that collection. Within a # classification, languages are ordered by holding size. - different_sizes = Counter(jpn=16000, fre=20000, spa=8000, - nav=6, ukr=4000, ira=1500) - assert ([['fre', 'jpn'], ['spa', 'ukr', 'ira'], ['nav']] == - m(different_sizes)) + different_sizes = Counter( + jpn=16000, fre=20000, spa=8000, nav=6, ukr=4000, ira=1500 + ) + assert [["fre", "jpn"], ["spa", "ukr", "ira"], ["nav"]] == m(different_sizes) def test_max_outstanding_fines(self): m = Configuration.max_outstanding_fines @@ -163,8 +174,7 @@ def test_max_outstanding_fines(self): # The maximum fine value is determined by this # ConfigurationSetting. setting = ConfigurationSetting.for_library( - Configuration.MAX_OUTSTANDING_FINES, - self._default_library + Configuration.MAX_OUTSTANDING_FINES, self._default_library ) # Any amount of fines is too much. @@ -180,4 +190,6 @@ def test_max_outstanding_fines(self): def test_default_opds_format(self): # Initializing the Configuration object modifies the corresponding # object in core, so that core code will behave appropriately. - assert Configuration.DEFAULT_OPDS_FORMAT == CoreConfiguration.DEFAULT_OPDS_FORMAT + assert ( + Configuration.DEFAULT_OPDS_FORMAT == CoreConfiguration.DEFAULT_OPDS_FORMAT + ) diff --git a/tests/test_controller.py b/tests/test_controller.py index a329019af1..ab7fd9157e 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -6,6 +6,7 @@ import os import random import time +import urllib.parse from contextlib import contextmanager from decimal import Decimal from time import mktime @@ -13,7 +14,6 @@ import feedparser import flask -import urllib.parse import pytest from flask import Response as FlaskResponse from flask import url_for @@ -128,35 +128,28 @@ from core.testing import DummyHTTPClient, MockRequestsResponse from core.user_profile import ProfileController, ProfileStorage from core.util.authentication_for_opds import AuthenticationForOPDSDocument -from core.util.datetime_helpers import ( - datetime_utc, - from_timestamp, - utc_now, -) +from core.util.datetime_helpers import datetime_utc, from_timestamp, utc_now from core.util.flask_util import Response from core.util.http import RemoteIntegrationException from core.util.opds_writer import OPDSFeed from core.util.problem_detail import ProblemDetail from core.util.string_helpers import base64 + class ControllerTest(VendorIDTest): """A test that requires a functional app server.""" # Authorization headers that will succeed (or fail) against the # SimpleAuthenticationProvider set up in ControllerTest.setup(). - valid_auth = 'Basic ' + base64.b64encode( - 'unittestuser:unittestpassword' - ) - invalid_auth = 'Basic ' + base64.b64encode('user1:password2') - valid_credentials = dict( - username="unittestuser", password="unittestpassword" - ) + valid_auth = "Basic " + base64.b64encode("unittestuser:unittestpassword") + invalid_auth = "Basic " + base64.b64encode("user1:password2") + valid_credentials = dict(username="unittestuser", password="unittestpassword") def setup_method(self): super(ControllerTest, self).setup_method() self.app = app - if not hasattr(self, 'setup_circulation_manager'): + if not hasattr(self, "setup_circulation_manager"): self.setup_circulation_manager = True # PRESERVE_CONTEXT_ON_EXCEPTION needs to be off in tests @@ -165,10 +158,10 @@ def setup_method(self): # from previous tests would cause flask to roll back the db # when you entered a new request context, deleting rows that # were created in the test setup. - app.config['PRESERVE_CONTEXT_ON_EXCEPTION'] = False + app.config["PRESERVE_CONTEXT_ON_EXCEPTION"] = False Configuration.instance[Configuration.INTEGRATIONS][ExternalIntegration.CDN] = { - "" : "http://cdn" + "": "http://cdn" } if self.setup_circulation_manager: @@ -180,7 +173,7 @@ def setup_method(self): def set_base_url(self, _db): base_url = ConfigurationSetting.sitewide(_db, Configuration.BASE_URL_KEY) - base_url.value = 'http://test-circulation-manager/' + base_url.value = "http://test-circulation-manager/" def circulation_manager_setup(self, _db): """Set up initial Library arrangements for this test. @@ -208,8 +201,7 @@ def circulation_manager_setup(self, _db): """ self.libraries = self.make_default_libraries(_db) self.collections = [ - self.make_default_collection(_db, library) - for library in self.libraries + self.make_default_collection(_db, library) for library in self.libraries ] self.default_patrons = {} @@ -227,24 +219,19 @@ def circulation_manager_setup(self, _db): self.authdata = AuthdataUtility.from_config(self.library) - self.manager = CirculationManager( - _db, testing=True - ) + self.manager = CirculationManager(_db, testing=True) # Set CirculationAPI and top-level lane for the default # library, for convenience in tests. - self.manager.d_circulation = self.manager.circulation_apis[ - self.library.id - ] - self.manager.d_top_level_lane = self.manager.top_level_lanes[ - self.library.id - ] + self.manager.d_circulation = self.manager.circulation_apis[self.library.id] + self.manager.d_top_level_lane = self.manager.top_level_lanes[self.library.id] self.controller = CirculationManagerController(self.manager) # Set a convenient default lane. [self.english_adult_fiction] = [ - x for x in self.library.lanes - if x.display_name=='Fiction' and x.languages==['eng'] + x + for x in self.library.lanes + if x.display_name == "Fiction" and x.languages == ["eng"] ] return self.manager @@ -254,24 +241,29 @@ def library_setup(self, library): _db = Session.object_session(library) # Create the patron used by the dummy authentication mechanism. default_patron, ignore = get_one_or_create( - _db, Patron, + _db, + Patron, library=library, authorization_identifier="unittestuser", - create_method_kwargs=dict( - external_identifier="unittestuser" - ) + create_method_kwargs=dict(external_identifier="unittestuser"), ) self.default_patrons[library] = default_patron # Create a simple authentication integration for this library, # unless it already has a way to authenticate patrons # (in which case we would just screw things up). - if not any([x for x in library.integrations if x.goal== - ExternalIntegration.PATRON_AUTH_GOAL]): + if not any( + [ + x + for x in library.integrations + if x.goal == ExternalIntegration.PATRON_AUTH_GOAL + ] + ): integration, ignore = create( - _db, ExternalIntegration, + _db, + ExternalIntegration, protocol="api.simple_authentication", - goal=ExternalIntegration.PATRON_AUTH_GOAL + goal=ExternalIntegration.PATRON_AUTH_GOAL, ) p = SimpleAuthenticationProvider integration.setting(p.TEST_IDENTIFIER).value = "unittestuser" @@ -280,9 +272,9 @@ def library_setup(self, library): library.integrations.append(integration) for k, v in [ - (Configuration.LARGE_COLLECTION_LANGUAGES, []), - (Configuration.SMALL_COLLECTION_LANGUAGES, ['eng']), - (Configuration.TINY_COLLECTION_LANGUAGES, ['spa','chi','fre']) + (Configuration.LARGE_COLLECTION_LANGUAGES, []), + (Configuration.SMALL_COLLECTION_LANGUAGES, ["eng"]), + (Configuration.TINY_COLLECTION_LANGUAGES, ["spa", "chi", "fre"]), ]: ConfigurationSetting.for_library(k, library).value = json.dumps(v) create_default_lanes(_db, library) @@ -295,8 +287,8 @@ def make_default_collection(self, _db, library): @contextmanager def request_context_with_library(self, route, *args, **kwargs): - if 'library' in kwargs: - library = kwargs.pop('library') + if "library" in kwargs: + library = kwargs.pop("library") else: library = self._default_library with self.app.test_request_context(route, *args, **kwargs) as c: @@ -316,8 +308,13 @@ def setup_method(self): super(CirculationControllerTest, self).setup_method() self.works = [] for (variable_name, title, author, language, fiction) in self.BOOKS: - work = self._work(title, author, language=language, fiction=fiction, - with_open_access_download=True) + work = self._work( + title, + author, + language=language, + fiction=fiction, + with_open_access_download=True, + ) setattr(self, variable_name, work) work.license_pools[0].collection = self.collection self.works.append(work) @@ -325,9 +322,7 @@ def setup_method(self): # Enable the audiobook entry point for the default library -- a lot of # tests verify that non-default entry points can be selected. - self._default_library.setting( - EntryPoint.ENABLED_SETTING - ).value = json.dumps( + self._default_library.setting(EntryPoint.ENABLED_SETTING).value = json.dumps( [EbooksEntryPoint.INTERNAL_NAME, AudiobooksEntryPoint.INTERNAL_NAME] ) @@ -345,10 +340,13 @@ def assert_bad_search_index_gives_problem_detail(self, test_function): response = test_function() assert 502 == response.status_code assert ( - "http://librarysimplified.org/terms/problem/remote-integration-failed" == - response.uri) - assert ('The search index for this site is not properly configured.' == - response.detail) + "http://librarysimplified.org/terms/problem/remote-integration-failed" + == response.uri + ) + assert ( + "The search index for this site is not properly configured." + == response.detail + ) self.manager.setup_external_search = old_setup self.manager._external_search = old_value @@ -362,8 +360,8 @@ def test_initialization(self): public, private = ConfigurationSetting.sitewide( self._db, Configuration.KEY_PAIR ).json_value - assert 'BEGIN PUBLIC KEY' in public - assert 'BEGIN RSA PRIVATE KEY' in private + assert "BEGIN PUBLIC KEY" in public + assert "BEGIN RSA PRIVATE KEY" in private def test_load_settings(self): # Here's a CirculationManager which we've been using for a while. @@ -405,24 +403,27 @@ def test_load_settings(self): # We also register a CustomIndexView for this new library. mock_custom_view = object() + @classmethod def mock_for_library(cls, incoming_library): if incoming_library == library: return mock_custom_view return None + old_for_library = CustomIndexView.for_library CustomIndexView.for_library = mock_for_library # We also set up some configuration settings that will # be loaded. ConfigurationSetting.sitewide( - self._db, Configuration.PATRON_WEB_HOSTNAMES).value = "http://sitewide/1234" + self._db, Configuration.PATRON_WEB_HOSTNAMES + ).value = "http://sitewide/1234" registry = self._external_integration( protocol="some protocol", goal=ExternalIntegration.DISCOVERY_GOAL ) ConfigurationSetting.for_library_and_externalintegration( - self._db, Registration.LIBRARY_REGISTRATION_WEB_CLIENT, - library, registry).value = "http://registration" + self._db, Registration.LIBRARY_REGISTRATION_WEB_CLIENT, library, registry + ).value = "http://registration" ConfigurationSetting.sitewide( self._db, Configuration.AUTHENTICATION_DOCUMENT_CACHE_TIME @@ -449,7 +450,7 @@ def mock_for_library(cls, incoming_library): # how to authenticate patrons of the new library. assert isinstance( manager.auth.library_authenticators[library.short_name], - LibraryAuthenticator + LibraryAuthenticator, ) # The ExternalSearch object has been reset. @@ -459,16 +460,19 @@ def mock_for_library(cls, incoming_library): assert isinstance(manager.oauth_controller, OAuthController) # So has the controller for the Device Management Protocol. - assert isinstance(manager.adobe_device_management, - DeviceManagementProtocolController) + assert isinstance( + manager.adobe_device_management, DeviceManagementProtocolController + ) # So has the SharecCollectionAPI. - assert isinstance(manager.shared_collection_api, - SharedCollectionAPI) + assert isinstance(manager.shared_collection_api, SharedCollectionAPI) # So have the patron web domains, and their paths have been # removed. - assert set(["http://sitewide", "http://registration"]) == manager.patron_web_domains + assert ( + set(["http://sitewide", "http://registration"]) + == manager.patron_web_domains + ) # The authentication document cache has been rebuilt with a # new max_age. @@ -483,23 +487,33 @@ def mock_for_library(cls, incoming_library): # The sitewide patron web domain can also be set to *. ConfigurationSetting.sitewide( - self._db, Configuration.PATRON_WEB_HOSTNAMES).value = "*" + self._db, Configuration.PATRON_WEB_HOSTNAMES + ).value = "*" self.manager.load_settings() assert set(["*", "http://registration"]) == manager.patron_web_domains # The sitewide patron web domain can have pipe separated domains, and will get spaces stripped ConfigurationSetting.sitewide( - self._db, Configuration.PATRON_WEB_HOSTNAMES).value = "https://1.com|http://2.com | http://subdomain.3.com|4.com" + self._db, Configuration.PATRON_WEB_HOSTNAMES + ).value = "https://1.com|http://2.com | http://subdomain.3.com|4.com" self.manager.load_settings() - assert set(["https://1.com", "http://2.com", "http://subdomain.3.com", "http://registration"]) == manager.patron_web_domains + assert ( + set( + [ + "https://1.com", + "http://2.com", + "http://subdomain.3.com", + "http://registration", + ] + ) + == manager.patron_web_domains + ) # Restore the CustomIndexView.for_library implementation CustomIndexView.for_library = old_for_library def test_exception_during_external_search_initialization_is_stored(self): - class BadSearch(CirculationManager): - @property def setup_search(self): raise Exception("doomed!") @@ -520,7 +534,8 @@ def test_exception_during_short_client_token_initialization_is_stored(self): # library. registry_integration = self._external_integration( protocol=ExternalIntegration.OPDS_REGISTRATION, - goal=ExternalIntegration.DISCOVERY_GOAL, libraries=[self.library] + goal=ExternalIntegration.DISCOVERY_GOAL, + libraries=[self.library], ) registry_integration.username = "something" registry_integration.set_setting(AuthdataUtility.VENDOR_ID_KEY, "vendorid") @@ -558,8 +573,8 @@ def test_sitewide_key_pair(self): # Calling sitewide_key_pair will create a new pair of keys. new_public, new_private = self.manager.sitewide_key_pair - assert 'BEGIN PUBLIC KEY' in new_public - assert 'BEGIN RSA PRIVATE KEY' in new_private + assert "BEGIN PUBLIC KEY" in new_public + assert "BEGIN RSA PRIVATE KEY" in new_private # The new values are stored in the appropriate # ConfigurationSetting. @@ -578,8 +593,10 @@ def test_annotator(self): facets = Facets.default(self._default_library) annotator = self.manager.annotator(lane, facets) assert isinstance(annotator, LibraryAnnotator) - assert (self.manager.circulation_apis[self._default_library.id] == - annotator.circulation) + assert ( + self.manager.circulation_apis[self._default_library.id] + == annotator.circulation + ) assert "All Books" == annotator.top_level_title() assert True == annotator.identifies_patrons @@ -607,13 +624,17 @@ class MockAnnotator(object): def __init__(self, *args, **kwargs): self.positional = args self.keyword = kwargs + annotator = self.manager.annotator( - lane, facets, "extra positional", - kw="extra keyword", annotator_class=MockAnnotator + lane, + facets, + "extra positional", + kw="extra keyword", + annotator_class=MockAnnotator, ) assert isinstance(annotator, MockAnnotator) - assert 'extra positional' == annotator.positional[-1] - assert 'extra keyword' == annotator.keyword.pop('kw') + assert "extra positional" == annotator.positional[-1] + assert "extra keyword" == annotator.keyword.pop("kw") # Now let's try more and more obscure ways of figuring out which # library should be used to build the LibraryAnnotator. @@ -657,6 +678,7 @@ class MockAdminSignInController(object): def authenticated_admin_from_request(self): return self.admin + admin = Admin() controller = MockAdminSignInController() @@ -767,6 +789,7 @@ def url_for(self, view, *args, **kwargs): # to see if it wants to disable caching. class MockFacets(BaseFacets): max_cache_age = None + kwargs_with_facets = dict(kwargs) kwargs_with_facets.update(_facets=MockFacets) url = manager.cdn_url_for("view", *args, **kwargs_with_facets) @@ -792,7 +815,6 @@ class MockFacets(BaseFacets): class TestBaseController(CirculationControllerTest): - def test_unscoped_session(self): """Compare to TestScopedSession.test_scoped_session to see @@ -822,10 +844,13 @@ def test_request_patron(self): # If not, authenticated_patron_from_request is called; it's # supposed to set flask.request.patron. o2 = object() + def set_patron(): flask.request.patron = o2 - mock = MagicMock(side_effect = set_patron, - return_value = "return value will be ignored") + + mock = MagicMock( + side_effect=set_patron, return_value="return value will be ignored" + ) self.controller.authenticated_patron_from_request = mock with self.app.test_request_context("/"): assert o2 == self.controller.request_patron @@ -844,8 +869,8 @@ def test_authenticated_patron_from_request(self): # No authorization header -> 401 error. with patch( - 'api.base_controller.BaseCirculationManagerController.authorization_header', - lambda x: None + "api.base_controller.BaseCirculationManagerController.authorization_header", + lambda x: None, ): with self.request_context_with_library("/"): result = self.controller.authenticated_patron_from_request() @@ -855,9 +880,10 @@ def test_authenticated_patron_from_request(self): # Exception contacting the authentication authority -> ProblemDetail def remote_failure(self, header): raise RemoteInitiatedServerError("argh", "service") + with patch( - 'api.base_controller.BaseCirculationManagerController.authenticated_patron', - remote_failure + "api.base_controller.BaseCirculationManagerController.authenticated_patron", + remote_failure, ): with self.request_context_with_library( "/", headers=dict(Authorization=self.valid_auth) @@ -871,8 +897,8 @@ def remote_failure(self, header): # Credentials provided but don't identify anyone in particular # -> 401 error. with patch( - 'api.base_controller.BaseCirculationManagerController.authenticated_patron', - lambda self, x: None + "api.base_controller.BaseCirculationManagerController.authenticated_patron", + lambda self, x: None, ): with self.request_context_with_library( "/", headers=dict(Authorization=self.valid_auth) @@ -896,14 +922,10 @@ def test_authenticated_patron_can_authenticate_with_expired_credentials(self): """ one_year_ago = utc_now() - datetime.timedelta(days=365) with self.request_context_with_library("/"): - patron = self.controller.authenticated_patron( - self.valid_credentials - ) + patron = self.controller.authenticated_patron(self.valid_credentials) patron.expires = one_year_ago - patron = self.controller.authenticated_patron( - self.valid_credentials - ) + patron = self.controller.authenticated_patron(self.valid_credentials) assert one_year_ago == patron.expires def test_authenticated_patron_correct_credentials(self): @@ -922,13 +944,15 @@ def test_authentication_sends_proper_headers(self): # Without quotes, some iOS versions don't recognize the header value. base_url = ConfigurationSetting.sitewide(self._db, Configuration.BASE_URL_KEY) - base_url.value = 'http://url' + base_url.value = "http://url" with self.request_context_with_library("/"): response = self.controller.authenticate() - assert response.headers['WWW-Authenticate'] == 'Basic realm="Library card"' + assert response.headers["WWW-Authenticate"] == 'Basic realm="Library card"' - with self.request_context_with_library("/", headers={"X-Requested-With": "XMLHttpRequest"}): + with self.request_context_with_library( + "/", headers={"X-Requested-With": "XMLHttpRequest"} + ): response = self.controller.authenticate() assert None == response.headers.get("WWW-Authenticate") @@ -945,9 +969,7 @@ def test_handle_conditional_request(self): # microseconds value of 'now'. now_datetime = now_datetime.replace(microsecond=random.randint(0, 999999)) - with self.app.test_request_context( - headers={"If-Modified-Since": now_string} - ): + with self.app.test_request_context(headers={"If-Modified-Since": now_string}): response = self.controller.handle_conditional_request(now_datetime) assert 304 == response.status_code @@ -955,13 +977,11 @@ def test_handle_conditional_request(self): # extent with the date-format spec. very_old = datetime_utc(2000, 1, 1) for value in [ - "Thu, 01 Aug 2019 10:00:40 -0000", - "Thu, 01 Aug 2019 10:00:40", - "01 Aug 2019 10:00:40", + "Thu, 01 Aug 2019 10:00:40 -0000", + "Thu, 01 Aug 2019 10:00:40", + "01 Aug 2019 10:00:40", ]: - with self.app.test_request_context( - headers={"If-Modified-Since": value} - ): + with self.app.test_request_context(headers={"If-Modified-Since": value}): response = self.controller.handle_conditional_request(very_old) assert 304 == response.status_code @@ -969,9 +989,7 @@ def test_handle_conditional_request(self): # the request is not a valid conditional request and the # method returns None. - with self.app.test_request_context( - headers={"If-Modified-Since": now_string} - ): + with self.app.test_request_context(headers={"If-Modified-Since": now_string}): # This request _would_ be a conditional request, but the # precondition fails: If-Modified-Since is earlier than # the 'last modified' date known by the server. @@ -1014,37 +1032,33 @@ def test_load_licensepools(self): data_source_name=DataSource.GUTENBERG, identifier_type=i1.type, identifier_id=i1.identifier, - with_license_pool = True, - collection=c1 + with_license_pool=True, + collection=c1, ) e2, lp2 = self._edition( data_source_name=DataSource.OVERDRIVE, identifier_type=i1.type, identifier_id=i1.identifier, - with_license_pool = True, - collection=c2 + with_license_pool=True, + collection=c2, ) e3, lp3 = self._edition( data_source_name=DataSource.BIBLIOTHECA, identifier_type=i1.type, identifier_id=i1.identifier, - with_license_pool = True, - collection=c3 + with_license_pool=True, + collection=c3, ) # The first collection also has a LicensePool for a totally # different Identifier. e4, lp4 = self._edition( - data_source_name=DataSource.GUTENBERG, - with_license_pool=True, - collection=c1 + data_source_name=DataSource.GUTENBERG, with_license_pool=True, collection=c1 ) # Same for the third collection e5, lp5 = self._edition( - data_source_name=DataSource.GUTENBERG, - with_license_pool=True, - collection=c3 + data_source_name=DataSource.GUTENBERG, with_license_pool=True, collection=c3 ) # Now let's try to load LicensePools for the first Identifier @@ -1059,7 +1073,7 @@ def test_load_licensepools(self): assert lp1 in loaded assert lp2 in loaded assert 2 == len(loaded) - assert all([lp.identifier==i1 for lp in loaded]) + assert all([lp.identifier == i1 for lp in loaded]) # Note that the LicensePool in c3 was not loaded, even though # the Identifier matches, because that collection is not @@ -1075,14 +1089,16 @@ def test_load_licensepools(self): self._default_library, "bad identifier type", i1.identifier ) assert NO_LICENSES.uri == problem_detail.uri - expect = "The item you're asking about (bad identifier type/%s) isn't in this collection." % i1.identifier + expect = ( + "The item you're asking about (bad identifier type/%s) isn't in this collection." + % i1.identifier + ) assert expect == problem_detail.detail # Try an identifier that would work except that it's not in a # Collection associated with the given Library. problem_detail = self.controller.load_licensepools( - self._default_library, lp5.identifier.type, - lp5.identifier.identifier + self._default_library, lp5.identifier.type, lp5.identifier.identifier ) assert NO_LICENSES.uri == problem_detail.uri @@ -1097,25 +1113,22 @@ def test_load_work(self): # Either identifier suffices to identify the Work. for i in [pool1.identifier, pool2.identifier]: with self.request_context_with_library("/"): - assert ( - work == - self.controller.load_work( - self._default_library, i.type, i.identifier - )) + assert work == self.controller.load_work( + self._default_library, i.type, i.identifier + ) # If a patron is authenticated, the requested Work must be # age-appropriate for that patron, or this method will return # a problem detail. headers = dict(Authorization=self.valid_auth) for retval, expect in ((True, work), (False, NOT_AGE_APPROPRIATE)): - work.age_appropriate_for_patron = MagicMock(return_value = retval) + work.age_appropriate_for_patron = MagicMock(return_value=retval) with self.request_context_with_library("/", headers=headers): - assert ( - expect == - self.controller.load_work( - self._default_library, pool1.identifier.type, - pool1.identifier.identifier - )) + assert expect == self.controller.load_work( + self._default_library, + pool1.identifier.type, + pool1.identifier.identifier, + ) work.age_appropriate_for_patron.called_with(self.default_patron) def test_load_licensepooldelivery(self): @@ -1174,10 +1187,7 @@ def test_apply_borrowing_policy_succeeds_for_unlimited_access_books(self): with self.request_context_with_library("/"): # Arrange patron = self.controller.authenticated_patron(self.valid_credentials) - work = self._work( - with_license_pool=True, - with_open_access_download=False - ) + work = self._work(with_license_pool=True, with_open_access_download=False) [pool] = work.license_pools pool.open_access = False pool.self_hosted = False @@ -1193,10 +1203,7 @@ def test_apply_borrowing_policy_succeeds_for_self_hosted_books(self): with self.request_context_with_library("/"): # Arrange patron = self.controller.authenticated_patron(self.valid_credentials) - work = self._work( - with_license_pool=True, - with_open_access_download=False - ) + work = self._work(with_license_pool=True, with_open_access_download=False) [pool] = work.license_pools pool.licenses_available = 0 pool.licenses_owned = 0 @@ -1217,8 +1224,7 @@ def test_apply_borrowing_policy_when_holds_prohibited(self): library.setting(library.ALLOW_HOLDS).value = "False" # This is an open-access work. - work = self._work(with_license_pool=True, - with_open_access_download=True) + work = self._work(with_license_pool=True, with_open_access_download=True) [pool] = work.license_pools pool.licenses_available = 0 assert True == pool.open_access @@ -1240,8 +1246,10 @@ def test_apply_borrowing_policy_for_age_inappropriate_book(self): # Set up lanes for different patron types. children_lane = self._lane() - children_lane.audiences = [Classifier.AUDIENCE_CHILDREN, - Classifier.AUDIENCE_YOUNG_ADULT] + children_lane.audiences = [ + Classifier.AUDIENCE_CHILDREN, + Classifier.AUDIENCE_YOUNG_ADULT, + ] children_lane.target_age = tuple_to_numericrange((9, 12)) children_lane.root_for_patron_type = ["child"] @@ -1252,13 +1260,11 @@ def test_apply_borrowing_policy_for_age_inappropriate_book(self): # This book is age-appropriate for anyone 13 years old or older. work = self._work(with_license_pool=True) work.audience = Classifier.AUDIENCE_CHILDREN - work.target_age = tuple_to_numericrange((13,15)) + work.target_age = tuple_to_numericrange((13, 15)) [pool] = work.license_pools with self.request_context_with_library("/"): - patron = self.controller.authenticated_patron( - self.valid_credentials - ) + patron = self.controller.authenticated_patron(self.valid_credentials) # This patron is restricted to a lane in which the 13-year-old # book would not appear. patron.external_type = "child" @@ -1269,7 +1275,7 @@ def test_apply_borrowing_policy_for_age_inappropriate_book(self): # If the lane is expanded to allow the book's age range, there's # no problem. - children_lane.target_age = tuple_to_numericrange((9,13)) + children_lane.target_age = tuple_to_numericrange((9, 13)) assert None == self.controller.apply_borrowing_policy(patron, pool) # Similarly if the patron has an external type @@ -1285,7 +1291,9 @@ def test_library_for_request(self): assert LIBRARY_NOT_FOUND == value with self.app.test_request_context("/"): - value = self.controller.library_for_request(self._default_library.short_name) + value = self.controller.library_for_request( + self._default_library.short_name + ) assert self._default_library == value assert self._default_library == flask.request.library @@ -1308,7 +1316,6 @@ def test_library_for_request_reloads_settings_if_necessary(self): problem = self.controller.library_for_request(new_name) assert LIBRARY_NOT_FOUND == problem - # Make the change. self._default_library.short_name = new_name self._db.commit() @@ -1342,9 +1349,7 @@ def test_load_lane(self): with self.request_context_with_library("/"): top_level = self.controller.load_lane(None) - expect = self.controller.manager.top_level_lanes[ - self._default_library.id - ] + expect = self.controller.manager.top_level_lanes[self._default_library.id] # expect and top_level are different ORM objects # representing the same lane. (They're different objects @@ -1361,15 +1366,15 @@ def test_load_lane(self): # If a lane cannot be looked up by ID, a problem detail # is returned. - for bad_id in ('nosuchlane', -1): + for bad_id in ("nosuchlane", -1): not_found = self.controller.load_lane(bad_id) assert isinstance(not_found, ProblemDetail) assert not_found.uri == NO_SUCH_LANE.uri assert ( - "Lane %s does not exist or is not associated with library %s" % ( - bad_id, self._default_library.id - ) == - not_found.detail) + "Lane %s does not exist or is not associated with library %s" + % (bad_id, self._default_library.id) + == not_found.detail + ) # If the requested lane exists but is not visible to the # authenticated patron, the server _acts_ like the lane does @@ -1395,18 +1400,18 @@ def test_load_lane(self): class TestIndexController(CirculationControllerTest): - def test_simple_redirect(self): - with self.app.test_request_context('/'): + with self.app.test_request_context("/"): flask.request.library = self.library response = self.manager.index_controller() assert 302 == response.status_code - assert "http://cdn/default/groups/" == response.headers['location'] + assert "http://cdn/default/groups/" == response.headers["location"] def test_custom_index_view(self): """If a custom index view is registered for a library, it is called instead of the normal IndexController code. """ + class MockCustomIndexView(object): def __call__(self, library, annotator): self.called_with = (library, annotator) @@ -1420,14 +1425,17 @@ def __call__(self, library, annotator): # Mock CirculationManager.annotator so it's easy to check # that it was called. mock_annotator = object() + def make_mock_annotator(lane): assert lane == None return mock_annotator + self.manager.annotator = make_mock_annotator # Make a request, and the custom index is invoked. with self.request_context_with_library( - "/", headers=dict(Authorization=self.invalid_auth)): + "/", headers=dict(Authorization=self.invalid_auth) + ): response = self.manager.index_controller() assert "fake response" == response @@ -1448,66 +1456,85 @@ def test_authenticated_patron_root_lane(self): self.default_patron.external_type = "1" with self.request_context_with_library( - "/", headers=dict(Authorization=self.invalid_auth)): + "/", headers=dict(Authorization=self.invalid_auth) + ): response = self.manager.index_controller() assert 401 == response.status_code with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): response = self.manager.index_controller() assert 302 == response.status_code - assert ("http://cdn/default/groups/%s" % root_1.id == - response.headers['location']) + assert ( + "http://cdn/default/groups/%s" % root_1.id + == response.headers["location"] + ) self.default_patron.external_type = "2" with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): response = self.manager.index_controller() assert 302 == response.status_code - assert "http://cdn/default/groups/%s" % root_1.id == response.headers['location'] + assert ( + "http://cdn/default/groups/%s" % root_1.id + == response.headers["location"] + ) self.default_patron.external_type = "3" with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): response = self.manager.index_controller() assert 302 == response.status_code - assert "http://cdn/default/groups/%s" % root_2.id == response.headers['location'] + assert ( + "http://cdn/default/groups/%s" % root_2.id + == response.headers["location"] + ) # Patrons with a different type get sent to the top-level lane. - self.default_patron.external_type = '4' + self.default_patron.external_type = "4" with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): response = self.manager.index_controller() assert 302 == response.status_code - assert "http://cdn/default/groups/" == response.headers['location'] + assert "http://cdn/default/groups/" == response.headers["location"] # Patrons with no type get sent to the top-level lane. self.default_patron.external_type = None with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): response = self.manager.index_controller() assert 302 == response.status_code - assert "http://cdn/default/groups/" == response.headers['location'] + assert "http://cdn/default/groups/" == response.headers["location"] def test_authentication_document(self): # Test the ability to retrieve an Authentication For OPDS document. library_name = self.library.short_name with self.request_context_with_library( - "/", headers=dict(Authorization=self.invalid_auth)): + "/", headers=dict(Authorization=self.invalid_auth) + ): response = self.manager.index_controller.authentication_document() assert 200 == response.status_code - assert AuthenticationForOPDSDocument.MEDIA_TYPE == response.headers['Content-Type'] + assert ( + AuthenticationForOPDSDocument.MEDIA_TYPE + == response.headers["Content-Type"] + ) data = response.get_data(as_text=True) assert self.manager.auth.create_authentication_document() == data # Make sure we got the A4OPDS document for the right library. doc = json.loads(data) - assert library_name == doc['title'] + assert library_name == doc["title"] # Currently, the authentication document cache is disabled by default. self.manager.authentication_for_opds_documents[library_name] = "Cached value" with self.request_context_with_library( - "/", headers=dict(Authorization=self.invalid_auth)): + "/", headers=dict(Authorization=self.invalid_auth) + ): response = self.manager.index_controller.authentication_document() assert "Cached value" != response.get_data(as_text=True) @@ -1516,51 +1543,54 @@ def test_authentication_document(self): cached_value = json.dumps(dict(key="Cached document")) self.manager.authentication_for_opds_documents[library_name] = cached_value with self.request_context_with_library( - "/?debug", headers=dict(Authorization=self.invalid_auth)): + "/?debug", headers=dict(Authorization=self.invalid_auth) + ): response = self.manager.index_controller.authentication_document() assert cached_value == response.get_data(as_text=True) # Note that WSGI debugging data was not provided, even # though we requested it, since WSGI debugging is # disabled. - assert '_debug' not in response.get_data(as_text=True) + assert "_debug" not in response.get_data(as_text=True) # When WSGI debugging is enabled and requested, an # authentication document includes some extra information in a # special '_debug' section. self.manager.wsgi_debug = True with self.request_context_with_library( - "/?debug", headers=dict(Authorization=self.invalid_auth)): + "/?debug", headers=dict(Authorization=self.invalid_auth) + ): response = self.manager.index_controller.authentication_document() doc = json.loads(response.data) - assert doc['key'] == 'Cached document' - debug = doc['_debug'] - assert all(x in debug for x in ('url', 'cache', 'environ')) + assert doc["key"] == "Cached document" + debug = doc["_debug"] + assert all(x in debug for x in ("url", "cache", "environ")) # WSGI debugging is not provided unless requested. with self.request_context_with_library( - "/", headers=dict(Authorization=self.invalid_auth)): + "/", headers=dict(Authorization=self.invalid_auth) + ): response = self.manager.index_controller.authentication_document() - assert '_debug' not in response.get_data(as_text=True) + assert "_debug" not in response.get_data(as_text=True) def test_public_key_integration_document(self): - base_url = ConfigurationSetting.sitewide(self._db, Configuration.BASE_URL_KEY).value + base_url = ConfigurationSetting.sitewide( + self._db, Configuration.BASE_URL_KEY + ).value # When a sitewide key pair exists (which should be all the # time), all of its data is included. - key_setting = ConfigurationSetting.sitewide( - self._db, Configuration.KEY_PAIR - ) - key_setting.value = json.dumps(['public key', 'private key']) - with self.app.test_request_context('/'): + key_setting = ConfigurationSetting.sitewide(self._db, Configuration.KEY_PAIR) + key_setting.value = json.dumps(["public key", "private key"]) + with self.app.test_request_context("/"): response = self.manager.index_controller.public_key_document() assert 200 == response.status_code - assert 'application/opds+json' == response.headers.get('Content-Type') + assert "application/opds+json" == response.headers.get("Content-Type") data = json.loads(response.get_data(as_text=True)) - assert 'RSA' == data.get('public_key', {}).get('type') - assert 'public key' == data.get('public_key', {}).get('value') + assert "RSA" == data.get("public_key", {}).get("type") + assert "public key" == data.get("public_key", {}).get("value") # If there is no sitewide key pair (which should never # happen), a new one is created. Library-specific public keys @@ -1568,28 +1598,30 @@ def test_public_key_integration_document(self): key_setting.value = None ConfigurationSetting.for_library( Configuration.KEY_PAIR, self.library - ).value = 'ignore me' + ).value = "ignore me" - with self.app.test_request_context('/'): + with self.app.test_request_context("/"): response = self.manager.index_controller.public_key_document() assert 200 == response.status_code - assert 'application/opds+json' == response.headers.get('Content-Type') + assert "application/opds+json" == response.headers.get("Content-Type") data = json.loads(response.get_data(as_text=True)) - assert 'http://test-circulation-manager/' == data.get('id') - key = data.get('public_key') - assert 'RSA' == key['type'] - assert 'BEGIN PUBLIC KEY' in key['value'] + assert "http://test-circulation-manager/" == data.get("id") + key = data.get("public_key") + assert "RSA" == key["type"] + assert "BEGIN PUBLIC KEY" in key["value"] -class TestMultipleLibraries(CirculationControllerTest): +class TestMultipleLibraries(CirculationControllerTest): def make_default_libraries(self, _db): return [self._library() for x in range(2)] def make_default_collection(self, _db, library): collection, ignore = get_one_or_create( - _db, Collection, name=self._str + " (for multi-library test)", + _db, + Collection, + name=self._str + " (for multi-library test)", ) collection.create_external_integration(ExternalIntegration.OPDS_IMPORT) library.collections.append(collection) @@ -1604,12 +1636,16 @@ def test_authentication(self): for library in self.libraries: headers = dict(Authorization=self.valid_auth) with self.request_context_with_library( - "/", headers=headers, library=library): + "/", headers=headers, library=library + ): patron = self.manager.loans.authenticated_patron_from_request() assert library == patron.library response = self.manager.index_controller() - assert ("http://cdn/%s/groups/" % library.short_name == - response.headers['location']) + assert ( + "http://cdn/%s/groups/" % library.short_name + == response.headers["location"] + ) + class TestLoanController(CirculationControllerTest): def setup_method(self): @@ -1617,8 +1653,10 @@ def setup_method(self): self.pool = self.english_1.license_pools[0] [self.mech1] = self.pool.delivery_mechanisms self.mech2 = self.pool.set_delivery_mechanism( - Representation.PDF_MEDIA_TYPE, DeliveryMechanism.NO_DRM, - RightsStatus.CC_BY, None + Representation.PDF_MEDIA_TYPE, + DeliveryMechanism.NO_DRM, + RightsStatus.CC_BY, + None, ) self.edition = self.pool.presentation_edition self.data_source = self.edition.data_source @@ -1644,12 +1682,15 @@ def test_can_fulfill_without_loan(self): # okay. class MockLibraryAuthenticator(object): identifies_individuals = False + self.manager.auth.library_authenticators[ self._default_library.short_name ] = MockLibraryAuthenticator() + def mock_can_fulfill_without_loan(patron, pool, lpdm): self.called_with = (patron, pool, lpdm) return True + with self.request_context_with_library("/"): self.manager.loans.circulation.can_fulfill_without_loan = ( mock_can_fulfill_without_loan @@ -1663,9 +1704,10 @@ def test_patron_circulation_retrieval(self): """ # Give the Work a second LicensePool. edition, other_pool = self._edition( - with_open_access_download=True, with_license_pool=True, + with_open_access_download=True, + with_license_pool=True, data_source_name=DataSource.BIBLIOTHECA, - collection=self.pool.collection + collection=self.pool.collection, ) other_pool.identifier = self.identifier other_pool.work = self.pool.work @@ -1675,42 +1717,37 @@ def test_patron_circulation_retrieval(self): ) with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.loans.authenticated_patron_from_request() # Without a loan or a hold, nothing is returned. # No loans. - result = self.manager.loans.get_patron_loan( - self.default_patron, pools - ) + result = self.manager.loans.get_patron_loan(self.default_patron, pools) assert (None, None) == result # No holds. - result = self.manager.loans.get_patron_hold( - self.default_patron, pools - ) + result = self.manager.loans.get_patron_hold(self.default_patron, pools) assert (None, None) == result # When there's a loan, we retrieve it. loan, newly_created = self.pool.loan_to(self.default_patron) - result = self.manager.loans.get_patron_loan( - self.default_patron, pools - ) + result = self.manager.loans.get_patron_loan(self.default_patron, pools) assert (loan, self.pool) == result # When there's a hold, we retrieve it. hold, newly_created = other_pool.on_hold_to(self.default_patron) - result = self.manager.loans.get_patron_hold( - self.default_patron, pools - ) + result = self.manager.loans.get_patron_hold(self.default_patron, pools) assert (hold, other_pool) == result def test_borrow_success(self): with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.loans.authenticated_patron_from_request() response = self.manager.loans.borrow( - self.identifier.type, self.identifier.identifier) + self.identifier.type, self.identifier.identifier + ) # A loan has been created for this license pool. loan = get_one(self._db, Loan, license_pool=self.pool) @@ -1722,9 +1759,12 @@ def test_borrow_success(self): # to fulfill the license. assert 201 == response.status_code feed = feedparser.parse(response.get_data()) - [entry] = feed['entries'] - fulfillment_links = [x['href'] for x in entry['links'] - if x['rel'] == OPDSFeed.ACQUISITION_REL] + [entry] = feed["entries"] + fulfillment_links = [ + x["href"] + for x in entry["links"] + if x["rel"] == OPDSFeed.ACQUISITION_REL + ] assert self.mech1.resource is not None @@ -1734,11 +1774,16 @@ def test_borrow_success(self): fulfillable_mechanism = self.mech1 self._db.commit() - expects = [url_for('fulfill', - license_pool_id=self.pool.id, - mechanism_id=mech.delivery_mechanism.id, - library_short_name=self.library.short_name, - _external=True) for mech in [self.mech1, self.mech2]] + expects = [ + url_for( + "fulfill", + license_pool_id=self.pool.id, + mechanism_id=mech.delivery_mechanism.id, + library_short_name=self.library.short_name, + _external=True, + ) + for mech in [self.mech1, self.mech2] + ] assert set(expects) == set(fulfillment_links) # Make sure the first delivery mechanism has the data necessary @@ -1749,13 +1794,17 @@ def test_borrow_success(self): # Now let's try to fulfill the loan using the first delivery mechanism. response = self.manager.loans.fulfill( - self.pool.id, fulfillable_mechanism.delivery_mechanism.id, + self.pool.id, + fulfillable_mechanism.delivery_mechanism.id, ) if isinstance(response, ProblemDetail): j, status, headers = response.response raise Exception(repr(j)) assert 302 == response.status_code - assert fulfillable_mechanism.resource.representation.public_url == response.headers.get("Location") + assert ( + fulfillable_mechanism.resource.representation.public_url + == response.headers.get("Location") + ) # The mechanism we used has been registered with the loan. assert fulfillable_mechanism == loan.fulfillment @@ -1774,16 +1823,15 @@ def test_borrow_success(self): content_link=fulfillable_mechanism.resource.url, content_type=fulfillable_mechanism.resource.representation.media_type, content=None, - content_expires=None) + content_expires=None, + ) # Now that we've set a mechanism, we can fulfill the loan # again without specifying a mechanism. self.manager.d_circulation.queue_fulfill(self.pool, fulfillment) http.queue_response(200, content="I am an ACSM file") - response = self.manager.loans.fulfill( - self.pool.id, do_get=http.do_get - ) + response = self.manager.loans.fulfill(self.pool.id, do_get=http.do_get) assert 200 == response.status_code assert "I am an ACSM file" == response.get_data(as_text=True) assert http.requests == [fulfillable_mechanism.resource.url] @@ -1795,16 +1843,18 @@ def test_borrow_success(self): ) assert 409 == response.status_code - assert "You already fulfilled this loan as application/epub+zip (DRM Scheme 1), you can't also do it as application/pdf (DRM Scheme 2)" in response.detail + assert ( + "You already fulfilled this loan as application/epub+zip (DRM Scheme 1), you can't also do it as application/pdf (DRM Scheme 2)" + in response.detail + ) # If the remote server fails, we get a problem detail. def doomed_get(url, headers, **kwargs): raise RemoteIntegrationException("fulfill service", "Error!") + self.manager.d_circulation.queue_fulfill(self.pool, fulfillment) - response = self.manager.loans.fulfill( - self.pool.id, do_get=doomed_get - ) + response = self.manager.loans.fulfill(self.pool.id, do_get=doomed_get) assert isinstance(response, ProblemDetail) assert 502 == response.status_code @@ -1815,26 +1865,29 @@ def test_borrow_and_fulfill_with_streaming_delivery_mechanism(self): pool = work.license_pools[0] pool.open_access = False streaming_mech = pool.set_delivery_mechanism( - DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, DeliveryMechanism.OVERDRIVE_DRM, - RightsStatus.IN_COPYRIGHT, None + DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, + DeliveryMechanism.OVERDRIVE_DRM, + RightsStatus.IN_COPYRIGHT, + None, ) identifier = edition.primary_identifier with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.loans.authenticated_patron_from_request() self.manager.d_circulation.queue_checkout( pool, LoanInfo( - pool.collection, pool.data_source.name, + pool.collection, + pool.data_source.name, pool.identifier.type, pool.identifier.identifier, utc_now(), utc_now() + datetime.timedelta(seconds=3600), - ) + ), ) - response = self.manager.loans.borrow( - identifier.type, identifier.identifier) + response = self.manager.loans.borrow(identifier.type, identifier.identifier) # A loan has been created for this license pool. loan = get_one(self._db, Loan, license_pool=pool) @@ -1846,35 +1899,45 @@ def test_borrow_and_fulfill_with_streaming_delivery_mechanism(self): # to fulfill the license. assert 201 == response.status_code feed = feedparser.parse(response.get_data()) - [entry] = feed['entries'] - fulfillment_links = [x['href'] for x in entry['links'] - if x['rel'] == OPDSFeed.ACQUISITION_REL] + [entry] = feed["entries"] + fulfillment_links = [ + x["href"] + for x in entry["links"] + if x["rel"] == OPDSFeed.ACQUISITION_REL + ] [mech1, mech2] = sorted( pool.delivery_mechanisms, - key=lambda x: x.delivery_mechanism.is_streaming + key=lambda x: x.delivery_mechanism.is_streaming, ) streaming_mechanism = mech2 - expects = [url_for('fulfill', - license_pool_id=pool.id, - mechanism_id=mech.delivery_mechanism.id, - library_short_name=self.library.short_name, - _external=True) for mech in [mech1, mech2]] + expects = [ + url_for( + "fulfill", + license_pool_id=pool.id, + mechanism_id=mech.delivery_mechanism.id, + library_short_name=self.library.short_name, + _external=True, + ) + for mech in [mech1, mech2] + ] assert set(expects) == set(fulfillment_links) # Now let's try to fulfill the loan using the streaming mechanism. self.manager.d_circulation.queue_fulfill( pool, FulfillmentInfo( - pool.collection, pool.data_source.name, + pool.collection, + pool.data_source.name, pool.identifier.type, pool.identifier.identifier, "http://streaming-content-link", - Representation.TEXT_HTML_MEDIA_TYPE + DeliveryMechanism.STREAMING_PROFILE, + Representation.TEXT_HTML_MEDIA_TYPE + + DeliveryMechanism.STREAMING_PROFILE, None, None, - ) + ), ) response = self.manager.loans.fulfill( pool.id, streaming_mechanism.delivery_mechanism.id @@ -1882,18 +1945,24 @@ def test_borrow_and_fulfill_with_streaming_delivery_mechanism(self): # We get an OPDS entry. assert 200 == response.status_code - opds_entries = feedparser.parse(response.response[0])['entries'] + opds_entries = feedparser.parse(response.response[0])["entries"] assert 1 == len(opds_entries) - links = opds_entries[0]['links'] + links = opds_entries[0]["links"] # The entry includes one fulfill link. - fulfill_links = [link for link in links if link['rel'] == "http://opds-spec.org/acquisition"] + fulfill_links = [ + link + for link in links + if link["rel"] == "http://opds-spec.org/acquisition" + ] assert 1 == len(fulfill_links) - assert (Representation.TEXT_HTML_MEDIA_TYPE + DeliveryMechanism.STREAMING_PROFILE == - fulfill_links[0]['type']) - assert "http://streaming-content-link" == fulfill_links[0]['href'] - + assert ( + Representation.TEXT_HTML_MEDIA_TYPE + + DeliveryMechanism.STREAMING_PROFILE + == fulfill_links[0]["type"] + ) + assert "http://streaming-content-link" == fulfill_links[0]["href"] # The mechanism has not been set, since fulfilling a streaming # mechanism does not lock in the format. @@ -1906,7 +1975,8 @@ def test_borrow_and_fulfill_with_streaming_delivery_mechanism(self): self.manager.d_circulation.queue_fulfill( pool, FulfillmentInfo( - pool.collection, pool.data_source.name, + pool.collection, + pool.data_source.name, pool.identifier.type, pool.identifier.identifier, "http://other-content-link", @@ -1927,38 +1997,47 @@ def test_borrow_and_fulfill_with_streaming_delivery_mechanism(self): self.manager.d_circulation.queue_fulfill( pool, FulfillmentInfo( - pool.collection, pool.data_source.name, + pool.collection, + pool.data_source.name, pool.identifier.type, pool.identifier.identifier, "http://streaming-content-link", - Representation.TEXT_HTML_MEDIA_TYPE + DeliveryMechanism.STREAMING_PROFILE, + Representation.TEXT_HTML_MEDIA_TYPE + + DeliveryMechanism.STREAMING_PROFILE, None, None, - ) + ), ) response = self.manager.loans.fulfill( pool.id, streaming_mechanism.delivery_mechanism.id ) assert 200 == response.status_code - opds_entries = feedparser.parse(response.response[0])['entries'] + opds_entries = feedparser.parse(response.response[0])["entries"] assert 1 == len(opds_entries) - links = opds_entries[0]['links'] + links = opds_entries[0]["links"] - fulfill_links = [link for link in links if link['rel'] == "http://opds-spec.org/acquisition"] + fulfill_links = [ + link + for link in links + if link["rel"] == "http://opds-spec.org/acquisition" + ] assert 1 == len(fulfill_links) - assert (Representation.TEXT_HTML_MEDIA_TYPE + DeliveryMechanism.STREAMING_PROFILE == - fulfill_links[0]['type']) - assert "http://streaming-content-link" == fulfill_links[0]['href'] + assert ( + Representation.TEXT_HTML_MEDIA_TYPE + + DeliveryMechanism.STREAMING_PROFILE + == fulfill_links[0]["type"] + ) + assert "http://streaming-content-link" == fulfill_links[0]["href"] def test_borrow_nonexistent_delivery_mechanism(self): with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.loans.authenticated_patron_from_request() response = self.manager.loans.borrow( - self.identifier.type, self.identifier.identifier, - -100 + self.identifier.type, self.identifier.identifier, -100 ) assert BAD_DELIVERY_MECHANISM == response @@ -1976,24 +2055,25 @@ def test_borrow_creates_hold_when_no_available_copies(self): pool.open_access = False with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.loans.authenticated_patron_from_request() - self.manager.d_circulation.queue_checkout( - pool, NoAvailableCopies() - ) + self.manager.d_circulation.queue_checkout(pool, NoAvailableCopies()) self.manager.d_circulation.queue_hold( pool, HoldInfo( - pool.collection, pool.data_source.name, + pool.collection, + pool.data_source.name, pool.identifier.type, pool.identifier.identifier, utc_now(), utc_now() + datetime.timedelta(seconds=3600), 1, - ) + ), ) response = self.manager.loans.borrow( - pool.identifier.type, pool.identifier.identifier) + pool.identifier.type, pool.identifier.identifier + ) assert 201 == response.status_code # A hold has been created for this license pool. @@ -2009,12 +2089,14 @@ def test_borrow_nolicenses(self): ) with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.loans.authenticated_patron_from_request() self.manager.d_circulation.queue_checkout(pool, NoLicenses()) response = self.manager.loans.borrow( - pool.identifier.type, pool.identifier.identifier) + pool.identifier.type, pool.identifier.identifier + ) assert 404 == response.status_code assert NOT_FOUND_ON_REMOTE == response @@ -2035,23 +2117,25 @@ def test_borrow_creates_local_hold_if_remote_hold_exists(self): pool.open_access = False with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.loans.authenticated_patron_from_request() - self.manager.d_circulation.queue_checkout( - pool, AlreadyOnHold() - ) + self.manager.d_circulation.queue_checkout(pool, AlreadyOnHold()) self.manager.d_circulation.queue_hold( - pool, HoldInfo( - pool.collection, pool.data_source.name, + pool, + HoldInfo( + pool.collection, + pool.data_source.name, pool.identifier.type, pool.identifier.identifier, utc_now(), utc_now() + datetime.timedelta(seconds=3600), 1, - ) + ), ) response = self.manager.loans.borrow( - pool.identifier.type, pool.identifier.identifier) + pool.identifier.type, pool.identifier.identifier + ) assert 201 == response.status_code # A hold has been created for this license pool. @@ -2059,40 +2143,42 @@ def test_borrow_creates_local_hold_if_remote_hold_exists(self): assert hold != None def test_borrow_fails_when_work_not_present_on_remote(self): - threem_edition, pool = self._edition( - with_open_access_download=False, - data_source_name=DataSource.THREEM, - identifier_type=Identifier.THREEM_ID, - with_license_pool=True, - ) - threem_book = self._work( - presentation_edition=threem_edition, - ) - pool.licenses_available = 1 - pool.open_access = False - - with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): - self.manager.loans.authenticated_patron_from_request() - self.manager.d_circulation.queue_checkout( - pool, NotFoundOnRemote() - ) - response = self.manager.loans.borrow( - pool.identifier.type, pool.identifier.identifier) - assert 404 == response.status_code - assert "http://librarysimplified.org/terms/problem/not-found-on-remote" == response.uri + threem_edition, pool = self._edition( + with_open_access_download=False, + data_source_name=DataSource.THREEM, + identifier_type=Identifier.THREEM_ID, + with_license_pool=True, + ) + threem_book = self._work( + presentation_edition=threem_edition, + ) + pool.licenses_available = 1 + pool.open_access = False + + with self.request_context_with_library( + "/", headers=dict(Authorization=self.valid_auth) + ): + self.manager.loans.authenticated_patron_from_request() + self.manager.d_circulation.queue_checkout(pool, NotFoundOnRemote()) + response = self.manager.loans.borrow( + pool.identifier.type, pool.identifier.identifier + ) + assert 404 == response.status_code + assert ( + "http://librarysimplified.org/terms/problem/not-found-on-remote" + == response.uri + ) def test_borrow_succeeds_when_work_already_checked_out(self): # An attempt to borrow a book that's already on loan is # treated as success without even going to the remote API. loan, _ignore = get_one_or_create( - self._db, Loan, license_pool=self.pool, - patron=self.default_patron + self._db, Loan, license_pool=self.pool, patron=self.default_patron ) - with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.loans.authenticated_patron_from_request() # Set it up that going to the remote API would raise an @@ -2101,29 +2187,47 @@ def test_borrow_succeeds_when_work_already_checked_out(self): circulation.queue_checkout(loan.license_pool, NotFoundOnRemote()) mock_remote = circulation.api_for_license_pool(loan.license_pool) - assert 1 == len(mock_remote.responses['checkout']) + assert 1 == len(mock_remote.responses["checkout"]) response = self.manager.loans.borrow( - self.identifier.type, self.identifier.identifier) + self.identifier.type, self.identifier.identifier + ) # No checkout request was actually made to the remote. - assert 1 == len(mock_remote.responses['checkout']) + assert 1 == len(mock_remote.responses["checkout"]) # We got an OPDS entry that includes at least one # fulfillment link, which is what we expect when we ask # about an active loan. assert 200 == response.status_code - [entry] = feedparser.parse(response.response[0])['entries'] - assert any([x for x in entry['links'] if x['rel'] == 'http://opds-spec.org/acquisition']) + [entry] = feedparser.parse(response.response[0])["entries"] + assert any( + [ + x + for x in entry["links"] + if x["rel"] == "http://opds-spec.org/acquisition" + ] + ) def test_fulfill(self): # Verify that arguments to the fulfill() method are propagated # correctly to the CirculationAPI. class MockCirculationAPI(object): - def fulfill(self, patron, credential, requested_license_pool, - mechanism, part, fulfill_part_url): + def fulfill( + self, + patron, + credential, + requested_license_pool, + mechanism, + part, + fulfill_part_url, + ): self.called_with = ( - patron, credential, requested_license_pool, - mechanism, part, fulfill_part_url + patron, + credential, + requested_license_pool, + mechanism, + part, + fulfill_part_url, ) raise CannotFulfill() @@ -2140,16 +2244,20 @@ def fulfill(self, patron, credential, requested_license_pool, # Try to fulfill a certain part of the loan. part = "part 1 million" - controller.fulfill( - self.pool.id, self.mech2.delivery_mechanism.id, part - ) + controller.fulfill(self.pool.id, self.mech2.delivery_mechanism.id, part) # Verify that the right arguments were passed into # CirculationAPI. - (patron, credential, pool, mechanism, part, - fulfill_part_url) = mock.called_with + ( + patron, + credential, + pool, + mechanism, + part, + fulfill_part_url, + ) = mock.called_with assert authenticated == patron - assert self.valid_credentials['password'] == credential + assert self.valid_credentials["password"] == credential assert self.pool == pool assert self.mech2 == mechanism assert "part 1 million" == part @@ -2158,10 +2266,12 @@ def fulfill(self, patron, credential, requested_license_pool, # generating partial fulfillment URLs. Let's try it out # and make sure it gives the result we expect. expect = url_for( - "fulfill", license_pool_id=self.pool.id, + "fulfill", + license_pool_id=self.pool.id, mechanism_id=mechanism.delivery_mechanism.id, library_short_name=library_short_name, - part=part, _external=True + part=part, + _external=True, ) part_url = fulfill_part_url(part) assert expect == part_url @@ -2180,9 +2290,11 @@ def fulfill(self, patron, credential, requested_license_pool, Response(status=200, response="Here's your response"), Response(status=401, response="Error"), Response(status=500, response="Fault"), - ] + ], ) - def test_fulfill_returns_fulfillment_info_implementing_as_response(self, as_response_value): + def test_fulfill_returns_fulfillment_info_implementing_as_response( + self, as_response_value + ): # If CirculationAPI.fulfill returns a FulfillmentInfo that # defines as_response, the result of as_response is returned # directly and the normal process of converting a FulfillmentInfo @@ -2195,8 +2307,7 @@ def as_response(self): class MockCirculationAPI(object): def fulfill(slf, *args, **kwargs): return MockFulfillmentInfo( - self._default_collection, None, None, None, None, - None, None, None + self._default_collection, None, None, None, None, None, None, None ) controller = self.manager.loans @@ -2210,9 +2321,7 @@ def fulfill(slf, *args, **kwargs): loan, ignore = self.pool.loan_to(authenticated) # Fulfill the loan. - result = controller.fulfill( - self.pool.id, self.mech2.delivery_mechanism.id - ) + result = controller.fulfill(self.pool.id, self.mech2.delivery_mechanism.id) # The result of MockFulfillmentInfo.as_response was # returned directly. @@ -2226,7 +2335,8 @@ def test_fulfill_without_active_loan(self): # patron has no active loan for the title. This might be # because the patron never checked out the book... with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): controller.authenticated_patron_from_request() response = controller.fulfill( self.pool.id, self.mech2.delivery_mechanism.id @@ -2245,13 +2355,13 @@ def test_fulfill_without_active_loan(self): # ...or it might be because of an error communicating # with the authentication provider. old_authenticated_patron = controller.authenticated_patron_from_request + def mock_authenticated_patron(): return INTEGRATION_ERROR + controller.authenticated_patron_from_request = mock_authenticated_patron with self.request_context_with_library("/"): - problem = controller.fulfill( - self.pool.id, self.mech2.delivery_mechanism.id - ) + problem = controller.fulfill(self.pool.id, self.mech2.delivery_mechanism.id) assert INTEGRATION_ERROR == problem controller.authenticated_patron_from_request = old_authenticated_patron @@ -2270,7 +2380,9 @@ def mock_fulfill(*args, **kwargs): self.pool.data_source.name, self.pool.identifier.type, self.pool.identifier.identifier, - None, "text/html", "here's your book", + None, + "text/html", + "here's your book", utc_now(), ) @@ -2287,32 +2399,35 @@ def mock_fulfill(*args, **kwargs): assert [] == self._db.query(Loan).all() def test_revoke_loan(self): - with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): - patron = self.manager.loans.authenticated_patron_from_request() - loan, newly_created = self.pool.loan_to(patron) + with self.request_context_with_library( + "/", headers=dict(Authorization=self.valid_auth) + ): + patron = self.manager.loans.authenticated_patron_from_request() + loan, newly_created = self.pool.loan_to(patron) - self.manager.d_circulation.queue_checkin(self.pool, True) + self.manager.d_circulation.queue_checkin(self.pool, True) - response = self.manager.loans.revoke(self.pool.id) + response = self.manager.loans.revoke(self.pool.id) - assert 200 == response.status_code + assert 200 == response.status_code def test_revoke_hold(self): - with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): - patron = self.manager.loans.authenticated_patron_from_request() - hold, newly_created = self.pool.on_hold_to(patron, position=0) + with self.request_context_with_library( + "/", headers=dict(Authorization=self.valid_auth) + ): + patron = self.manager.loans.authenticated_patron_from_request() + hold, newly_created = self.pool.on_hold_to(patron, position=0) - self.manager.d_circulation.queue_release_hold(self.pool, True) + self.manager.d_circulation.queue_release_hold(self.pool, True) - response = self.manager.loans.revoke(self.pool.id) + response = self.manager.loans.revoke(self.pool.id) - assert 200 == response.status_code + assert 200 == response.status_code def test_revoke_hold_nonexistent_licensepool(self): - with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + with self.request_context_with_library( + "/", headers=dict(Authorization=self.valid_auth) + ): patron = self.manager.loans.authenticated_patron_from_request() response = self.manager.loans.revoke(-10) assert isinstance(response, ProblemDetail) @@ -2322,17 +2437,13 @@ def test_hold_fails_when_patron_is_at_hold_limit(self): edition, pool = self._edition(with_license_pool=True) pool.open_access = False with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): patron = self.manager.loans.authenticated_patron_from_request() - self.manager.d_circulation.queue_checkout( - pool, NoAvailableCopies() - ) - self.manager.d_circulation.queue_hold( - pool, PatronHoldLimitReached() - ) + self.manager.d_circulation.queue_checkout(pool, NoAvailableCopies()) + self.manager.d_circulation.queue_hold(pool, PatronHoldLimitReached()) response = self.manager.loans.borrow( - pool.identifier.type, - pool.identifier.identifier + pool.identifier.type, pool.identifier.identifier ) assert isinstance(response, ProblemDetail) assert HOLD_LIMIT_REACHED.uri == response.uri @@ -2350,16 +2461,19 @@ def test_borrow_fails_with_outstanding_fines(self): pool.open_access = False ConfigurationSetting.for_library( - Configuration.MAX_OUTSTANDING_FINES, self._default_library).value = "$0.50" + Configuration.MAX_OUTSTANDING_FINES, self._default_library + ).value = "$0.50" with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): # The patron's credentials are valid, but they have a lot # of fines. patron = self.manager.loans.authenticated_patron_from_request() patron.fines = Decimal("12345678.90") response = self.manager.loans.borrow( - pool.identifier.type, pool.identifier.identifier) + pool.identifier.type, pool.identifier.identifier + ) assert 403 == response.status_code assert OUTSTANDING_FINES.uri == response.uri @@ -2367,53 +2481,62 @@ def test_borrow_fails_with_outstanding_fines(self): # Reduce the patron's fines, and there's no problem. with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): patron = self.manager.loans.authenticated_patron_from_request() patron.fines = Decimal("0.49") self.manager.d_circulation.queue_checkout( pool, LoanInfo( - pool.collection, pool.data_source.name, + pool.collection, + pool.data_source.name, pool.identifier.type, pool.identifier.identifier, utc_now(), utc_now() + datetime.timedelta(seconds=3600), - ) + ), ) response = self.manager.loans.borrow( - pool.identifier.type, pool.identifier.identifier) + pool.identifier.type, pool.identifier.identifier + ) assert 201 == response.status_code def test_3m_cant_revoke_hold_if_reserved(self): - threem_edition, pool = self._edition( - with_open_access_download=False, - data_source_name=DataSource.THREEM, - identifier_type=Identifier.THREEM_ID, - with_license_pool=True, - ) - threem_book = self._work( - presentation_edition=threem_edition, - ) - pool.open_access = False - - with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): - patron = self.manager.loans.authenticated_patron_from_request() - hold, newly_created = pool.on_hold_to(patron, position=0) - response = self.manager.loans.revoke(pool.id) - assert 400 == response.status_code - assert CANNOT_RELEASE_HOLD.uri == response.uri - assert "Cannot release a hold once it enters reserved state." == response.detail - - def test_active_loans(self): + threem_edition, pool = self._edition( + with_open_access_download=False, + data_source_name=DataSource.THREEM, + identifier_type=Identifier.THREEM_ID, + with_license_pool=True, + ) + threem_book = self._work( + presentation_edition=threem_edition, + ) + pool.open_access = False + + with self.request_context_with_library( + "/", headers=dict(Authorization=self.valid_auth) + ): + patron = self.manager.loans.authenticated_patron_from_request() + hold, newly_created = pool.on_hold_to(patron, position=0) + response = self.manager.loans.revoke(pool.id) + assert 400 == response.status_code + assert CANNOT_RELEASE_HOLD.uri == response.uri + assert ( + "Cannot release a hold once it enters reserved state." + == response.detail + ) + + def test_active_loans(self): # First, verify that this controller supports conditional HTTP # GET by calling handle_conditional_request and propagating # any Response it returns. response_304 = Response(status=304) + def handle_conditional_request(last_modified=None): return response_304 + original_handle_conditional_request = self.controller.handle_conditional_request self.manager.loans.handle_conditional_request = handle_conditional_request @@ -2421,9 +2544,7 @@ def handle_conditional_request(last_modified=None): # to a known value. patron = None with self.request_context_with_library("/"): - patron = self.controller.authenticated_patron( - self.valid_credentials - ) + patron = self.controller.authenticated_patron(self.valid_credentials) now = utc_now() patron.last_loan_activity_sync = now @@ -2464,7 +2585,7 @@ def handle_conditional_request(last_modified=None): patron = self.manager.loans.authenticated_patron_from_request() response = self.manager.loans.sync() assert not "" in response.get_data(as_text=True) - assert response.headers['Cache-Control'].startswith('private,') + assert response.headers["Cache-Control"].startswith("private,") # patron.last_loan_activity_sync was set to the moment the # LoanController started calling out to the remote APIs. @@ -2496,14 +2617,16 @@ def handle_conditional_request(last_modified=None): bibliotheca_pool.open_access = False self.manager.d_circulation.add_remote_loan( - overdrive_pool.collection, overdrive_pool.data_source, + overdrive_pool.collection, + overdrive_pool.data_source, overdrive_pool.identifier.type, overdrive_pool.identifier.identifier, utc_now(), - utc_now() + datetime.timedelta(seconds=3600) + utc_now() + datetime.timedelta(seconds=3600), ) self.manager.d_circulation.add_remote_hold( - bibliotheca_pool.collection, bibliotheca_pool.data_source, + bibliotheca_pool.collection, + bibliotheca_pool.data_source, bibliotheca_pool.identifier.type, bibliotheca_pool.identifier.identifier, utc_now(), @@ -2516,10 +2639,11 @@ def handle_conditional_request(last_modified=None): # APIs. The resulting feed won't reflect what we know to be # the reality. with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): patron = self.manager.loans.authenticated_patron_from_request() response = self.manager.loans.sync() - assert '' not in response.get_data(as_text=True) + assert "" not in response.get_data(as_text=True) # patron.last_loan_activity_sync was not changed as the result # of this request, since we didn't go to the vendor APIs. @@ -2533,36 +2657,63 @@ def handle_conditional_request(last_modified=None): # LoanController actually goes out to the vendor APIs for new # information. with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): patron = self.manager.loans.authenticated_patron_from_request() response = self.manager.loans.sync() # This time, the feed contains entries. feed = feedparser.parse(response.data) - entries = feed['entries'] - - overdrive_entry = [entry for entry in entries if entry['title'] == overdrive_book.title][0] - bibliotheca_entry = [entry for entry in entries if entry['title'] == bibliotheca_book.title][0] - - assert overdrive_entry['opds_availability']['status'] == 'available' - assert bibliotheca_entry['opds_availability']['status'] == 'ready' - - overdrive_links = overdrive_entry['links'] - fulfill_link = [x for x in overdrive_links if x['rel'] == 'http://opds-spec.org/acquisition'][0]['href'] - revoke_link = [x for x in overdrive_links if x['rel'] == OPDSFeed.REVOKE_LOAN_REL][0]['href'] - bibliotheca_links = bibliotheca_entry['links'] - borrow_link = [x for x in bibliotheca_links if x['rel'] == 'http://opds-spec.org/acquisition/borrow'][0]['href'] - bibliotheca_revoke_links = [x for x in bibliotheca_links if x['rel'] == OPDSFeed.REVOKE_LOAN_REL] + entries = feed["entries"] + + overdrive_entry = [ + entry for entry in entries if entry["title"] == overdrive_book.title + ][0] + bibliotheca_entry = [ + entry for entry in entries if entry["title"] == bibliotheca_book.title + ][0] + + assert overdrive_entry["opds_availability"]["status"] == "available" + assert bibliotheca_entry["opds_availability"]["status"] == "ready" + + overdrive_links = overdrive_entry["links"] + fulfill_link = [ + x + for x in overdrive_links + if x["rel"] == "http://opds-spec.org/acquisition" + ][0]["href"] + revoke_link = [ + x for x in overdrive_links if x["rel"] == OPDSFeed.REVOKE_LOAN_REL + ][0]["href"] + bibliotheca_links = bibliotheca_entry["links"] + borrow_link = [ + x + for x in bibliotheca_links + if x["rel"] == "http://opds-spec.org/acquisition/borrow" + ][0]["href"] + bibliotheca_revoke_links = [ + x for x in bibliotheca_links if x["rel"] == OPDSFeed.REVOKE_LOAN_REL + ] assert urllib.parse.quote("%s/fulfill" % overdrive_pool.id) in fulfill_link assert urllib.parse.quote("%s/revoke" % overdrive_pool.id) in revoke_link - assert urllib.parse.quote("%s/%s/borrow" % (bibliotheca_pool.identifier.type, bibliotheca_pool.identifier.identifier)) in borrow_link + assert ( + urllib.parse.quote( + "%s/%s/borrow" + % ( + bibliotheca_pool.identifier.type, + bibliotheca_pool.identifier.identifier, + ) + ) + in borrow_link + ) assert 0 == len(bibliotheca_revoke_links) # Since we went out the the vendor APIs, # patron.last_loan_activity_sync was updated. assert patron.last_loan_activity_sync > new_sync_time + class TestAnnotationController(CirculationControllerTest): def setup_method(self): super(TestAnnotationController, self).setup_method() @@ -2572,30 +2723,32 @@ def setup_method(self): def test_get_empty_container(self): with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.loans.authenticated_patron_from_request() response = self.manager.annotations.container() assert 200 == response.status_code # We've been given an annotation container with no items. container = json.loads(response.get_data(as_text=True)) - assert [] == container['first']['items'] - assert 0 == container['total'] + assert [] == container["first"]["items"] + assert 0 == container["total"] # The response has the appropriate headers. - allow_header = response.headers['Allow'] - for method in ['GET', 'HEAD', 'OPTIONS', 'POST']: + allow_header = response.headers["Allow"] + for method in ["GET", "HEAD", "OPTIONS", "POST"]: assert method in allow_header - assert AnnotationWriter.CONTENT_TYPE == response.headers['Accept-Post'] - assert AnnotationWriter.CONTENT_TYPE == response.headers['Content-Type'] - assert 'W/""' == response.headers['ETag'] + assert AnnotationWriter.CONTENT_TYPE == response.headers["Accept-Post"] + assert AnnotationWriter.CONTENT_TYPE == response.headers["Content-Type"] + assert 'W/""' == response.headers["ETag"] def test_get_container_with_item(self): self.pool.loan_to(self.default_patron) annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=self.default_patron, identifier=self.identifier, motivation=Annotation.IDLING, @@ -2604,34 +2757,36 @@ def test_get_container_with_item(self): annotation.timestamp = utc_now() with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.annotations.authenticated_patron_from_request() response = self.manager.annotations.container() assert 200 == response.status_code # We've been given an annotation container with one item. container = json.loads(response.get_data(as_text=True)) - assert 1 == container['total'] - item = container['first']['items'][0] - assert annotation.motivation == item['motivation'] + assert 1 == container["total"] + item = container["first"]["items"][0] + assert annotation.motivation == item["motivation"] # The response has the appropriate headers. - allow_header = response.headers['Allow'] - for method in ['GET', 'HEAD', 'OPTIONS', 'POST']: + allow_header = response.headers["Allow"] + for method in ["GET", "HEAD", "OPTIONS", "POST"]: assert method in allow_header - assert AnnotationWriter.CONTENT_TYPE == response.headers['Accept-Post'] - assert AnnotationWriter.CONTENT_TYPE == response.headers['Content-Type'] + assert AnnotationWriter.CONTENT_TYPE == response.headers["Accept-Post"] + assert AnnotationWriter.CONTENT_TYPE == response.headers["Content-Type"] expected_etag = 'W/"%s"' % annotation.timestamp - assert expected_etag == response.headers['ETag'] + assert expected_etag == response.headers["ETag"] expected_time = format_date_time(mktime(annotation.timestamp.timetuple())) - assert expected_time == response.headers['Last-Modified'] + assert expected_time == response.headers["Last-Modified"] def test_get_container_for_work(self): self.pool.loan_to(self.default_patron) annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=self.default_patron, identifier=self.identifier, motivation=Annotation.IDLING, @@ -2640,56 +2795,71 @@ def test_get_container_for_work(self): annotation.timestamp = utc_now() other_annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=self.default_patron, identifier=self._identifier(), motivation=Annotation.IDLING, ) with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.annotations.authenticated_patron_from_request() - response = self.manager.annotations.container_for_work(self.identifier.type, self.identifier.identifier) + response = self.manager.annotations.container_for_work( + self.identifier.type, self.identifier.identifier + ) assert 200 == response.status_code # We've been given an annotation container with one item. container = json.loads(response.get_data(as_text=True)) - assert 1 == container['total'] - item = container['first']['items'][0] - assert annotation.motivation == item['motivation'] + assert 1 == container["total"] + item = container["first"]["items"][0] + assert annotation.motivation == item["motivation"] # The response has the appropriate headers - POST is not allowed. - allow_header = response.headers['Allow'] - for method in ['GET', 'HEAD', 'OPTIONS']: + allow_header = response.headers["Allow"] + for method in ["GET", "HEAD", "OPTIONS"]: assert method in allow_header - assert 'Accept-Post' not in list(response.headers.keys()) - assert AnnotationWriter.CONTENT_TYPE == response.headers['Content-Type'] + assert "Accept-Post" not in list(response.headers.keys()) + assert AnnotationWriter.CONTENT_TYPE == response.headers["Content-Type"] expected_etag = 'W/"%s"' % annotation.timestamp - assert expected_etag == response.headers['ETag'] + assert expected_etag == response.headers["ETag"] expected_time = format_date_time(mktime(annotation.timestamp.timetuple())) - assert expected_time == response.headers['Last-Modified'] + assert expected_time == response.headers["Last-Modified"] def test_post_to_container(self): data = dict() - data['@context'] = AnnotationWriter.JSONLD_CONTEXT - data['type'] = "Annotation" - data['motivation'] = Annotation.IDLING - data['target'] = dict(source=self.identifier.urn, selector="epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)") + data["@context"] = AnnotationWriter.JSONLD_CONTEXT + data["type"] = "Annotation" + data["motivation"] = Annotation.IDLING + data["target"] = dict( + source=self.identifier.urn, + selector="epubcfi(/6/4[chap01ref]!/4[body01]/10[para05]/3:10)", + ) with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth), method='POST', data=json.dumps(data)): + "/", + headers=dict(Authorization=self.valid_auth), + method="POST", + data=json.dumps(data), + ): patron = self.manager.annotations.authenticated_patron_from_request() patron.synchronize_annotations = True # The patron doesn't have any annotations yet. - annotations = self._db.query(Annotation).filter(Annotation.patron==patron).all() + annotations = ( + self._db.query(Annotation).filter(Annotation.patron == patron).all() + ) assert 0 == len(annotations) response = self.manager.annotations.container() # The patron doesn't have the pool on loan yet, so the request fails. assert 400 == response.status_code - annotations = self._db.query(Annotation).filter(Annotation.patron==patron).all() + annotations = ( + self._db.query(Annotation).filter(Annotation.patron == patron).all() + ) assert 0 == len(annotations) # Give the patron a loan and try again, and the request creates an annotation. @@ -2697,23 +2867,30 @@ def test_post_to_container(self): response = self.manager.annotations.container() assert 200 == response.status_code - annotations = self._db.query(Annotation).filter(Annotation.patron==patron).all() + annotations = ( + self._db.query(Annotation).filter(Annotation.patron == patron).all() + ) assert 1 == len(annotations) annotation = annotations[0] assert Annotation.IDLING == annotation.motivation - selector = json.loads(annotation.target).get("http://www.w3.org/ns/oa#hasSelector")[0].get('@id') - assert data['target']['selector'] == selector + selector = ( + json.loads(annotation.target) + .get("http://www.w3.org/ns/oa#hasSelector")[0] + .get("@id") + ) + assert data["target"]["selector"] == selector # The response contains the annotation in the db. item = json.loads(response.get_data(as_text=True)) - assert str(annotation.id) in item['id'] - assert annotation.motivation == item['motivation'] + assert str(annotation.id) in item["id"] + assert annotation.motivation == item["motivation"] def test_detail(self): self.pool.loan_to(self.default_patron) annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=self.default_patron, identifier=self.identifier, motivation=Annotation.IDLING, @@ -2721,29 +2898,31 @@ def test_detail(self): annotation.active = True with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.annotations.authenticated_patron_from_request() response = self.manager.annotations.detail(annotation.id) assert 200 == response.status_code # We've been given a single annotation item. item = json.loads(response.get_data(as_text=True)) - assert str(annotation.id) in item['id'] - assert annotation.motivation == item['motivation'] + assert str(annotation.id) in item["id"] + assert annotation.motivation == item["motivation"] # The response has the appropriate headers. - allow_header = response.headers['Allow'] - for method in ['GET', 'HEAD', 'OPTIONS', 'DELETE']: + allow_header = response.headers["Allow"] + for method in ["GET", "HEAD", "OPTIONS", "DELETE"]: assert method in allow_header - assert AnnotationWriter.CONTENT_TYPE == response.headers['Content-Type'] + assert AnnotationWriter.CONTENT_TYPE == response.headers["Content-Type"] def test_detail_for_other_patrons_annotation_returns_404(self): patron = self._patron() self.pool.loan_to(patron) annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=patron, identifier=self.identifier, motivation=Annotation.IDLING, @@ -2751,7 +2930,8 @@ def test_detail_for_other_patrons_annotation_returns_404(self): annotation.active = True with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.annotations.authenticated_patron_from_request() # The patron can't see that this annotation exists. @@ -2760,7 +2940,8 @@ def test_detail_for_other_patrons_annotation_returns_404(self): def test_detail_for_missing_annotation_returns_404(self): with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.annotations.authenticated_patron_from_request() # This annotation does not exist. @@ -2771,7 +2952,8 @@ def test_detail_for_deleted_annotation_returns_404(self): self.pool.loan_to(self.default_patron) annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=self.default_patron, identifier=self.identifier, motivation=Annotation.IDLING, @@ -2779,7 +2961,8 @@ def test_detail_for_deleted_annotation_returns_404(self): annotation.active = False with self.request_context_with_library( - "/", headers=dict(Authorization=self.valid_auth)): + "/", headers=dict(Authorization=self.valid_auth) + ): self.manager.annotations.authenticated_patron_from_request() response = self.manager.annotations.detail(annotation.id) assert 404 == response.status_code @@ -2788,7 +2971,8 @@ def test_delete(self): self.pool.loan_to(self.default_patron) annotation, ignore = create( - self._db, Annotation, + self._db, + Annotation, patron=self.default_patron, identifier=self.identifier, motivation=Annotation.IDLING, @@ -2796,7 +2980,8 @@ def test_delete(self): annotation.active = True with self.request_context_with_library( - "/", method='DELETE', headers=dict(Authorization=self.valid_auth)): + "/", method="DELETE", headers=dict(Authorization=self.valid_auth) + ): self.manager.annotations.authenticated_patron_from_request() response = self.manager.annotations.detail(annotation.id) assert 200 == response.status_code @@ -2827,16 +3012,16 @@ def test_contributor(self): contributor.sort_name = None # No contributor name -> ProblemDetail - with self.request_context_with_library('/'): - response = m('', None, None) + with self.request_context_with_library("/"): + response = m("", None, None) assert 404 == response.status_code assert NO_SUCH_LANE.uri == response.uri assert "No contributor provided" == response.detail # Unable to load ContributorData from contributor name -> # ProblemDetail - with self.request_context_with_library('/'): - response = m('Unknown Author', None, None) + with self.request_context_with_library("/"): + response = m("Unknown Author", None, None) assert 404 == response.status_code assert NO_SUCH_LANE.uri == response.uri assert "Unknown contributor: Unknown Author" == response.detail @@ -2845,19 +3030,17 @@ def test_contributor(self): # Search index misconfiguration -> Problem detail self.assert_bad_search_index_gives_problem_detail( - lambda: self.manager.work_controller.series( - contributor, None, None - ) + lambda: self.manager.work_controller.series(contributor, None, None) ) # Bad facet data -> ProblemDetail - with self.request_context_with_library('/?order=nosuchorder'): + with self.request_context_with_library("/?order=nosuchorder"): response = m(contributor, None, None) assert 400 == response.status_code assert INVALID_INPUT.uri == response.uri # Bad pagination data -> ProblemDetail - with self.request_context_with_library('/?size=abc'): + with self.request_context_with_library("/?size=abc"): response = m(contributor, None, None) assert 400 == response.status_code assert INVALID_INPUT.uri == response.uri @@ -2865,32 +3048,31 @@ def test_contributor(self): # Test an end-to-end success (not including a test that the # search engine can actually find books by a given person -- # that's tested in core/tests/test_external_search.py). - with self.request_context_with_library('/'): - response = m(contributor, 'eng,spa', 'Children,Young Adult') + with self.request_context_with_library("/"): + response = m(contributor, "eng,spa", "Children,Young Adult") assert 200 == response.status_code - assert OPDSFeed.ACQUISITION_FEED_TYPE == response.headers['Content-Type'] + assert OPDSFeed.ACQUISITION_FEED_TYPE == response.headers["Content-Type"] feed = feedparser.parse(response.data) # The feed is named after the person we looked up. - assert contributor == feed['feed']['title'] + assert contributor == feed["feed"]["title"] # It's got one entry -- the book added to the search engine # during test setup. - [entry] = feed['entries'] - assert self.english_1.title == entry['title'] + [entry] = feed["entries"] + assert self.english_1.title == entry["title"] # The feed has facet links. - links = feed['feed']['links'] - facet_links = [link for link in links - if link['rel'] == 'http://opds-spec.org/facet'] + links = feed["feed"]["links"] + facet_links = [ + link for link in links if link["rel"] == "http://opds-spec.org/facet" + ] assert 8 == len(facet_links) # The feed was cached. cached = self._db.query(CachedFeed).one() assert CachedFeed.CONTRIBUTOR_TYPE == cached.type - assert ( - 'John Bull-eng,spa-Children,Young+Adult' == - cached.unique_key) + assert "John Bull-eng,spa-Children,Young+Adult" == cached.unique_key # At this point we don't want to generate real feeds anymore. # We can't do a real end-to-end test without setting up a real @@ -2922,9 +3104,7 @@ def page(cls, **kwargs): audiences = "some audiences" sort_key = ["sort", "pagination", "key"] with self.request_context_with_library( - "/?order=title&size=100&key=%s&entrypoint=Audio" % ( - json.dumps(sort_key) - ) + "/?order=title&size=100&key=%s&entrypoint=Audio" % (json.dumps(sort_key)) ): response = m(contributor, languages, audiences, feed_class=Mock) @@ -2937,26 +3117,26 @@ def page(cls, **kwargs): # page(). kwargs = self.called_with - assert self._db == kwargs.pop('_db') - assert self.manager._external_search == kwargs.pop('search_engine') + assert self._db == kwargs.pop("_db") + assert self.manager._external_search == kwargs.pop("search_engine") # The feed is named after the contributor the request asked # about. - assert contributor == kwargs.pop('title') + assert contributor == kwargs.pop("title") # Query string arguments were taken into account when # creating the Facets and Pagination objects. - facets = kwargs.pop('facets') + facets = kwargs.pop("facets") assert isinstance(facets, ContributorFacets) assert AudiobooksEntryPoint == facets.entrypoint - assert 'title' == facets.order + assert "title" == facets.order - pagination = kwargs.pop('pagination') + pagination = kwargs.pop("pagination") assert isinstance(pagination, SortKeyPagination) assert sort_key == pagination.last_item_on_previous_page assert 100 == pagination.size - lane = kwargs.pop('worklist') + lane = kwargs.pop("worklist") assert isinstance(lane, ContributorLane) assert isinstance(lane.contributor, ContributorData) @@ -2977,15 +3157,16 @@ def page(cls, **kwargs): url_kwargs.update(dict(list(pagination.items()))) with self.request_context_with_library(""): expect_url = self.manager.opds_feeds.url_for( - route, lane_identifier=None, + route, + lane_identifier=None, library_short_name=library.short_name, **url_kwargs ) - assert kwargs.pop('url') == expect_url + assert kwargs.pop("url") == expect_url # The Annotator object was instantiated with the proper lane # and the newly created Facets object. - annotator = kwargs.pop('annotator') + annotator = kwargs.pop("annotator") assert lane == annotator.lane assert facets == annotator.facets @@ -3017,17 +3198,16 @@ def test_age_appropriateness_end_to_end(self): # author, we're denied access -- the authenticated # patron's root lane would make any adult books # age-inappropriate. - audiences = ",".join([ - Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_CHILDREN - ]) + audiences = ",".join( + [Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_CHILDREN] + ) response = m(contributor.sort_name, "eng", audiences) assert isinstance(response, ProblemDetail) assert NO_SUCH_LANE.uri == response.uri # If we only ask for children's books by the same author, # we're fine. - response = m(contributor.sort_name, "eng", - Classifier.AUDIENCE_CHILDREN) + response = m(contributor.sort_name, "eng", Classifier.AUDIENCE_CHILDREN) assert 200 == response.status_code # We're also fine if we don't authenticate the request at all. @@ -3037,7 +3217,9 @@ def test_age_appropriateness_end_to_end(self): def test_permalink(self): with self.request_context_with_library("/"): - response = self.manager.work_controller.permalink(self.identifier.type, self.identifier.identifier) + response = self.manager.work_controller.permalink( + self.identifier.type, self.identifier.identifier + ) annotator = LibraryAnnotator(None, None, self._default_library) expect = AcquisitionFeed.single_entry( self._db, self.english_1, annotator @@ -3045,9 +3227,11 @@ def test_permalink(self): assert 200 == response.status_code assert expect == response.get_data() - assert OPDSFeed.ENTRY_TYPE == response.headers['Content-Type'] + assert OPDSFeed.ENTRY_TYPE == response.headers["Content-Type"] - def test_permalink_does_not_return_fulfillment_links_for_authenticated_patrons_without_loans(self): + def test_permalink_does_not_return_fulfillment_links_for_authenticated_patrons_without_loans( + self, + ): with self.request_context_with_library("/"): # We have two patrons. patron_1 = self._patron() @@ -3057,17 +3241,15 @@ def test_permalink_does_not_return_fulfillment_links_for_authenticated_patrons_w flask.request.patron = patron_1 identifier_type = Identifier.GUTENBERG_ID - identifier = '1234567890' + identifier = "1234567890" edition, _ = self._edition( - title='Test Book', + title="Test Book", identifier_type=identifier_type, identifier_id=identifier, - with_license_pool=True + with_license_pool=True, ) work = self._work( - 'Test Book', - presentation_edition=edition, - with_license_pool=True + "Test Book", presentation_edition=edition, with_license_pool=True ) pool = work.license_pools[0] @@ -3080,19 +3262,21 @@ def test_permalink_does_not_return_fulfillment_links_for_authenticated_patrons_w None, None, self._default_library, - active_loans_by_work=active_loans_by_work + active_loans_by_work=active_loans_by_work, ) - expect = AcquisitionFeed.single_entry( - self._db, work, annotator - ).data + expect = AcquisitionFeed.single_entry(self._db, work, annotator).data - response = self.manager.work_controller.permalink(identifier_type, identifier) + response = self.manager.work_controller.permalink( + identifier_type, identifier + ) assert 200 == response.status_code assert expect == response.get_data() - assert OPDSFeed.ENTRY_TYPE == response.headers['Content-Type'] + assert OPDSFeed.ENTRY_TYPE == response.headers["Content-Type"] - def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_loans(self): + def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_loans( + self, + ): with self.request_context_with_library("/"): # We have two patrons. patron_1 = self._patron() @@ -3102,17 +3286,15 @@ def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_loan flask.request.patron = patron_1 identifier_type = Identifier.GUTENBERG_ID - identifier = '1234567890' + identifier = "1234567890" edition, _ = self._edition( - title='Test Book', + title="Test Book", identifier_type=identifier_type, identifier_id=identifier, - with_license_pool=True + with_license_pool=True, ) work = self._work( - 'Test Book', - presentation_edition=edition, - with_license_pool=True + "Test Book", presentation_edition=edition, with_license_pool=True ) pool = work.license_pools[0] @@ -3121,30 +3303,30 @@ def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_loan patron2_loan, _ = pool.loan_to(patron_2) # We want to make sure that only the first patron's loan will be in the feed. - active_loans_by_work = { - work: patron1_loan - } + active_loans_by_work = {work: patron1_loan} annotator = LibraryAnnotator( None, None, self._default_library, - active_loans_by_work=active_loans_by_work + active_loans_by_work=active_loans_by_work, ) - expect = AcquisitionFeed.single_entry( - self._db, work, annotator - ).data + expect = AcquisitionFeed.single_entry(self._db, work, annotator).data - response = self.manager.work_controller.permalink(identifier_type, identifier) + response = self.manager.work_controller.permalink( + identifier_type, identifier + ) assert 200 == response.status_code assert expect == response.get_data() - assert OPDSFeed.ENTRY_TYPE == response.headers['Content-Type'] + assert OPDSFeed.ENTRY_TYPE == response.headers["Content-Type"] - def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_fulfillment(self): + def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_fulfillment( + self, + ): auth = dict(Authorization=self.valid_auth) with self.request_context_with_library("/", headers=auth): - content_link = 'https://content' + content_link = "https://content" # We have two patrons. patron_1 = self.controller.authenticated_patron(self.valid_credentials) @@ -3154,32 +3336,28 @@ def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_fulf flask.request.patron = patron_1 identifier_type = Identifier.GUTENBERG_ID - identifier = '1234567890' + identifier = "1234567890" edition, _ = self._edition( - title='Test Book', + title="Test Book", identifier_type=identifier_type, identifier_id=identifier, - with_license_pool=True + with_license_pool=True, ) work = self._work( - 'Test Book', - presentation_edition=edition, - with_license_pool=True + "Test Book", presentation_edition=edition, with_license_pool=True ) pool = work.license_pools[0] [delivery_mechanism] = pool.delivery_mechanisms loan_info = LoanInfo( - pool.collection, pool.data_source.name, + pool.collection, + pool.data_source.name, pool.identifier.type, pool.identifier.identifier, utc_now(), utc_now() + datetime.timedelta(seconds=3600), ) - self.manager.d_circulation.queue_checkout( - pool, - loan_info - ) + self.manager.d_circulation.queue_checkout(pool, loan_info) fulfillment = FulfillmentInfo( pool.collection, @@ -3189,7 +3367,7 @@ def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_fulf content_link=content_link, content_type=MediaTypes.EPUB_MEDIA_TYPE, content=None, - content_expires=None + content_expires=None, ) self.manager.d_circulation.queue_fulfill(pool, fulfillment) @@ -3197,10 +3375,13 @@ def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_fulf # - the first patron's loan and fulfillment will be created via API. # - the second patron's loan will be created via loan_to method. self.manager.loans.borrow( - pool.identifier.type, pool.identifier.identifier, delivery_mechanism.delivery_mechanism.id + pool.identifier.type, + pool.identifier.identifier, + delivery_mechanism.delivery_mechanism.id, ) self.manager.loans.fulfill( - pool.id, delivery_mechanism.delivery_mechanism.id, + pool.id, + delivery_mechanism.delivery_mechanism.id, ) patron1_loan = pool.loans[0] @@ -3211,24 +3392,22 @@ def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_fulf patron2_loan, _ = pool.loan_to(patron_2) # We want to make sure that only the first patron's fulfillment will be in the feed. - active_loans_by_work = { - work: patron1_loan - } + active_loans_by_work = {work: patron1_loan} annotator = LibraryAnnotator( None, None, self._default_library, active_loans_by_work=active_loans_by_work, ) - expect = AcquisitionFeed.single_entry( - self._db, work, annotator - ).data + expect = AcquisitionFeed.single_entry(self._db, work, annotator).data - response = self.manager.work_controller.permalink(identifier_type, identifier) + response = self.manager.work_controller.permalink( + identifier_type, identifier + ) assert 200 == response.status_code assert expect == response.get_data() - assert OPDSFeed.ENTRY_TYPE == response.headers['Content-Type'] + assert OPDSFeed.ENTRY_TYPE == response.headers["Content-Type"] def test_recommendations(self): # Test the ability to get a feed of works recommended by an @@ -3243,33 +3422,26 @@ def test_recommendations(self): metadata = Metadata(source) mock_api = MockNoveListAPI(self._db) - args = [self.identifier.type, - self.identifier.identifier] + args = [self.identifier.type, self.identifier.identifier] kwargs = dict(novelist_api=mock_api) # We get a 400 response if the pagination data is bad. - with self.request_context_with_library('/?size=abc'): - response = self.manager.work_controller.recommendations( - *args, **kwargs - ) + with self.request_context_with_library("/?size=abc"): + response = self.manager.work_controller.recommendations(*args, **kwargs) assert 400 == response.status_code # Or if the facet data is bad. - with self.request_context_with_library('/?order=nosuchorder'): - response = self.manager.work_controller.recommendations( - *args, **kwargs - ) + with self.request_context_with_library("/?order=nosuchorder"): + response = self.manager.work_controller.recommendations(*args, **kwargs) assert 400 == response.status_code # Or if the search index is misconfigured. self.assert_bad_search_index_gives_problem_detail( - lambda: self.manager.work_controller.recommendations( - *args, **kwargs - ) + lambda: self.manager.work_controller.recommendations(*args, **kwargs) ) # If no NoveList API is configured, the lane does not exist. - with self.request_context_with_library('/'): + with self.request_context_with_library("/"): response = self.manager.work_controller.recommendations( *args, novelist_api=None ) @@ -3285,18 +3457,16 @@ def test_recommendations(self): # created with .return_nothing set, but our mock # ExternalSearchIndex will ignore that setting and return # everything in its index -- as it always does. - with self.request_context_with_library('/'): - response = self.manager.work_controller.recommendations( - *args, **kwargs - ) + with self.request_context_with_library("/"): + response = self.manager.work_controller.recommendations(*args, **kwargs) # A feed is returned with the data from the # ExternalSearchIndex. assert 200 == response.status_code feed = feedparser.parse(response.data) - assert 'Titles recommended by NoveList' == feed['feed']['title'] + assert "Titles recommended by NoveList" == feed["feed"]["title"] [entry] = feed.entries - assert self.english_1.title == entry['title'] + assert self.english_1.title == entry["title"] author = self.edition.author_contributors[0] expected_author_name = author.display_name or author.sort_name assert expected_author_name == entry.author @@ -3309,13 +3479,11 @@ def page(cls, **kwargs): cls.called_with = kwargs return Response("A bunch of titles") - kwargs['feed_class'] = Mock + kwargs["feed_class"] = Mock with self.request_context_with_library( - '/?order=title&size=2&after=30&entrypoint=Audio' + "/?order=title&size=2&after=30&entrypoint=Audio" ): - response = self.manager.work_controller.recommendations( - *args, **kwargs - ) + response = self.manager.work_controller.recommendations(*args, **kwargs) # The return value of Mock.page was used as the response # to the incoming request. @@ -3323,29 +3491,29 @@ def page(cls, **kwargs): assert "A bunch of titles" == response.get_data(as_text=True) kwargs = Mock.called_with - assert self._db == kwargs.pop('_db') - assert 'Titles recommended by NoveList' == kwargs.pop('title') + assert self._db == kwargs.pop("_db") + assert "Titles recommended by NoveList" == kwargs.pop("title") # The RecommendationLane is set up to ask for recommendations # for this book. - lane = kwargs.pop('worklist') + lane = kwargs.pop("worklist") assert isinstance(lane, RecommendationLane) library = self._default_library assert library.id == lane.library_id assert self.english_1 == lane.work - assert 'Recommendations for Quite British by John Bull' == lane.display_name + assert "Recommendations for Quite British by John Bull" == lane.display_name assert mock_api == lane.novelist_api - facets = kwargs.pop('facets') + facets = kwargs.pop("facets") assert isinstance(facets, Facets) assert Facets.ORDER_TITLE == facets.order assert AudiobooksEntryPoint == facets.entrypoint - pagination = kwargs.pop('pagination') + pagination = kwargs.pop("pagination") assert 30 == pagination.offset assert 2 == pagination.size - annotator = kwargs.pop('annotator') + annotator = kwargs.pop("annotator") assert lane == annotator.lane # Checking the URL is difficult because it requires a request @@ -3356,10 +3524,9 @@ def page(cls, **kwargs): url_kwargs.update(dict(list(pagination.items()))) with self.request_context_with_library(""): expect_url = self.manager.work_controller.url_for( - route, library_short_name=library.short_name, - **url_kwargs + route, library_short_name=library.short_name, **url_kwargs ) - assert kwargs.pop('url') == expect_url + assert kwargs.pop("url") == expect_url def test_related_books(self): # Test the related_books controller. @@ -3380,12 +3547,16 @@ def test_related_books(self): # integration is configured. The 'related books' lane ends up # with no sublanes, so the controller acts as if the lane # itself does not exist. - with self.request_context_with_library('/'): + with self.request_context_with_library("/"): response = self.manager.work_controller.related( - identifier.type, identifier.identifier, + identifier.type, + identifier.identifier, ) assert 404 == response.status_code - assert "http://librarysimplified.org/terms/problem/unknown-lane" == response.uri + assert ( + "http://librarysimplified.org/terms/problem/unknown-lane" + == response.uri + ) # Now test some error cases where the lane exists but # something else goes wrong. @@ -3439,51 +3610,53 @@ def test_related_books(self): # Now, ask for works related to self.english_1. with mock_search_index(self.manager.external_search): - with self.request_context_with_library('/?entrypoint=Book'): + with self.request_context_with_library("/?entrypoint=Book"): response = self.manager.work_controller.related( - self.identifier.type, self.identifier.identifier, - novelist_api=mock_api + self.identifier.type, + self.identifier.identifier, + novelist_api=mock_api, ) assert 200 == response.status_code - assert OPDSFeed.ACQUISITION_FEED_TYPE == response.headers['content-type'] + assert OPDSFeed.ACQUISITION_FEED_TYPE == response.headers["content-type"] feed = feedparser.parse(response.data) - assert "Related Books" == feed['feed']['title'] + assert "Related Books" == feed["feed"]["title"] # The feed contains three entries: one for each sublane. - assert 3 == len(feed['entries']) + assert 3 == len(feed["entries"]) # Group the entries by the sublane they're in. def collection_link(entry): - [link] = [l for l in entry['links'] if l['rel']=='collection'] - return link['title'], link['href'] + [link] = [l for l in entry["links"] if l["rel"] == "collection"] + return link["title"], link["href"] + by_collection_link = {} - for entry in feed['entries']: + for entry in feed["entries"]: title, href = collection_link(entry) by_collection_link[title] = (href, entry) # Here's the sublane for books in the same series. - [same_series_href, same_series_entry] = by_collection_link[ - 'Around the World' - ] - assert "Same author and series" == same_series_entry['title'] - expected_series_link = 'series/%s/eng/Adult' % urllib.parse.quote("Around the World") + [same_series_href, same_series_entry] = by_collection_link["Around the World"] + assert "Same author and series" == same_series_entry["title"] + expected_series_link = "series/%s/eng/Adult" % urllib.parse.quote( + "Around the World" + ) assert same_series_href.endswith(expected_series_link) # Here's the sublane for books by this contributor. [same_contributor_href, same_contributor_entry] = by_collection_link[ - 'John Bull' + "John Bull" ] - assert "Same author and series" == same_contributor_entry['title'] - expected_contributor_link = urllib.parse.quote('contributor/John Bull/eng/') + assert "Same author and series" == same_contributor_entry["title"] + expected_contributor_link = urllib.parse.quote("contributor/John Bull/eng/") assert same_contributor_href.endswith(expected_contributor_link) # Here's the sublane for recommendations from NoveList. [recommended_href, recommended_entry] = by_collection_link[ - 'Similar titles recommended by NoveList' + "Similar titles recommended by NoveList" ] - assert "Same author and series" == recommended_entry['title'] + assert "Same author and series" == recommended_entry["title"] work_url = "/works/%s/%s/" % (identifier.type, identifier.identifier) - expected = urllib.parse.quote(work_url + 'recommendations') + expected = urllib.parse.quote(work_url + "recommendations") assert True == recommended_href.endswith(expected) # Finally, let's pass in a mock feed class so we can look at the @@ -3495,10 +3668,12 @@ def groups(cls, **kwargs): return Response("An OPDS feed") mock_api.setup_method(metadata) - with self.request_context_with_library('/?entrypoint=Audio'): + with self.request_context_with_library("/?entrypoint=Audio"): response = self.manager.work_controller.related( - self.identifier.type, self.identifier.identifier, - novelist_api=mock_api, feed_class=Mock + self.identifier.type, + self.identifier.identifier, + novelist_api=mock_api, + feed_class=Mock, ) # The return value of Mock.groups was used as the response @@ -3508,19 +3683,19 @@ def groups(cls, **kwargs): # Verify that groups() was called with the arguments we expect. kwargs = Mock.called_with - assert self._db == kwargs.pop('_db') - assert self.manager.external_search == kwargs.pop('search_engine') - assert "Related Books" == kwargs.pop('title') + assert self._db == kwargs.pop("_db") + assert self.manager.external_search == kwargs.pop("search_engine") + assert "Related Books" == kwargs.pop("title") # We're passing in a FeaturedFacets. Each lane will have a chance # to adapt it to a faceting object appropriate for that lane. - facets = kwargs.pop('facets') + facets = kwargs.pop("facets") assert isinstance(facets, FeaturedFacets) assert AudiobooksEntryPoint == facets.entrypoint # We're generating a grouped feed using a RelatedBooksLane # that has three sublanes. - lane = kwargs.pop('worklist') + lane = kwargs.pop("worklist") assert isinstance(lane, RelatedBooksLane) contributor_lane, novelist_lane, series_lane = lane.children @@ -3534,7 +3709,7 @@ def groups(cls, **kwargs): assert "Around the World" == series_lane.series # The Annotator is associated with the parent RelatedBooksLane. - annotator = kwargs.pop('annotator') + annotator = kwargs.pop("annotator") assert isinstance(annotator, LibraryAnnotator) assert self._default_library == annotator.library assert lane == annotator.lane @@ -3547,31 +3722,33 @@ def groups(cls, **kwargs): url_kwargs.update(dict(list(facets.items()))) with self.request_context_with_library(""): expect_url = self.manager.work_controller.url_for( - route, lane_identifier=None, + route, + lane_identifier=None, library_short_name=library.short_name, **url_kwargs ) - assert kwargs.pop('url') == expect_url + assert kwargs.pop("url") == expect_url # That's it! assert {} == kwargs def test_report_problem_get(self): with self.request_context_with_library("/"): - response = self.manager.work_controller.report(self.identifier.type, self.identifier.identifier) + response = self.manager.work_controller.report( + self.identifier.type, self.identifier.identifier + ) assert 200 == response.status_code - assert "text/uri-list" == response.headers['Content-Type'] + assert "text/uri-list" == response.headers["Content-Type"] for i in Complaint.VALID_TYPES: assert i in response.get_data(as_text=True) def test_report_problem_post_success(self): error_type = random.choice(list(Complaint.VALID_TYPES)) - data = json.dumps({ "type": error_type, - "source": "foo", - "detail": "bar"} - ) + data = json.dumps({"type": error_type, "source": "foo", "detail": "bar"}) with self.request_context_with_library("/", method="POST", data=data): - response = self.manager.work_controller.report(self.identifier.type, self.identifier.identifier) + response = self.manager.work_controller.report( + self.identifier.type, self.identifier.identifier + ) assert 201 == response.status_code [complaint] = self.lp.complaints assert error_type == complaint.type @@ -3585,18 +3762,18 @@ def test_series(self): series_name = "Like As If Whatever Mysteries" # If no series is given, a ProblemDetail is returned. - with self.request_context_with_library('/'): + with self.request_context_with_library("/"): response = self.manager.work_controller.series("", None, None) assert 404 == response.status_code assert "http://librarysimplified.org/terms/problem/unknown-lane" == response.uri # Similarly if the pagination data is bad. - with self.request_context_with_library('/?size=abc'): + with self.request_context_with_library("/?size=abc"): response = self.manager.work_controller.series(series_name, None, None) assert 400 == response.status_code # Or if the facet data is bad - with self.request_context_with_library('/?order=nosuchorder'): + with self.request_context_with_library("/?order=nosuchorder"): response = self.manager.work_controller.series(series_name, None, None) assert 400 == response.status_code @@ -3615,9 +3792,11 @@ def test_series(self): search_engine.bulk_update([work]) # If a series is provided, a feed for that series is returned. - with self.request_context_with_library('/'): + with self.request_context_with_library("/"): response = self.manager.work_controller.series( - series_name, "eng,spa", "Children,Young Adult", + series_name, + "eng,spa", + "Children,Young Adult", ) assert 200 == response.status_code feed = feedparser.parse(response.data) @@ -3625,30 +3804,30 @@ def test_series(self): # The book we added to the mock search engine is in the feed. # This demonstrates that series() asks the search engine for # books to put in the feed. - assert series_name == feed['feed']['title'] - [entry] = feed['entries'] - assert work.title == entry['title'] + assert series_name == feed["feed"]["title"] + [entry] = feed["entries"] + assert work.title == entry["title"] # The feed has facet links. - links = feed['feed']['links'] - facet_links = [link for link in links - if link['rel'] == 'http://opds-spec.org/facet'] + links = feed["feed"]["links"] + facet_links = [ + link for link in links if link["rel"] == "http://opds-spec.org/facet" + ] assert 9 == len(facet_links) # The facet link we care most about is the default sort order, # put into place by SeriesFacets. - [series_position] = [ - x for x in facet_links if x['title'] == 'Series Position' - ] - assert 'Sort by' == series_position['opds:facetgroup'] - assert 'true' == series_position['opds:activefacet'] + [series_position] = [x for x in facet_links if x["title"] == "Series Position"] + assert "Sort by" == series_position["opds:facetgroup"] + assert "true" == series_position["opds:activefacet"] # The feed was cached. cached = self._db.query(CachedFeed).one() assert CachedFeed.SERIES_TYPE == cached.type assert ( - 'Like As If Whatever Mysteries-eng,spa-Children,Young+Adult' == - cached.unique_key) + "Like As If Whatever Mysteries-eng,spa-Children,Young+Adult" + == cached.unique_key + ) # At this point we don't want to generate real feeds anymore. # We can't do a real end-to-end test without setting up a real @@ -3680,8 +3859,7 @@ def page(cls, **kwargs): "/?order=title&size=100&key=%s" % json.dumps(sort_key) ): response = self.manager.work_controller.series( - series_name, "some languages", "some audiences", - feed_class=Mock + series_name, "some languages", "some audiences", feed_class=Mock ) # The return value of Mock.page() is the response to the @@ -3690,14 +3868,14 @@ def page(cls, **kwargs): assert "An OPDS feed" == response.get_data(as_text=True) kwargs = self.called_with - assert self._db == kwargs.pop('_db') + assert self._db == kwargs.pop("_db") # The feed is titled after the series. - assert series_name == kwargs.pop('title') + assert series_name == kwargs.pop("title") # A SeriesLane was created to ask the search index for # matching works. - lane = kwargs.pop('worklist') + lane = kwargs.pop("worklist") assert isinstance(lane, SeriesLane) assert self._default_library.id == lane.library_id assert series_name == lane.series @@ -3707,7 +3885,7 @@ def page(cls, **kwargs): # A SeriesFacets was created to add an extra sort order and # to provide additional search index constraints that can only # be provided through the faceting object. - facets = kwargs.pop('facets') + facets = kwargs.pop("facets") assert isinstance(facets, SeriesFacets) # The 'order' in the query string went into the SeriesFacets @@ -3715,23 +3893,23 @@ def page(cls, **kwargs): assert "title" == facets.order # The 'key' and 'size' went into a SortKeyPagination object. - pagination = kwargs.pop('pagination') + pagination = kwargs.pop("pagination") assert isinstance(pagination, SortKeyPagination) assert sort_key == pagination.last_item_on_previous_page assert 100 == pagination.size # The lane, facets, and pagination were all taken into effect # when constructing the feed URL. - annotator = kwargs.pop('annotator') + annotator = kwargs.pop("annotator") assert lane == annotator.lane with self.request_context_with_library("/"): - assert ( - annotator.feed_url(lane, facets=facets, pagination=pagination) == - kwargs.pop('url')) + assert annotator.feed_url( + lane, facets=facets, pagination=pagination + ) == kwargs.pop("url") # The (mocked) search engine associated with the CirculationManager was # passed in. - assert self.manager.external_search == kwargs.pop('search_engine') + assert self.manager.external_search == kwargs.pop("search_engine") # No other arguments were passed into Mock.page. assert {} == kwargs @@ -3743,7 +3921,7 @@ def page(cls, **kwargs): response = self.manager.work_controller.series( series_name, None, None, feed_class=Mock ) - facets = self.called_with.pop('facets') + facets = self.called_with.pop("facets") assert isinstance(facets, SeriesFacets) assert "series" == facets.order @@ -3770,8 +3948,9 @@ def test_feed(self): response = self.manager.opds_feeds.feed(-1) assert 404 == response.status_code assert ( - "http://librarysimplified.org/terms/problem/unknown-lane" == - response.uri) + "http://librarysimplified.org/terms/problem/unknown-lane" + == response.uri + ) # Bad faceting information -> Problem detail lane_id = self.english_adult_fiction.id @@ -3779,16 +3958,18 @@ def test_feed(self): response = self.manager.opds_feeds.feed(lane_id) assert 400 == response.status_code assert ( - "http://librarysimplified.org/terms/problem/invalid-input" == - response.uri) + "http://librarysimplified.org/terms/problem/invalid-input" + == response.uri + ) # Bad pagination -> Problem detail with self.request_context_with_library("/?size=abc"): response = self.manager.opds_feeds.feed(lane_id) assert 400 == response.status_code assert ( - "http://librarysimplified.org/terms/problem/invalid-input" == - response.uri) + "http://librarysimplified.org/terms/problem/invalid-input" + == response.uri + ) # Bad search index setup -> Problem detail self.assert_bad_search_index_gives_problem_detail( @@ -3799,20 +3980,17 @@ def test_feed(self): # Set up configuration settings for links and entry points library = self._default_library - for rel, value in [(LibraryAnnotator.TERMS_OF_SERVICE, "a"), - (LibraryAnnotator.PRIVACY_POLICY, "b"), - (LibraryAnnotator.COPYRIGHT, "c"), - (LibraryAnnotator.ABOUT, "d"), - ]: + for rel, value in [ + (LibraryAnnotator.TERMS_OF_SERVICE, "a"), + (LibraryAnnotator.PRIVACY_POLICY, "b"), + (LibraryAnnotator.COPYRIGHT, "c"), + (LibraryAnnotator.ABOUT, "d"), + ]: ConfigurationSetting.for_library(rel, library).value = value # Make a real OPDS feed and poke at it. - with self.request_context_with_library( - "/?entrypoint=Book&size=10" - ): - response = self.manager.opds_feeds.feed( - self.english_adult_fiction.id - ) + with self.request_context_with_library("/?entrypoint=Book&size=10"): + response = self.manager.opds_feeds.feed(self.english_adult_fiction.id) # The mock search index returned every book it has, without # respect to which books _ought_ to show up on this page. @@ -3823,62 +4001,64 @@ def test_feed(self): assert 200 == response.status_code assert ( - 'max-age=%d' % Lane.MAX_CACHE_AGE - in response.headers['Cache-Control'] + "max-age=%d" % Lane.MAX_CACHE_AGE in response.headers["Cache-Control"] ) feed = feedparser.parse(response.data) - assert (set([x.title for x in self.works]) == - set([x['title'] for x in feed['entries']])) + assert set([x.title for x in self.works]) == set( + [x["title"] for x in feed["entries"]] + ) # But the rest of the feed looks good. - links = feed['feed']['links'] + links = feed["feed"]["links"] by_rel = dict() # Put the links into a data structure based on their rel values. for i in links: - rel = i['rel'] - href = i['href'] + rel = i["rel"] + href = i["href"] if isinstance(by_rel.get(rel), (bytes, str)): by_rel[rel] = [by_rel[rel]] if isinstance(by_rel.get(rel), list): by_rel[rel].append(href) else: - by_rel[i['rel']] = i['href'] + by_rel[i["rel"]] = i["href"] assert "a" == by_rel[LibraryAnnotator.TERMS_OF_SERVICE] assert "b" == by_rel[LibraryAnnotator.PRIVACY_POLICY] assert "c" == by_rel[LibraryAnnotator.COPYRIGHT] assert "d" == by_rel[LibraryAnnotator.ABOUT] - next_link = by_rel['next'] + next_link = by_rel["next"] lane_str = str(lane_id) assert lane_str in next_link - assert 'entrypoint=Book' in next_link - assert 'size=10' in next_link + assert "entrypoint=Book" in next_link + assert "size=10" in next_link last_item = self.works[-1] # The pagination key for the next page is derived from the # sort fields of the last work in the current page. expected_pagination_key = [ - last_item.sort_title, last_item.sort_author, last_item.id + last_item.sort_title, + last_item.sort_author, + last_item.id, ] expect = "key=%s" % urllib.parse.quote_plus( json.dumps(expected_pagination_key) ) assert expect in next_link - search_link = by_rel['search'] + search_link = by_rel["search"] assert lane_str in search_link - assert 'entrypoint=Book' in search_link + assert "entrypoint=Book" in search_link - shelf_link = by_rel['http://opds-spec.org/shelf'] - assert shelf_link.endswith('/loans/') + shelf_link = by_rel["http://opds-spec.org/shelf"] + assert shelf_link.endswith("/loans/") - facet_links = by_rel['http://opds-spec.org/facet'] + facet_links = by_rel["http://opds-spec.org/facet"] assert all(lane_str in x for x in facet_links) - assert all('entrypoint=Book' in x for x in facet_links) - assert any('order=title' in x for x in facet_links) - assert any('order=author' in x for x in facet_links) + assert all("entrypoint=Book" in x for x in facet_links) + assert any("order=title" in x for x in facet_links) + assert any("order=author" in x for x in facet_links) # Now let's take a closer look at what this controller method # passes into AcquisitionFeed.page(), by mocking page(). @@ -3890,9 +4070,7 @@ def page(cls, **kwargs): sort_key = ["sort", "pagination", "key"] with self.request_context_with_library( - "/?entrypoint=Audio&size=36&key=%s&order=added" % ( - json.dumps(sort_key) - ) + "/?entrypoint=Audio&size=36&key=%s&order=added" % (json.dumps(sort_key)) ): response = self.manager.opds_feeds.feed( self.english_adult_fiction.id, feed_class=Mock @@ -3901,9 +4079,10 @@ def page(cls, **kwargs): # While we're in request context, generate the URL we # expect to be used for this feed. expect_url = self.controller.cdn_url_for( - "feed", lane_identifier=lane_id, + "feed", + lane_identifier=lane_id, library_short_name=self._default_library.short_name, - _facets=load_facets_from_request() + _facets=load_facets_from_request(), ) assert isinstance(response, Response) @@ -3912,32 +4091,32 @@ def page(cls, **kwargs): # Now check all the keyword arguments that were passed into # page(). kwargs = self.called_with - assert kwargs.pop('url') == expect_url - assert self._db == kwargs.pop('_db') - assert self.english_adult_fiction.display_name == kwargs.pop('title') - assert self.english_adult_fiction == kwargs.pop('worklist') + assert kwargs.pop("url") == expect_url + assert self._db == kwargs.pop("_db") + assert self.english_adult_fiction.display_name == kwargs.pop("title") + assert self.english_adult_fiction == kwargs.pop("worklist") # Query string arguments were taken into account when # creating the Facets and Pagination objects. - facets = kwargs.pop('facets') + facets = kwargs.pop("facets") assert AudiobooksEntryPoint == facets.entrypoint - assert 'added' == facets.order + assert "added" == facets.order - pagination = kwargs.pop('pagination') + pagination = kwargs.pop("pagination") assert isinstance(pagination, SortKeyPagination) assert 36 == pagination.size assert sort_key == pagination.last_item_on_previous_page # The Annotator object was instantiated with the proper lane # and the newly created Facets object. - annotator = kwargs.pop('annotator') + annotator = kwargs.pop("annotator") assert self.english_adult_fiction == annotator.lane assert facets == annotator.facets # The ExternalSearchIndex associated with the # CirculationManager was passed in; that way we don't have to # connect to the search engine again. - assert self.manager.external_search == kwargs.pop('search_engine') + assert self.manager.external_search == kwargs.pop("search_engine") # No other arguments were passed into page(). assert {} == kwargs @@ -3961,19 +4140,21 @@ def test_groups(self): response = controller.groups(None) assert 302 == response.status_code expect_url = controller.cdn_url_for( - 'acquisition_groups', + "acquisition_groups", library_short_name=self._default_library.short_name, - lane_identifier=lane.id, _external=True + lane_identifier=lane.id, + _external=True, ) - assert response.headers['Location'] == expect_url + assert response.headers["Location"] == expect_url # Bad lane -> Problem detail with self.request_context_with_library("/"): response = self.manager.opds_feeds.groups(-1) assert 404 == response.status_code assert ( - "http://librarysimplified.org/terms/problem/unknown-lane" == - response.uri) + "http://librarysimplified.org/terms/problem/unknown-lane" + == response.uri + ) # Bad search index setup -> Problem detail self.assert_bad_search_index_gives_problem_detail( @@ -4028,28 +4209,29 @@ def page(cls, **kwargs): # While we're in request context, generate the URL we # expect to be used for this feed. expect_url = self.manager.opds_feeds.cdn_url_for( - "acquisition_groups", lane_identifier=None, + "acquisition_groups", + lane_identifier=None, library_short_name=library.short_name, - _facets=load_facets_from_request() + _facets=load_facets_from_request(), ) kwargs = self.groups_called_with - assert self._db == kwargs.pop('_db') - lane = kwargs.pop('worklist') + assert self._db == kwargs.pop("_db") + lane = kwargs.pop("worklist") assert expect_lane == lane - assert lane.display_name == kwargs.pop('title') - assert expect_url == kwargs.pop('url') + assert lane.display_name == kwargs.pop("title") + assert expect_url == kwargs.pop("url") # A FeaturedFacets object was loaded from library, lane and # request configuration. - facets = kwargs.pop('facets') + facets = kwargs.pop("facets") assert isinstance(facets, FeaturedFacets) assert AudiobooksEntryPoint == facets.entrypoint assert 0.15 == facets.minimum_featured_quality # A LibraryAnnotator object was created from the Lane and # Facets objects. - annotator = kwargs.pop('annotator') + annotator = kwargs.pop("annotator") assert lane == annotator.lane assert facets == annotator.facets @@ -4066,22 +4248,23 @@ def page(cls, **kwargs): # While we're in request context, generate the URL we # expect to be used for this feed. expect_url = self.manager.opds_feeds.cdn_url_for( - "feed", lane_identifier=self.english_adult_fiction.id, + "feed", + lane_identifier=self.english_adult_fiction.id, library_short_name=library.short_name, - _facets=load_facets_from_request() + _facets=load_facets_from_request(), ) - assert self.english_adult_fiction == self.page_called_with.pop('worklist') + assert self.english_adult_fiction == self.page_called_with.pop("worklist") # The canonical URL for this feed is a page-type URL, not a # groups-type URL. - assert expect_url == self.page_called_with.pop('url') + assert expect_url == self.page_called_with.pop("url") # The faceting and pagination objects are typical for the # first page of a paginated feed. - pagination = self.page_called_with.pop('pagination') + pagination = self.page_called_with.pop("pagination") assert isinstance(pagination, SortKeyPagination) - facets = self.page_called_with.pop('facets') + facets = self.page_called_with.pop("facets") assert isinstance(facets, Facets) # groups() was never called. @@ -4095,9 +4278,9 @@ def page(cls, **kwargs): self.english_adult_fiction.id, feed_class=Mock ) assert None == self.page_called_with - assert self.english_adult_fiction == self.groups_called_with.pop('worklist') - assert isinstance(self.groups_called_with.pop('facets'), FeaturedFacets) - assert 'pagination' not in self.groups_called_with + assert self.english_adult_fiction == self.groups_called_with.pop("worklist") + assert isinstance(self.groups_called_with.pop("facets"), FeaturedFacets) + assert "pagination" not in self.groups_called_with def test_navigation(self): library = self._default_library @@ -4107,17 +4290,19 @@ def test_navigation(self): # Mock NavigationFeed.navigation so we can see the arguments going # into it. old_navigation = NavigationFeed.navigation + @classmethod def mock_navigation(cls, *args, **kwargs): self.called_with = (args, kwargs) return old_navigation(*args, **kwargs) + NavigationFeed.navigation = mock_navigation with self.request_context_with_library("/"): response = self.manager.opds_feeds.navigation(lane.id) feed = feedparser.parse(response.data) - entries = feed['entries'] + entries = feed["entries"] # The default top-level lane is "World Languages", which contains # sublanes for English, Spanish, Chinese, and French. assert len(lane.sublanes) == len(entries) @@ -4125,7 +4310,7 @@ def mock_navigation(cls, *args, **kwargs): # A NavigationFacets object was created and passed in to # NavigationFeed.navigation(). args, kwargs = self.called_with - facets = kwargs['facets'] + facets = kwargs["facets"] assert isinstance(facets, NavigationFacets) NavigationFeed.navigation = old_navigation @@ -4140,6 +4325,7 @@ def _set(work, time): work.last_update_time = time for lp in work.license_pools: lp.availability_time = time + the_far_future = now + datetime.timedelta(hours=2) the_future = now + datetime.timedelta(hours=1) the_past = now - datetime.timedelta(hours=1) @@ -4156,7 +4342,10 @@ def test_search_document(self): # term, you get an OpenSearch document. with self.request_context_with_library("/"): response = self.manager.opds_feeds.search(None) - assert response.headers['Content-Type'] == 'application/opensearchdescription+xml' + assert ( + response.headers["Content-Type"] + == "application/opensearchdescription+xml" + ) assert "OpenSearchDescription" in response.get_data(as_text=True) def test_search(self): @@ -4167,16 +4356,18 @@ def test_search(self): response = self.manager.opds_feeds.search(-1) assert 404 == response.status_code assert ( - "http://librarysimplified.org/terms/problem/unknown-lane" == - response.uri) + "http://librarysimplified.org/terms/problem/unknown-lane" + == response.uri + ) # Bad pagination -> problem detail with self.request_context_with_library("/?size=abc"): response = self.manager.opds_feeds.search(None) assert 400 == response.status_code assert ( - "http://librarysimplified.org/terms/problem/invalid-input" == - response.uri) + "http://librarysimplified.org/terms/problem/invalid-input" + == response.uri + ) # Bad search index setup -> Problem detail self.assert_bad_search_index_gives_problem_detail( @@ -4196,29 +4387,27 @@ def search(cls, **kwargs): self.called_with = kwargs return "An OPDS feed" - with self.request_context_with_library( - "/?q=t&size=99&after=22&media=Music" - ): + with self.request_context_with_library("/?q=t&size=99&after=22&media=Music"): # Try the top-level lane, "World Languages" expect_lane = self.manager.opds_feeds.load_lane(None) response = self.manager.opds_feeds.search(None, feed_class=Mock) kwargs = self.called_with - assert self._db == kwargs.pop('_db') + assert self._db == kwargs.pop("_db") # Unlike other types of feeds, here the argument is called # 'lane' instead of 'worklist', because a Lane is the _only_ # kind of WorkList that is currently searchable. - lane = kwargs.pop('lane') + lane = kwargs.pop("lane") assert expect_lane == lane query = kwargs.pop("query") assert "t" == query assert "Search" == kwargs.pop("title") - assert self.manager.external_search == kwargs.pop('search_engine') + assert self.manager.external_search == kwargs.pop("search_engine") # A SearchFacets object was loaded from library, lane and # request configuration. - facets = kwargs.pop('facets') + facets = kwargs.pop("facets") assert isinstance(facets, SearchFacets) # There are multiple possible entry points, and the request @@ -4232,13 +4421,13 @@ def search(cls, **kwargs): # Information from the query string was used to make a # Pagination object. - pagination = kwargs.pop('pagination') + pagination = kwargs.pop("pagination") assert 22 == pagination.offset assert 99 == pagination.size # A LibraryAnnotator object was created from the Lane and # Facets objects. - annotator = kwargs.pop('annotator') + annotator = kwargs.pop("annotator") assert lane == annotator.lane assert facets == annotator.facets @@ -4248,11 +4437,13 @@ def search(cls, **kwargs): library = self._default_library with self.request_context_with_library(""): expect_url = self.manager.opds_feeds.url_for( - 'lane_search', lane_identifier=None, + "lane_search", + lane_identifier=None, library_short_name=library.short_name, - **dict(list(facets.items())), q=query + **dict(list(facets.items())), + q=query ) - assert expect_url == kwargs.pop('url') + assert expect_url == kwargs.pop("url") # No other arguments were passed into search(). assert {} == kwargs @@ -4267,10 +4458,10 @@ def search(cls, **kwargs): kwargs = self.called_with # We're searching that lane. - assert self.english_adult_fiction == kwargs['lane'] + assert self.english_adult_fiction == kwargs["lane"] # And we get the entry point we asked for. - assert AudiobooksEntryPoint == kwargs['facets'].entrypoint + assert AudiobooksEntryPoint == kwargs["facets"].entrypoint # When only a single entry point is enabled, it's used as the # default. @@ -4279,12 +4470,10 @@ def search(cls, **kwargs): ) with self.request_context_with_library("/?q=t"): response = self.manager.opds_feeds.search(None, feed_class=Mock) - assert AudiobooksEntryPoint == self.called_with['facets'].entrypoint + assert AudiobooksEntryPoint == self.called_with["facets"].entrypoint def test_misconfigured_search(self): - class BadSearch(CirculationManager): - @property def setup_search(self): raise Exception("doomed!") @@ -4296,8 +4485,10 @@ def setup_search(self): with self.request_context_with_library("/?q=t"): problem = circulation.opds_feeds.search(None) assert REMOTE_INTEGRATION_FAILED.uri == problem.uri - assert ('The search index for this site is not properly configured.' == - problem.detail) + assert ( + "The search index for this site is not properly configured." + == problem.detail + ) def test__qa_feed(self): # Test the _qa_feed() controller method. @@ -4309,26 +4500,25 @@ def test__qa_feed(self): feed_method = MagicMock(return_value="an OPDS feed") m = self.manager.opds_feeds._qa_feed - args = (feed_method, "QA test feed", "qa_feed", Facets, - worklist_factory) + args = (feed_method, "QA test feed", "qa_feed", Facets, worklist_factory) # Bad search index setup -> Problem detail - self.assert_bad_search_index_gives_problem_detail( - lambda: m(*args) - ) + self.assert_bad_search_index_gives_problem_detail(lambda: m(*args)) # Bad faceting information -> Problem detail with self.request_context_with_library("/?order=nosuchorder"): response = m(*args) assert 400 == response.status_code assert ( - "http://librarysimplified.org/terms/problem/invalid-input" == - response.uri) + "http://librarysimplified.org/terms/problem/invalid-input" + == response.uri + ) # Now test success. with self.request_context_with_library("/"): expect_url = self.manager.opds_feeds.url_for( - 'qa_feed', library_short_name=self._default_library.short_name, + "qa_feed", + library_short_name=self._default_library.short_name, ) response = m(*args) @@ -4349,33 +4539,33 @@ def test__qa_feed(self): [call] = feed_method.mock_calls kwargs = call.kwargs - assert self._db == kwargs.pop('_db') + assert self._db == kwargs.pop("_db") assert "QA test feed" == kwargs.pop("title") - assert self.manager.external_search == kwargs.pop('search_engine') - assert expect_url == kwargs.pop('url') + assert self.manager.external_search == kwargs.pop("search_engine") + assert expect_url == kwargs.pop("url") # These feeds are never to be cached. - assert CachedFeed.IGNORE_CACHE == kwargs.pop('max_age') + assert CachedFeed.IGNORE_CACHE == kwargs.pop("max_age") # To improve performance, a Pagination object was created that # limits each lane in the test feed to a single Work. - pagination = kwargs.pop('pagination') + pagination = kwargs.pop("pagination") assert isinstance(pagination, Pagination) assert 1 == pagination.size # The WorkList returned by worklist_factory was passed into # feed_method. - assert wl == kwargs.pop('worklist') + assert wl == kwargs.pop("worklist") # So was a LibraryAnnotator object created from that WorkList. - annotator = kwargs.pop('annotator') + annotator = kwargs.pop("annotator") assert isinstance(annotator, LibraryAnnotator) assert wl == annotator.lane assert None == annotator.facets # The Facets object used to initialize the feed is the same # one passed into worklist_factory. - assert facets == kwargs.pop('facets') + assert facets == kwargs.pop("facets") # No other arguments were passed into feed_method(). assert {} == kwargs @@ -4393,8 +4583,8 @@ def test_qa_feed(self): # For the most part, we're verifying that the expected values # are passed in to _qa_feed. - assert AcquisitionFeed.groups == kwargs.pop('feed_method') - assert JackpotFacets == kwargs.pop('facet_class') + assert AcquisitionFeed.groups == kwargs.pop("feed_method") + assert JackpotFacets == kwargs.pop("facet_class") assert "qa_feed" == kwargs.pop("controller_name") assert "QA test feed" == kwargs.pop("feed_title") factory = kwargs.pop("worklist_factory") @@ -4406,8 +4596,7 @@ def test_qa_feed(self): # other calls. with self.request_context_with_library("/"): facets = load_facets_from_request( - base_class=JackpotFacets, - default_entrypoint=EverythingEntryPoint + base_class=JackpotFacets, default_entrypoint=EverythingEntryPoint ) worklist = factory(self._default_library, facets) @@ -4431,8 +4620,8 @@ def test_qa_feed(self): # For the most part, we're verifying that the expected values # are passed in to _qa_feed. - assert AcquisitionFeed.groups == kwargs.pop('feed_factory') - assert JackpotFacets == kwargs.pop('facet_class') + assert AcquisitionFeed.groups == kwargs.pop("feed_factory") + assert JackpotFacets == kwargs.pop("facet_class") assert "qa_feed" == kwargs.pop("controller_name") assert "QA test feed" == kwargs.pop("feed_title") factory = kwargs.pop("worklist_factory") @@ -4444,8 +4633,7 @@ def test_qa_feed(self): # other calls. with self.request_context_with_library("/"): facets = load_facets_from_request( - base_class=JackpotFacets, - default_entrypoint=EverythingEntryPoint + base_class=JackpotFacets, default_entrypoint=EverythingEntryPoint ) worklist = factory(self._default_library, facets) @@ -4473,8 +4661,8 @@ def test_qa_series_feed(self): # Note that the feed_method is different from the one in qa_feed. # We want to generate an ungrouped feed rather than a grouped one. - assert AcquisitionFeed.page == kwargs.pop('feed_factory') - assert HasSeriesFacets == kwargs.pop('facet_class') + assert AcquisitionFeed.page == kwargs.pop("feed_factory") + assert HasSeriesFacets == kwargs.pop("facet_class") assert "qa_series_feed" == kwargs.pop("controller_name") assert "QA series test feed" == kwargs.pop("feed_title") factory = kwargs.pop("worklist_factory") @@ -4492,7 +4680,6 @@ def test_qa_series_feed(self): class TestCrawlableFeed(CirculationControllerTest): - @contextmanager def mock_crawlable_feed(self): """Temporarily mock _crawlable_feed with something @@ -4500,13 +4687,17 @@ def mock_crawlable_feed(self): """ controller = self.manager.opds_feeds original = controller._crawlable_feed - def mock(title, url, worklist, annotator=None, - feed_class=AcquisitionFeed): + + def mock(title, url, worklist, annotator=None, feed_class=AcquisitionFeed): self._crawlable_feed_called_with = dict( - title=title, url=url, worklist=worklist, annotator=annotator, - feed_class=feed_class + title=title, + url=url, + worklist=worklist, + annotator=annotator, + feed_class=feed_class, ) return "An OPDS feed." + controller._crawlable_feed = mock yield controller._crawlable_feed = original @@ -4531,14 +4722,14 @@ def test_crawlable_library_feed(self): # Verify that _crawlable_feed was called with the right arguments. kwargs = self._crawlable_feed_called_with - assert expect_url == kwargs.pop('url') - assert library.name == kwargs.pop('title') - assert None == kwargs.pop('annotator') - assert AcquisitionFeed == kwargs.pop('feed_class') + assert expect_url == kwargs.pop("url") + assert library.name == kwargs.pop("title") + assert None == kwargs.pop("annotator") + assert AcquisitionFeed == kwargs.pop("feed_class") # A CrawlableCollectionBasedLane has been set up to show # everything in any of the requested library's collections. - lane = kwargs.pop('worklist') + lane = kwargs.pop("worklist") assert isinstance(lane, CrawlableCollectionBasedLane) assert library.id == lane.library_id assert [x.id for x in library.collections] == lane.collection_ids @@ -4578,12 +4769,12 @@ def test_crawlable_collection_feed(self): # Verify that _crawlable_feed was called with the right arguments. kwargs = self._crawlable_feed_called_with - assert expect_url == kwargs.pop('url') - assert collection.name == kwargs.pop('title') + assert expect_url == kwargs.pop("url") + assert collection.name == kwargs.pop("title") # A CrawlableCollectionBasedLane has been set up to show # everything in the requested collection. - lane = kwargs.pop('worklist') + lane = kwargs.pop("worklist") assert isinstance(lane, CrawlableCollectionBasedLane) assert None == lane.library_id assert [collection.id] == lane.collection_ids @@ -4591,7 +4782,7 @@ def test_crawlable_collection_feed(self): # No specific Annotator as created to build the OPDS # feed. We'll be using the default for a request with no # library context--a CirculationManagerAnnotator. - assert None == kwargs.pop('annotator') + assert None == kwargs.pop("annotator") # A specific annotator _is_ created for an ODL collection: # A SharedCollectionAnnotator that knows about the Collection @@ -4603,10 +4794,10 @@ def test_crawlable_collection_feed(self): collection_name=collection.name ) kwargs = self._crawlable_feed_called_with - annotator = kwargs['annotator'] + annotator = kwargs["annotator"] assert isinstance(annotator, SharedCollectionAnnotator) assert collection == annotator.collection - assert kwargs['worklist'] == annotator.lane + assert kwargs["worklist"] == annotator.lane def test_crawlable_list_feed(self): # Test the creation of a crawlable feed for everything in @@ -4643,14 +4834,14 @@ def test_crawlable_list_feed(self): # Verify that _crawlable_feed was called with the right arguments. kwargs = self._crawlable_feed_called_with - assert expect_url == kwargs.pop('url') - assert customlist.name == kwargs.pop('title') - assert None == kwargs.pop('annotator') - assert AcquisitionFeed == kwargs.pop('feed_class') + assert expect_url == kwargs.pop("url") + assert customlist.name == kwargs.pop("title") + assert None == kwargs.pop("annotator") + assert AcquisitionFeed == kwargs.pop("feed_class") # A CrawlableCustomListBasedLane was created to fetch only # the works in the custom list. - lane = kwargs.pop('worklist') + lane = kwargs.pop("worklist") assert isinstance(lane, CrawlableCustomListBasedLane) assert [customlist.id] == lane.customlist_ids assert {} == kwargs @@ -4658,6 +4849,7 @@ def test_crawlable_list_feed(self): def test__crawlable_feed(self): # Test the helper method called by all other feed methods. self.page_called_with = None + class MockFeed(object): @classmethod def page(cls, **kwargs): @@ -4665,6 +4857,7 @@ def page(cls, **kwargs): return Response("An OPDS feed") work = self._work(with_open_access_download=True) + class MockLane(DynamicLane): def works(self, _db, facets, pagination, *args, **kwargs): # We need to call page_loaded() (normally called by @@ -4675,9 +4868,7 @@ def works(self, _db, facets, pagination, *args, **kwargs): # It's not necessary for this test to call it with a # realistic value, but we might as well. results = [ - MockSearchResult( - work.sort_title, work.sort_author, {}, work.id - ) + MockSearchResult(work.sort_title, work.sort_author, {}, work.id) ] pagination.page_loaded(results) return [work] @@ -4685,10 +4876,7 @@ def works(self, _db, facets, pagination, *args, **kwargs): mock_lane = MockLane() mock_lane.initialize(None) in_kwargs = dict( - title="Lane title", - url="Lane URL", - worklist=mock_lane, - feed_class=MockFeed + title="Lane title", url="Lane URL", worklist=mock_lane, feed_class=MockFeed ) # Bad pagination data -> problem detail @@ -4705,9 +4893,7 @@ def works(self, _db, facets, pagination, *args, **kwargs): # Good pagination data -> feed_class.page() is called. sort_key = ["sort", "pagination", "key"] - with self.app.test_request_context( - "/?size=23&key=%s" % json.dumps(sort_key) - ): + with self.app.test_request_context("/?size=23&key=%s" % json.dumps(sort_key)): response = self.manager.opds_feeds._crawlable_feed(**in_kwargs) # The result of page() was served as an OPDS feed. @@ -4716,28 +4902,27 @@ def works(self, _db, facets, pagination, *args, **kwargs): # Verify the arguments passed in to page(). out_kwargs = self.page_called_with - assert self._db == out_kwargs.pop('_db') - assert (self.manager.opds_feeds.search_engine == - out_kwargs.pop('search_engine')) - assert in_kwargs['worklist'] == out_kwargs.pop('worklist') - assert in_kwargs['title'] == out_kwargs.pop('title') - assert in_kwargs['url'] == out_kwargs.pop('url') + assert self._db == out_kwargs.pop("_db") + assert self.manager.opds_feeds.search_engine == out_kwargs.pop("search_engine") + assert in_kwargs["worklist"] == out_kwargs.pop("worklist") + assert in_kwargs["title"] == out_kwargs.pop("title") + assert in_kwargs["url"] == out_kwargs.pop("url") # Since no annotator was provided and the request did not # happen in a library context, a generic # CirculationManagerAnnotator was created. - annotator = out_kwargs.pop('annotator') + annotator = out_kwargs.pop("annotator") assert isinstance(annotator, CirculationManagerAnnotator) assert mock_lane == annotator.lane # There's only one way to configure CrawlableFacets, so it's # sufficient to check that our faceting object is in fact a # CrawlableFacets. - facets = out_kwargs.pop('facets') + facets = out_kwargs.pop("facets") assert isinstance(facets, CrawlableFacets) # Verify that pagination was picked up from the request. - pagination = out_kwargs.pop('pagination') + pagination = out_kwargs.pop("pagination") assert isinstance(pagination, SortKeyPagination) assert sort_key == pagination.last_item_on_previous_page assert 23 == pagination.size @@ -4752,18 +4937,18 @@ def works(self, _db, facets, pagination, *args, **kwargs): response = self.manager.opds_feeds._crawlable_feed( annotator=mock_annotator, **in_kwargs ) - assert mock_annotator == self.page_called_with['annotator'] + assert mock_annotator == self.page_called_with["annotator"] # Finally, remove the mock feed class and verify that a real OPDS # feed is generated from the result of MockLane.works() - del in_kwargs['feed_class'] + del in_kwargs["feed_class"] with self.request_context_with_library("/"): response = self.manager.opds_feeds._crawlable_feed(**in_kwargs) feed = feedparser.parse(response.data) # There is one entry with the expected title. - [entry] = feed['entries'] - assert entry['title'] == work.title + [entry] = feed["entries"] + assert entry["title"] == work.title class TestMARCRecordController(CirculationControllerTest): @@ -4775,40 +4960,62 @@ def test_download_page_with_exporter_and_files(self): lane = self._lane(display_name="Test Lane") exporter = self._external_integration( - ExternalIntegration.MARC_EXPORT, ExternalIntegration.CATALOG_GOAL, - libraries=[self._default_library]) + ExternalIntegration.MARC_EXPORT, + ExternalIntegration.CATALOG_GOAL, + libraries=[self._default_library], + ) rep1, ignore = create( - self._db, Representation, - url="http://mirror1", mirror_url="http://mirror1", + self._db, + Representation, + url="http://mirror1", + mirror_url="http://mirror1", media_type=Representation.MARC_MEDIA_TYPE, - mirrored_at=now) + mirrored_at=now, + ) cache1, ignore = create( - self._db, CachedMARCFile, - library=self._default_library, lane=None, - representation=rep1, end_time=now) + self._db, + CachedMARCFile, + library=self._default_library, + lane=None, + representation=rep1, + end_time=now, + ) rep2, ignore = create( - self._db, Representation, - url="http://mirror2", mirror_url="http://mirror2", + self._db, + Representation, + url="http://mirror2", + mirror_url="http://mirror2", media_type=Representation.MARC_MEDIA_TYPE, - mirrored_at=yesterday) + mirrored_at=yesterday, + ) cache2, ignore = create( - self._db, CachedMARCFile, - library=self._default_library, lane=lane, - representation=rep2, end_time=yesterday) + self._db, + CachedMARCFile, + library=self._default_library, + lane=lane, + representation=rep2, + end_time=yesterday, + ) rep3, ignore = create( - self._db, Representation, - url="http://mirror3", mirror_url="http://mirror3", + self._db, + Representation, + url="http://mirror3", + mirror_url="http://mirror3", media_type=Representation.MARC_MEDIA_TYPE, - mirrored_at=now) + mirrored_at=now, + ) cache3, ignore = create( - self._db, CachedMARCFile, - library=self._default_library, lane=None, - representation=rep3, end_time=now, - start_time=yesterday) - + self._db, + CachedMARCFile, + library=self._default_library, + lane=None, + representation=rep3, + end_time=now, + start_time=yesterday, + ) with self.request_context_with_library("/"): response = self.manager.marc_records.download_page() @@ -4817,12 +5024,24 @@ def test_download_page_with_exporter_and_files(self): assert ("Download MARC files for %s" % library.name) in html assert "

    All Books

    " in html - assert 'Full file - last updated %s' % now.strftime("%B %-d, %Y") in html + assert ( + 'Full file - last updated %s' + % now.strftime("%B %-d, %Y") + in html + ) assert "

    Update-only files

    " in html - assert 'Updates from %s to %s' % (yesterday.strftime("%B %-d, %Y"), now.strftime("%B %-d, %Y")) in html + assert ( + 'Updates from %s to %s' + % (yesterday.strftime("%B %-d, %Y"), now.strftime("%B %-d, %Y")) + in html + ) - assert '

    Test Lane

    ' in html - assert 'Full file - last updated %s' % yesterday.strftime("%B %-d, %Y") in html + assert "

    Test Lane

    " in html + assert ( + 'Full file - last updated %s' + % yesterday.strftime("%B %-d, %Y") + in html + ) def test_download_page_with_exporter_but_no_files(self): now = utc_now() @@ -4831,8 +5050,10 @@ def test_download_page_with_exporter_but_no_files(self): library = self._default_library exporter = self._external_integration( - ExternalIntegration.MARC_EXPORT, ExternalIntegration.CATALOG_GOAL, - libraries=[self._default_library]) + ExternalIntegration.MARC_EXPORT, + ExternalIntegration.CATALOG_GOAL, + libraries=[self._default_library], + ) with self.request_context_with_library("/"): response = self.manager.marc_records.download_page() @@ -4855,14 +5076,21 @@ def test_download_page_no_exporter(self): # they will still be available to download. now = utc_now() rep, ignore = create( - self._db, Representation, - url="http://mirror1", mirror_url="http://mirror1", + self._db, + Representation, + url="http://mirror1", + mirror_url="http://mirror1", media_type=Representation.MARC_MEDIA_TYPE, - mirrored_at=now) + mirrored_at=now, + ) cache, ignore = create( - self._db, CachedMARCFile, - library=self._default_library, lane=None, - representation=rep, end_time=now) + self._db, + CachedMARCFile, + library=self._default_library, + lane=None, + representation=rep, + end_time=now, + ) with self.request_context_with_library("/"): response = self.manager.marc_records.download_page() @@ -4870,8 +5098,12 @@ def test_download_page_no_exporter(self): html = response.get_data(as_text=True) assert ("Download MARC files for %s" % library.name) in html assert "No MARC exporter is currently configured" in html - assert '

    All Books

    ' in html - assert 'Full file - last updated %s' % now.strftime("%B %-d, %Y") in html + assert "

    All Books

    " in html + assert ( + 'Full file - last updated %s' + % now.strftime("%B %-d, %Y") + in html + ) class TestAnalyticsController(CirculationControllerTest): @@ -4882,7 +5114,8 @@ def setup_method(self): def test_track_event(self): integration, ignore = create( - self._db, ExternalIntegration, + self._db, + ExternalIntegration, goal=ExternalIntegration.ANALYTICS_GOAL, protocol="core.local_analytics_provider", ) @@ -4892,7 +5125,9 @@ def test_track_event(self): self.manager.analytics = Analytics(self._db) with self.request_context_with_library("/"): - response = self.manager.analytics_controller.track_event(self.identifier.type, self.identifier.identifier, "invalid_type") + response = self.manager.analytics_controller.track_event( + self.identifier.type, self.identifier.identifier, "invalid_type" + ) assert 400 == response.status_code assert INVALID_ANALYTICS_EVENT_TYPE.uri == response.uri @@ -4904,15 +5139,12 @@ def test_track_event(self): with self.request_context_with_library("/"): flask.request.patron = request_patron response = self.manager.analytics_controller.track_event( - self.identifier.type, self.identifier.identifier, - "open_book" + self.identifier.type, self.identifier.identifier, "open_book" ) assert 200 == response.status_code circulation_event = get_one( - self._db, CirculationEvent, - type="open_book", - license_pool=self.lp + self._db, CirculationEvent, type="open_book", license_pool=self.lp ) assert None == circulation_event.location self._db.delete(circulation_event) @@ -4930,15 +5162,13 @@ def test_track_event(self): assert 200 == response.status_code circulation_event = get_one( - self._db, CirculationEvent, - type="open_book", - license_pool=self.lp + self._db, CirculationEvent, type="open_book", license_pool=self.lp ) assert patron.neighborhood == circulation_event.location self._db.delete(circulation_event) -class TestDeviceManagementProtocolController(ControllerTest): +class TestDeviceManagementProtocolController(ControllerTest): def setup_method(self): super(TestDeviceManagementProtocolController, self).setup_method() self.initialize_adobe(self.library, self.libraries) @@ -4965,7 +5195,7 @@ def _create_credential(self): return self._credential( DataSource.INTERNAL_PROCESSING, AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, - self.default_patron + self.default_patron, ) def test_link_template_header(self): @@ -4975,8 +5205,13 @@ def test_link_template_header(self): with self.request_context_with_library("/"): headers = self.controller.link_template_header assert 1 == len(headers) - template = headers['Link-Template'] - expected_url = url_for("adobe_drm_device", library_short_name=self.library.short_name, device_id="{id}", _external=True) + template = headers["Link-Template"] + expected_url = url_for( + "adobe_drm_device", + library_short_name=self.library.short_name, + device_id="{id}", + _external=True, + ) expected_url = expected_url.replace("%7Bid%7D", "{id}") assert '<%s>; rel="item"' % expected_url == template @@ -4994,9 +5229,9 @@ def test_device_id_list_handler_post_success(self): # The patron has no credentials, and thus no registered devices. assert [] == self.default_patron.credentials headers = dict(self.auth) - headers['Content-Type'] = self.controller.DEVICE_ID_LIST_MEDIA_TYPE + headers["Content-Type"] = self.controller.DEVICE_ID_LIST_MEDIA_TYPE with self.request_context_with_library( - "/", method='POST', headers=headers, data="device" + "/", method="POST", headers=headers, data="device" ): self.controller.authenticated_patron_from_request() response = self.controller.device_id_list_handler() @@ -5007,11 +5242,11 @@ def test_device_id_list_handler_post_success(self): # them. [credential] = self.default_patron.credentials assert DataSource.INTERNAL_PROCESSING == credential.data_source.name - assert (AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER == - credential.type) + assert AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER == credential.type - assert (['device'] == - [x.device_identifier for x in credential.drm_device_identifiers]) + assert ["device"] == [ + x.device_identifier for x in credential.drm_device_identifiers + ] def test_device_id_list_handler_get_success(self): credential = self._create_credential() @@ -5023,8 +5258,10 @@ def test_device_id_list_handler_get_success(self): assert 200 == response.status_code # We got a list of device IDs. - assert (self.controller.DEVICE_ID_LIST_MEDIA_TYPE == - response.headers['Content-Type']) + assert ( + self.controller.DEVICE_ID_LIST_MEDIA_TYPE + == response.headers["Content-Type"] + ) assert "device1\ndevice2" == response.get_data(as_text=True) # We got a URL Template (see test_link_template_header()) @@ -5041,9 +5278,7 @@ def device_id_list_handler_bad_auth(self): assert 401 == response.status_code def device_id_list_handler_bad_method(self): - with self.request_context_with_library( - "/", method='DELETE', headers=self.auth - ): + with self.request_context_with_library("/", method="DELETE", headers=self.auth): self.controller.authenticated_patron_from_request() response = self.controller.device_id_list_handler() assert isinstance(response, ProblemDetail) @@ -5052,9 +5287,9 @@ def device_id_list_handler_bad_method(self): def test_device_id_list_handler_too_many_simultaneous_registrations(self): # We only allow registration of one device ID at a time. headers = dict(self.auth) - headers['Content-Type'] = self.controller.DEVICE_ID_LIST_MEDIA_TYPE + headers["Content-Type"] = self.controller.DEVICE_ID_LIST_MEDIA_TYPE with self.request_context_with_library( - "/", method='POST', headers=headers, data="device1\ndevice2" + "/", method="POST", headers=headers, data="device1\ndevice2" ): self.controller.authenticated_patron_from_request() response = self.controller.device_id_list_handler() @@ -5063,32 +5298,32 @@ def test_device_id_list_handler_too_many_simultaneous_registrations(self): def test_device_id_list_handler_wrong_media_type(self): headers = dict(self.auth) - headers['Content-Type'] = "text/plain" + headers["Content-Type"] = "text/plain" with self.request_context_with_library( - "/", method='POST', headers=headers, data="device1\ndevice2" + "/", method="POST", headers=headers, data="device1\ndevice2" ): self.controller.authenticated_patron_from_request() response = self.controller.device_id_list_handler() assert 415 == response.status_code - assert ("Expected vnd.librarysimplified/drm-device-id-list document." == - response.detail) + assert ( + "Expected vnd.librarysimplified/drm-device-id-list document." + == response.detail + ) def test_device_id_handler_success(self): credential = self._create_credential() credential.register_drm_device_identifier("device") - with self.request_context_with_library( - "/", method='DELETE', headers=self.auth - ): + with self.request_context_with_library("/", method="DELETE", headers=self.auth): patron = self.controller.authenticated_patron_from_request() response = self.controller.device_id_handler("device") assert 200 == response.status_code def test_device_id_handler_bad_auth(self): - with self.request_context_with_library("/", method='DELETE'): + with self.request_context_with_library("/", method="DELETE"): with temp_config() as config: config[Configuration.INTEGRATIONS] = { - "Circulation Manager" : { "url" : "http://foo/" } + "Circulation Manager": {"url": "http://foo/"} } patron = self.controller.authenticated_patron_from_request() response = self.controller.device_id_handler("device") @@ -5096,7 +5331,7 @@ def test_device_id_handler_bad_auth(self): assert 401 == response.status_code def test_device_id_handler_bad_method(self): - with self.request_context_with_library("/", method='POST', headers=self.auth): + with self.request_context_with_library("/", method="POST", headers=self.auth): patron = self.controller.authenticated_patron_from_request() response = self.controller.device_id_handler("device") assert isinstance(response, ProblemDetail) @@ -5118,16 +5353,19 @@ def test_notify_success(self): loan.external_identifier = self._str with self.request_context_with_library("/", method="POST"): - flask.request.data = json.dumps({ - "id": loan.external_identifier, - "status": "revoked", - }) - response = self.manager.odl_notification_controller.notify( - loan.id) + flask.request.data = json.dumps( + { + "id": loan.external_identifier, + "status": "revoked", + } + ) + response = self.manager.odl_notification_controller.notify(loan.id) assert 200 == response.status_code # The pool's availability has been updated. - api = self.manager.circulation_apis[self._default_library.id].api_for_license_pool(loan.license_pool) + api = self.manager.circulation_apis[ + self._default_library.id + ].api_for_license_pool(loan.license_pool) assert [loan.license_pool] == api.availability_updated_for def test_notify_errors(self): @@ -5146,6 +5384,7 @@ def test_notify_errors(self): response = self.manager.odl_notification_controller.notify(loan.id) assert INVALID_LOAN_FOR_ODL_NOTIFICATION == response + class TestSharedCollectionController(ControllerTest): """Test that other circ managers can register to borrow books from a shared collection.""" @@ -5154,28 +5393,27 @@ def setup_method(self): self.setup_circulation_manager = False super(TestSharedCollectionController, self).setup_method() from api.odl import ODLAPI + self.collection = self._collection(protocol=ODLAPI.NAME) self._default_library.collections = [self.collection] self.client, ignore = IntegrationClient.register(self._db, "http://library.org") self.app.manager = self.circulation_manager_setup(self._db) - self.work = self._work( - with_license_pool=True, collection=self.collection - ) + self.work = self._work(with_license_pool=True, collection=self.collection) self.pool = self.work.license_pools[0] [self.delivery_mechanism] = self.pool.delivery_mechanisms @contextmanager def request_context_with_client(self, route, *args, **kwargs): - if 'client' in kwargs: - client = kwargs.pop('client') + if "client" in kwargs: + client = kwargs.pop("client") else: client = self.client - if 'headers' in kwargs: - headers = kwargs.pop('headers') + if "headers" in kwargs: + headers = kwargs.pop("headers") else: headers = dict() - headers['Authorization'] = "Bearer " + base64.b64encode(client.shared_secret) - kwargs['headers'] = headers + headers["Authorization"] = "Bearer " + base64.b64encode(client.shared_secret) + kwargs["headers"] = headers with self.app.test_request_context(route, *args, **kwargs) as c: yield c @@ -5184,19 +5422,30 @@ def test_info(self): collection = self.manager.shared_collection_controller.info(self._str) assert NO_SUCH_COLLECTION == collection - response = self.manager.shared_collection_controller.info(self.collection.name) + response = self.manager.shared_collection_controller.info( + self.collection.name + ) assert 200 == response.status_code - assert response.headers.get("Content-Type").startswith("application/opds+json") + assert response.headers.get("Content-Type").startswith( + "application/opds+json" + ) links = json.loads(response.get_data(as_text=True)).get("links") [register_link] = [link for link in links if link.get("rel") == "register"] - assert "/collections/%s/register" % self.collection.name in register_link.get("href") + assert ( + "/collections/%s/register" % self.collection.name + in register_link.get("href") + ) def test_load_collection(self): with self.app.test_request_context("/"): - collection = self.manager.shared_collection_controller.load_collection(self._str) + collection = self.manager.shared_collection_controller.load_collection( + self._str + ) assert NO_SUCH_COLLECTION == collection - collection = self.manager.shared_collection_controller.load_collection(self.collection.name) + collection = self.manager.shared_collection_controller.load_collection( + self.collection.name + ) assert self.collection == collection def test_register(self): @@ -5205,60 +5454,91 @@ def test_register(self): flask.request.form = ImmutableMultiDict([("url", "http://test")]) api.queue_register(InvalidInputException()) - response = self.manager.shared_collection_controller.register(self.collection.name) + response = self.manager.shared_collection_controller.register( + self.collection.name + ) assert 400 == response.status_code assert INVALID_REGISTRATION.uri == response.uri api.queue_register(AuthorizationFailedException()) - response = self.manager.shared_collection_controller.register(self.collection.name) + response = self.manager.shared_collection_controller.register( + self.collection.name + ) assert 401 == response.status_code assert INVALID_CREDENTIALS.uri == response.uri api.queue_register(RemoteInitiatedServerError("Error", "Service")) - response = self.manager.shared_collection_controller.register(self.collection.name) + response = self.manager.shared_collection_controller.register( + self.collection.name + ) assert 502 == response.status_code assert INTEGRATION_ERROR.uri == response.uri api.queue_register(dict(shared_secret="secret")) - response = self.manager.shared_collection_controller.register(self.collection.name) + response = self.manager.shared_collection_controller.register( + self.collection.name + ) assert 200 == response.status_code - assert "secret" == json.loads(response.get_data(as_text=True)).get("shared_secret") + assert "secret" == json.loads(response.get_data(as_text=True)).get( + "shared_secret" + ) def test_loan_info(self): now = utc_now() tomorrow = utc_now() + datetime.timedelta(days=1) - other_client, ignore = IntegrationClient.register(self._db, "http://otherlibrary") + other_client, ignore = IntegrationClient.register( + self._db, "http://otherlibrary" + ) other_client_loan, ignore = create( - self._db, Loan, license_pool=self.pool, integration_client=other_client, + self._db, + Loan, + license_pool=self.pool, + integration_client=other_client, ) ignore, other_pool = self._edition( - with_license_pool=True, collection=self._collection(), + with_license_pool=True, + collection=self._collection(), ) other_pool_loan, ignore = create( - self._db, Loan, license_pool=other_pool, integration_client=self.client, + self._db, + Loan, + license_pool=other_pool, + integration_client=self.client, ) loan, ignore = create( - self._db, Loan, license_pool=self.pool, integration_client=self.client, - start=now, end=tomorrow, + self._db, + Loan, + license_pool=self.pool, + integration_client=self.client, + start=now, + end=tomorrow, ) with self.request_context_with_client("/"): # This loan doesn't exist. - response = self.manager.shared_collection_controller.loan_info(self.collection.name, 1234567) + response = self.manager.shared_collection_controller.loan_info( + self.collection.name, 1234567 + ) assert LOAN_NOT_FOUND == response # This loan belongs to a different library. - response = self.manager.shared_collection_controller.loan_info(self.collection.name, other_client_loan.id) + response = self.manager.shared_collection_controller.loan_info( + self.collection.name, other_client_loan.id + ) assert LOAN_NOT_FOUND == response # This loan's pool belongs to a different collection. - response = self.manager.shared_collection_controller.loan_info(self.collection.name, other_pool_loan.id) + response = self.manager.shared_collection_controller.loan_info( + self.collection.name, other_pool_loan.id + ) assert LOAN_NOT_FOUND == response # This loan is ours. - response = self.manager.shared_collection_controller.loan_info(self.collection.name, loan.id) + response = self.manager.shared_collection_controller.loan_info( + self.collection.name, loan.id + ) assert 200 == response.status_code feed = feedparser.parse(response.data) [entry] = feed.get("entries") @@ -5266,53 +5546,113 @@ def test_loan_info(self): since = availability.get("since") until = availability.get("until") assert datetime.datetime.strftime(now, "%Y-%m-%dT%H:%M:%S+00:00") == since - assert datetime.datetime.strftime(tomorrow, "%Y-%m-%dT%H:%M:%S+00:00") == until - [revoke_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke"] - assert "/collections/%s/loans/%s/revoke" % (self.collection.name, loan.id) in revoke_url - [fulfill_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "http://opds-spec.org/acquisition"] - assert "/collections/%s/loans/%s/fulfill/%s" % (self.collection.name, loan.id, self.delivery_mechanism.delivery_mechanism.id) in fulfill_url - [self_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "self"] + assert ( + datetime.datetime.strftime(tomorrow, "%Y-%m-%dT%H:%M:%S+00:00") == until + ) + [revoke_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke" + ] + assert ( + "/collections/%s/loans/%s/revoke" % (self.collection.name, loan.id) + in revoke_url + ) + [fulfill_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "http://opds-spec.org/acquisition" + ] + assert ( + "/collections/%s/loans/%s/fulfill/%s" + % ( + self.collection.name, + loan.id, + self.delivery_mechanism.delivery_mechanism.id, + ) + in fulfill_url + ) + [self_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "self" + ] assert "/collections/%s/loans/%s" % (self.collection.name, loan.id) def test_borrow(self): now = utc_now() tomorrow = utc_now() + datetime.timedelta(days=1) loan, ignore = create( - self._db, Loan, license_pool=self.pool, integration_client=self.client, - start=now, end=tomorrow, + self._db, + Loan, + license_pool=self.pool, + integration_client=self.client, + start=now, + end=tomorrow, ) hold, ignore = create( - self._db, Hold, license_pool=self.pool, integration_client=self.client, - start=now, end=tomorrow, + self._db, + Hold, + license_pool=self.pool, + integration_client=self.client, + start=now, + end=tomorrow, ) no_pool = self._identifier() with self.request_context_with_client("/"): - response = self.manager.shared_collection_controller.borrow(self.collection.name, no_pool.type, no_pool.identifier, None) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, no_pool.type, no_pool.identifier, None + ) assert NO_LICENSES.uri == response.uri api = self.app.manager.shared_collection_controller.shared_collection # Attempt to borrow without a previous hold. api.queue_borrow(AuthorizationFailedException()) - response = self.manager.shared_collection_controller.borrow(self.collection.name, self.pool.identifier.type, self.pool.identifier.identifier, None) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, + self.pool.identifier.type, + self.pool.identifier.identifier, + None, + ) assert INVALID_CREDENTIALS.uri == response.uri api.queue_borrow(CannotLoan()) - response = self.manager.shared_collection_controller.borrow(self.collection.name, self.pool.identifier.type, self.pool.identifier.identifier, None) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, + self.pool.identifier.type, + self.pool.identifier.identifier, + None, + ) assert CHECKOUT_FAILED.uri == response.uri api.queue_borrow(NoAvailableCopies()) - response = self.manager.shared_collection_controller.borrow(self.collection.name, self.pool.identifier.type, self.pool.identifier.identifier, None) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, + self.pool.identifier.type, + self.pool.identifier.identifier, + None, + ) assert NO_AVAILABLE_LICENSE.uri == response.uri api.queue_borrow(RemoteIntegrationException("error!", "service")) - response = self.manager.shared_collection_controller.borrow(self.collection.name, self.pool.identifier.type, self.pool.identifier.identifier, None) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, + self.pool.identifier.type, + self.pool.identifier.identifier, + None, + ) assert INTEGRATION_ERROR.uri == response.uri api.queue_borrow(loan) - response = self.manager.shared_collection_controller.borrow(self.collection.name, self.pool.identifier.type, self.pool.identifier.identifier, None) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, + self.pool.identifier.type, + self.pool.identifier.identifier, + None, + ) assert 201 == response.status_code feed = feedparser.parse(response.data) [entry] = feed.get("entries") @@ -5320,34 +5660,72 @@ def test_borrow(self): since = availability.get("since") until = availability.get("until") assert datetime.datetime.strftime(now, "%Y-%m-%dT%H:%M:%S+00:00") == since - assert datetime.datetime.strftime(tomorrow, "%Y-%m-%dT%H:%M:%S+00:00") == until + assert ( + datetime.datetime.strftime(tomorrow, "%Y-%m-%dT%H:%M:%S+00:00") == until + ) assert "available" == availability.get("status") - [revoke_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke"] - assert "/collections/%s/loans/%s/revoke" % (self.collection.name, loan.id) in revoke_url - [fulfill_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "http://opds-spec.org/acquisition"] - assert "/collections/%s/loans/%s/fulfill/%s" % (self.collection.name, loan.id, self.delivery_mechanism.delivery_mechanism.id) in fulfill_url - [self_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "self"] + [revoke_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke" + ] + assert ( + "/collections/%s/loans/%s/revoke" % (self.collection.name, loan.id) + in revoke_url + ) + [fulfill_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "http://opds-spec.org/acquisition" + ] + assert ( + "/collections/%s/loans/%s/fulfill/%s" + % ( + self.collection.name, + loan.id, + self.delivery_mechanism.delivery_mechanism.id, + ) + in fulfill_url + ) + [self_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "self" + ] assert "/collections/%s/loans/%s" % (self.collection.name, loan.id) # Now try to borrow when we already have a previous hold. api.queue_borrow(AuthorizationFailedException()) - response = self.manager.shared_collection_controller.borrow(self.collection.name, self.pool.identifier.type, self.pool.identifier.identifier, hold.id) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, + self.pool.identifier.type, + self.pool.identifier.identifier, + hold.id, + ) assert INVALID_CREDENTIALS.uri == response.uri api.queue_borrow(CannotLoan()) - response = self.manager.shared_collection_controller.borrow(self.collection.name, None, None, hold.id) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, None, None, hold.id + ) assert CHECKOUT_FAILED.uri == response.uri api.queue_borrow(NoAvailableCopies()) - response = self.manager.shared_collection_controller.borrow(self.collection.name, None, None, hold.id) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, None, None, hold.id + ) assert NO_AVAILABLE_LICENSE.uri == response.uri api.queue_borrow(RemoteIntegrationException("error!", "service")) - response = self.manager.shared_collection_controller.borrow(self.collection.name, None, None, hold.id) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, None, None, hold.id + ) assert INTEGRATION_ERROR.uri == response.uri api.queue_borrow(loan) - response = self.manager.shared_collection_controller.borrow(self.collection.name, None, None, hold.id) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, None, None, hold.id + ) assert 201 == response.status_code feed = feedparser.parse(response.data) [entry] = feed.get("entries") @@ -5356,17 +5734,47 @@ def test_borrow(self): until = availability.get("until") assert "available" == availability.get("status") assert datetime.datetime.strftime(now, "%Y-%m-%dT%H:%M:%S+00:00") == since - assert datetime.datetime.strftime(tomorrow, "%Y-%m-%dT%H:%M:%S+00:00") == until - [revoke_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke"] - assert "/collections/%s/loans/%s/revoke" % (self.collection.name, loan.id) in revoke_url - [fulfill_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "http://opds-spec.org/acquisition"] - assert "/collections/%s/loans/%s/fulfill/%s" % (self.collection.name, loan.id, self.delivery_mechanism.delivery_mechanism.id) in fulfill_url - [self_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "self"] + assert ( + datetime.datetime.strftime(tomorrow, "%Y-%m-%dT%H:%M:%S+00:00") == until + ) + [revoke_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke" + ] + assert ( + "/collections/%s/loans/%s/revoke" % (self.collection.name, loan.id) + in revoke_url + ) + [fulfill_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "http://opds-spec.org/acquisition" + ] + assert ( + "/collections/%s/loans/%s/fulfill/%s" + % ( + self.collection.name, + loan.id, + self.delivery_mechanism.delivery_mechanism.id, + ) + in fulfill_url + ) + [self_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "self" + ] assert "/collections/%s/loans/%s" % (self.collection.name, loan.id) # Now try to borrow, but actually get a hold. api.queue_borrow(hold) - response = self.manager.shared_collection_controller.borrow(self.collection.name, self.pool.identifier.type, self.pool.identifier.identifier, None) + response = self.manager.shared_collection_controller.borrow( + self.collection.name, + self.pool.identifier.type, + self.pool.identifier.identifier, + None, + ) assert 201 == response.status_code feed = feedparser.parse(response.data) [entry] = feed.get("entries") @@ -5374,92 +5782,152 @@ def test_borrow(self): since = availability.get("since") until = availability.get("until") assert datetime.datetime.strftime(now, "%Y-%m-%dT%H:%M:%S+00:00") == since - assert datetime.datetime.strftime(tomorrow, "%Y-%m-%dT%H:%M:%S+00:00") == until + assert ( + datetime.datetime.strftime(tomorrow, "%Y-%m-%dT%H:%M:%S+00:00") == until + ) assert "reserved" == availability.get("status") - [revoke_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke"] - assert "/collections/%s/holds/%s/revoke" % (self.collection.name, hold.id) in revoke_url - assert [] == [link.get("href") for link in entry.get("links") if link.get("rel") == "http://opds-spec.org/acquisition"] - [self_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "self"] + [revoke_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke" + ] + assert ( + "/collections/%s/holds/%s/revoke" % (self.collection.name, hold.id) + in revoke_url + ) + assert [] == [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "http://opds-spec.org/acquisition" + ] + [self_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "self" + ] assert "/collections/%s/holds/%s" % (self.collection.name, hold.id) def test_revoke_loan(self): now = utc_now() tomorrow = utc_now() + datetime.timedelta(days=1) loan, ignore = create( - self._db, Loan, license_pool=self.pool, integration_client=self.client, - start=now, end=tomorrow, + self._db, + Loan, + license_pool=self.pool, + integration_client=self.client, + start=now, + end=tomorrow, ) - other_client, ignore = IntegrationClient.register(self._db, "http://otherlibrary") + other_client, ignore = IntegrationClient.register( + self._db, "http://otherlibrary" + ) other_client_loan, ignore = create( - self._db, Loan, license_pool=self.pool, integration_client=other_client, + self._db, + Loan, + license_pool=self.pool, + integration_client=other_client, ) ignore, other_pool = self._edition( - with_license_pool=True, collection=self._collection(), + with_license_pool=True, + collection=self._collection(), ) other_pool_loan, ignore = create( - self._db, Loan, license_pool=other_pool, integration_client=self.client, + self._db, + Loan, + license_pool=other_pool, + integration_client=self.client, ) with self.request_context_with_client("/"): - response = self.manager.shared_collection_controller.revoke_loan(self.collection.name, other_pool_loan.id) + response = self.manager.shared_collection_controller.revoke_loan( + self.collection.name, other_pool_loan.id + ) assert LOAN_NOT_FOUND.uri == response.uri - response = self.manager.shared_collection_controller.revoke_loan(self.collection.name, other_client_loan.id) + response = self.manager.shared_collection_controller.revoke_loan( + self.collection.name, other_client_loan.id + ) assert LOAN_NOT_FOUND.uri == response.uri api = self.app.manager.shared_collection_controller.shared_collection api.queue_revoke_loan(AuthorizationFailedException()) - response = self.manager.shared_collection_controller.revoke_loan(self.collection.name, loan.id) + response = self.manager.shared_collection_controller.revoke_loan( + self.collection.name, loan.id + ) assert INVALID_CREDENTIALS.uri == response.uri api.queue_revoke_loan(CannotReturn()) - response = self.manager.shared_collection_controller.revoke_loan(self.collection.name, loan.id) + response = self.manager.shared_collection_controller.revoke_loan( + self.collection.name, loan.id + ) assert COULD_NOT_MIRROR_TO_REMOTE.uri == response.uri api.queue_revoke_loan(NotCheckedOut()) - response = self.manager.shared_collection_controller.revoke_loan(self.collection.name, loan.id) + response = self.manager.shared_collection_controller.revoke_loan( + self.collection.name, loan.id + ) assert NO_ACTIVE_LOAN.uri == response.uri def test_fulfill(self): now = utc_now() tomorrow = utc_now() + datetime.timedelta(days=1) loan, ignore = create( - self._db, Loan, license_pool=self.pool, integration_client=self.client, - start=now, end=tomorrow, + self._db, + Loan, + license_pool=self.pool, + integration_client=self.client, + start=now, + end=tomorrow, ) ignore, other_pool = self._edition( - with_license_pool=True, collection=self._collection(), + with_license_pool=True, + collection=self._collection(), ) other_pool_loan, ignore = create( - self._db, Loan, license_pool=other_pool, integration_client=self.client, + self._db, + Loan, + license_pool=other_pool, + integration_client=self.client, ) with self.request_context_with_client("/"): - response = self.manager.shared_collection_controller.fulfill(self.collection.name, other_pool_loan.id, None) + response = self.manager.shared_collection_controller.fulfill( + self.collection.name, other_pool_loan.id, None + ) assert LOAN_NOT_FOUND.uri == response.uri api = self.app.manager.shared_collection_controller.shared_collection # If the loan doesn't have a mechanism set, we need to specify one. - response = self.manager.shared_collection_controller.fulfill(self.collection.name, loan.id, None) + response = self.manager.shared_collection_controller.fulfill( + self.collection.name, loan.id, None + ) assert BAD_DELIVERY_MECHANISM.uri == response.uri loan.fulfillment = self.delivery_mechanism api.queue_fulfill(AuthorizationFailedException()) - response = self.manager.shared_collection_controller.fulfill(self.collection.name, loan.id, None) + response = self.manager.shared_collection_controller.fulfill( + self.collection.name, loan.id, None + ) assert INVALID_CREDENTIALS.uri == response.uri api.queue_fulfill(CannotFulfill()) - response = self.manager.shared_collection_controller.fulfill(self.collection.name, loan.id, None) + response = self.manager.shared_collection_controller.fulfill( + self.collection.name, loan.id, None + ) assert CANNOT_FULFILL.uri == response.uri api.queue_fulfill(RemoteIntegrationException("error!", "service")) - response = self.manager.shared_collection_controller.fulfill(self.collection.name, loan.id, self.delivery_mechanism.delivery_mechanism.id) + response = self.manager.shared_collection_controller.fulfill( + self.collection.name, + loan.id, + self.delivery_mechanism.delivery_mechanism.id, + ) assert INTEGRATION_ERROR.uri == response.uri fulfillment_info = FulfillmentInfo( @@ -5467,20 +5935,36 @@ def test_fulfill(self): self.pool.data_source.name, self.pool.identifier.type, self.pool.identifier.identifier, - "http://content", "text/html", None, + "http://content", + "text/html", + None, utc_now(), ) api.queue_fulfill(fulfillment_info) + def do_get_error(url): raise RemoteIntegrationException("error!", "service") - response = self.manager.shared_collection_controller.fulfill(self.collection.name, loan.id, self.delivery_mechanism.delivery_mechanism.id, do_get=do_get_error) + + response = self.manager.shared_collection_controller.fulfill( + self.collection.name, + loan.id, + self.delivery_mechanism.delivery_mechanism.id, + do_get=do_get_error, + ) assert INTEGRATION_ERROR.uri == response.uri api.queue_fulfill(fulfillment_info) + def do_get_success(url): return MockRequestsResponse(200, content="Content") - response = self.manager.shared_collection_controller.fulfill(self.collection.name, loan.id, self.delivery_mechanism.delivery_mechanism.id, do_get=do_get_success) + + response = self.manager.shared_collection_controller.fulfill( + self.collection.name, + loan.id, + self.delivery_mechanism.delivery_mechanism.id, + do_get=do_get_success, + ) assert 200 == response.status_code assert "Content" == response.get_data(as_text=True) assert "text/html" == response.headers.get("Content-Type") @@ -5488,7 +5972,11 @@ def do_get_success(url): fulfillment_info.content_link = None fulfillment_info.content = "Content" api.queue_fulfill(fulfillment_info) - response = self.manager.shared_collection_controller.fulfill(self.collection.name, loan.id, self.delivery_mechanism.delivery_mechanism.id) + response = self.manager.shared_collection_controller.fulfill( + self.collection.name, + loan.id, + self.delivery_mechanism.delivery_mechanism.id, + ) assert 200 == response.status_code assert "Content" == response.get_data(as_text=True) assert "text/html" == response.headers.get("Content-Type") @@ -5497,37 +5985,58 @@ def test_hold_info(self): now = utc_now() tomorrow = utc_now() + datetime.timedelta(days=1) - other_client, ignore = IntegrationClient.register(self._db, "http://otherlibrary") + other_client, ignore = IntegrationClient.register( + self._db, "http://otherlibrary" + ) other_client_hold, ignore = create( - self._db, Hold, license_pool=self.pool, integration_client=other_client, + self._db, + Hold, + license_pool=self.pool, + integration_client=other_client, ) ignore, other_pool = self._edition( - with_license_pool=True, collection=self._collection(), + with_license_pool=True, + collection=self._collection(), ) other_pool_hold, ignore = create( - self._db, Hold, license_pool=other_pool, integration_client=self.client, + self._db, + Hold, + license_pool=other_pool, + integration_client=self.client, ) hold, ignore = create( - self._db, Hold, license_pool=self.pool, integration_client=self.client, - start=now, end=tomorrow, + self._db, + Hold, + license_pool=self.pool, + integration_client=self.client, + start=now, + end=tomorrow, ) with self.request_context_with_client("/"): # This hold doesn't exist. - response = self.manager.shared_collection_controller.hold_info(self.collection.name, 1234567) + response = self.manager.shared_collection_controller.hold_info( + self.collection.name, 1234567 + ) assert HOLD_NOT_FOUND == response # This hold belongs to a different library. - response = self.manager.shared_collection_controller.hold_info(self.collection.name, other_client_hold.id) + response = self.manager.shared_collection_controller.hold_info( + self.collection.name, other_client_hold.id + ) assert HOLD_NOT_FOUND == response # This hold's pool belongs to a different collection. - response = self.manager.shared_collection_controller.hold_info(self.collection.name, other_pool_hold.id) + response = self.manager.shared_collection_controller.hold_info( + self.collection.name, other_pool_hold.id + ) assert HOLD_NOT_FOUND == response # This hold is ours. - response = self.manager.shared_collection_controller.hold_info(self.collection.name, hold.id) + response = self.manager.shared_collection_controller.hold_info( + self.collection.name, hold.id + ) assert 200 == response.status_code feed = feedparser.parse(response.data) [entry] = feed.get("entries") @@ -5535,52 +6044,92 @@ def test_hold_info(self): since = availability.get("since") until = availability.get("until") assert datetime.datetime.strftime(now, "%Y-%m-%dT%H:%M:%S+00:00") == since - assert datetime.datetime.strftime(tomorrow, "%Y-%m-%dT%H:%M:%S+00:00") == until - [revoke_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke"] - assert "/collections/%s/holds/%s/revoke" % (self.collection.name, hold.id) in revoke_url - assert [] == [link.get("href") for link in entry.get("links") if link.get("rel") == "http://opds-spec.org/acquisition"] - [self_url] = [link.get("href") for link in entry.get("links") if link.get("rel") == "self"] + assert ( + datetime.datetime.strftime(tomorrow, "%Y-%m-%dT%H:%M:%S+00:00") == until + ) + [revoke_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "http://librarysimplified.org/terms/rel/revoke" + ] + assert ( + "/collections/%s/holds/%s/revoke" % (self.collection.name, hold.id) + in revoke_url + ) + assert [] == [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "http://opds-spec.org/acquisition" + ] + [self_url] = [ + link.get("href") + for link in entry.get("links") + if link.get("rel") == "self" + ] assert "/collections/%s/holds/%s" % (self.collection.name, hold.id) def test_revoke_hold(self): now = utc_now() tomorrow = utc_now() + datetime.timedelta(days=1) hold, ignore = create( - self._db, Hold, license_pool=self.pool, integration_client=self.client, - start=now, end=tomorrow, + self._db, + Hold, + license_pool=self.pool, + integration_client=self.client, + start=now, + end=tomorrow, ) - other_client, ignore = IntegrationClient.register(self._db, "http://otherlibrary") + other_client, ignore = IntegrationClient.register( + self._db, "http://otherlibrary" + ) other_client_hold, ignore = create( - self._db, Hold, license_pool=self.pool, integration_client=other_client, + self._db, + Hold, + license_pool=self.pool, + integration_client=other_client, ) ignore, other_pool = self._edition( - with_license_pool=True, collection=self._collection(), + with_license_pool=True, + collection=self._collection(), ) other_pool_hold, ignore = create( - self._db, Hold, license_pool=other_pool, integration_client=self.client, + self._db, + Hold, + license_pool=other_pool, + integration_client=self.client, ) with self.request_context_with_client("/"): - response = self.manager.shared_collection_controller.revoke_hold(self.collection.name, other_pool_hold.id) + response = self.manager.shared_collection_controller.revoke_hold( + self.collection.name, other_pool_hold.id + ) assert HOLD_NOT_FOUND.uri == response.uri - response = self.manager.shared_collection_controller.revoke_hold(self.collection.name, other_client_hold.id) + response = self.manager.shared_collection_controller.revoke_hold( + self.collection.name, other_client_hold.id + ) assert HOLD_NOT_FOUND.uri == response.uri api = self.app.manager.shared_collection_controller.shared_collection api.queue_revoke_hold(AuthorizationFailedException()) - response = self.manager.shared_collection_controller.revoke_hold(self.collection.name, hold.id) + response = self.manager.shared_collection_controller.revoke_hold( + self.collection.name, hold.id + ) assert INVALID_CREDENTIALS.uri == response.uri api.queue_revoke_hold(CannotReleaseHold()) - response = self.manager.shared_collection_controller.revoke_hold(self.collection.name, hold.id) + response = self.manager.shared_collection_controller.revoke_hold( + self.collection.name, hold.id + ) assert CANNOT_RELEASE_HOLD.uri == response.uri api.queue_revoke_hold(NotOnHold()) - response = self.manager.shared_collection_controller.revoke_hold(self.collection.name, hold.id) + response = self.manager.shared_collection_controller.revoke_hold( + self.collection.name, hold.id + ) assert NO_ACTIVE_HOLD.uri == response.uri @@ -5599,27 +6148,25 @@ def test_work_lookup(self): # We got an OPDS feed. assert 200 == response.status_code - assert ( - OPDSFeed.ACQUISITION_FEED_TYPE == - response.headers['Content-Type']) + assert OPDSFeed.ACQUISITION_FEED_TYPE == response.headers["Content-Type"] # Parse it. feed = feedparser.parse(response.data) # The route name we passed into work_lookup shows up in # the feed-level link with rel="self". - [self_link] = feed['feed']['links'] - assert '/' + route_name in self_link['href'] + [self_link] = feed["feed"]["links"] + assert "/" + route_name in self_link["href"] # The work we looked up has an OPDS entry. - [entry] = feed['entries'] - assert work.title == entry['title'] + [entry] = feed["entries"] + assert work.title == entry["title"] # The OPDS feed includes an open-access acquisition link # -- something that only gets inserted by the # CirculationManagerAnnotator. [link] = entry.links - assert LinkRelations.OPEN_ACCESS_DOWNLOAD == link['rel'] + assert LinkRelations.OPEN_ACCESS_DOWNLOAD == link["rel"] class TestProfileController(ControllerTest): @@ -5638,46 +6185,47 @@ def setup_method(self): def test_controller_uses_circulation_patron_profile_storage(self): """Verify that this controller uses circulation manager-specific extensions.""" - with self.request_context_with_library( - "/", method='GET', headers=self.auth - ): - assert isinstance(self.manager.profiles._controller.storage, CirculationPatronProfileStorage) + with self.request_context_with_library("/", method="GET", headers=self.auth): + assert isinstance( + self.manager.profiles._controller.storage, + CirculationPatronProfileStorage, + ) def test_get(self): """Verify that a patron can see their own profile.""" - with self.request_context_with_library( - "/", method='GET', headers=self.auth - ): + with self.request_context_with_library("/", method="GET", headers=self.auth): patron = self.controller.authenticated_patron_from_request() patron.synchronize_annotations = True response = self.manager.profiles.protocol() assert "200 OK" == response.status data = json.loads(response.get_data(as_text=True)) - settings = data['settings'] + settings = data["settings"] assert True == settings[ProfileStorage.SYNCHRONIZE_ANNOTATIONS] def test_put(self): """Verify that a patron can modify their own profile.""" - payload = { - 'settings': { - ProfileStorage.SYNCHRONIZE_ANNOTATIONS: True - } - } + payload = {"settings": {ProfileStorage.SYNCHRONIZE_ANNOTATIONS: True}} request_patron = None identifier = self._identifier() with self.request_context_with_library( - "/", method='PUT', headers=self.auth, - content_type=ProfileController.MEDIA_TYPE, - data=json.dumps(payload) + "/", + method="PUT", + headers=self.auth, + content_type=ProfileController.MEDIA_TYPE, + data=json.dumps(payload), ): # By default, a patron has no value for synchronize_annotations. request_patron = self.controller.authenticated_patron_from_request() assert None == request_patron.synchronize_annotations # This means we can't create annotations for them. - pytest.raises(ValueError, Annotation.get_one_or_create, - self._db, patron=request_patron, identifier=identifier + pytest.raises( + ValueError, + Annotation.get_one_or_create, + self._db, + patron=request_patron, + identifier=identifier, ) # But by sending a PUT request... @@ -5692,16 +6240,19 @@ def test_put(self): # Now we can create an annotation for the patron who enabled # annotation sync. annotation = Annotation.get_one_or_create( - self._db, patron=request_patron, identifier=identifier) + self._db, patron=request_patron, identifier=identifier + ) assert 1 == len(request_patron.annotations) # But if we make another request and change their # synchronize_annotations field to False... - payload['settings'][ProfileStorage.SYNCHRONIZE_ANNOTATIONS] = False + payload["settings"][ProfileStorage.SYNCHRONIZE_ANNOTATIONS] = False with self.request_context_with_library( - "/", method='PUT', headers=self.auth, - content_type=ProfileController.MEDIA_TYPE, - data=json.dumps(payload) + "/", + method="PUT", + headers=self.auth, + content_type=ProfileController.MEDIA_TYPE, + data=json.dumps(payload), ): response = self.manager.profiles.protocol() @@ -5715,14 +6266,15 @@ def test_problemdetail_on_error(self): from the controller. """ with self.request_context_with_library( - "/", method='PUT', headers=self.auth, - content_type="text/plain", + "/", + method="PUT", + headers=self.auth, + content_type="text/plain", ): response = self.manager.profiles.protocol() assert isinstance(response, ProblemDetail) assert 415 == response.status_code - assert ("Expected vnd.librarysimplified/user-profile+json" == - response.detail) + assert "Expected vnd.librarysimplified/user-profile+json" == response.detail class TestScopedSession(ControllerTest): @@ -5760,7 +6312,9 @@ def make_default_collection(self, _db, library): uses the scoped session. """ collection, ignore = create( - _db, Collection, name=self._str + " (collection for scoped session)", + _db, + Collection, + name=self._str + " (collection for scoped session)", ) collection.create_external_integration(ExternalIntegration.OPDS_IMPORT) library.collections.append(collection) @@ -5773,9 +6327,7 @@ def test_request_context_and_transaction(self, *args): """ with self.app.test_request_context(*args) as ctx: transaction = current_session.begin_nested() - self.app.manager = self.circulation_manager_setup( - current_session - ) + self.app.manager = self.circulation_manager_setup(current_session) yield ctx transaction.rollback() @@ -5860,6 +6412,7 @@ def test_scoped_session(self): # used by most other unit tests. assert session1 != session2 + class TestStaticFileController(CirculationControllerTest): def test_static_file(self): cache_timeout = ConfigurationSetting.sitewide( @@ -5867,7 +6420,9 @@ def test_static_file(self): ) cache_timeout.value = 10 - directory = os.path.join(os.path.abspath(os.path.dirname(__file__)), "files", "images") + directory = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "files", "images" + ) filename = "blue.jpg" with open(os.path.join(directory, filename), "rb") as f: expected_content = f.read() @@ -5876,15 +6431,21 @@ def test_static_file(self): response = self.app.manager.static_files.static_file(directory, filename) assert 200 == response.status_code - assert 'public, max-age=10' == response.headers.get('Cache-Control') + assert "public, max-age=10" == response.headers.get("Cache-Control") assert expected_content == response.response.file.read() with self.app.test_request_context("/"): - pytest.raises(NotFound, self.app.manager.static_files.static_file, - directory, "missing.png") + pytest.raises( + NotFound, + self.app.manager.static_files.static_file, + directory, + "missing.png", + ) def test_image(self): - directory = os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", "resources", "images") + directory = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "..", "resources", "images" + ) filename = "CleverLoginButton280.png" with open(os.path.join(directory, filename), "rb") as f: expected_content = f.read() diff --git a/tests/test_coverage.py b/tests/test_coverage.py index 485311d787..5522504117 100644 --- a/tests/test_coverage.py +++ b/tests/test_coverage.py @@ -1,16 +1,13 @@ import pytest -from core.testing import ( - DatabaseTest, -) - -from core.testing import MockRequestsResponse - -from core.config import ( - CannotLoadConfiguration, - Configuration, - temp_config, +from api.coverage import ( + MockOPDSImportCoverageProvider, + OPDSImportCoverageProvider, + ReaperImporter, + RegistrarImporter, ) +from core.config import CannotLoadConfiguration, Configuration, temp_config +from core.coverage import CoverageFailure from core.model import ( Collection, DataSource, @@ -18,25 +15,11 @@ Identifier, LicensePool, ) -from core.util.opds_writer import ( - OPDSFeed, -) -from core.opds_import import ( - MockSimplifiedOPDSLookup, - OPDSImporter, -) -from core.coverage import ( - CoverageFailure, -) - +from core.opds_import import MockSimplifiedOPDSLookup, OPDSImporter +from core.testing import DatabaseTest, MockRequestsResponse from core.util.http import BadResponseException +from core.util.opds_writer import OPDSFeed -from api.coverage import ( - MockOPDSImportCoverageProvider, - OPDSImportCoverageProvider, - ReaperImporter, - RegistrarImporter, -) class TestImporterSubclasses(DatabaseTest): """Test the subclasses of OPDSImporter.""" @@ -50,7 +33,6 @@ def test_success_status_codes(self): class TestOPDSImportCoverageProvider(DatabaseTest): - def _provider(self): """Create a generic MockOPDSImportCoverageProvider for testing purposes.""" return MockOPDSImportCoverageProvider(self._default_collection) @@ -62,7 +44,9 @@ def test_badresponseexception_on_non_opds_feed(self): provider = self._provider() provider.lookup_client = MockSimplifiedOPDSLookup(self._url) - response = MockRequestsResponse(200, {"content-type" : "text/plain"}, "Some data") + response = MockRequestsResponse( + 200, {"content-type": "text/plain"}, "Some data" + ) provider.lookup_client.queue_response(response) with pytest.raises(BadResponseException) as excinfo: provider.import_feed_response(response, None) @@ -89,7 +73,8 @@ def create_identifier_mapping(self, batch): # And create an ExternalIntegration for the metadata_client object. self._external_integration( ExternalIntegration.METADATA_WRANGLER, - goal=ExternalIntegration.METADATA_GOAL, url=self._url + goal=ExternalIntegration.METADATA_GOAL, + url=self._url, ) self._default_collection.external_integration.set_setting( @@ -101,10 +86,13 @@ def create_identifier_mapping(self, batch): # foreign data source knows the book as id2. id1 = self._identifier() id2 = self._identifier() - provider.mapping = { id2 : id1 } + provider.mapping = {id2: id1} - feed = "%sHere's your title!" % id2.urn - headers = {"content-type" : OPDSFeed.ACQUISITION_FEED_TYPE} + feed = ( + "%sHere's your title!" + % id2.urn + ) + headers = {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE} lookup.queue_response(200, headers=headers, content=feed) [identifier] = provider.process_batch([id1]) @@ -127,8 +115,11 @@ def test_process_batch(self): license_source = DataSource.lookup(self._db, DataSource.GUTENBERG) pool, is_new = LicensePool.for_foreign_id( - self._db, license_source, identifier.type, identifier.identifier, - collection=self._default_collection + self._db, + license_source, + identifier.type, + identifier.identifier, + collection=self._default_collection, ) assert None == pool.work @@ -144,26 +135,26 @@ def test_process_batch(self): error_identifier = self._identifier() not_an_error_identifier = self._identifier() messages_by_id = { - error_identifier.urn : CoverageFailure( + error_identifier.urn: CoverageFailure( error_identifier, "500: internal error" ), - not_an_error_identifier.urn : not_an_error_identifier, + not_an_error_identifier.urn: not_an_error_identifier, } # When we call CoverageProvider.process_batch(), it's going to # return the information we just set up: a matched # Edition/LicensePool pair, a mismatched LicensePool, and an # error message. - provider.queue_import_results( - [edition], [pool, pool2], [], messages_by_id - ) + provider.queue_import_results([edition], [pool, pool2], [], messages_by_id) # Make the CoverageProvider do its thing. fake_batch = [object()] - (success_import, failure_mismatched, failure_message, - success_message) = provider.process_batch( - fake_batch - ) + ( + success_import, + failure_mismatched, + failure_message, + success_message, + ) = provider.process_batch(fake_batch) # The fake batch was provided to lookup_and_import_batch. assert [fake_batch] == provider.batches @@ -178,8 +169,10 @@ def test_process_batch(self): # The mismatched LicensePool turned into a CoverageFailure # object. assert isinstance(failure_mismatched, CoverageFailure) - assert ('OPDS import operation imported LicensePool, but no Edition.' == - failure_mismatched.exception) + assert ( + "OPDS import operation imported LicensePool, but no Edition." + == failure_mismatched.exception + ) assert pool2.identifier == failure_mismatched.obj assert True == failure_mismatched.transient @@ -238,8 +231,10 @@ def import_from_feed(self, text): right values. """ return ( - text, self.collection, - self.identifier_mapping, self.data_source_name + text, + self.collection, + self.identifier_mapping, + self.data_source_name, ) class MockProvider(MockOPDSImportCoverageProvider): @@ -249,11 +244,12 @@ class MockProvider(MockOPDSImportCoverageProvider): provider.lookup_client = MockSimplifiedOPDSLookup(self._url) response = MockRequestsResponse( - 200, {'content-type': OPDSFeed.ACQUISITION_FEED_TYPE}, "some data" + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, "some data" ) id_mapping = object() - (text, collection, mapping, - data_source_name) = provider.import_feed_response(response, id_mapping) + (text, collection, mapping, data_source_name) = provider.import_feed_response( + response, id_mapping + ) assert "some data" == text assert provider.collection == collection assert id_mapping == mapping diff --git a/tests/test_custom_index.py b/tests/test_custom_index.py index 16a5b721f2..871e7b3896 100644 --- a/tests/test_custom_index.py +++ b/tests/test_custom_index.py @@ -1,23 +1,15 @@ import pytest - -from lxml import etree - from flask import Response +from lxml import etree +from api.config import CannotLoadConfiguration +from api.custom_index import COPPAGate, CustomIndexView from core.model import ConfigurationSetting - +from core.testing import DatabaseTest from core.util.opds_writer import OPDSFeed -from api.config import CannotLoadConfiguration -from api.custom_index import ( - CustomIndexView, - COPPAGate, -) - -from core.testing import DatabaseTest class TestCustomIndexView(DatabaseTest): - def test_register(self): c = CustomIndexView old_registry = c.BY_PROTOCOL @@ -39,10 +31,7 @@ class Mock2(object): def test_default_registry(self): """Verify the default contents of the registry.""" - assert ( - {COPPAGate.PROTOCOL : COPPAGate} == - CustomIndexView.BY_PROTOCOL) - + assert {COPPAGate.PROTOCOL: COPPAGate} == CustomIndexView.BY_PROTOCOL def test_for_library(self): m = CustomIndexView.for_library @@ -51,8 +40,10 @@ def test_for_library(self): # instantiated. class MockCustomIndexView(object): PROTOCOL = self._str + def __init__(self, library, integration): self.instantiated_with = (library, integration) + CustomIndexView.register(MockCustomIndexView) # By default, a library has no CustomIndexView. @@ -61,8 +52,9 @@ def __init__(self, library, integration): # But if a library has an ExternalIntegration that corresponds # to a registered CustomIndexView... integration = self._external_integration( - MockCustomIndexView.PROTOCOL, CustomIndexView.GOAL, - libraries=[self._default_library] + MockCustomIndexView.PROTOCOL, + CustomIndexView.GOAL, + libraries=[self._default_library], ) # A CustomIndexView of the appropriate class is instantiated @@ -73,24 +65,26 @@ def __init__(self, library, integration): class TestCOPPAGate(DatabaseTest): - def setup_method(self): super(TestCOPPAGate, self).setup_method() # Configure a COPPAGate for the default library. self.integration = self._external_integration( - COPPAGate.PROTOCOL, CustomIndexView.GOAL, - libraries=[self._default_library] + COPPAGate.PROTOCOL, CustomIndexView.GOAL, libraries=[self._default_library] ) self.lane1 = self._lane() self.lane2 = self._lane() m = ConfigurationSetting.for_library_and_externalintegration m( - self._db, COPPAGate.REQUIREMENT_MET_LANE, self._default_library, - self.integration + self._db, + COPPAGate.REQUIREMENT_MET_LANE, + self._default_library, + self.integration, ).value = self.lane1.id m( - self._db, COPPAGate.REQUIREMENT_NOT_MET_LANE, self._default_library, - self.integration + self._db, + COPPAGate.REQUIREMENT_NOT_MET_LANE, + self._default_library, + self.integration, ).value = self.lane2.id def test_lane_loading(self): @@ -106,14 +100,18 @@ def test_lane_loading(self): self._db.commit() with pytest.raises(CannotLoadConfiguration) as excinfo: COPPAGate(self._default_library, self.integration) - assert "Lane {} is for the wrong library".format(self.lane1.id) in str(excinfo.value) + assert "Lane {} is for the wrong library".format(self.lane1.id) in str( + excinfo.value + ) self.lane1.library_id = self._default_library.id # If the lane ID doesn't correspond to a real lane, the # COPPAGate cannot be instantiated. ConfigurationSetting.for_library_and_externalintegration( - self._db, COPPAGate.REQUIREMENT_MET_LANE, self._default_library, - self.integration + self._db, + COPPAGate.REQUIREMENT_MET_LANE, + self._default_library, + self.integration, ).value = -100 with pytest.raises(CannotLoadConfiguration) as excinfo: COPPAGate(self._default_library, self.integration) @@ -125,6 +123,7 @@ def test_invocation(self): class MockCOPPAGate(COPPAGate): def _navigation_feed(self, *args, **kwargs): return "fake feed" + gate = MockCOPPAGate(self._default_library, self.integration) # Calling a COPPAGate creates a Response. @@ -134,7 +133,7 @@ def _navigation_feed(self, *args, **kwargs): # The entity-body is the result of calling _navigation_feed, # which has been cached as .navigation_feed. assert "200 OK" == response.status - assert OPDSFeed.NAVIGATION_FEED_TYPE == response.headers['Content-Type'] + assert OPDSFeed.NAVIGATION_FEED_TYPE == response.headers["Content-Type"] response_data = response.get_data(as_text=True) assert "fake feed" == response_data assert response_data == gate.navigation_feed @@ -146,21 +145,23 @@ class MockAnnotator(object): """This annotator will have its chance to annotate the feed before it's finalized. """ + def annotate_feed(self, feed, lane): self.called_with = (feed, lane) + annotator = MockAnnotator() url_for_calls = [] + def mock_url_for(controller, library_short_name, **kwargs): """Create a real-looking URL for any random controller.""" url_for_calls.append((controller, library_short_name, kwargs)) - query = "&".join( - ["%s=%s" % (k,v) for k, v in sorted(kwargs.items())] - ) + query = "&".join(["%s=%s" % (k, v) for k, v in sorted(kwargs.items())]) return "http://%s/%s?%s" % (library_short_name, controller, query) navigation_entry_calls = [] gate_tag_calls = [] + class MockCOPPAGate(COPPAGate): def navigation_entry(self, url, title, content): navigation_entry_calls.append((url, title, content)) @@ -174,9 +175,7 @@ def gate_tag(cls, restriction, met_uri, not_met_uri): self._default_library.name = "The Library" self._default_library.short_name = "LIBR" gate = MockCOPPAGate(self._default_library, self.integration) - feed = gate._navigation_feed( - self._default_library, annotator, mock_url_for - ) + feed = gate._navigation_feed(self._default_library, annotator, mock_url_for) # The feed was passed to our mock Annotator, which decided to do # nothing to it. @@ -188,8 +187,9 @@ def gate_tag(cls, restriction, met_uri, not_met_uri): lane_url, title, content = older yes_url = mock_url_for( - "acquisition_groups", self._default_library.short_name, - lane_identifier=gate.yes_lane_id + "acquisition_groups", + self._default_library.short_name, + lane_identifier=gate.yes_lane_id, ) assert lane_url == yes_url assert title == gate.YES_TITLE @@ -197,8 +197,9 @@ def gate_tag(cls, restriction, met_uri, not_met_uri): lane_url, title, content = younger no_url = mock_url_for( - "acquisition_groups", self._default_library.short_name, - lane_identifier=gate.no_lane_id + "acquisition_groups", + self._default_library.short_name, + lane_identifier=gate.no_lane_id, ) assert lane_url == no_url assert title == gate.NO_TITLE @@ -221,33 +222,29 @@ def gate_tag(cls, restriction, met_uri, not_met_uri): index = mock_url_for("index", self._default_library.short_name) assert ('' % index) in feed assert ("%s" % self._default_library.name) in feed - assert ('%s' % index) in feed - assert '' in feed + assert ("%s" % index) in feed + assert "" in feed def test_navigation_entry(self): # navigation_entry creates an OPDS entry with a subsection link. entry = etree.tostring( - COPPAGate.navigation_entry( - "some href", "some title", "some content" - ), - encoding="unicode" + COPPAGate.navigation_entry("some href", "some title", "some content"), + encoding="unicode", ) - assert entry.startswith('some href', - 'some title', - 'some content', - '', - 'some href", + "some title", + 'some content', + '', + "/book/ client_url_template = "{client_base}/book/{work_link}" - qualified_identifier = urllib.parse.quote(identifier.type + "/" + identifier.identifier, safe='') + qualified_identifier = urllib.parse.quote( + identifier.type + "/" + identifier.identifier, safe="" + ) cm_base_url = "http://test-circulation-manager" expected_work_link = work_link_template.format( cm_base=cm_base_url, lib=lib_short_name, qid=qualified_identifier ) - encoded_work_link = urllib.parse.quote(expected_work_link, safe='') + encoded_work_link = urllib.parse.quote(expected_work_link, safe="") client_base_1 = "http://web_catalog" client_base_2 = "http://another_web_catalog" @@ -141,7 +182,9 @@ def test_add_web_client_urls(self): assert expected_client_url_1.startswith(client_base_1) assert expected_client_url_2.startswith(client_base_2) - ConfigurationSetting.sitewide(self._db, Configuration.BASE_URL_KEY).value = cm_base_url + ConfigurationSetting.sitewide( + self._db, Configuration.BASE_URL_KEY + ).value = cm_base_url annotator = LibraryAnnotator(self._default_library) @@ -152,11 +195,16 @@ def test_add_web_client_urls(self): # Add a URL from a library registry. registry = self._external_integration( - ExternalIntegration.OPDS_REGISTRATION, ExternalIntegration.DISCOVERY_GOAL, - libraries=[self._default_library]) + ExternalIntegration.OPDS_REGISTRATION, + ExternalIntegration.DISCOVERY_GOAL, + libraries=[self._default_library], + ) ConfigurationSetting.for_library_and_externalintegration( - self._db, Registration.LIBRARY_REGISTRATION_WEB_CLIENT, - self._default_library, registry).value = client_base_1 + self._db, + Registration.LIBRARY_REGISTRATION_WEB_CLIENT, + self._default_library, + registry, + ).value = client_base_1 record = Record() annotator.add_web_client_urls(record, self._default_library, identifier) @@ -166,15 +214,19 @@ def test_add_web_client_urls(self): # Add a manually configured URL on a MARC export integration. integration = self._external_integration( - ExternalIntegration.MARC_EXPORT, ExternalIntegration.CATALOG_GOAL, - libraries=[self._default_library]) + ExternalIntegration.MARC_EXPORT, + ExternalIntegration.CATALOG_GOAL, + libraries=[self._default_library], + ) ConfigurationSetting.for_library_and_externalintegration( - self._db, MARCExporter.WEB_CLIENT_URL, - self._default_library, integration).value = client_base_2 + self._db, MARCExporter.WEB_CLIENT_URL, self._default_library, integration + ).value = client_base_2 record = Record() - annotator.add_web_client_urls(record, self._default_library, identifier, integration) + annotator.add_web_client_urls( + record, self._default_library, identifier, integration + ) [field1, field2] = record.get_fields("856") assert ["4", "0"] == field1.indicators assert expected_client_url_2 == field1.get_subfields("u")[0] diff --git a/tests/test_metadata_wrangler.py b/tests/test_metadata_wrangler.py index 9ebaba4be7..27b632b12d 100644 --- a/tests/test_metadata_wrangler.py +++ b/tests/test_metadata_wrangler.py @@ -3,17 +3,21 @@ """ import datetime + import feedparser import pytest -from core.config import ( - CannotLoadConfiguration, - Configuration, - temp_config, -) -from core.coverage import ( - CoverageFailure, +from api.metadata_wrangler import ( + BaseMetadataWranglerCoverageProvider, + MetadataUploadCoverageProvider, + MetadataWranglerCollectionReaper, + MetadataWranglerCollectionRegistrar, + MWAuxiliaryMetadataMonitor, + MWCollectionUpdateMonitor, ) +from api.testing import MonitorTest +from core.config import CannotLoadConfiguration, Configuration, temp_config +from core.coverage import CoverageFailure from core.model import ( CoverageRecord, DataSource, @@ -25,51 +29,42 @@ ) from core.opds_import import MockMetadataWranglerOPDSLookup from core.testing import ( - MockRequestsResponse, AlwaysSuccessfulCoverageProvider, + DatabaseTest, + MockRequestsResponse, ) -from core.util.datetime_helpers import ( - datetime_utc, - utc_now, -) +from core.util.datetime_helpers import datetime_utc, utc_now from core.util.http import BadResponseException from core.util.opds_writer import OPDSFeed -from api.metadata_wrangler import ( - BaseMetadataWranglerCoverageProvider, - MetadataUploadCoverageProvider, - MetadataWranglerCollectionReaper, - MetadataWranglerCollectionRegistrar, - MWAuxiliaryMetadataMonitor, - MWCollectionUpdateMonitor, -) -from api.testing import MonitorTest -from core.testing import DatabaseTest from . import sample_data -class InstrumentedMWCollectionUpdateMonitor(MWCollectionUpdateMonitor): +class InstrumentedMWCollectionUpdateMonitor(MWCollectionUpdateMonitor): def __init__(self, *args, **kwargs): super(InstrumentedMWCollectionUpdateMonitor, self).__init__(*args, **kwargs) self.imports = [] def import_one_feed(self, timestamp, url): self.imports.append((timestamp, url)) - return super(InstrumentedMWCollectionUpdateMonitor, - self).import_one_feed(timestamp, url) + return super(InstrumentedMWCollectionUpdateMonitor, self).import_one_feed( + timestamp, url + ) -class TestMWCollectionUpdateMonitor(MonitorTest): +class TestMWCollectionUpdateMonitor(MonitorTest): def setup_method(self): super(TestMWCollectionUpdateMonitor, self).setup_method() self._external_integration( ExternalIntegration.METADATA_WRANGLER, ExternalIntegration.METADATA_GOAL, - username='abc', password='def', url=self._url + username="abc", + password="def", + url=self._url, ) self.collection = self._collection( - protocol=ExternalIntegration.BIBLIOTHECA, external_account_id='lib' + protocol=ExternalIntegration.BIBLIOTHECA, external_account_id="lib" ) self.lookup = MockMetadataWranglerOPDSLookup.from_config( @@ -83,28 +78,26 @@ def setup_method(self): def test_monitor_requires_authentication(self): class Mock(object): authenticated = False + self.monitor.lookup = Mock() with pytest.raises(Exception) as excinfo: self.monitor.run_once(self.ts) assert "no authentication credentials" in str(excinfo.value) def test_import_one_feed(self): - data = sample_data('metadata_updates_response.opds', 'opds') + data = sample_data("metadata_updates_response.opds", "opds") self.lookup.queue_response( - 200, {'content-type' : OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) - next_links, editions, timestamp = self.monitor.import_one_feed( - None, None - ) + next_links, editions, timestamp = self.monitor.import_one_feed(None, None) # The 'next' links found in the OPDS feed are returned. - assert ['http://next-link/'] == next_links + assert ["http://next-link/"] == next_links # Insofar as is possible, all tags are converted into # Editions. - assert ['9781594632556'] == [x.primary_identifier.identifier - for x in editions] + assert ["9781594632556"] == [x.primary_identifier.identifier for x in editions] # The earliest time found in the OPDS feed is returned as a # candidate for the Monitor's timestamp. @@ -112,9 +105,9 @@ def test_import_one_feed(self): def test_empty_feed_stops_import(self): # We don't follow the 'next' link of an empty feed. - data = sample_data('metadata_updates_empty_response.opds', 'opds') + data = sample_data("metadata_updates_empty_response.opds", "opds") self.lookup.queue_response( - 200, {'content-type' : OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) new_timestamp = self.monitor.run() @@ -126,17 +119,15 @@ def test_empty_feed_stops_import(self): # Since there were no tags, the timestamp's finish # date was set to the date of the feed itself, minus # one day (to avoid race conditions). - assert (datetime_utc(2016, 9, 19, 19, 37, 10) == - self.monitor.timestamp().finish) + assert datetime_utc(2016, 9, 19, 19, 37, 10) == self.monitor.timestamp().finish def test_run_once(self): # Setup authentication and Metadata Wrangler details. lp = self._licensepool( - None, data_source_name=DataSource.BIBLIOTHECA, - collection=self.collection + None, data_source_name=DataSource.BIBLIOTHECA, collection=self.collection ) lp.identifier.type = Identifier.BIBLIOTHECA_ID - isbn = Identifier.parse_urn(self._db, 'urn:isbn:9781594632556')[0] + isbn = Identifier.parse_urn(self._db, "urn:isbn:9781594632556")[0] lp.identifier.equivalent_to( DataSource.lookup(self._db, DataSource.BIBLIOTHECA), isbn, 1 ) @@ -145,13 +136,13 @@ def test_run_once(self): # Queue some data to be found. responses = ( - 'metadata_updates_response.opds', - 'metadata_updates_empty_response.opds', + "metadata_updates_response.opds", + "metadata_updates_empty_response.opds", ) for filename in responses: - data = sample_data(filename, 'opds') + data = sample_data(filename, "opds") self.lookup.queue_response( - 200, {'content-type' : OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) timestamp = self.ts @@ -194,9 +185,11 @@ def test_no_changes_means_no_timestamp_update(self): # We're going to ask the metadata wrangler for updates, but # there will be none -- not even a feed-level update - data = sample_data('metadata_updates_empty_response_no_feed_timestamp.opds', 'opds') + data = sample_data( + "metadata_updates_empty_response_no_feed_timestamp.opds", "opds" + ) self.lookup.queue_response( - 200, {'content-type' : OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) new_timestamp = self.monitor.run_once(self.ts) @@ -210,7 +203,7 @@ def test_no_changes_means_no_timestamp_update(self): # to None. self.monitor.timestamp().finish = None self.lookup.queue_response( - 200, {'content-type' : OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) new_timestamp = self.monitor.run_once(self.ts) assert Timestamp.CLEAR_VALUE == new_timestamp.finish @@ -219,19 +212,19 @@ def test_no_import_loop(self): # We stop processing a feed's 'next' link if it links to a URL we've # already seen. - data = sample_data('metadata_updates_response.opds', 'opds') + data = sample_data("metadata_updates_response.opds", "opds") self.lookup.queue_response( - 200, {'content-type' : OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) data = data.replace(b"http://next-link/", b"http://different-link/") self.lookup.queue_response( - 200, {'content-type' : OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) # This introduces a loop. data = data.replace(b"http://next-link/", b"http://next-link/") self.lookup.queue_response( - 200, {'content-type' : OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) new_timestamp = self.monitor.run_once(self.ts) @@ -240,13 +233,12 @@ def test_no_import_loop(self): # seen before; then we stopped. first, second, third = self.monitor.imports assert (None, None) == first - assert (None, 'http://next-link/') == second - assert (None, 'http://different-link/') == third + assert (None, "http://next-link/") == second + assert (None, "http://different-link/") == third assert datetime_utc(2016, 9, 20, 19, 37, 2) == new_timestamp.finish def test_get_response(self): - class Mock(MockMetadataWranglerOPDSLookup): def __init__(self): self.last_timestamp = None @@ -267,9 +259,7 @@ def _get(self, _url): # If you pass in None for the URL, it passes the timestamp into # updates() lookup = Mock() - monitor = MWCollectionUpdateMonitor( - self._db, self.collection, lookup - ) + monitor = MWCollectionUpdateMonitor(self._db, self.collection, lookup) timestamp = object() response = monitor.get_response(timestamp=timestamp, url=None) assert 200 == response.status_code @@ -279,28 +269,27 @@ def _get(self, _url): # If you pass in a URL, the timestamp is ignored and # the URL is passed into _get(). lookup = Mock() - monitor = MWCollectionUpdateMonitor( - self._db, self.collection, lookup - ) - response = monitor.get_response(timestamp=None, url='http://now used/') + monitor = MWCollectionUpdateMonitor(self._db, self.collection, lookup) + response = monitor.get_response(timestamp=None, url="http://now used/") assert 200 == response.status_code assert None == lookup.last_timestamp - assert ['http://now used/'] == lookup.urls + assert ["http://now used/"] == lookup.urls class TestMWAuxiliaryMetadataMonitor(MonitorTest): - def setup_method(self): super(TestMWAuxiliaryMetadataMonitor, self).setup_method() self._external_integration( ExternalIntegration.METADATA_WRANGLER, ExternalIntegration.METADATA_GOAL, - username='abc', password='def', url=self._url + username="abc", + password="def", + url=self._url, ) self.collection = self._collection( - protocol=ExternalIntegration.OVERDRIVE, external_account_id='lib' + protocol=ExternalIntegration.OVERDRIVE, external_account_id="lib" ) self.lookup = MockMetadataWranglerOPDSLookup.from_config( @@ -316,6 +305,7 @@ def setup_method(self): def test_monitor_requires_authentication(self): class Mock(object): authenticated = False + self.monitor.lookup = Mock() with pytest.raises(Exception) as excinfo: self.monitor.run_once(self.ts) @@ -327,20 +317,20 @@ def prep_feed_identifiers(self): # Create an Overdrive ID to match the one in the feed. overdrive = self._identifier( identifier_type=Identifier.OVERDRIVE_ID, - foreign_id='4981c34f-d518-48ff-9659-2601b2b9bdc1' + foreign_id="4981c34f-d518-48ff-9659-2601b2b9bdc1", ) # Create an ISBN to match the one in the feed. isbn = self._identifier( - identifier_type=Identifier.ISBN, foreign_id='9781602835740' + identifier_type=Identifier.ISBN, foreign_id="9781602835740" ) # Create a Axis 360 ID equivalent to the other ISBN in the feed. axis_360 = self._identifier( - identifier_type=Identifier.AXIS_360_ID, foreign_id='fake' + identifier_type=Identifier.AXIS_360_ID, foreign_id="fake" ) axis_360_isbn = self._identifier( - identifier_type=Identifier.ISBN, foreign_id='9781569478295' + identifier_type=Identifier.ISBN, foreign_id="9781569478295" ) axis_source = DataSource.lookup(self._db, DataSource.AXIS_360) axis_360.equivalent_to(axis_source, axis_360_isbn, 1) @@ -360,9 +350,9 @@ def prep_feed_identifiers(self): def test_get_identifiers(self): overdrive, isbn, axis_360 = self.prep_feed_identifiers() - data = sample_data('metadata_data_needed_response.opds', 'opds') + data = sample_data("metadata_data_needed_response.opds", "opds") self.lookup.queue_response( - 200, {'content-type' : OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) identifiers, next_links = self.monitor.get_identifiers() @@ -370,7 +360,7 @@ def test_get_identifiers(self): # identifier. assert sorted([overdrive, axis_360, isbn]) == sorted(identifiers) - assert ['http://next-link'] == next_links + assert ["http://next-link"] == next_links def test_run_once(self): overdrive, isbn, axis_360 = self.prep_feed_identifiers() @@ -382,11 +372,11 @@ def test_run_once(self): w.simple_opds_entry = w.verbose_opds_entry = None # Queue some response feeds. - feed1 = sample_data('metadata_data_needed_response.opds', 'opds') - feed2 = sample_data('metadata_data_needed_empty_response.opds', 'opds') + feed1 = sample_data("metadata_data_needed_response.opds", "opds") + feed2 = sample_data("metadata_data_needed_empty_response.opds", "opds") for feed in [feed1, feed2]: self.lookup.queue_response( - 200, {'content-type' : OPDSFeed.ACQUISITION_FEED_TYPE}, feed + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, feed ) progress = self.monitor.run_once(self.ts) @@ -400,21 +390,22 @@ def test_run_once(self): assert None == progress.finish record = CoverageRecord.lookup( - overdrive, self.monitor.provider.data_source, - operation=self.monitor.provider.operation + overdrive, + self.monitor.provider.data_source, + operation=self.monitor.provider.operation, ) assert record for identifier in [axis_360, isbn]: record = CoverageRecord.lookup( - identifier, self.monitor.provider.data_source, - operation=self.monitor.provider.operation + identifier, + self.monitor.provider.data_source, + operation=self.monitor.provider.operation, ) assert None == record class MetadataWranglerCoverageProviderTest(DatabaseTest): - def create_provider(self, **kwargs): lookup = MockMetadataWranglerOPDSLookup.from_config(self._db, self.collection) return self.TEST_CLASS(self.collection, lookup, **kwargs) @@ -423,12 +414,14 @@ def setup_method(self): super(MetadataWranglerCoverageProviderTest, self).setup_method() self.integration = self._external_integration( ExternalIntegration.METADATA_WRANGLER, - goal=ExternalIntegration.METADATA_GOAL, url=self._url, - username='abc', password='def' + goal=ExternalIntegration.METADATA_GOAL, + url=self._url, + username="abc", + password="def", ) self.source = DataSource.lookup(self._db, DataSource.METADATA_WRANGLER) self.collection = self._collection( - protocol=ExternalIntegration.BIBLIOTHECA, external_account_id='lib' + protocol=ExternalIntegration.BIBLIOTHECA, external_account_id="lib" ) self.provider = self.create_provider() self.lookup_client = self.provider.lookup_client @@ -437,16 +430,16 @@ def opds_feed_identifiers(self): """Creates three Identifiers to use for testing with sample OPDS files.""" # An identifier directly represented in the OPDS response. - valid_id = self._identifier(foreign_id='2020110') + valid_id = self._identifier(foreign_id="2020110") # An identifier mapped to an identifier represented in the OPDS # response. source = DataSource.lookup(self._db, DataSource.AXIS_360) mapped_id = self._identifier( - identifier_type=Identifier.AXIS_360_ID, foreign_id='0015187876' + identifier_type=Identifier.AXIS_360_ID, foreign_id="0015187876" ) equivalent_id = self._identifier( - identifier_type=Identifier.ISBN, foreign_id='9781936460236' + identifier_type=Identifier.ISBN, foreign_id="9781936460236" ) mapped_id.equivalent_to(source, equivalent_id, 1) @@ -457,7 +450,6 @@ def opds_feed_identifiers(self): class TestBaseMetadataWranglerCoverageProvider(MetadataWranglerCoverageProviderTest): - class Mock(BaseMetadataWranglerCoverageProvider): SERVICE_NAME = "Mock" DATA_SOURCE_NAME = DataSource.OVERDRIVE @@ -469,25 +461,31 @@ def test_must_be_authenticated(self): metadata wrangler coverage provider that can't authenticate with the metadata wrangler. """ + class UnauthenticatedLookupClient(object): authenticated = False with pytest.raises(CannotLoadConfiguration) as excinfo: self.Mock(self.collection, UnauthenticatedLookupClient()) - assert "Authentication for the Library Simplified Metadata Wrangler " in str(excinfo.value) + assert "Authentication for the Library Simplified Metadata Wrangler " in str( + excinfo.value + ) def test_input_identifier_types(self): """Verify all the different types of identifiers we send to the metadata wrangler. """ assert ( - set([ - Identifier.OVERDRIVE_ID, - Identifier.BIBLIOTHECA_ID, - Identifier.AXIS_360_ID, - Identifier.URI, - ]) == - set(BaseMetadataWranglerCoverageProvider.INPUT_IDENTIFIER_TYPES)) + set( + [ + Identifier.OVERDRIVE_ID, + Identifier.BIBLIOTHECA_ID, + Identifier.AXIS_360_ID, + Identifier.URI, + ] + ) + == set(BaseMetadataWranglerCoverageProvider.INPUT_IDENTIFIER_TYPES) + ) def test_create_identifier_mapping(self): # Most identifiers map to themselves. @@ -512,9 +510,9 @@ def test_create_identifier_mapping(self): def test_coverage_records_for_unhandled_items_include_collection(self): # NOTE: This could be made redundant by adding test coverage to # CoverageProvider.process_batch_and_handle_results in core. - data = sample_data('metadata_sync_response.opds', 'opds') + data = sample_data("metadata_sync_response.opds", "opds") self.lookup_client.queue_response( - 200, {'content-type': OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) identifier = self._identifier() @@ -538,17 +536,14 @@ def test_constants(self): assert CoverageRecord.IMPORT_OPERATION == self.TEST_CLASS.OPERATION def test_process_batch(self): - """End-to-end test of the registrar's process_batch() implementation. - """ - data = sample_data('metadata_sync_response.opds', 'opds') + """End-to-end test of the registrar's process_batch() implementation.""" + data = sample_data("metadata_sync_response.opds", "opds") self.lookup_client.queue_response( - 200, {'content-type': OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) valid_id, mapped_id, lost_id = self.opds_feed_identifiers() - results = self.provider.process_batch( - [valid_id, mapped_id, lost_id] - ) + results = self.provider.process_batch([valid_id, mapped_id, lost_id]) # The Identifier that resulted in a 200 message was returned. # @@ -568,21 +563,22 @@ def test_process_batch_errors(self): # This happens if the 'server' sends data with the wrong media # type. self.lookup_client.queue_response( - 200, {'content-type': 'json/application'}, '{ "title": "It broke." }' + 200, {"content-type": "json/application"}, '{ "title": "It broke." }' ) id1 = self._identifier() id2 = self._identifier() with pytest.raises(BadResponseException) as excinfo: self.provider.process_batch([id1, id2]) - assert 'Wrong media type' in str(excinfo.value) + assert "Wrong media type" in str(excinfo.value) assert [] == id1.coverage_records assert [] == id2.coverage_records # Of if the 'server' sends an error response code. self.lookup_client.queue_response( - 500, {'content-type': OPDSFeed.ACQUISITION_FEED_TYPE}, - 'Internal Server Error' + 500, + {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, + "Internal Server Error", ) with pytest.raises(BadResponseException) as excinfo: self.provider.process_batch([id1, id2]) @@ -592,15 +588,15 @@ def test_process_batch_errors(self): # If a message comes back with an unexpected status, a # CoverageFailure is created. - data = sample_data('unknown_message_status_code.opds', 'opds') + data = sample_data("unknown_message_status_code.opds", "opds") valid_id = self.opds_feed_identifiers()[0] self.lookup_client.queue_response( - 200, {'content-type': OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) [result] = self.provider.process_batch([valid_id]) assert True == isinstance(result, CoverageFailure) assert valid_id == result.obj - assert '418: Mad Hatter' == result.exception + assert "418: Mad Hatter" == result.exception # The OPDS importer didn't know which Collection to associate # with this CoverageFailure, but the CoverageProvider does, @@ -608,11 +604,11 @@ def test_process_batch_errors(self): assert self.provider.collection == result.collection def test_items_that_need_coverage_excludes_unavailable_items(self): - """A LicensePool that's not actually available doesn't need coverage. - """ + """A LicensePool that's not actually available doesn't need coverage.""" edition, pool = self._edition( - with_license_pool=True, collection=self.collection, - identifier_type=Identifier.BIBLIOTHECA_ID + with_license_pool=True, + collection=self.collection, + identifier_type=Identifier.BIBLIOTHECA_ID, ) pool.licenses_owned = 0 assert 0 == self.provider.items_that_need_coverage().count() @@ -622,11 +618,11 @@ def test_items_that_need_coverage_excludes_unavailable_items(self): assert [pool.identifier] == self.provider.items_that_need_coverage().all() def test_items_that_need_coverage_removes_reap_records_for_relicensed_items(self): - """A LicensePool that's not actually available doesn't need coverage. - """ + """A LicensePool that's not actually available doesn't need coverage.""" edition, pool = self._edition( - with_license_pool=True, collection=self.collection, - identifier_type=Identifier.BIBLIOTHECA_ID + with_license_pool=True, + collection=self.collection, + identifier_type=Identifier.BIBLIOTHECA_ID, ) identifier = pool.identifier @@ -634,13 +630,12 @@ def test_items_that_need_coverage_removes_reap_records_for_relicensed_items(self # This identifier was reaped... cr = self._coverage_record( - pool.identifier, self.provider.data_source, + pool.identifier, + self.provider.data_source, operation=CoverageRecord.REAP_OPERATION, - collection=self.collection + collection=self.collection, ) - assert ( - set(original_coverage_records + [cr]) == - set(identifier.coverage_records)) + assert set(original_coverage_records + [cr]) == set(identifier.coverage_records) # ... but then it was relicensed. pool.licenses_owned = 10 @@ -652,8 +647,9 @@ def test_items_that_need_coverage_removes_reap_records_for_relicensed_items(self def test_identifier_covered_in_one_collection_not_covered_in_another(self): edition, pool = self._edition( - with_license_pool=True, collection=self.collection, - identifier_type=Identifier.BIBLIOTHECA_ID + with_license_pool=True, + collection=self.collection, + identifier_type=Identifier.BIBLIOTHECA_ID, ) identifier = pool.identifier @@ -665,17 +661,19 @@ def test_identifier_covered_in_one_collection_not_covered_in_another(self): # Adding coverage for an irrelevant collection won't fix that. cr = self._coverage_record( - pool.identifier, self.provider.data_source, + pool.identifier, + self.provider.data_source, operation=self.provider.OPERATION, - collection=other_collection + collection=other_collection, ) assert [identifier] == qu.all() # Adding coverage for the relevant collection will. cr = self._coverage_record( - pool.identifier, self.provider.data_source, + pool.identifier, + self.provider.data_source, operation=self.provider.OPERATION, - collection=self.provider.collection + collection=self.provider.collection, ) assert [] == qu.all() @@ -684,8 +682,9 @@ def test_identifier_reaped_from_one_collection_covered_in_another(self): need coverage in another. """ edition, pool = self._edition( - with_license_pool=True, collection=self.collection, - identifier_type=Identifier.BIBLIOTHECA_ID + with_license_pool=True, + collection=self.collection, + identifier_type=Identifier.BIBLIOTHECA_ID, ) identifier = pool.identifier @@ -694,9 +693,10 @@ def test_identifier_reaped_from_one_collection_covered_in_another(self): # This identifier was reaped from other_collection, but not # from self.provider.collection. cr = self._coverage_record( - pool.identifier, self.provider.data_source, + pool.identifier, + self.provider.data_source, operation=CoverageRecord.REAP_OPERATION, - collection=other_collection + collection=other_collection, ) # It still needs to be covered in self.provider.collection. @@ -708,12 +708,15 @@ def test_items_that_need_coverage_respects_cutoff(self): """ edition, pool = self._edition( - with_license_pool=True, collection=self.collection, - identifier_type=Identifier.BIBLIOTHECA_ID + with_license_pool=True, + collection=self.collection, + identifier_type=Identifier.BIBLIOTHECA_ID, ) cr = self._coverage_record( - pool.identifier, self.provider.data_source, - operation=self.provider.OPERATION, collection=self.collection + pool.identifier, + self.provider.data_source, + operation=self.provider.OPERATION, + collection=self.collection, ) # We have a coverage record already, so this book doesn't show @@ -723,28 +726,27 @@ def test_items_that_need_coverage_respects_cutoff(self): # But if we send a cutoff_time that's later than the time # associated with the coverage record... - one_hour_from_now = ( - utc_now() + datetime.timedelta(seconds=3600) - ) - provider_with_cutoff = self.create_provider( - cutoff_time=one_hour_from_now - ) + one_hour_from_now = utc_now() + datetime.timedelta(seconds=3600) + provider_with_cutoff = self.create_provider(cutoff_time=one_hour_from_now) # The book starts showing up in items_that_need_coverage. - assert ([pool.identifier] == - provider_with_cutoff.items_that_need_coverage().all()) + assert [ + pool.identifier + ] == provider_with_cutoff.items_that_need_coverage().all() def test_items_that_need_coverage_respects_count_as_covered(self): # Here's a coverage record with a transient failure. edition, pool = self._edition( - with_license_pool=True, collection=self.collection, + with_license_pool=True, + collection=self.collection, identifier_type=Identifier.OVERDRIVE_ID, ) cr = self._coverage_record( - pool.identifier, self.provider.data_source, + pool.identifier, + self.provider.data_source, operation=self.provider.operation, status=CoverageRecord.TRANSIENT_FAILURE, - collection=self.collection + collection=self.collection, ) # Ordinarily, a transient failure does not count as coverage. @@ -753,10 +755,12 @@ def test_items_that_need_coverage_respects_count_as_covered(self): # But if we say that transient failure counts as coverage, it # does count. - assert ([] == - self.provider.items_that_need_coverage( + assert ( + [] + == self.provider.items_that_need_coverage( count_as_covered=CoverageRecord.TRANSIENT_FAILURE - ).all()) + ).all() + ) def test_isbn_covers_are_imported_from_mapped_identifiers(self): # Now that we pass ISBN equivalents instead of Bibliotheca identifiers @@ -767,29 +771,39 @@ def test_isbn_covers_are_imported_from_mapped_identifiers(self): source = DataSource.lookup(self._db, DataSource.BIBLIOTHECA) identifier = self._identifier(identifier_type=Identifier.BIBLIOTHECA_ID) LicensePool.for_foreign_id( - self._db, source, identifier.type, identifier.identifier, - collection=self.provider.collection + self._db, + source, + identifier.type, + identifier.identifier, + collection=self.provider.collection, ) # Create an ISBN and set it equivalent. isbn = self._identifier(identifier_type=Identifier.ISBN) - isbn.identifier = '9781594632556' + isbn.identifier = "9781594632556" identifier.equivalent_to(source, isbn, 1) - opds = sample_data('metadata_isbn_response.opds', 'opds') + opds = sample_data("metadata_isbn_response.opds", "opds") self.provider.lookup_client.queue_response( - 200, {'content-type': 'application/atom+xml;profile=opds-catalog;kind=acquisition'}, opds + 200, + { + "content-type": "application/atom+xml;profile=opds-catalog;kind=acquisition" + }, + opds, ) result = self.provider.process_item(identifier) # The lookup is successful assert result == identifier # The appropriate cover links are transferred. - identifier_uris = [l.resource.url for l in identifier.links - if l.rel in [Hyperlink.IMAGE, Hyperlink.THUMBNAIL_IMAGE]] + identifier_uris = [ + l.resource.url + for l in identifier.links + if l.rel in [Hyperlink.IMAGE, Hyperlink.THUMBNAIL_IMAGE] + ] expected = [ - 'http://book-covers.nypl.org/Content%20Cafe/ISBN/9781594632556/cover.jpg', - 'http://book-covers.nypl.org/scaled/300/Content%20Cafe/ISBN/9781594632556/cover.jpg' + "http://book-covers.nypl.org/Content%20Cafe/ISBN/9781594632556/cover.jpg", + "http://book-covers.nypl.org/scaled/300/Content%20Cafe/ISBN/9781594632556/cover.jpg", ] assert sorted(identifier_uris) == sorted(expected) @@ -797,18 +811,20 @@ def test_isbn_covers_are_imported_from_mapped_identifiers(self): # The ISBN doesn't get any information. assert isbn.links == [] -class MetadataWranglerCollectionManagerTest(DatabaseTest): +class MetadataWranglerCollectionManagerTest(DatabaseTest): def setup_method(self): super(MetadataWranglerCollectionManagerTest, self).setup_method() self.integration = self._external_integration( ExternalIntegration.METADATA_WRANGLER, - goal=ExternalIntegration.METADATA_GOAL, url=self._url, - username='abc', password='def' + goal=ExternalIntegration.METADATA_GOAL, + url=self._url, + username="abc", + password="def", ) self.source = DataSource.lookup(self._db, DataSource.METADATA_WRANGLER) self.collection = self._collection( - protocol=ExternalIntegration.BIBLIOTHECA, external_account_id='lib' + protocol=ExternalIntegration.BIBLIOTHECA, external_account_id="lib" ) self.lookup = MockMetadataWranglerOPDSLookup.from_config( self._db, collection=self.collection @@ -834,12 +850,15 @@ def test_items_that_need_coverage(self): # Create an item that was imported into the Wrangler-side # collection but no longer has any owned licenses covered_unlicensed_lp = self._licensepool( - None, open_access=False, set_edition_as_presentation=True, - collection=self.collection + None, + open_access=False, + set_edition_as_presentation=True, + collection=self.collection, ) covered_unlicensed_lp.update_availability(0, 0, 0, 0) cr = self._coverage_record( - covered_unlicensed_lp.presentation_edition, self.source, + covered_unlicensed_lp.presentation_edition, + self.source, operation=CoverageRecord.IMPORT_OPERATION, collection=self.provider.collection, ) @@ -875,9 +894,9 @@ def test_items_that_need_coverage(self): assert [] == self.provider.items_that_need_coverage().all() def test_process_batch(self): - data = sample_data('metadata_reaper_response.opds', 'opds') + data = sample_data("metadata_reaper_response.opds", "opds") self.lookup_client.queue_response( - 200, {'content-type': OPDSFeed.ACQUISITION_FEED_TYPE}, data + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, data ) valid_id, mapped_id, lost_id = self.opds_feed_identifiers() @@ -896,26 +915,30 @@ def test_finalize_batch(self): # Create an identifier that has been imported and one that's # been reaped. sync_cr = self._coverage_record( - self._edition(), self.source, + self._edition(), + self.source, operation=CoverageRecord.IMPORT_OPERATION, - collection=self.provider.collection + collection=self.provider.collection, ) reaped_cr = self._coverage_record( - self._edition(), self.source, + self._edition(), + self.source, operation=CoverageRecord.REAP_OPERATION, - collection=self.provider.collection + collection=self.provider.collection, ) # Create coverage records for an Identifier that has been both synced # and reaped. doubly_covered = self._edition() doubly_sync_record = self._coverage_record( - doubly_covered, self.source, + doubly_covered, + self.source, operation=CoverageRecord.IMPORT_OPERATION, - collection=self.provider.collection + collection=self.provider.collection, ) doubly_reap_record = self._coverage_record( - doubly_covered, self.source, + doubly_covered, + self.source, operation=CoverageRecord.REAP_OPERATION, collection=self.provider.collection, ) @@ -925,28 +948,30 @@ def test_finalize_batch(self): # The syncing record has been deleted from the database assert doubly_sync_record not in remaining_records - assert (sorted([sync_cr, reaped_cr, doubly_reap_record], key=lambda x: x.id) == - sorted(remaining_records, key=lambda x: x.id)) + assert sorted( + [sync_cr, reaped_cr, doubly_reap_record], key=lambda x: x.id + ) == sorted(remaining_records, key=lambda x: x.id) class TestMetadataUploadCoverageProvider(DatabaseTest): - def create_provider(self, **kwargs): - upload_client = MockMetadataWranglerOPDSLookup.from_config(self._db, self.collection) - return MetadataUploadCoverageProvider( - self.collection, upload_client, **kwargs + upload_client = MockMetadataWranglerOPDSLookup.from_config( + self._db, self.collection ) + return MetadataUploadCoverageProvider(self.collection, upload_client, **kwargs) def setup_method(self): super(TestMetadataUploadCoverageProvider, self).setup_method() self.integration = self._external_integration( ExternalIntegration.METADATA_WRANGLER, - goal=ExternalIntegration.METADATA_GOAL, url=self._url, - username='abc', password='def' + goal=ExternalIntegration.METADATA_GOAL, + url=self._url, + username="abc", + password="def", ) self.source = DataSource.lookup(self._db, DataSource.METADATA_WRANGLER) self.collection = self._collection( - protocol=ExternalIntegration.BIBLIOTHECA, external_account_id='lib' + protocol=ExternalIntegration.BIBLIOTHECA, external_account_id="lib" ) self.provider = self.create_provider() @@ -956,16 +981,19 @@ def test_items_that_need_coverage_only_finds_transient_failures(self): """ edition, pool = self._edition( - with_license_pool=True, collection=self.collection, - identifier_type=Identifier.BIBLIOTHECA_ID + with_license_pool=True, + collection=self.collection, + identifier_type=Identifier.BIBLIOTHECA_ID, ) # We don't have a CoverageRecord yet, so the book doesn't show up. items = self.provider.items_that_need_coverage().all() assert [] == items cr = self._coverage_record( - pool.identifier, self.provider.data_source, - operation=self.provider.OPERATION, collection=self.collection + pool.identifier, + self.provider.data_source, + operation=self.provider.OPERATION, + collection=self.collection, ) # With a successful or persistent failure CoverageRecord, it still doesn't show up. @@ -986,27 +1014,27 @@ def test_process_batch_uploads_metadata(self): class MockMetadataClient(object): metadata_feed = None authenticated = True + def canonicalize_author_name(self, identifier, working_display_name): return working_display_name + def add_with_metadata(self, feed): self.metadata_feed = feed - metadata_client = MockMetadataClient() - provider = MetadataUploadCoverageProvider( - self.collection, metadata_client - ) + metadata_client = MockMetadataClient() + provider = MetadataUploadCoverageProvider(self.collection, metadata_client) edition, pool = self._edition( - with_license_pool=True, collection=self.collection, - identifier_type=Identifier.BIBLIOTHECA_ID + with_license_pool=True, + collection=self.collection, + identifier_type=Identifier.BIBLIOTHECA_ID, ) work = pool.calculate_work() # This identifier has no Work. no_work = self._identifier() - results = provider.process_batch([pool.identifier, no_work]) # An OPDS feed of metadata was sent to the metadata wrangler. diff --git a/tests/test_migration_scripts.py b/tests/test_migration_scripts.py index a75d2bad1e..d0f91b80ff 100644 --- a/tests/test_migration_scripts.py +++ b/tests/test_migration_scripts.py @@ -1,11 +1,10 @@ import json -from migartion_scripts import RandomSortOptionRemover - from core.facets import FacetConstants from core.lane import Facets from core.model import Library from core.testing import DatabaseTest +from migartion_scripts import RandomSortOptionRemover class TestRandomSortOptionRemover(DatabaseTest): diff --git a/tests/test_millenium_patron.py b/tests/test_millenium_patron.py index a2cc4b5b6b..ff686ca039 100644 --- a/tests/test_millenium_patron.py +++ b/tests/test_millenium_patron.py @@ -1,37 +1,35 @@ +import json import pkgutil from datetime import date, timedelta from decimal import Decimal -import json from urllib import parse import pytest -from api.config import ( - CannotLoadConfiguration, - Configuration, -) -from core.model import ConfigurationSetting + from api.authenticator import PatronData +from api.config import CannotLoadConfiguration, Configuration from api.millenium_patron import MilleniumPatronAPI +from core.model import ConfigurationSetting from core.testing import DatabaseTest -from core.util.datetime_helpers import ( - utc_now -) +from core.util.datetime_helpers import utc_now + from . import sample_data + class MockResponse(object): def __init__(self, content): self.status_code = 200 self.content = content -class MockAPI(MilleniumPatronAPI): +class MockAPI(MilleniumPatronAPI): def __init__(self, library_id, integration): super(MockAPI, self).__init__(library_id, integration) self.queue = [] self.requests_made = [] def sample_data(self, filename): - return sample_data(filename, 'millenium_patron') + return sample_data(filename, "millenium_patron") def enqueue(self, filename): data = self.sample_data(filename) @@ -45,29 +43,47 @@ def request(self, *args, **kwargs): class TestMilleniumPatronAPI(DatabaseTest): - - def mock_api(self, url="http://url/", blacklist=[], auth_mode=None, verify_certificate=True, - block_types=None, password_keyboard=None, library_identifier_field=None, - neighborhood_mode=None + def mock_api( + self, + url="http://url/", + blacklist=[], + auth_mode=None, + verify_certificate=True, + block_types=None, + password_keyboard=None, + library_identifier_field=None, + neighborhood_mode=None, ): integration = self._external_integration(self._str) integration.url = url - integration.setting(MilleniumPatronAPI.IDENTIFIER_BLACKLIST).value = json.dumps(blacklist) - integration.setting(MilleniumPatronAPI.VERIFY_CERTIFICATE).value = json.dumps(verify_certificate) + integration.setting(MilleniumPatronAPI.IDENTIFIER_BLACKLIST).value = json.dumps( + blacklist + ) + integration.setting(MilleniumPatronAPI.VERIFY_CERTIFICATE).value = json.dumps( + verify_certificate + ) if block_types: integration.setting(MilleniumPatronAPI.BLOCK_TYPES).value = block_types if auth_mode: - integration.setting(MilleniumPatronAPI.AUTHENTICATION_MODE).value = auth_mode + integration.setting( + MilleniumPatronAPI.AUTHENTICATION_MODE + ).value = auth_mode if neighborhood_mode: - integration.setting(MilleniumPatronAPI.NEIGHBORHOOD_MODE).value = neighborhood_mode + integration.setting( + MilleniumPatronAPI.NEIGHBORHOOD_MODE + ).value = neighborhood_mode if password_keyboard: - integration.setting(MilleniumPatronAPI.PASSWORD_KEYBOARD).value = password_keyboard + integration.setting( + MilleniumPatronAPI.PASSWORD_KEYBOARD + ).value = password_keyboard if library_identifier_field: ConfigurationSetting.for_library_and_externalintegration( - self._db, MilleniumPatronAPI.LIBRARY_IDENTIFIER_FIELD, - self._default_library, integration + self._db, + MilleniumPatronAPI.LIBRARY_IDENTIFIER_FIELD, + self._default_library, + integration, ).value = library_identifier_field return MockAPI(self._default_library, integration) @@ -83,7 +99,9 @@ def test_constructor(self): with pytest.raises(CannotLoadConfiguration) as excinfo: self.mock_api(neighborhood_mode="nope") - assert "Unrecognized Millenium Patron API neighborhood mode: nope." in str(excinfo.value) + assert "Unrecognized Millenium Patron API neighborhood mode: nope." in str( + excinfo.value + ) def test__remote_patron_lookup_no_such_patron(self): self.api.enqueue("dump.no such barcode.html") @@ -113,7 +131,10 @@ def test__remote_patron_lookup_barcode_spaces(self): patrondata = PatronData(authorization_identifier="44444444444447") patrondata = self.api._remote_patron_lookup(patrondata) assert "44444444444447" == patrondata.authorization_identifier - assert ["44444444444447", "4 444 4444 44444 7"] == patrondata.authorization_identifiers + assert [ + "44444444444447", + "4 444 4444 44444 7", + ] == patrondata.authorization_identifiers def test__remote_patron_lookup_block_rules(self): """This patron has a value of "m" in MBLOCK[56], which generally @@ -127,7 +148,7 @@ def test__remote_patron_lookup_block_rules(self): # If we set custom block types that say 'm' doesn't really # mean the patron is blocked, they're not blocked. - api = self.mock_api(block_types='abcde') + api = self.mock_api(block_types="abcde") api.enqueue("dump.blocked.html") patrondata = PatronData(authorization_identifier="good barcode") patrondata = api._remote_patron_lookup(patrondata) @@ -135,7 +156,7 @@ def test__remote_patron_lookup_block_rules(self): # If we set custom block types that include 'm', the patron # is blocked. - api = self.mock_api(block_types='lmn') + api = self.mock_api(block_types="lmn") api.enqueue("dump.blocked.html") patrondata = PatronData(authorization_identifier="good barcode") patrondata = api._remote_patron_lookup(patrondata) @@ -167,9 +188,7 @@ def test_incoming_authorization_identifier_retained(self): assert "SECOND-barcode" == patrondata.authorization_identifier # Let's say they authenticate with a username. - patrondata = self.api.patron_dump_to_patrondata( - "username", dump - ) + patrondata = self.api.patron_dump_to_patrondata("username", dump) # Their Patron record will suggest the second barcode as # authorization identifier, because it's likely to be the most # recently added one. @@ -196,11 +215,11 @@ def test_remote_authenticate_correct_pin(self): [args, kwargs] = self.api.requests_made.pop() [url] = args assert kwargs == {} - assert url == 'http://url/%s/%s/pintest' % (barcode, parse.quote(pin, safe='')) + assert url == "http://url/%s/%s/pintest" % (barcode, parse.quote(pin, safe="")) # In particular, verify that the slash character in the PIN was encoded; # by default, parse.quote leaves it alone. - assert '%2F' in url + assert "%2F" in url def test_authentication_updates_patron_authorization_identifier(self): """Verify that Patron.authorization_identifier is updated when @@ -297,9 +316,7 @@ def test_authenticated_patron_success(self): # Patron is valid, but not in our database yet self.api.enqueue("pintest.good.html") self.api.enqueue("dump.success.html") - alice = self.api.authenticate( - self._db, dict(username="alice", password="4444") - ) + alice = self.api.authenticate(self._db, dict(username="alice", password="4444")) assert "44444444444447" == alice.authorization_identifier assert "alice" == alice.username @@ -307,19 +324,23 @@ def test_authenticated_patron_success(self): # to verify that our authentication mechanism chooses the right patron # and doesn't look up whoever happens to be in the database. p = self._patron() - p.username = 'notalice' - p.authorization_identifier='111111111111' + p.username = "notalice" + p.authorization_identifier = "111111111111" self._db.commit() # Patron is in the db, now authenticate with barcode self.api.enqueue("pintest.good.html") - alice = self.api.authenticated_patron(self._db, dict(username="44444444444447", password="4444")) + alice = self.api.authenticated_patron( + self._db, dict(username="44444444444447", password="4444") + ) assert "44444444444447" == alice.authorization_identifier assert "alice" == alice.username # Authenticate with username again self.api.enqueue("pintest.good.html") - alice = self.api.authenticated_patron(self._db, dict(username="alice", password="4444")) + alice = self.api.authenticated_patron( + self._db, dict(username="alice", password="4444") + ) assert "44444444444447" == alice.authorization_identifier assert "alice" == alice.username @@ -386,7 +407,7 @@ def test_authentication_patron_invalid_fine_amount(self): def test_patron_dump_to_patrondata(self): content = self.api.sample_data("dump.success.html") - patrondata = self.api.patron_dump_to_patrondata('alice', content) + patrondata = self.api.patron_dump_to_patrondata("alice", content) assert "44444444444447" == patrondata.authorization_identifier assert "alice" == patrondata.username assert None == patrondata.library_identifier @@ -394,11 +415,11 @@ def test_patron_dump_to_patrondata(self): def test_patron_dump_to_patrondata_restriction_field(self): api = self.mock_api(library_identifier_field="HOME LIBR[p53]") content = api.sample_data("dump.success.html") - patrondata = api.patron_dump_to_patrondata('alice', content) + patrondata = api.patron_dump_to_patrondata("alice", content) assert "mm" == patrondata.library_identifier api = self.mock_api(library_identifier_field="P TYPE[p47]") content = api.sample_data("dump.success.html") - patrondata = api.patron_dump_to_patrondata('alice', content) + patrondata = api.patron_dump_to_patrondata("alice", content) assert "10" == patrondata.library_identifier def test_neighborhood(self): @@ -407,32 +428,35 @@ def test_neighborhood(self): # Default behavior is not to gather neighborhood information at all. api = self.mock_api() content = api.sample_data("dump.success.html") - patrondata = api.patron_dump_to_patrondata('alice', content) + patrondata = api.patron_dump_to_patrondata("alice", content) assert PatronData.NO_VALUE == patrondata.neighborhood # Patron neighborhood may be the identifier of their home library branch. - api = self.mock_api(neighborhood_mode=MilleniumPatronAPI.HOME_BRANCH_NEIGHBORHOOD_MODE) + api = self.mock_api( + neighborhood_mode=MilleniumPatronAPI.HOME_BRANCH_NEIGHBORHOOD_MODE + ) content = api.sample_data("dump.success.html") - patrondata = api.patron_dump_to_patrondata('alice', content) + patrondata = api.patron_dump_to_patrondata("alice", content) assert "mm" == patrondata.neighborhood # Or it may be the ZIP code of their home address. - api = self.mock_api(neighborhood_mode=MilleniumPatronAPI.POSTAL_CODE_NEIGHBORHOOD_MODE) - patrondata = api.patron_dump_to_patrondata('alice', content) + api = self.mock_api( + neighborhood_mode=MilleniumPatronAPI.POSTAL_CODE_NEIGHBORHOOD_MODE + ) + patrondata = api.patron_dump_to_patrondata("alice", content) assert "10001" == patrondata.neighborhood - def test_authorization_identifier_blacklist(self): """A patron has two authorization identifiers. Ordinarily the second one (which would normally be preferred), but it contains a blacklisted string, so the first takes precedence. """ content = self.api.sample_data("dump.two_barcodes.html") - patrondata = self.api.patron_dump_to_patrondata('alice', content) + patrondata = self.api.patron_dump_to_patrondata("alice", content) assert "SECOND-barcode" == patrondata.authorization_identifier api = self.mock_api(blacklist=["second"]) - patrondata = api.patron_dump_to_patrondata('alice', content) + patrondata = api.patron_dump_to_patrondata("alice", content) assert "FIRST-barcode" == patrondata.authorization_identifier def test_blacklist_may_remove_every_authorization_identifier(self): @@ -441,7 +465,7 @@ def test_blacklist_may_remove_every_authorization_identifier(self): """ api = self.mock_api(blacklist=["barcode"]) content = api.sample_data("dump.two_barcodes.html") - patrondata = api.patron_dump_to_patrondata('alice', content) + patrondata = api.patron_dump_to_patrondata("alice", content) assert patrondata.NO_VALUE == patrondata.authorization_identifier assert [] == patrondata.authorization_identifiers @@ -458,9 +482,9 @@ def test_verify_certificate(self): # Test that the value of verify_certificate becomes the # 'verify' argument when _modify_request_kwargs() is called. kwargs = dict(verify=False) - api = self.mock_api(verify_certificate = "yes please") + api = self.mock_api(verify_certificate="yes please") api._update_request_kwargs(kwargs) - assert "yes please" == kwargs['verify'] + assert "yes please" == kwargs["verify"] # NOTE: We can't automatically test that request() actually # calls _modify_request_kwargs() because request() is the @@ -500,23 +524,22 @@ def test_family_name_match(self): def test_misconfigured_authentication_mode(self): with pytest.raises(CannotLoadConfiguration) as excinfo: - self.mock_api(auth_mode = 'nosuchauthmode') - assert "Unrecognized Millenium Patron API authentication mode: nosuchauthmode." in str(excinfo.value) + self.mock_api(auth_mode="nosuchauthmode") + assert ( + "Unrecognized Millenium Patron API authentication mode: nosuchauthmode." + in str(excinfo.value) + ) def test_authorization_without_password(self): """Test authorization when no password is required, only patron identifier. """ - self.api = self.mock_api( - password_keyboard=MilleniumPatronAPI.NULL_KEYBOARD - ) + self.api = self.mock_api(password_keyboard=MilleniumPatronAPI.NULL_KEYBOARD) assert False == self.api.collects_password # If the patron lookup succeeds, the user is authenticated # as that patron. self.api.enqueue("dump.success.html") - patrondata = self.api.remote_authenticate( - "44444444444447", None - ) + patrondata = self.api.remote_authenticate("44444444444447", None) assert "44444444444447" == patrondata.authorization_identifier # If it fails, the user is not authenticated. @@ -528,11 +551,9 @@ def test_authorization_family_name_success(self): """Test authenticating against the patron's family name, given the correct name (case insensitive) """ - self.api = self.mock_api(auth_mode = "family_name") + self.api = self.mock_api(auth_mode="family_name") self.api.enqueue("dump.success.html") - patrondata = self.api.remote_authenticate( - "44444444444447", "Sheldon" - ) + patrondata = self.api.remote_authenticate("44444444444447", "Sheldon") assert "44444444444447" == patrondata.authorization_identifier # Since we got a full patron dump, the PatronData we get back @@ -543,7 +564,7 @@ def test_authorization_family_name_failure(self): """Test authenticating against the patron's family name, given the incorrect name """ - self.api = self.mock_api(auth_mode = "family_name") + self.api = self.mock_api(auth_mode="family_name") self.api.enqueue("dump.success.html") assert False == self.api.remote_authenticate("44444444444447", "wrong name") @@ -551,7 +572,7 @@ def test_authorization_family_name_no_such_patron(self): """If no patron is found, authorization based on family name cannot proceed. """ - self.api = self.mock_api(auth_mode = "family_name") + self.api = self.mock_api(auth_mode="family_name") self.api.enqueue("dump.no such barcode.html") assert False == self.api.remote_authenticate("44444444444447", "somebody") @@ -565,11 +586,13 @@ def test_extract_postal_code(self): assert "93203" == m("10145-6789 Main Street$Arvin CA 93203-1234") assert "93203" == m("10145-6789 Main Street$Arvin CA 93203-1234 (old address)") assert "93203" == m("10145-6789 Main Street$Arvin CA 93203 (old address)") - assert "93203" == m("10145-6789 Main Street Apartment #12345$Arvin CA 93203 (old address)") + assert "93203" == m( + "10145-6789 Main Street Apartment #12345$Arvin CA 93203 (old address)" + ) assert None == m("10145 Main Street Apartment 123456$Arvin CA") assert None == m("10145 Main Street$Arvin CA") assert None == m("123 Main Street") # Some cases where we incorrectly detect a ZIP code where there is none. - assert '12345' == m("10145 Main Street, Apartment #12345$Arvin CA") + assert "12345" == m("10145 Main Street, Apartment #12345$Arvin CA") diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 3b4a22c600..ea7235b998 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -1,9 +1,14 @@ import datetime import random -from core.testing import ( - DatabaseTest, -) +from api.monitor import ( + HoldReaper, + IdlingAnnotationReaper, + LoanlikeReaperMonitor, + LoanReaper, +) +from api.odl import ODLAPI, SharedODLAPI +from api.testing import MonitorTest from core.metadata_layer import TimestampData from core.model import ( Annotation, @@ -13,21 +18,9 @@ ExternalIntegration, Identifier, ) +from core.testing import DatabaseTest from core.util.datetime_helpers import utc_now -from api.monitor import ( - HoldReaper, - IdlingAnnotationReaper, - LoanlikeReaperMonitor, - LoanReaper, -) - -from api.odl import ( - ODLAPI, - SharedODLAPI, -) -from api.testing import MonitorTest - class TestLoanlikeReaperMonitor(DatabaseTest): """Tests the loan and hold reapers.""" @@ -37,13 +30,12 @@ def test_source_of_truth_protocols(self): will be exempt from the reaper. """ for i in ( - ODLAPI.NAME, - SharedODLAPI.NAME, - ExternalIntegration.OPDS_FOR_DISTRIBUTORS, + ODLAPI.NAME, + SharedODLAPI.NAME, + ExternalIntegration.OPDS_FOR_DISTRIBUTORS, ): assert i in LoanlikeReaperMonitor.SOURCE_OF_TRUTH_PROTOCOLS - def test_reaping(self): # This patron stopped using the circulation manager a long time # ago. @@ -54,37 +46,44 @@ def test_reaping(self): # We're going to give these patrons some loans and holds. edition, open_access = self._edition( - with_license_pool=True, with_open_access_download=True) + with_license_pool=True, with_open_access_download=True + ) - not_open_access_1 = self._licensepool(edition, - open_access=False, data_source_name=DataSource.OVERDRIVE) - not_open_access_2 = self._licensepool(edition, - open_access=False, data_source_name=DataSource.BIBLIOTHECA) - not_open_access_3 = self._licensepool(edition, - open_access=False, data_source_name=DataSource.AXIS_360) - not_open_access_4 = self._licensepool(edition, - open_access=False, data_source_name=DataSource.ODILO) + not_open_access_1 = self._licensepool( + edition, open_access=False, data_source_name=DataSource.OVERDRIVE + ) + not_open_access_2 = self._licensepool( + edition, open_access=False, data_source_name=DataSource.BIBLIOTHECA + ) + not_open_access_3 = self._licensepool( + edition, open_access=False, data_source_name=DataSource.AXIS_360 + ) + not_open_access_4 = self._licensepool( + edition, open_access=False, data_source_name=DataSource.ODILO + ) # Here's a collection that is the source of truth for its # loans and holds, rather than mirroring loan and hold information # from some remote source. sot_collection = self._collection( "Source of Truth", - protocol=random.choice(LoanReaper.SOURCE_OF_TRUTH_PROTOCOLS) + protocol=random.choice(LoanReaper.SOURCE_OF_TRUTH_PROTOCOLS), ) edition2 = self._edition(with_license_pool=False) sot_lp1 = self._licensepool( - edition2, open_access=False, + edition2, + open_access=False, data_source_name=DataSource.OVERDRIVE, - collection=sot_collection + collection=sot_collection, ) sot_lp2 = self._licensepool( - edition2, open_access=False, + edition2, + open_access=False, data_source_name=DataSource.BIBLIOTHECA, - collection=sot_collection + collection=sot_collection, ) now = utc_now() @@ -100,45 +99,41 @@ def test_reaping(self): # This hold expired without ever becoming a loan (that we saw). not_open_access_2.on_hold_to( - inactive_patron, - start=even_longer, - end=a_long_time_ago + inactive_patron, start=even_longer, end=a_long_time_ago ) # This hold has no end date and is older than a year. not_open_access_3.on_hold_to( - inactive_patron, start=a_long_time_ago, end=None, + inactive_patron, + start=a_long_time_ago, + end=None, ) # This loan has no end date and is older than 90 days. not_open_access_4.loan_to( - inactive_patron, start=a_long_time_ago, end=None, + inactive_patron, + start=a_long_time_ago, + end=None, ) # This loan has no end date, but it's for an open-access work. open_access_loan, ignore = open_access.loan_to( - inactive_patron, start=a_long_time_ago, end=None, + inactive_patron, + start=a_long_time_ago, + end=None, ) # This loan has not expired yet. - not_open_access_1.loan_to( - current_patron, start=now, end=the_future - ) + not_open_access_1.loan_to(current_patron, start=now, end=the_future) # This hold has not expired yet. - not_open_access_2.on_hold_to( - current_patron, start=now, end=the_future - ) + not_open_access_2.on_hold_to(current_patron, start=now, end=the_future) # This loan has no end date but is pretty recent. - not_open_access_3.loan_to( - current_patron, start=not_very_long_ago, end=None - ) + not_open_access_3.loan_to(current_patron, start=not_very_long_ago, end=None) # This hold has no end date but is pretty recent. - not_open_access_4.on_hold_to( - current_patron, start=not_very_long_ago, end=None - ) + not_open_access_4.on_hold_to(current_patron, start=not_very_long_ago, end=None) # Reapers will not touch loans or holds from the # source-of-truth collection, even ones that have 'obviously' @@ -188,7 +183,6 @@ def test_reaping(self): class TestIdlingAnnotationReaper(DatabaseTest): - def test_where_clause(self): # Two books. @@ -204,8 +198,9 @@ def test_where_clause(self): not_that_old = now - datetime.timedelta(days=59) very_old = now - datetime.timedelta(days=61) - def _annotation(patron, pool, content, motivation=Annotation.IDLING, - timestamp=very_old): + def _annotation( + patron, pool, content, motivation=Annotation.IDLING, timestamp=very_old + ): annotation, ignore = Annotation.get_one_or_create( self._db, patron=patron, @@ -240,9 +235,7 @@ def _annotation(patron, pool, content, motivation=Annotation.IDLING, # The second patron has a non-old idling annotation for the # second book, which will not be reaped (even though there is # no active loan or hold) because it's not old enough. - new_idling = _annotation( - p2, lp2, "recent", timestamp=not_that_old - ) + new_idling = _annotation(p2, lp2, "recent", timestamp=not_that_old) reaper = IdlingAnnotationReaper(self._db) qu = self._db.query(Annotation).filter(reaper.where_clause) assert [reapable] == qu.all() diff --git a/tests/test_novelist.py b/tests/test_novelist.py index 99e161c146..faae5459fe 100644 --- a/tests/test_novelist.py +++ b/tests/test_novelist.py @@ -2,29 +2,23 @@ import json import pytest -from core.testing import DatabaseTest -from . import sample_data +from api.config import CannotLoadConfiguration +from api.novelist import MockNoveListAPI, NoveListAPI from core.metadata_layer import Metadata from core.model import ( - get_one, - get_one_or_create, DataSource, Edition, ExternalIntegration, Identifier, Representation, + get_one, + get_one_or_create, ) -from core.testing import DummyHTTPClient -from api.config import CannotLoadConfiguration -from api.novelist import ( - MockNoveListAPI, - NoveListAPI, -) -from core.util.http import ( - HTTP -) -from core.testing import MockRequestsResponse +from core.testing import DatabaseTest, DummyHTTPClient, MockRequestsResponse +from core.util.http import HTTP + +from . import sample_data class TestNoveListAPI(DatabaseTest): @@ -34,8 +28,10 @@ def setup_method(self): super(TestNoveListAPI, self).setup_method() self.integration = self._external_integration( ExternalIntegration.NOVELIST, - ExternalIntegration.METADATA_GOAL, username='library', - password='yep', libraries=[self._default_library], + ExternalIntegration.METADATA_GOAL, + username="library", + password="yep", + libraries=[self._default_library], ) self.novelist = NoveListAPI.from_config(self._default_library) @@ -44,13 +40,11 @@ def teardown_method(self): super(TestNoveListAPI, self).teardown_method() def sample_data(self, filename): - return sample_data(filename, 'novelist') + return sample_data(filename, "novelist") def sample_representation(self, filename): content = self.sample_data(filename) - return self._representation( - media_type='application/json', content=content - )[0] + return self._representation(media_type="application/json", content=content)[0] def test_from_config(self): """Confirms that NoveListAPI can be built from config successfully""" @@ -61,11 +55,15 @@ def test_from_config(self): # Without either configuration value, an error is raised. self.integration.password = None - pytest.raises(CannotLoadConfiguration, NoveListAPI.from_config, self._default_library) + pytest.raises( + CannotLoadConfiguration, NoveListAPI.from_config, self._default_library + ) - self.integration.password = 'yep' + self.integration.password = "yep" self.integration.username = None - pytest.raises(CannotLoadConfiguration, NoveListAPI.from_config, self._default_library) + pytest.raises( + CannotLoadConfiguration, NoveListAPI.from_config, self._default_library + ) def test_is_configured(self): # If an ExternalIntegration exists, the API is_configured @@ -80,11 +78,19 @@ def test_is_configured(self): assert library.id == NoveListAPI._configuration_library_id def test_review_response(self): - invalid_credential_response = (403, {}, b'HTML Access Denied page') - pytest.raises(Exception, self.novelist.review_response, invalid_credential_response) + invalid_credential_response = (403, {}, b"HTML Access Denied page") + pytest.raises( + Exception, self.novelist.review_response, invalid_credential_response + ) - missing_argument_response = (200, {}, b'"Missing ISBN, UPC, or Client Identifier!"') - pytest.raises(Exception, self.novelist.review_response, missing_argument_response) + missing_argument_response = ( + 200, + {}, + b'"Missing ISBN, UPC, or Client Identifier!"', + ) + pytest.raises( + Exception, self.novelist.review_response, missing_argument_response + ) response = (200, {}, b"Here's the goods!") assert response == self.novelist.review_response(response) @@ -99,7 +105,7 @@ def test_lookup_info_to_metadata(self): assert True == isinstance(metadata, Metadata) assert Identifier.NOVELIST_ID == metadata.primary_identifier.type - assert '10392078' == metadata.primary_identifier.identifier + assert "10392078" == metadata.primary_identifier.identifier assert "A bad character" == metadata.title assert None == metadata.subtitle assert 1 == len(metadata.contributors) @@ -118,12 +124,12 @@ def test_lookup_info_to_metadata(self): vampire = self.sample_representation("vampire_kisses.json") metadata = self.novelist.lookup_info_to_metadata(vampire) - [lexile] = filter(lambda s: s.type=='Lexile', metadata.subjects) - assert '630' == lexile.identifier - assert 'Vampire kisses manga' == metadata.series + [lexile] = filter(lambda s: s.type == "Lexile", metadata.subjects) + assert "630" == lexile.identifier + assert "Vampire kisses manga" == metadata.series # The full title should be selected, since every volume # has the same main title: 'Vampire kisses' - assert 'Vampire kisses: blood relatives. Volume 1' == metadata.title + assert "Vampire kisses: blood relatives. Volume 1" == metadata.title assert 1 == metadata.series_position assert 5 == len(metadata.recommendations) @@ -131,58 +137,63 @@ def test_get_series_information(self): metadata = Metadata(data_source=DataSource.NOVELIST) vampire = json.loads(self.sample_data("vampire_kisses.json")) - book_info = vampire['TitleInfo'] - series_info = vampire['FeatureContent']['SeriesInfo'] + book_info = vampire["TitleInfo"] + series_info = vampire["FeatureContent"]["SeriesInfo"] (metadata, ideal_title_key) = self.novelist.get_series_information( metadata, series_info, book_info ) # Relevant series information is extracted - assert 'Vampire kisses manga' == metadata.series + assert "Vampire kisses manga" == metadata.series assert 1 == metadata.series_position # The 'full_title' key should be returned as ideal because # all the volumes have the same 'main_title' - assert 'full_title' == ideal_title_key - + assert "full_title" == ideal_title_key watchman = json.loads(self.sample_data("alternate_series_example.json")) - book_info = watchman['TitleInfo'] - series_info = watchman['FeatureContent']['SeriesInfo'] + book_info = watchman["TitleInfo"] + series_info = watchman["FeatureContent"]["SeriesInfo"] # Confirms that the new example doesn't match any volume's full title - assert [] == [v for v in series_info['series_titles'] - if v.get('full_title')==book_info.get('full_title')] + assert [] == [ + v + for v in series_info["series_titles"] + if v.get("full_title") == book_info.get("full_title") + ] # But it still finds its matching volume (metadata, ideal_title_key) = self.novelist.get_series_information( metadata, series_info, book_info ) - assert 'Elvis Cole/Joe Pike novels' == metadata.series + assert "Elvis Cole/Joe Pike novels" == metadata.series assert 11 == metadata.series_position # And recommends using the main_title - assert 'main_title' == ideal_title_key + assert "main_title" == ideal_title_key # If the volume is found in the series more than once... book_info = dict( - main_title='The Baby-Sitters Club', - full_title='The Baby-Sitters Club: Claudia and Mean Janine' + main_title="The Baby-Sitters Club", + full_title="The Baby-Sitters Club: Claudia and Mean Janine", ) series_info = dict( - full_title='The Baby-Sitters Club series', + full_title="The Baby-Sitters Club series", series_titles=[ # The volume is here twice! book_info, book_info, dict( - full_title='The Baby-Sitters Club', - main_title='The Baby-Sitters Club: Claudia and Mean Janine', - series_position='3.' - ) - ] + full_title="The Baby-Sitters Club", + main_title="The Baby-Sitters Club: Claudia and Mean Janine", + series_position="3.", + ), + ], ) # An error is raised. pytest.raises( - ValueError, self.novelist.get_series_information, - metadata, series_info, book_info + ValueError, + self.novelist.get_series_information, + metadata, + series_info, + book_info, ) def test_lookup(self): @@ -221,21 +232,25 @@ def lookup_info_to_metadata(self, representation): assert params1 == params2 assert ( - dict(profile=novelist.profile, - ClientIdentifier=identifier.urn, - ISBN=identifier.identifier, - password=novelist.password, - version=novelist.version, - ) == - params1) + dict( + profile=novelist.profile, + ClientIdentifier=identifier.urn, + ISBN=identifier.identifier, + password=novelist.password, + version=novelist.version, + ) + == params1 + ) # The HTTP request went out to the query URL -- not the scrubbed URL. assert ["http://query-url/"] == h.requests # The HTTP response was passed into novelist.review_response() assert ( - (200, {'content-type': 'text/html'}, b'yay') == - novelist.review_response_called_with) + 200, + {"content-type": "text/html"}, + b"yay", + ) == novelist.review_response_called_with # Finally, the Representation was passed into # lookup_info_to_metadata, which returned a hard-coded string @@ -264,31 +279,31 @@ def test_lookup_info_to_metadata_ignores_empty_responses(self): def test_build_query_url(self): params = dict( - ClientIdentifier='C I', - ISBN='456', - version='2.2', - profile='username', - password='secret' + ClientIdentifier="C I", + ISBN="456", + version="2.2", + profile="username", + password="secret", ) # Authentication information is included in the URL by default full_result = self.novelist.build_query_url(params) - auth_details = '&profile=username&password=secret' + auth_details = "&profile=username&password=secret" assert True == full_result.endswith(auth_details) - assert 'profile=username' in full_result - assert 'password=secret' in full_result + assert "profile=username" in full_result + assert "password=secret" in full_result # With a scrub, no authentication information is included. scrubbed_result = self.novelist.build_query_url(params, include_auth=False) assert False == scrubbed_result.endswith(auth_details) - assert 'profile=username' not in scrubbed_result - assert 'password=secret' not in scrubbed_result + assert "profile=username" not in scrubbed_result + assert "password=secret" not in scrubbed_result # Other details are urlencoded and available in both versions. for url in (scrubbed_result, full_result): - assert 'ClientIdentifier=C%20I' in url - assert 'ISBN=456' in url - assert 'version=2.2' in url + assert "ClientIdentifier=C%20I" in url + assert "ISBN=456" in url + assert "version=2.2" in url # The method to create a scrubbed url returns the same result # as the NoveListAPI.build_query_url @@ -299,17 +314,17 @@ def test_scrub_subtitle(self): scrub = self.novelist._scrub_subtitle assert None == scrub(None) - assert None == scrub('[electronic resource]') - assert None == scrub('[electronic resource] : ') - assert 'A Biomythography' == scrub('[electronic resource] : A Biomythography') + assert None == scrub("[electronic resource]") + assert None == scrub("[electronic resource] : ") + assert "A Biomythography" == scrub("[electronic resource] : A Biomythography") def test_confirm_same_identifier(self): source = DataSource.lookup(self._db, DataSource.NOVELIST) identifier, ignore = Identifier.for_foreign_id( - self._db, Identifier.NOVELIST_ID, '84752928' + self._db, Identifier.NOVELIST_ID, "84752928" ) unmatched_identifier, ignore = Identifier.for_foreign_id( - self._db, Identifier.NOVELIST_ID, '23781947' + self._db, Identifier.NOVELIST_ID, "23781947" ) metadata = Metadata(source, primary_identifier=identifier) match = Metadata(source, primary_identifier=identifier) @@ -341,8 +356,10 @@ def test_lookup_equivalent_isbns(self): # Create an API class that can mockout NoveListAPI.choose_best_metadata class MockBestMetadataAPI(MockNoveListAPI): choose_best_metadata_return = None + def choose_best_metadata(self, *args, **kwargs): return self.choose_best_metadata_return + api = MockBestMetadataAPI.from_config(self._default_library) # Give the identifier another ISBN equivalent. @@ -386,11 +403,17 @@ def test_choose_best_metadata(self): assert 1.0 == result[1] # When top identifiers have equal representation, the method returns none. - metadatas.append(Metadata(DataSource.NOVELIST, primary_identifier=less_identifier)) - assert (None, None) == self.novelist.choose_best_metadata(metadatas, self._identifier()) + metadatas.append( + Metadata(DataSource.NOVELIST, primary_identifier=less_identifier) + ) + assert (None, None) == self.novelist.choose_best_metadata( + metadatas, self._identifier() + ) # But when one pulls ahead, we get the metadata object again. - metadatas.append(Metadata(DataSource.NOVELIST, primary_identifier=more_identifier)) + metadatas.append( + Metadata(DataSource.NOVELIST, primary_identifier=more_identifier) + ) result = self.novelist.choose_best_metadata(metadatas, self._identifier()) assert True == isinstance(result, tuple) metadata, confidence = result @@ -404,26 +427,37 @@ def test_get_items_from_query(self): assert items == [] # Set up a book for this library. - edition = self._edition(identifier_type=Identifier.ISBN, publication_date="2012-01-01") + edition = self._edition( + identifier_type=Identifier.ISBN, publication_date="2012-01-01" + ) pool = self._licensepool(edition, collection=self._default_collection) - contributor = self._contributor(sort_name=edition.sort_author, name=edition.author) + contributor = self._contributor( + sort_name=edition.sort_author, name=edition.author + ) items = self.novelist.get_items_from_query(self._default_library) item = dict( author=contributor[0]._sort_name, title=edition.title, - mediaType=self.novelist.medium_to_book_format_type_values.get(edition.medium, ""), + mediaType=self.novelist.medium_to_book_format_type_values.get( + edition.medium, "" + ), isbn=edition.primary_identifier.identifier, distributor=edition.data_source.name, - publicationDate=edition.published.strftime("%Y%m%d") + publicationDate=edition.published.strftime("%Y%m%d"), ) assert items == [item] def test_create_item_object(self): # We pass no identifier or item to process so we get nothing back. - (currentIdentifier, existingItem, newItem, addItem) = self.novelist.create_item_object(None, None, None) + ( + currentIdentifier, + existingItem, + newItem, + addItem, + ) = self.novelist.create_item_object(None, None, None) assert currentIdentifier == None assert existingItem == None assert newItem == None @@ -435,25 +469,49 @@ def test_create_item_object(self): # contribution role, contributor sort name # distributor) book1_from_query = ( - "12345", "Axis 360 ID", "23456", - "Title 1", "Book", datetime.date(2002, 1, 1), - "Author", "Author 1", - "Gutenberg") + "12345", + "Axis 360 ID", + "23456", + "Title 1", + "Book", + datetime.date(2002, 1, 1), + "Author", + "Author 1", + "Gutenberg", + ) book1_from_query_primary_author = ( - "12345", "Axis 360 ID", "23456", - "Title 1", "Book", datetime.date(2002, 1, 1), - "Primary Author", "Author 2", - "Gutenberg") + "12345", + "Axis 360 ID", + "23456", + "Title 1", + "Book", + datetime.date(2002, 1, 1), + "Primary Author", + "Author 2", + "Gutenberg", + ) book1_narrator_from_query = ( - "12345", "Axis 360 ID", "23456", - "Title 1", "Book", datetime.date(2002, 1, 1), - "Narrator", "Narrator 1", - "Gutenberg") + "12345", + "Axis 360 ID", + "23456", + "Title 1", + "Book", + datetime.date(2002, 1, 1), + "Narrator", + "Narrator 1", + "Gutenberg", + ) book2_from_query = ( - "34567", "Axis 360 ID", "56789", - "Title 2", "Book", datetime.date(1414, 1, 1), - "Author", "Author 3", - "Gutenberg") + "34567", + "Axis 360 ID", + "56789", + "Title 2", + "Book", + datetime.date(1414, 1, 1), + "Author", + "Author 3", + "Gutenberg", + ) (currentIdentifier, existingItem, newItem, addItem) = ( # params: new item, identifier, existing item @@ -461,16 +519,15 @@ def test_create_item_object(self): ) assert currentIdentifier == book1_from_query[2] assert existingItem == None - assert ( - newItem == - {"isbn": "23456", + assert newItem == { + "isbn": "23456", "mediaType": "EBook", "title": "Title 1", "role": "Author", "author": "Author 1", "distributor": "Gutenberg", - "publicationDate": "20020101" - }) + "publicationDate": "20020101", + } # We want to still process this item along with the next one in case # the following one has the same ISBN. assert addItem == False @@ -479,33 +536,39 @@ def test_create_item_object(self): # We are now processing the previous object along with the new one. # This is to check and update the value for `author` if the role changes # to `Primary Author`. - (currentIdentifier, existingItem, newItem, addItem) = ( - self.novelist.create_item_object( - book1_from_query_primary_author, - currentIdentifier, - newItem - ) + ( + currentIdentifier, + existingItem, + newItem, + addItem, + ) = self.novelist.create_item_object( + book1_from_query_primary_author, currentIdentifier, newItem ) assert currentIdentifier == book1_from_query[2] - assert (existingItem == - {"isbn": "23456", + assert existingItem == { + "isbn": "23456", "mediaType": "EBook", "title": "Title 1", "author": "Author 2", "role": "Primary Author", "distributor": "Gutenberg", - "publicationDate": "20020101" - }) + "publicationDate": "20020101", + } assert newItem == None assert addItem == False # Test that a narrator gets added along with an author. - (currentIdentifier, existingItem, newItem, addItem) = ( - self.novelist.create_item_object(book1_narrator_from_query, currentIdentifier, existingItem) + ( + currentIdentifier, + existingItem, + newItem, + addItem, + ) = self.novelist.create_item_object( + book1_narrator_from_query, currentIdentifier, existingItem ) assert currentIdentifier == book1_narrator_from_query[2] - assert (existingItem == - {"isbn": "23456", + assert existingItem == { + "isbn": "23456", "mediaType": "EBook", "title": "Title 1", "author": "Author 2", @@ -515,57 +578,64 @@ def test_create_item_object(self): "role": "Narrator", "narrator": "Narrator 1", "distributor": "Gutenberg", - "publicationDate": "20020101" - }) + "publicationDate": "20020101", + } assert newItem == None assert addItem == False # New Object - (currentIdentifier, existingItem, newItem, addItem) = ( - self.novelist.create_item_object(book2_from_query, currentIdentifier, existingItem) + ( + currentIdentifier, + existingItem, + newItem, + addItem, + ) = self.novelist.create_item_object( + book2_from_query, currentIdentifier, existingItem ) assert currentIdentifier == book2_from_query[2] - assert (existingItem == - {"isbn": "23456", + assert existingItem == { + "isbn": "23456", "mediaType": "EBook", "title": "Title 1", "author": "Author 2", "role": "Narrator", "narrator": "Narrator 1", "distributor": "Gutenberg", - "publicationDate": "20020101" - }) - assert (newItem == - {"isbn": "56789", + "publicationDate": "20020101", + } + assert newItem == { + "isbn": "56789", "mediaType": "EBook", "title": "Title 2", "role": "Author", "author": "Author 3", "distributor": "Gutenberg", - "publicationDate": "14140101" - }) + "publicationDate": "14140101", + } assert addItem == True # New Object # Test that a narrator got added but not an author - (currentIdentifier, existingItem, newItem, addItem) = ( - self.novelist.create_item_object(book1_narrator_from_query, None, None) - ) + ( + currentIdentifier, + existingItem, + newItem, + addItem, + ) = self.novelist.create_item_object(book1_narrator_from_query, None, None) assert currentIdentifier == book1_narrator_from_query[2] assert existingItem == None - assert (newItem == - {"isbn": "23456", + assert newItem == { + "isbn": "23456", "mediaType": "EBook", "title": "Title 1", "role": "Narrator", "narrator": "Narrator 1", "distributor": "Gutenberg", - "publicationDate": "20020101" - }) + "publicationDate": "20020101", + } assert addItem == False - def test_put_items_novelist(self): response = self.novelist.put_items_novelist(self._default_library) @@ -573,7 +643,7 @@ def test_put_items_novelist(self): edition = self._edition(identifier_type=Identifier.ISBN) pool = self._licensepool(edition, collection=self._default_collection) - mock_response = {'Customer': 'NYPL', 'RecordsReceived': 10} + mock_response = {"Customer": "NYPL", "RecordsReceived": 10} def mockHTTPPut(url, headers, **kwargs): return MockRequestsResponse(200, content=json.dumps(mock_response)) @@ -591,21 +661,25 @@ def test_make_novelist_data_object(self): bad_data = [] result = self.novelist.make_novelist_data_object(bad_data) - assert result == { - "customer": "library:yep", - "records": [] - } + assert result == {"customer": "library:yep", "records": []} data = [ - {"isbn":"12345", "mediaType": "http://schema.org/EBook", "title": "Book 1", "author": "Author 1" }, - {"isbn":"12346", "mediaType": "http://schema.org/EBook", "title": "Book 2", "author": "Author 2" }, + { + "isbn": "12345", + "mediaType": "http://schema.org/EBook", + "title": "Book 1", + "author": "Author 1", + }, + { + "isbn": "12346", + "mediaType": "http://schema.org/EBook", + "title": "Book 2", + "author": "Author 2", + }, ] result = self.novelist.make_novelist_data_object(data) - assert result == { - "customer": "library:yep", - "records": data - } + assert result == {"customer": "library:yep", "records": data} def mockHTTPPut(self, *args, **kwargs): self.called_with = (args, kwargs) diff --git a/tests/test_nyt.py b/tests/test_nyt.py index efe7b8aa1f..f810869de5 100644 --- a/tests/test_nyt.py +++ b/tests/test_nyt.py @@ -1,40 +1,29 @@ # encoding: utf-8 +import datetime +import json import os from pdb import set_trace -import pytest -import datetime + import dateutil -import json +import pytest -from core.testing import ( - DatabaseTest, -) -from core.testing import DummyMetadataClient +from api.nyt import NYTAPI, NYTBestSellerAPI, NYTBestSellerList, NYTBestSellerListTitle from core.config import CannotLoadConfiguration -from api.nyt import ( - NYTAPI, - NYTBestSellerAPI, - NYTBestSellerList, - NYTBestSellerListTitle, -) from core.model import ( Contributor, + CustomListEntry, Edition, ExternalIntegration, Hyperlink, Identifier, Resource, - CustomListEntry, -) -from core.opds_import import ( - MetadataWranglerOPDSLookup, - MockMetadataWranglerOPDSLookup ) +from core.opds_import import MetadataWranglerOPDSLookup, MockMetadataWranglerOPDSLookup +from core.testing import DatabaseTest, DummyMetadataClient from core.util.http import IntegrationException class DummyNYTBestSellerAPI(NYTBestSellerAPI): - def __init__(self, _db): self._db = _db self.metadata_client = DummyMetadataClient() @@ -51,13 +40,16 @@ def list_of_lists(self): def update(self, list, date=None, max_age=None): if date: - filename = "list_%s_%s.json" % (list.foreign_identifier, self.date_string(date)) + filename = "list_%s_%s.json" % ( + list.foreign_identifier, + self.date_string(date), + ) else: filename = "list_%s.json" % list.foreign_identifier list.update(self.sample_json(filename)) -class NYTBestSellerAPITest(DatabaseTest): +class NYTBestSellerAPITest(DatabaseTest): def setup_method(self): super(NYTBestSellerAPITest, self).setup_method() self.api = DummyNYTBestSellerAPI(self._db) @@ -80,8 +72,7 @@ def test_from_config(self): NYTBestSellerAPI.from_config(self._db) assert "No ExternalIntegration found for the NYT." in str(excinfo.value) integration = self._external_integration( - protocol=ExternalIntegration.NYT, - goal=ExternalIntegration.METADATA_GOAL + protocol=ExternalIntegration.NYT, goal=ExternalIntegration.METADATA_GOAL ) # It has to have the api key in its 'password' setting. @@ -100,7 +91,7 @@ def test_from_config(self): # But if you do, it's picked up. mw = self._external_integration( protocol=ExternalIntegration.METADATA_WRANGLER, - goal=ExternalIntegration.METADATA_GOAL + goal=ExternalIntegration.METADATA_GOAL, ) mw.url = self._url @@ -116,6 +107,7 @@ def test_run_self_tests(self): class Mock(NYTBestSellerAPI): def __init__(self): pass + def list_of_lists(self): return "some lists" @@ -126,21 +118,24 @@ def list_of_lists(self): def test_list_of_lists(self): all_lists = self.api.list_of_lists() - assert (['copyright', 'num_results', 'results', 'status'] == - sorted(all_lists.keys())) - assert 47 == len(all_lists['results']) + assert ["copyright", "num_results", "results", "status"] == sorted( + all_lists.keys() + ) + assert 47 == len(all_lists["results"]) def test_list_info(self): list_info = self.api.list_info("combined-print-and-e-book-fiction") - assert "Combined Print & E-Book Fiction" == list_info['display_name'] + assert "Combined Print & E-Book Fiction" == list_info["display_name"] def test_request_failure(self): # Verify that certain unexpected HTTP results are turned into # IntegrationExceptions. self.api.api_key = "some key" + def result_403(*args, **kwargs): return 403, None, None + self.api.do_get = result_403 with pytest.raises(IntegrationException) as excinfo: self.api.request("some path") @@ -148,6 +143,7 @@ def result_403(*args, **kwargs): def result_500(*args, **kwargs): return 500, {}, "bad value" + self.api.do_get = result_500 try: self.api.request("some path") @@ -157,6 +153,7 @@ def result_500(*args, **kwargs): assert e.debug_message.startswith("Response from") assert e.debug_message.endswith("was: 'bad value'") + class TestNYTBestSellerList(NYTBestSellerAPITest): """Test the NYTBestSellerList object and its ability to be turned @@ -181,7 +178,7 @@ def test_medium(self): def test_update(self): list_name = "combined-print-and-e-book-fiction" - self.metadata_client.lookups['Paula Hawkins'] = 'Hawkins, Paula' + self.metadata_client.lookups["Paula Hawkins"] = "Hawkins, Paula" l = self.api.best_seller_list(list_name) self.api.update(l) @@ -192,7 +189,7 @@ def test_update(self): assert list_name == l.foreign_identifier # Let's do a spot check on the list items. - title = [x for x in l if x.metadata.title=='THE GIRL ON THE TRAIN'][0] + title = [x for x in l if x.metadata.title == "THE GIRL ON THE TRAIN"][0] [isbn] = title.metadata.identifiers assert "ISBN" == isbn.type assert "9780698185395" == isbn.identifier @@ -203,8 +200,10 @@ def test_update(self): [contributor] = title.metadata.contributors assert "Paula Hawkins" == contributor.display_name assert "Riverhead" == title.metadata.publisher - assert ("A psychological thriller set in London is full of complications and betrayals." == - title.annotation) + assert ( + "A psychological thriller set in London is full of complications and betrayals." + == title.annotation + ) assert self._midnight(2015, 1, 17) == title.first_appearance assert self._midnight(2015, 2, 1) == title.most_recent_appearance @@ -221,7 +220,7 @@ def test_historical_dates(self): def test_to_customlist(self): list_name = "combined-print-and-e-book-fiction" - self.metadata_client.lookups['Paula Hawkins'] = 'Hawkins, Paula' + self.metadata_client.lookups["Paula Hawkins"] = "Hawkins, Paula" l = self.api.best_seller_list(list_name) self.api.update(l) custom = l.to_customlist(self._db) @@ -229,25 +228,22 @@ def test_to_customlist(self): assert custom.updated == l.updated assert custom.name == l.name assert len(l) == len(custom.entries) - assert True == all([isinstance(x, CustomListEntry) - for x in custom.entries]) + assert True == all([isinstance(x, CustomListEntry) for x in custom.entries]) assert 20 == len(custom.entries) # The publication of a NYT best-seller list is treated as # midnight Eastern time on the publication date. jan_17 = self._midnight(2015, 1, 17) - assert (True == - all([x.first_appearance == jan_17 for x in custom.entries])) + assert True == all([x.first_appearance == jan_17 for x in custom.entries]) feb_1 = self._midnight(2015, 2, 1) - assert (True == - all([x.most_recent_appearance == feb_1 for x in custom.entries])) + assert True == all([x.most_recent_appearance == feb_1 for x in custom.entries]) # Now replace this list's entries with the entries from a # different list. We wouldn't do this in real life, but it's # a convenient way to change the contents of a list. - other_nyt_list = self.api.best_seller_list('hardcover-fiction') + other_nyt_list = self.api.best_seller_list("hardcover-fiction") self.api.update(other_nyt_list) other_nyt_list.update_custom_list(custom) @@ -267,7 +263,9 @@ def test_fill_in_history(self): class TestNYTBestSellerListTitle(NYTBestSellerAPITest): - one_list_title = json.loads("""{"list_name":"Combined Print and E-Book Fiction","display_name":"Combined Print & E-Book Fiction","bestsellers_date":"2015-01-17","published_date":"2015-02-01","rank":1,"rank_last_week":0,"weeks_on_list":1,"asterisk":0,"dagger":0,"amazon_product_url":"http:\/\/www.amazon.com\/The-Girl-Train-A-Novel-ebook\/dp\/B00L9B7IKE?tag=thenewyorktim-20","isbns":[{"isbn10":"1594633665","isbn13":"9781594633669"},{"isbn10":"0698185390","isbn13":"9780698185395"}],"book_details":[{"title":"THE GIRL ON THE TRAIN","description":"A psychological thriller set in London is full of complications and betrayals.","contributor":"by Paula Hawkins","author":"Paula Hawkins","contributor_note":"","price":0,"age_group":"","publisher":"Riverhead","isbns":[{"isbn10":"1594633665","isbn13":"9781594633669"},{"isbn10":"0698185390","isbn13":"9780698185395"}],"primary_isbn13":"9780698185395","primary_isbn10":"0698185390"}],"reviews":[{"book_review_link":"","first_chapter_link":"","sunday_review_link":"","article_chapter_link":""}]}""") + one_list_title = json.loads( + """{"list_name":"Combined Print and E-Book Fiction","display_name":"Combined Print & E-Book Fiction","bestsellers_date":"2015-01-17","published_date":"2015-02-01","rank":1,"rank_last_week":0,"weeks_on_list":1,"asterisk":0,"dagger":0,"amazon_product_url":"http:\/\/www.amazon.com\/The-Girl-Train-A-Novel-ebook\/dp\/B00L9B7IKE?tag=thenewyorktim-20","isbns":[{"isbn10":"1594633665","isbn13":"9781594633669"},{"isbn10":"0698185390","isbn13":"9780698185395"}],"book_details":[{"title":"THE GIRL ON THE TRAIN","description":"A psychological thriller set in London is full of complications and betrayals.","contributor":"by Paula Hawkins","author":"Paula Hawkins","contributor_note":"","price":0,"age_group":"","publisher":"Riverhead","isbns":[{"isbn10":"1594633665","isbn13":"9781594633669"},{"isbn10":"0698185390","isbn13":"9780698185395"}],"primary_isbn13":"9780698185395","primary_isbn10":"0698185390"}],"reviews":[{"book_review_link":"","first_chapter_link":"","sunday_review_link":"","article_chapter_link":""}]}""" + ) def test_creation(self): title = NYTBestSellerListTitle(self.one_list_title, Edition.BOOK_MEDIUM) @@ -293,8 +291,7 @@ def test_creation(self): assert "Riverhead" == edition.publisher def test_to_edition_sets_sort_author_name_if_obvious(self): - [contributor], ignore = Contributor.lookup( - self._db, "Hawkins, Paula") + [contributor], ignore = Contributor.lookup(self._db, "Hawkins, Paula") contributor.display_name = "Paula Hawkins" title = NYTBestSellerListTitle(self.one_list_title, Edition.BOOK_MEDIUM) diff --git a/tests/test_odilo.py b/tests/test_odilo.py index ea10e0733d..01e4f1819d 100644 --- a/tests/test_odilo.py +++ b/tests/test_odilo.py @@ -1,34 +1,20 @@ # encoding: utf-8 import json - -import pytest - import os -from core.util.http import ( - BadResponseException, -) +import pytest from api.authenticator import BasicAuthenticationProvider - +from api.circulation import CirculationAPI +from api.circulation_exceptions import * from api.odilo import ( - OdiloAPI, MockOdiloAPI, - OdiloRepresentationExtractor, + OdiloAPI, OdiloBibliographicCoverageProvider, - OdiloCirculationMonitor -) - -from api.circulation import ( - CirculationAPI, + OdiloCirculationMonitor, + OdiloRepresentationExtractor, ) - -from api.circulation_exceptions import * - -from . import sample_data - from core.metadata_layer import TimestampData - from core.model import ( Classification, Contributor, @@ -40,25 +26,22 @@ Identifier, Representation, ) +from core.testing import DatabaseTest, MockRequestsResponse +from core.util.datetime_helpers import datetime_utc, utc_now +from core.util.http import BadResponseException + +from . import sample_data -from core.testing import ( - DatabaseTest, - MockRequestsResponse, -) -from core.util.datetime_helpers import ( - datetime_utc, - utc_now, -) class OdiloAPITest(DatabaseTest): - PIN = 'c4ca4238a0b923820dcc509a6f75849b' - RECORD_ID = '00010982' + PIN = "c4ca4238a0b923820dcc509a6f75849b" + RECORD_ID = "00010982" def setup_method(self): super(OdiloAPITest, self).setup_method() library = self._default_library self.patron = self._patron() - self.patron.authorization_identifier='0001000265' + self.patron.authorization_identifier = "0001000265" self.collection = MockOdiloAPI.mock_collection(self._db) self.circulation = CirculationAPI( self._db, library, api_map={ExternalIntegration.ODILO: MockOdiloAPI} @@ -70,12 +53,12 @@ def setup_method(self): identifier_type=Identifier.ODILO_ID, collection=self.collection, identifier_id=self.RECORD_ID, - with_license_pool=True + with_license_pool=True, ) @classmethod def sample_data(cls, filename): - return sample_data(filename, 'odilo') + return sample_data(filename, "odilo") @classmethod def sample_json(cls, filename): @@ -93,20 +76,21 @@ def error_message(self, error_code, message=None, token=None): class TestOdiloAPI(OdiloAPITest): - def test_token_post_success(self): self.api.queue_response(200, content="some content") response = self.api.token_post(self._url, "the payload") - assert 200 == response.status_code, "Status code != 200 --> %i" % response.status_code + assert 200 == response.status_code, ( + "Status code != 200 --> %i" % response.status_code + ) assert self.api.access_token_response.content == response.content - self.api.log.info('Test token post success ok!') + self.api.log.info("Test token post success ok!") def test_get_success(self): self.api.queue_response(200, content="some content") status_code, headers, content = self.api.get(self._url, {}) assert 200 == status_code assert b"some content" == content - self.api.log.info('Test get success ok!') + self.api.log.info("Test get success ok!") def test_401_on_get_refreshes_bearer_token(self): assert "bearer token" == self.api.token @@ -131,11 +115,10 @@ def test_401_on_get_refreshes_bearer_token(self): # The bearer token has been updated. assert "new bearer token" == self.api.token - self.api.log.info('Test 401 on get refreshes bearer token ok!') + self.api.log.info("Test 401 on get refreshes bearer token ok!") def test_credential_refresh_success(self): - """Verify the process of refreshing the Odilo bearer token. - """ + """Verify the process of refreshing the Odilo bearer token.""" credential = self.api.credential_object(lambda x: x) assert "bearer token" == credential.credential assert self.api.token == credential.credential @@ -169,20 +152,22 @@ def test_credential_refresh_failure(self): of failure on a new setup. """ self.api.access_token_response = MockRequestsResponse( - 200, {"Content-Type": "text/html"}, - "Hi, this is the website, not the API." + 200, {"Content-Type": "text/html"}, "Hi, this is the website, not the API." ) credential = self.api.credential_object(lambda x: x) with pytest.raises(BadResponseException) as excinfo: self.api.refresh_creds(credential) assert "Bad response from " in str(excinfo.value) - assert "may not be the right base URL. Response document was: 'Hi, this is the website, not the API.'" in str(excinfo.value) + assert ( + "may not be the right base URL. Response document was: 'Hi, this is the website, not the API.'" + in str(excinfo.value) + ) # Also test a 400 response code. self.api.access_token_response = MockRequestsResponse( - 400, {"Content-Type": "application/json"}, - - json.dumps(dict(errors=[dict(description="Oops")])) + 400, + {"Content-Type": "application/json"}, + json.dumps(dict(errors=[dict(description="Oops")])), ) with pytest.raises(BadResponseException) as excinfo: self.api.refresh_creds(credential) @@ -192,9 +177,7 @@ def test_credential_refresh_failure(self): # If there's a 400 response but no error information, # the generic error message is used. self.api.access_token_response = MockRequestsResponse( - 400, {"Content-Type": "application/json"}, - - json.dumps(dict()) + 400, {"Content-Type": "application/json"}, json.dumps(dict()) ) with pytest.raises(BadResponseException) as excinfo: self.api.refresh_creds(credential) @@ -218,19 +201,23 @@ def test_401_after_token_refresh_raises_error(self): # That raises a BadResponseException with pytest.raises(BadResponseException) as excinfo: self.api.get(self._url, {}) - assert "Something's wrong with the Odilo OAuth Bearer Token!" in str(excinfo.value) + assert "Something's wrong with the Odilo OAuth Bearer Token!" in str( + excinfo.value + ) # The bearer token has been updated. assert "new bearer token" == self.api.token def test_external_integration(self): - assert (self.collection.external_integration == - self.api.external_integration(self._db)) + assert self.collection.external_integration == self.api.external_integration( + self._db + ) def test__run_self_tests(self): """Verify that OdiloAPI._run_self_tests() calls the right methods. """ + class Mock(MockOdiloAPI): "Mock every method used by OdiloAPI._run_self_tests." @@ -241,6 +228,7 @@ def __init__(self, _db, collection): # First we will call check_creds() to get a fresh credential. mock_credential = object() + def check_creds(self, force_refresh=False): self.check_creds_called_with = force_refresh return self.mock_credential @@ -250,10 +238,9 @@ def check_creds(self, force_refresh=False): # the credentials of that library's test patron. mock_patron_checkouts = object() get_patron_checkouts_called_with = [] + def get_patron_checkouts(self, patron, pin): - self.get_patron_checkouts_called_with.append( - (patron, pin) - ) + self.get_patron_checkouts_called_with.append((patron, pin)) return self.mock_patron_checkouts # Now let's make sure two Libraries have access to this @@ -266,7 +253,7 @@ def get_patron_checkouts(self, patron, pin): integration = self._external_integration( "api.simple_authentication", ExternalIntegration.PATRON_AUTH_GOAL, - libraries=[with_default_patron] + libraries=[with_default_patron], ) p = BasicAuthenticationProvider integration.setting(p.TEST_IDENTIFIER).value = "username1" @@ -274,16 +261,14 @@ def get_patron_checkouts(self, patron, pin): # Now that everything is set up, run the self-test. api = Mock(self._db, self.collection) - results = sorted( - api._run_self_tests(self._db), key=lambda x: x.name - ) + results = sorted(api._run_self_tests(self._db), key=lambda x: x.name) loans_failure, sitewide, loans_success = results # Make sure all three tests were run and got the expected result. # # We got a sitewide access token. - assert 'Obtaining a sitewide access token' == sitewide.name + assert "Obtaining a sitewide access token" == sitewide.name assert True == sitewide.success assert api.mock_credential == sitewide.result assert True == api.check_creds_called_with @@ -291,8 +276,10 @@ def get_patron_checkouts(self, patron, pin): # We got the default patron's checkouts for the library that had # a default patron configured. assert ( - 'Viewing the active loans for the test patron for library %s' % with_default_patron.name == - loans_success.name) + "Viewing the active loans for the test patron for library %s" + % with_default_patron.name + == loans_success.name + ) assert True == loans_success.success # get_patron_checkouts was only called once. [(patron, pin)] = api.get_patron_checkouts_called_with @@ -302,11 +289,11 @@ def get_patron_checkouts(self, patron, pin): # We couldn't get a patron access token for the other library. assert ( - 'Acquiring test patron credentials for library %s' % no_default_patron.name == - loans_failure.name) + "Acquiring test patron credentials for library %s" % no_default_patron.name + == loans_failure.name + ) assert False == loans_failure.success - assert ("Library has no test patron configured." == - str(loans_failure.exception)) + assert "Library has no test patron configured." == str(loans_failure.exception) def test_run_self_tests_short_circuit(self): """If OdiloAPI.check_creds can't get credentials, the rest of @@ -315,8 +302,10 @@ def test_run_self_tests_short_circuit(self): This probably doesn't matter much, because if check_creds doesn't work we won't be able to instantiate the OdiloAPI class. """ + def explode(*args, **kwargs): raise Exception("Failure!") + self.api.check_creds = explode # Only one test will be run. @@ -331,22 +320,40 @@ class TestOdiloCirculationAPI(OdiloAPITest): # Test 404 Not Found --> patron not found --> 'patronNotFound' def test_01_patron_not_found(self): - patron_not_found_data, patron_not_found_json = self.sample_json("error_patron_not_found.json") + patron_not_found_data, patron_not_found_json = self.sample_json( + "error_patron_not_found.json" + ) self.api.queue_response(404, content=patron_not_found_json) patron = self._patron() patron.authorization_identifier = "no such patron" - pytest.raises(PatronNotFoundOnRemote, self.api.checkout, patron, self.PIN, self.licensepool, 'ACSM_EPUB') - self.api.log.info('Test patron not found ok!') + pytest.raises( + PatronNotFoundOnRemote, + self.api.checkout, + patron, + self.PIN, + self.licensepool, + "ACSM_EPUB", + ) + self.api.log.info("Test patron not found ok!") # Test 404 Not Found --> record not found --> 'ERROR_DATA_NOT_FOUND' def test_02_data_not_found(self): - data_not_found_data, data_not_found_json = self.sample_json("error_data_not_found.json") + data_not_found_data, data_not_found_json = self.sample_json( + "error_data_not_found.json" + ) self.api.queue_response(404, content=data_not_found_json) - self.licensepool.identifier.identifier = '12345678' - pytest.raises(NotFoundOnRemote, self.api.checkout, self.patron, self.PIN, self.licensepool, 'ACSM_EPUB') - self.api.log.info('Test resource not found on remote ok!') + self.licensepool.identifier.identifier = "12345678" + pytest.raises( + NotFoundOnRemote, + self.api.checkout, + self.patron, + self.PIN, + self.licensepool, + "ACSM_EPUB", + ) + self.api.log.info("Test resource not found on remote ok!") def test_make_absolute_url(self): @@ -356,11 +363,10 @@ def test_make_absolute_url(self): assert absolute == self.api.library_api_base_url.decode("utf-8") + relative # An absolute URL is not modified. - for protocol in ('http', 'https'): + for protocol in ("http", "https"): already_absolute = "%s://example.com/" % protocol assert already_absolute == self.api._make_absolute_url(already_absolute) - ################# # Checkout tests ################# @@ -368,32 +374,43 @@ def test_make_absolute_url(self): # Test 400 Bad Request --> Invalid format for that resource def test_11_checkout_fake_format(self): self.api.queue_response(400, content="") - pytest.raises(NoAcceptableFormat, self.api.checkout, self.patron, self.PIN, self.licensepool, 'FAKE_FORMAT') - self.api.log.info('Test invalid format for resource ok!') + pytest.raises( + NoAcceptableFormat, + self.api.checkout, + self.patron, + self.PIN, + self.licensepool, + "FAKE_FORMAT", + ) + self.api.log.info("Test invalid format for resource ok!") def test_12_checkout_acsm_epub(self): checkout_data, checkout_json = self.sample_json("checkout_acsm_epub_ok.json") self.api.queue_response(200, content=checkout_json) - self.perform_and_validate_checkout('ACSM_EPUB') + self.perform_and_validate_checkout("ACSM_EPUB") def test_13_checkout_acsm_pdf(self): checkout_data, checkout_json = self.sample_json("checkout_acsm_pdf_ok.json") self.api.queue_response(200, content=checkout_json) - self.perform_and_validate_checkout('ACSM_PDF') + self.perform_and_validate_checkout("ACSM_PDF") def test_14_checkout_ebook_streaming(self): - checkout_data, checkout_json = self.sample_json("checkout_ebook_streaming_ok.json") + checkout_data, checkout_json = self.sample_json( + "checkout_ebook_streaming_ok.json" + ) self.api.queue_response(200, content=checkout_json) - self.perform_and_validate_checkout('EBOOK_STREAMING') + self.perform_and_validate_checkout("EBOOK_STREAMING") def test_mechanism_set_on_borrow(self): """The delivery mechanism for an Odilo title is set on checkout.""" assert OdiloAPI.SET_DELIVERY_MECHANISM_AT == OdiloAPI.BORROW_STEP def perform_and_validate_checkout(self, internal_format): - loan_info = self.api.checkout(self.patron, self.PIN, self.licensepool, internal_format) + loan_info = self.api.checkout( + self.patron, self.PIN, self.licensepool, internal_format + ) assert loan_info, "LoanInfo null --> checkout failed!" - self.api.log.info('Loan ok: %s' % loan_info.identifier) + self.api.log.info("Loan ok: %s" % loan_info.identifier) ################# # Fulfill tests @@ -406,7 +423,7 @@ def test_21_fulfill_acsm_epub(self): acsm_data = self.sample_data("fulfill_ok_acsm_epub.acsm") self.api.queue_response(200, content=acsm_data) - fulfillment_info = self.fulfill('ACSM_EPUB') + fulfillment_info = self.fulfill("ACSM_EPUB") assert fulfillment_info.content_type[0] == Representation.EPUB_MEDIA_TYPE assert fulfillment_info.content_type[1] == DeliveryMechanism.ADOBE_DRM @@ -417,7 +434,7 @@ def test_22_fulfill_acsm_pdf(self): acsm_data = self.sample_data("fulfill_ok_acsm_pdf.acsm") self.api.queue_response(200, content=acsm_data) - fulfillment_info = self.fulfill('ACSM_PDF') + fulfillment_info = self.fulfill("ACSM_PDF") assert fulfillment_info.content_type[0] == Representation.PDF_MEDIA_TYPE assert fulfillment_info.content_type[1] == DeliveryMechanism.ADOBE_DRM @@ -425,19 +442,24 @@ def test_23_fulfill_ebook_streaming(self): checkout_data, checkout_json = self.sample_json("patron_checkouts.json") self.api.queue_response(200, content=checkout_json) - self.licensepool.identifier.identifier = '00011055' - fulfillment_info = self.fulfill('EBOOK_STREAMING') + self.licensepool.identifier.identifier = "00011055" + fulfillment_info = self.fulfill("EBOOK_STREAMING") assert fulfillment_info.content_type[0] == Representation.TEXT_HTML_MEDIA_TYPE - assert fulfillment_info.content_type[1] == DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE + assert ( + fulfillment_info.content_type[1] + == DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE + ) def fulfill(self, internal_format): - fulfillment_info = self.api.fulfill(self.patron, self.PIN, self.licensepool, internal_format) - assert fulfillment_info, 'Cannot Fulfill !!' + fulfillment_info = self.api.fulfill( + self.patron, self.PIN, self.licensepool, internal_format + ) + assert fulfillment_info, "Cannot Fulfill !!" if fulfillment_info.content_link: - self.api.log.info('Fulfill link: %s' % fulfillment_info.content_link) + self.api.log.info("Fulfill link: %s" % fulfillment_info.content_link) if fulfillment_info.content: - self.api.log.info('Fulfill content: %s' % fulfillment_info.content) + self.api.log.info("Fulfill content: %s" % fulfillment_info.content) return fulfillment_info @@ -446,36 +468,52 @@ def fulfill(self, internal_format): ################# def test_31_already_on_hold(self): - already_on_hold_data, already_on_hold_json = self.sample_json("error_hold_already_in_hold.json") + already_on_hold_data, already_on_hold_json = self.sample_json( + "error_hold_already_in_hold.json" + ) self.api.queue_response(403, content=already_on_hold_json) - pytest.raises(AlreadyOnHold, self.api.place_hold, self.patron, self.PIN, self.licensepool, - 'ejcepas@odilotid.es') + pytest.raises( + AlreadyOnHold, + self.api.place_hold, + self.patron, + self.PIN, + self.licensepool, + "ejcepas@odilotid.es", + ) - self.api.log.info('Test hold already on hold ok!') + self.api.log.info("Test hold already on hold ok!") def test_32_place_hold(self): hold_ok_data, hold_ok_json = self.sample_json("place_hold_ok.json") self.api.queue_response(200, content=hold_ok_json) - hold_info = self.api.place_hold(self.patron, self.PIN, self.licensepool, 'ejcepas@odilotid.es') + hold_info = self.api.place_hold( + self.patron, self.PIN, self.licensepool, "ejcepas@odilotid.es" + ) assert hold_info, "HoldInfo null --> place hold failed!" - self.api.log.info('Hold ok: %s' % hold_info.identifier) + self.api.log.info("Hold ok: %s" % hold_info.identifier) ################# # Patron Activity tests ################# def test_41_patron_activity_invalid_patron(self): - patron_not_found_data, patron_not_found_json = self.sample_json("error_patron_not_found.json") + patron_not_found_data, patron_not_found_json = self.sample_json( + "error_patron_not_found.json" + ) self.api.queue_response(404, content=patron_not_found_json) - pytest.raises(PatronNotFoundOnRemote, self.api.patron_activity, self.patron, self.PIN) + pytest.raises( + PatronNotFoundOnRemote, self.api.patron_activity, self.patron, self.PIN + ) - self.api.log.info('Test patron activity --> invalid patron ok!') + self.api.log.info("Test patron activity --> invalid patron ok!") def test_42_patron_activity(self): - patron_checkouts_data, patron_checkouts_json = self.sample_json("patron_checkouts.json") + patron_checkouts_data, patron_checkouts_json = self.sample_json( + "patron_checkouts.json" + ) patron_holds_data, patron_holds_json = self.sample_json("patron_holds.json") self.api.queue_response(200, content=patron_checkouts_json) self.api.queue_response(200, content=patron_holds_json) @@ -483,27 +521,39 @@ def test_42_patron_activity(self): loans_and_holds = self.api.patron_activity(self.patron, self.PIN) assert loans_and_holds assert 12 == len(loans_and_holds) - self.api.log.info('Test patron activity ok !!') + self.api.log.info("Test patron activity ok !!") ################# # Checkin tests ################# def test_51_checkin_patron_not_found(self): - patron_not_found_data, patron_not_found_json = self.sample_json("error_patron_not_found.json") + patron_not_found_data, patron_not_found_json = self.sample_json( + "error_patron_not_found.json" + ) self.api.queue_response(404, content=patron_not_found_json) - pytest.raises(PatronNotFoundOnRemote, self.api.checkin, self.patron, self.PIN, self.licensepool) + pytest.raises( + PatronNotFoundOnRemote, + self.api.checkin, + self.patron, + self.PIN, + self.licensepool, + ) - self.api.log.info('Test checkin --> invalid patron ok!') + self.api.log.info("Test checkin --> invalid patron ok!") def test_52_checkin_checkout_not_found(self): - checkout_not_found_data, checkout_not_found_json = self.sample_json("error_checkout_not_found.json") + checkout_not_found_data, checkout_not_found_json = self.sample_json( + "error_checkout_not_found.json" + ) self.api.queue_response(404, content=checkout_not_found_json) - pytest.raises(NotCheckedOut, self.api.checkin, self.patron, self.PIN, self.licensepool) + pytest.raises( + NotCheckedOut, self.api.checkin, self.patron, self.PIN, self.licensepool + ) - self.api.log.info('Test checkin --> invalid checkout ok!') + self.api.log.info("Test checkin --> invalid checkout ok!") def test_53_checkin(self): checkout_data, checkout_json = self.sample_json("patron_checkouts.json") @@ -513,27 +563,38 @@ def test_53_checkin(self): self.api.queue_response(200, content=checkin_json) response = self.api.checkin(self.patron, self.PIN, self.licensepool) - assert response.status_code == 200, \ - "Response code != 200, cannot perform checkin for record: " \ - + self.licensepool.identifier.identifier + " patron: " + self.patron.authorization_identifier + assert response.status_code == 200, ( + "Response code != 200, cannot perform checkin for record: " + + self.licensepool.identifier.identifier + + " patron: " + + self.patron.authorization_identifier + ) checkout_returned = response.json() assert checkout_returned - assert '4318' == checkout_returned['id'] - self.api.log.info('Checkout returned: %s' % checkout_returned['id']) + assert "4318" == checkout_returned["id"] + self.api.log.info("Checkout returned: %s" % checkout_returned["id"]) ################# # Patron Activity tests ################# def test_61_return_hold_patron_not_found(self): - patron_not_found_data, patron_not_found_json = self.sample_json("error_patron_not_found.json") + patron_not_found_data, patron_not_found_json = self.sample_json( + "error_patron_not_found.json" + ) self.api.queue_response(404, content=patron_not_found_json) - pytest.raises(PatronNotFoundOnRemote, self.api.release_hold, self.patron, self.PIN, self.licensepool) + pytest.raises( + PatronNotFoundOnRemote, + self.api.release_hold, + self.patron, + self.PIN, + self.licensepool, + ) - self.api.log.info('Test release hold --> invalid patron ok!') + self.api.log.info("Test release hold --> invalid patron ok!") def test_62_return_hold_not_found(self): holds_data, holds_json = self.sample_json("patron_holds.json") @@ -543,31 +604,39 @@ def test_62_return_hold_not_found(self): self.api.queue_response(404, content=checkin_json) response = self.api.release_hold(self.patron, self.PIN, self.licensepool) - assert response == True, \ - "Cannot release hold, response false " \ - + self.licensepool.identifier.identifier + " patron: " + self.patron.authorization_identifier + assert response == True, ( + "Cannot release hold, response false " + + self.licensepool.identifier.identifier + + " patron: " + + self.patron.authorization_identifier + ) - self.api.log.info('Hold returned: %s' % self.licensepool.identifier.identifier) + self.api.log.info("Hold returned: %s" % self.licensepool.identifier.identifier) def test_63_return_hold(self): holds_data, holds_json = self.sample_json("patron_holds.json") self.api.queue_response(200, content=holds_json) - release_hold_ok_data, release_hold_ok_json = self.sample_json("release_hold_ok.json") + release_hold_ok_data, release_hold_ok_json = self.sample_json( + "release_hold_ok.json" + ) self.api.queue_response(200, content=release_hold_ok_json) response = self.api.release_hold(self.patron, self.PIN, self.licensepool) - assert response == True, \ - "Cannot release hold, response false " \ - + self.licensepool.identifier.identifier + " patron: " + self.patron.authorization_identifier + assert response == True, ( + "Cannot release hold, response false " + + self.licensepool.identifier.identifier + + " patron: " + + self.patron.authorization_identifier + ) - self.api.log.info('Hold returned: %s' % self.licensepool.identifier.identifier) + self.api.log.info("Hold returned: %s" % self.licensepool.identifier.identifier) class TestOdiloDiscoveryAPI(OdiloAPITest): - def test_run(self): """Verify that running the OdiloCirculationMonitor calls all_ids().""" + class Mock(OdiloCirculationMonitor): def all_ids(self, modification_date=None): self.called_with = modification_date @@ -589,56 +658,65 @@ def all_ids(self, modification_date=None): # modification date five minutes earlier than the completion # of the last run. monitor.run() - expect = completed-monitor.OVERLAP - assert (expect-monitor.called_with).total_seconds() < 2 + expect = completed - monitor.OVERLAP + assert (expect - monitor.called_with).total_seconds() < 2 def test_all_ids_with_date(self): # TODO: This tests that all_ids doesn't crash when you pass in # a date. It doesn't test anything about all_ids except the # return value. - monitor = OdiloCirculationMonitor(self._db, self.collection, api_class=MockOdiloAPI) - assert monitor, 'Monitor null !!' - assert ExternalIntegration.ODILO == monitor.protocol, 'Wat??' + monitor = OdiloCirculationMonitor( + self._db, self.collection, api_class=MockOdiloAPI + ) + assert monitor, "Monitor null !!" + assert ExternalIntegration.ODILO == monitor.protocol, "Wat??" - records_metadata_data, records_metadata_json = self.sample_json("records_metadata.json") + records_metadata_data, records_metadata_json = self.sample_json( + "records_metadata.json" + ) monitor.api.queue_response(200, content=records_metadata_data) availability_data = self.sample_data("record_availability.json") for record in records_metadata_json: monitor.api.queue_response(200, content=availability_data) - monitor.api.queue_response(200, content='[]') # No more resources retrieved + monitor.api.queue_response(200, content="[]") # No more resources retrieved timestamp = TimestampData(start=datetime_utc(2017, 9, 1)) updated, new = monitor.all_ids(None) assert 10 == updated assert 10 == new - self.api.log.info('Odilo circulation monitor with date finished ok!!') + self.api.log.info("Odilo circulation monitor with date finished ok!!") def test_all_ids_without_date(self): # TODO: This tests that all_ids doesn't crash when you pass in # an empty date. It doesn't test anything about all_ids except the # return value. - monitor = OdiloCirculationMonitor(self._db, self.collection, api_class=MockOdiloAPI) - assert monitor, 'Monitor null !!' - assert ExternalIntegration.ODILO == monitor.protocol, 'Wat??' + monitor = OdiloCirculationMonitor( + self._db, self.collection, api_class=MockOdiloAPI + ) + assert monitor, "Monitor null !!" + assert ExternalIntegration.ODILO == monitor.protocol, "Wat??" - records_metadata_data, records_metadata_json = self.sample_json("records_metadata.json") + records_metadata_data, records_metadata_json = self.sample_json( + "records_metadata.json" + ) monitor.api.queue_response(200, content=records_metadata_data) availability_data = self.sample_data("record_availability.json") for record in records_metadata_json: monitor.api.queue_response(200, content=availability_data) - monitor.api.queue_response(200, content='[]') # No more resources retrieved + monitor.api.queue_response(200, content="[]") # No more resources retrieved updated, new = monitor.all_ids(datetime_utc(2017, 9, 1)) assert 10 == updated assert 10 == new - self.api.log.info('Odilo circulation monitor without date finished ok!!') + self.api.log.info("Odilo circulation monitor without date finished ok!!") + class TestOdiloBibliographicCoverageProvider(OdiloAPITest): def setup_method(self): @@ -654,12 +732,12 @@ def test_process_item(self): availability, availability_json = self.sample_json("odilo_availability.json") self.api.queue_response(200, content=availability) - identifier, made_new = self.provider.process_item('00010982') + identifier, made_new = self.provider.process_item("00010982") # Check that the Identifier returned has the right .type and .identifier. assert identifier, "Problem while testing process item !!!" assert identifier.type == Identifier.ODILO_ID - assert identifier.identifier == '00010982' + assert identifier.identifier == "00010982" # Check that metadata and availability information were imported properly [pool] = identifier.licensed_through @@ -671,30 +749,48 @@ def test_process_item(self): assert 1 == pool.licenses_reserved names = [x.delivery_mechanism.name for x in pool.delivery_mechanisms] - assert (sorted([Representation.EPUB_MEDIA_TYPE + ' (' + DeliveryMechanism.ADOBE_DRM + ')', - Representation.TEXT_HTML_MEDIA_TYPE + ' (' + DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE + ')']) == - sorted(names)) + assert ( + sorted( + [ + Representation.EPUB_MEDIA_TYPE + + " (" + + DeliveryMechanism.ADOBE_DRM + + ")", + Representation.TEXT_HTML_MEDIA_TYPE + + " (" + + DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE + + ")", + ] + ) + == sorted(names) + ) # Check that handle_success was called --> A Work was created and made presentation ready. assert True == pool.work.presentation_ready - self.api.log.info('Testing process item finished ok !!') + self.api.log.info("Testing process item finished ok !!") def test_process_inactive_item(self): - record_metadata, record_metadata_json = self.sample_json("odilo_metadata_inactive.json") + record_metadata, record_metadata_json = self.sample_json( + "odilo_metadata_inactive.json" + ) self.api.queue_response(200, content=record_metadata_json) - availability, availability_json = self.sample_json("odilo_availability_inactive.json") + availability, availability_json = self.sample_json( + "odilo_availability_inactive.json" + ) self.api.queue_response(200, content=availability) - identifier, made_new = self.provider.process_item('00011135') + identifier, made_new = self.provider.process_item("00011135") # Check that the Identifier returned has the right .type and .identifier. assert identifier, "Problem while testing process inactive item !!!" assert identifier.type == Identifier.ODILO_ID - assert identifier.identifier == '00011135' + assert identifier.identifier == "00011135" [pool] = identifier.licensed_through - assert "!Tention A Story of Boy-Life during the Peninsular War" == pool.work.title + assert ( + "!Tention A Story of Boy-Life during the Peninsular War" == pool.work.title + ) # Check work not available assert 0 == pool.licenses_owned @@ -702,7 +798,8 @@ def test_process_inactive_item(self): assert True == pool.work.presentation_ready - self.api.log.info('Testing process item inactive finished ok !!') + self.api.log.info("Testing process item inactive finished ok !!") + class TestOdiloRepresentationExtractor(OdiloAPITest): def test_book_info_with_metadata(self): @@ -710,13 +807,21 @@ def test_book_info_with_metadata(self): raw, book_json = self.sample_json("odilo_metadata.json") raw, availability = self.sample_json("odilo_availability.json") - metadata, active = OdiloRepresentationExtractor.record_info_to_metadata(book_json, availability) + metadata, active = OdiloRepresentationExtractor.record_info_to_metadata( + book_json, availability + ) assert "Busy Brownies" == metadata.title - assert " (The Classic Fantasy Literature of Elves for Children)" == metadata.subtitle + assert ( + " (The Classic Fantasy Literature of Elves for Children)" + == metadata.subtitle + ) assert "eng" == metadata.language assert Edition.BOOK_MEDIUM == metadata.medium - assert "The Classic Fantasy Literature for Children written in 1896 retold for Elves adventure." == metadata.series + assert ( + "The Classic Fantasy Literature for Children written in 1896 retold for Elves adventure." + == metadata.series + ) assert "1" == metadata.series_position assert "ANBOCO" == metadata.publisher assert 2013 == metadata.published.year @@ -726,26 +831,27 @@ def test_book_info_with_metadata(self): assert 3 == metadata.data_source_last_updated.month assert 10 == metadata.data_source_last_updated.day # Related IDs. - assert ((Identifier.ODILO_ID, '00010982') == - (metadata.primary_identifier.type, metadata.primary_identifier.identifier)) + assert (Identifier.ODILO_ID, "00010982") == ( + metadata.primary_identifier.type, + metadata.primary_identifier.identifier, + ) ids = [(x.type, x.identifier) for x in metadata.identifiers] - assert ( - [ - (Identifier.ISBN, '9783736418837'), - (Identifier.ODILO_ID, '00010982') - ] == - sorted(ids)) + assert [ + (Identifier.ISBN, "9783736418837"), + (Identifier.ODILO_ID, "00010982"), + ] == sorted(ids) subjects = sorted(metadata.subjects, key=lambda x: x.identifier) weight = Classification.TRUSTED_DISTRIBUTOR_WEIGHT - assert ([('Children', 'tag', weight), - ('Classics', 'tag', weight), - ('FIC004000', 'BISAC', weight), - ('Fantasy', 'tag', weight), - ('K-12', 'Grade level', weight), - ('LIT009000', 'BISAC', weight), - ('YAF019020', 'BISAC', weight)] == - [(x.identifier, x.type, x.weight) for x in subjects]) + assert [ + ("Children", "tag", weight), + ("Classics", "tag", weight), + ("FIC004000", "BISAC", weight), + ("Fantasy", "tag", weight), + ("K-12", "Grade level", weight), + ("LIT009000", "BISAC", weight), + ("YAF019020", "BISAC", weight), + ] == [(x.identifier, x.type, x.weight) for x in subjects] [author] = metadata.contributors assert "Veale, E." == author.sort_name @@ -753,29 +859,36 @@ def test_book_info_with_metadata(self): assert [Contributor.AUTHOR_ROLE] == author.roles # Available formats. - [acsm_epub, ebook_streaming] = sorted(metadata.circulation.formats, key=lambda x: x.content_type) + [acsm_epub, ebook_streaming] = sorted( + metadata.circulation.formats, key=lambda x: x.content_type + ) assert Representation.EPUB_MEDIA_TYPE == acsm_epub.content_type assert DeliveryMechanism.ADOBE_DRM == acsm_epub.drm_scheme assert Representation.TEXT_HTML_MEDIA_TYPE == ebook_streaming.content_type - assert DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE == ebook_streaming.drm_scheme + assert ( + DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE == ebook_streaming.drm_scheme + ) # Links to various resources. image, thumbnail, description = sorted(metadata.links, key=lambda x: x.rel) assert Hyperlink.IMAGE == image.rel assert ( - 'http://pruebasotk.odilotk.es/public/OdiloPlace_eduDistUS/pg54159.jpg' == - image.href) + "http://pruebasotk.odilotk.es/public/OdiloPlace_eduDistUS/pg54159.jpg" + == image.href + ) assert Hyperlink.THUMBNAIL_IMAGE == thumbnail.rel assert ( - 'http://pruebasotk.odilotk.es/public/OdiloPlace_eduDistUS/pg54159_225x318.jpg' == - thumbnail.href) + "http://pruebasotk.odilotk.es/public/OdiloPlace_eduDistUS/pg54159_225x318.jpg" + == thumbnail.href + ) assert Hyperlink.DESCRIPTION == description.rel assert description.content.startswith( - "All the Brownies had promised to help, and when a Brownie undertakes a thing he works as busily") + "All the Brownies had promised to help, and when a Brownie undertakes a thing he works as busily" + ) circulation = metadata.circulation assert 2 == circulation.licenses_owned @@ -783,15 +896,15 @@ def test_book_info_with_metadata(self): assert 2 == circulation.patrons_in_hold_queue assert 1 == circulation.licenses_reserved - self.api.log.info('Testing book info with metadata finished ok !!') + self.api.log.info("Testing book info with metadata finished ok !!") def test_book_info_missing_metadata(self): # Verify that we properly handle missing metadata from Odilo. raw, book_json = self.sample_json("odilo_metadata.json") # This was seen in real data. - book_json['series'] = ' ' - book_json['seriesPosition'] = ' ' + book_json["series"] = " " + book_json["seriesPosition"] = " " metadata, active = OdiloRepresentationExtractor.record_info_to_metadata( book_json, {} @@ -806,6 +919,8 @@ def test_default_language_spanish(self): """ raw, book_json = self.sample_json("odilo_metadata.json") raw, availability = self.sample_json("odilo_availability.json") - del book_json['language'] - metadata, active = OdiloRepresentationExtractor.record_info_to_metadata(book_json, availability) - assert 'spa' == metadata.language + del book_json["language"] + metadata, active = OdiloRepresentationExtractor.record_info_to_metadata( + book_json, availability + ) + assert "spa" == metadata.language diff --git a/tests/test_odl.py b/tests/test_odl.py index 09d8a71f27..9b1d15a851 100644 --- a/tests/test_odl.py +++ b/tests/test_odl.py @@ -7,6 +7,11 @@ import dateutil import pytest +from freezegun import freeze_time +from jinja2 import Environment, FileSystemLoader, select_autoescape +from mock import MagicMock, PropertyMock, patch +from parameterized import parameterized + from api.circulation_exceptions import * from api.odl import ( ODLAPI, @@ -18,11 +23,6 @@ SharedODLAPI, SharedODLImporter, ) -from freezegun import freeze_time -from jinja2 import Environment, FileSystemLoader, select_autoescape -from mock import MagicMock, PropertyMock, patch -from parameterized import parameterized - from core.model import ( Collection, ConfigurationSetting, @@ -58,19 +58,18 @@ def get_data(cls, filename): class TestODLAPI(DatabaseTest, BaseODLTest): - def setup_method(self): super(TestODLAPI, self).setup_method() self.collection = MockODLAPI.mock_collection(self._db) self.collection.external_integration.set_setting( - Collection.DATA_SOURCE_NAME_SETTING, - "Feedbooks" + Collection.DATA_SOURCE_NAME_SETTING, "Feedbooks" ) self.api = MockODLAPI(self._db, self.collection) self.work = self._work(with_license_pool=True, collection=self.collection) self.pool = self.work.license_pools[0] self.license = self._license( - self.pool, checkout_url="https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}", + self.pool, + checkout_url="https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}", concurrent_checkouts=1, ) self.patron = self._patron() @@ -103,7 +102,7 @@ def test_get_license_status_document_success(self): # The expiration time passed to the server is associated with # the UTC time zone. - assert expires.endswith('+00:00') + assert expires.endswith("+00:00") expires = dateutil.parser.parse(expires) assert expires.tzinfo == dateutil.tz.tz.tzutc() @@ -112,8 +111,11 @@ def test_get_license_status_document_success(self): assert expires < after_expiration notification_url = urllib.parse.unquote_plus(params.get("notification_url")[0]) - assert ("http://odl_notify?library_short_name=%s&loan_id=%s" % (self._default_library.short_name, loan.id) == - notification_url) + assert ( + "http://odl_notify?library_short_name=%s&loan_id=%s" + % (self._default_library.short_name, loan.id) + == notification_url + ) # With an existing loan. loan, ignore = self.license.loan_to(self.patron) @@ -129,12 +131,16 @@ def test_get_license_status_document_errors(self): self.api.queue_response(200, content="not json") pytest.raises( - BadResponseException, self.api.get_license_status_document, loan, + BadResponseException, + self.api.get_license_status_document, + loan, ) self.api.queue_response(200, content=json.dumps(dict(status="unknown"))) pytest.raises( - BadResponseException, self.api.get_license_status_document, loan, + BadResponseException, + self.api.get_license_status_document, + loan, ) def test_checkin_success(self): @@ -146,16 +152,22 @@ def test_checkin_success(self): loan.end = utc_now() + datetime.timedelta(days=3) # The patron returns the book successfully. - lsd = json.dumps({ - "status": "ready", - "links": [{ - "rel": "return", - "href": "http://return", - }], - }) - returned_lsd = json.dumps({ - "status": "returned", - }) + lsd = json.dumps( + { + "status": "ready", + "links": [ + { + "rel": "return", + "href": "http://return", + } + ], + } + ) + returned_lsd = json.dumps( + { + "status": "returned", + } + ) self.api.queue_response(200, content=lsd) self.api.queue_response(200) @@ -182,19 +194,27 @@ def test_checkin_success_with_holds_queue(self): # Another patron has the book on hold. patron_with_hold = self._patron() self.pool.patrons_in_hold_queue = 1 - hold, ignore = self.pool.on_hold_to(patron_with_hold, start=utc_now(), end=None, position=1) + hold, ignore = self.pool.on_hold_to( + patron_with_hold, start=utc_now(), end=None, position=1 + ) # The first patron returns the book successfully. - lsd = json.dumps({ - "status": "ready", - "links": [{ - "rel": "return", - "href": "http://return", - }], - }) - returned_lsd = json.dumps({ - "status": "returned", - }) + lsd = json.dumps( + { + "status": "ready", + "links": [ + { + "rel": "return", + "href": "http://return", + } + ], + } + ) + returned_lsd = json.dumps( + { + "status": "returned", + } + ) self.api.queue_response(200, content=lsd) self.api.queue_response(200) @@ -220,9 +240,11 @@ def test_checkin_already_fulfilled(self): loan.external_identifier = self._str loan.end = utc_now() + datetime.timedelta(days=3) - lsd = json.dumps({ - "status": "active", - }) + lsd = json.dumps( + { + "status": "active", + } + ) self.api.queue_response(200, content=lsd) # Checking in the book silently does nothing. @@ -234,8 +256,11 @@ def test_checkin_already_fulfilled(self): def test_checkin_not_checked_out(self): # Not checked out locally. pytest.raises( - NotCheckedOut, self.api.checkin, - self.patron, "pin", self.pool, + NotCheckedOut, + self.api.checkin, + self.patron, + "pin", + self.pool, ) # Not checked out according to the distributor. @@ -243,14 +268,19 @@ def test_checkin_not_checked_out(self): loan.external_identifier = self._str loan.end = utc_now() + datetime.timedelta(days=3) - lsd = json.dumps({ - "status": "revoked", - }) + lsd = json.dumps( + { + "status": "revoked", + } + ) self.api.queue_response(200, content=lsd) pytest.raises( - NotCheckedOut, self.api.checkin, - self.patron, "pin", self.pool, + NotCheckedOut, + self.api.checkin, + self.patron, + "pin", + self.pool, ) def test_checkin_cannot_return(self): @@ -259,9 +289,11 @@ def test_checkin_cannot_return(self): loan.external_identifier = self._str loan.end = utc_now() + datetime.timedelta(days=3) - lsd = json.dumps({ - "status": "ready", - }) + lsd = json.dumps( + { + "status": "ready", + } + ) self.api.queue_response(200, content=lsd) # Checking in silently does nothing. @@ -269,13 +301,17 @@ def test_checkin_cannot_return(self): # If the return link doesn't change the status, it still # silently ignores the problem. - lsd = json.dumps({ - "status": "ready", - "links": [{ - "rel": "return", - "href": "http://return", - }], - }) + lsd = json.dumps( + { + "status": "ready", + "links": [ + { + "rel": "return", + "href": "http://return", + } + ], + } + ) self.api.queue_response(200, content=lsd) self.api.queue_response(200, content="Deleted") @@ -291,19 +327,23 @@ def test_checkout_success(self): # A patron checks out the book successfully. loan_url = self._str - lsd = json.dumps({ - "status": "ready", - "potential_rights": { - "end": "3017-10-21T11:12:13Z" - }, - "links": [{ - "rel": "self", - "href": loan_url, - }], - }) + lsd = json.dumps( + { + "status": "ready", + "potential_rights": {"end": "3017-10-21T11:12:13Z"}, + "links": [ + { + "rel": "self", + "href": loan_url, + } + ], + } + ) self.api.queue_response(200, content=lsd) - loan = self.api.checkout(self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE) + loan = self.api.checkout( + self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE + ) assert self.collection == loan.collection(self._db) assert self.pool.data_source.name == loan.data_source_name assert self.pool.identifier.type == loan.identifier_type @@ -333,25 +373,31 @@ def test_checkout_success_with_hold(self): self.pool.licenses_reserved = 1 self.pool.patrons_in_hold_queue = 1 self.license.remaining_checkouts = 5 - self.pool.on_hold_to(self.patron, start=utc_now() - datetime.timedelta(days=1), position=0) + self.pool.on_hold_to( + self.patron, start=utc_now() - datetime.timedelta(days=1), position=0 + ) # The patron checks out the book. loan_url = self._str - lsd = json.dumps({ - "status": "ready", - "potential_rights": { - "end": "3017-10-21T11:12:13Z" - }, - "links": [{ - "rel": "self", - "href": loan_url, - }], - }) + lsd = json.dumps( + { + "status": "ready", + "potential_rights": {"end": "3017-10-21T11:12:13Z"}, + "links": [ + { + "rel": "self", + "href": loan_url, + } + ], + } + ) self.api.queue_response(200, content=lsd) # The patron gets a loan successfully. - loan = self.api.checkout(self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE) + loan = self.api.checkout( + self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE + ) assert self.collection == loan.collection(self._db) assert self.pool.data_source.name == loan.data_source_name assert self.pool.identifier.type == loan.identifier_type @@ -381,8 +427,12 @@ def test_checkout_already_checked_out(self): existing_loan.end = utc_now() + datetime.timedelta(days=3) pytest.raises( - AlreadyCheckedOut, self.api.checkout, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + AlreadyCheckedOut, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) assert 1 == self._db.query(Loan).count() @@ -391,12 +441,18 @@ def test_checkout_expired_hold(self): # The patron was at the beginning of the hold queue, but the hold already expired. self.pool.licenses_owned = 1 yesterday = utc_now() - datetime.timedelta(days=1) - hold, ignore = self.pool.on_hold_to(self.patron, start=yesterday, end=yesterday, position=0) + hold, ignore = self.pool.on_hold_to( + self.patron, start=yesterday, end=yesterday, position=0 + ) other_hold, ignore = self.pool.on_hold_to(self._patron(), start=utc_now()) pytest.raises( - NoAvailableCopies, self.api.checkout, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + NoAvailableCopies, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) def test_checkout_no_available_copies(self): @@ -406,8 +462,12 @@ def test_checkout_no_available_copies(self): existing_loan, ignore = self.license.loan_to(self._patron()) pytest.raises( - NoAvailableCopies, self.api.checkout, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + NoAvailableCopies, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) assert 1 == self._db.query(Loan).count() @@ -419,11 +479,17 @@ def test_checkout_no_available_copies(self): last_week = now - datetime.timedelta(weeks=1) # A different patron has the only copy reserved. - other_patron_hold, ignore = self.pool.on_hold_to(self._patron(), position=0, start=last_week) + other_patron_hold, ignore = self.pool.on_hold_to( + self._patron(), position=0, start=last_week + ) pytest.raises( - NoAvailableCopies, self.api.checkout, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + NoAvailableCopies, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) assert 0 == self._db.query(Loan).count() @@ -432,8 +498,12 @@ def test_checkout_no_available_copies(self): hold, ignore = self.pool.on_hold_to(self._patron(), position=1, start=yesterday) pytest.raises( - NoAvailableCopies, self.api.checkout, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + NoAvailableCopies, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) assert 0 == self._db.query(Loan).count() @@ -443,20 +513,27 @@ def test_checkout_no_available_copies(self): hold.end = yesterday pytest.raises( - NoAvailableCopies, self.api.checkout, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + NoAvailableCopies, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) assert 0 == self._db.query(Loan).count() - def test_checkout_no_licenses(self): self.pool.licenses_owned = 0 self.license.remaining_checkouts = 0 pytest.raises( - NoLicenses, self.api.checkout, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + NoLicenses, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) assert 0 == self._db.query(Loan).count() @@ -469,8 +546,12 @@ def test_checkout_when_all_licenses_expired(self): self.license.expires = utc_now() - datetime.timedelta(weeks=1) pytest.raises( - NoLicenses, self.api.checkout, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + NoLicenses, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) # license expired by no remaining checkouts @@ -480,35 +561,49 @@ def test_checkout_when_all_licenses_expired(self): self.license.expires = utc_now() + datetime.timedelta(weeks=1) pytest.raises( - NoLicenses, self.api.checkout, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + NoLicenses, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) def test_checkout_cannot_loan(self): - lsd = json.dumps({ - "status": "revoked", - }) + lsd = json.dumps( + { + "status": "revoked", + } + ) self.api.queue_response(200, content=lsd) pytest.raises( - CannotLoan, self.api.checkout, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + CannotLoan, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) assert 0 == self._db.query(Loan).count() # No external identifier. - lsd = json.dumps({ - "status": "ready", - "potential_rights": { - "end": "2017-10-21T11:12:13Z" - }, - }) + lsd = json.dumps( + { + "status": "ready", + "potential_rights": {"end": "2017-10-21T11:12:13Z"}, + } + ) self.api.queue_response(200, content=lsd) pytest.raises( - CannotLoan, self.api.checkout, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + CannotLoan, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) assert 0 == self._db.query(Loan).count() @@ -519,20 +614,24 @@ def test_fulfill_success_license(self): loan.external_identifier = self._str loan.end = utc_now() + datetime.timedelta(days=3) - lsd = json.dumps({ - "status": "ready", - "potential_rights": { - "end": "2017-10-21T11:12:13Z" - }, - "links": [{ - "rel": "license", - "href": "http://acsm", - "type": DeliveryMechanism.ADOBE_DRM, - }], - }) + lsd = json.dumps( + { + "status": "ready", + "potential_rights": {"end": "2017-10-21T11:12:13Z"}, + "links": [ + { + "rel": "license", + "href": "http://acsm", + "type": DeliveryMechanism.ADOBE_DRM, + } + ], + } + ) self.api.queue_response(200, content=lsd) - fulfillment = self.api.fulfill(self.patron, "pin", self.pool, DeliveryMechanism.ADOBE_DRM) + fulfillment = self.api.fulfill( + self.patron, "pin", self.pool, DeliveryMechanism.ADOBE_DRM + ) assert self.collection == fulfillment.collection(self._db) assert self.pool.data_source.name == fulfillment.data_source_name assert self.pool.identifier.type == fulfillment.identifier_type @@ -550,17 +649,19 @@ def test_fulfill_success_manifest(self): audiobook = MediaTypes.AUDIOBOOK_MANIFEST_MEDIA_TYPE - lsd = json.dumps({ - "status": "ready", - "potential_rights": { - "end": "2017-10-21T11:12:13Z" - }, - "links": [{ - "rel": "manifest", - "href": "http://manifest", - "type": audiobook, - }], - }) + lsd = json.dumps( + { + "status": "ready", + "potential_rights": {"end": "2017-10-21T11:12:13Z"}, + "links": [ + { + "rel": "manifest", + "href": "http://manifest", + "type": audiobook, + } + ], + } + ) self.api.queue_response(200, content=lsd) fulfillment = self.api.fulfill(self.patron, "pin", self.pool, audiobook) @@ -572,7 +673,6 @@ def test_fulfill_success_manifest(self): assert "http://manifest" == fulfillment.content_link assert audiobook == fulfillment.content_type - def test_fulfill_cannot_fulfill(self): self.pool.licenses_owned = 7 self.pool.licenses_available = 6 @@ -581,14 +681,20 @@ def test_fulfill_cannot_fulfill(self): loan.external_identifier = self._str loan.end = utc_now() + datetime.timedelta(days=3) - lsd = json.dumps({ - "status": "revoked", - }) + lsd = json.dumps( + { + "status": "revoked", + } + ) self.api.queue_response(200, content=lsd) pytest.raises( - CannotFulfill, self.api.fulfill, - self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE, + CannotFulfill, + self.api.fulfill, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, ) # The pool's availability has been updated and the local @@ -625,7 +731,9 @@ def test_count_holds_before(self): assert 1 == self.api._count_holds_before(hold) for i in range(3): - self.pool.on_hold_to(self._patron(), start=yesterday, end=tomorrow, position=1) + self.pool.on_hold_to( + self._patron(), start=yesterday, end=tomorrow, position=1 + ) assert 4 == self.api._count_holds_before(hold) def test_update_hold_end_date(self): @@ -732,7 +840,12 @@ def test_update_hold_end_date(self): self.pool.loan_to(self._patron(), end=next_week + datetime.timedelta(days=2)) self.pool.licenses_reserved = 3 for i in range(3): - self.pool.on_hold_to(self._patron(), start=last_week + datetime.timedelta(days=i), end=next_week + datetime.timedelta(days=i), position=0) + self.pool.on_hold_to( + self._patron(), + start=last_week + datetime.timedelta(days=i), + end=next_week + datetime.timedelta(days=i), + position=0, + ) for i in range(5): self.pool.on_hold_to(self._patron(), start=yesterday) self.api._update_hold_end_date(hold) @@ -833,7 +946,9 @@ def test_update_hold_queue(self): # If there are holds, a license will get reserved for the next hold # and its end date will be set. hold, ignore = self.pool.on_hold_to(self.patron, start=utc_now(), position=1) - later_hold, ignore = self.pool.on_hold_to(self._patron(), start=utc_now() + datetime.timedelta(days=1), position=2) + later_hold, ignore = self.pool.on_hold_to( + self._patron(), start=utc_now() + datetime.timedelta(days=1), position=2 + ) self.api.update_hold_queue(self.pool) # The pool's licenses were updated. @@ -843,7 +958,9 @@ def test_update_hold_queue(self): # And the first hold changed. assert 0 == hold.position - assert hold.end - utc_now() - datetime.timedelta(days=3) < datetime.timedelta(hours=1) + assert hold.end - utc_now() - datetime.timedelta(days=3) < datetime.timedelta( + hours=1 + ) # The later hold is the same. assert 2 == later_hold.position @@ -858,7 +975,9 @@ def test_update_hold_queue(self): assert 2 == self.pool.licenses_reserved assert 2 == self.pool.patrons_in_hold_queue assert 0 == later_hold.position - assert later_hold.end - utc_now() - datetime.timedelta(days=3) < datetime.timedelta(hours=1) + assert later_hold.end - utc_now() - datetime.timedelta( + days=3 + ) < datetime.timedelta(hours=1) # Now there are no more holds. If we add another license, # it ends up being available. @@ -876,10 +995,16 @@ def test_update_hold_queue(self): loans = [] holds = [] for i in range(3): - loan, ignore = self.license.loan_to(self._patron(), end=utc_now() + datetime.timedelta(days=1)) + loan, ignore = self.license.loan_to( + self._patron(), end=utc_now() + datetime.timedelta(days=1) + ) loans.append(loan) for i in range(3): - hold, ignore = self.pool.on_hold_to(self._patron(), start=utc_now() - datetime.timedelta(days=3-i), position=i+1) + hold, ignore = self.pool.on_hold_to( + self._patron(), + start=utc_now() - datetime.timedelta(days=3 - i), + position=i + 1, + ) holds.append(hold) self.pool.licenses_owned = 5 self.pool.licenses_available = 0 @@ -892,8 +1017,12 @@ def test_update_hold_queue(self): assert 0 == holds[0].position assert 0 == holds[1].position assert 3 == holds[2].position - assert holds[0].end - utc_now() - datetime.timedelta(days=3) < datetime.timedelta(hours=1) - assert holds[1].end - utc_now() - datetime.timedelta(days=3) < datetime.timedelta(hours=1) + assert holds[0].end - utc_now() - datetime.timedelta( + days=3 + ) < datetime.timedelta(hours=1) + assert holds[1].end - utc_now() - datetime.timedelta( + days=3 + ) < datetime.timedelta(hours=1) # If there are more licenses that change than holds, some of them become available. loans[0].end = utc_now() - datetime.timedelta(days=1) @@ -904,14 +1033,18 @@ def test_update_hold_queue(self): assert 3 == self.pool.patrons_in_hold_queue for hold in holds: assert 0 == hold.position - assert hold.end - utc_now() - datetime.timedelta(days=3) < datetime.timedelta(hours=1) + assert hold.end - utc_now() - datetime.timedelta( + days=3 + ) < datetime.timedelta(hours=1) def test_place_hold_success(self): tomorrow = utc_now() + datetime.timedelta(days=1) self.pool.licenses_owned = 1 self.license.loan_to(self._patron(), end=tomorrow) - hold = self.api.place_hold(self.patron, "pin", self.pool, "notifications@librarysimplified.org") + hold = self.api.place_hold( + self.patron, "pin", self.pool, "notifications@librarysimplified.org" + ) assert 1 == self.pool.patrons_in_hold_queue assert self.collection == hold.collection(self._db) @@ -927,15 +1060,23 @@ def test_place_hold_success(self): def test_place_hold_already_on_hold(self): self.pool.on_hold_to(self.patron) pytest.raises( - AlreadyOnHold, self.api.place_hold, - self.patron, "pin", self.pool, "notifications@librarysimplified.org", + AlreadyOnHold, + self.api.place_hold, + self.patron, + "pin", + self.pool, + "notifications@librarysimplified.org", ) def test_place_hold_currently_available(self): self.pool.licenses_owned = 1 pytest.raises( - CurrentlyAvailable, self.api.place_hold, - self.patron, "pin", self.pool, "notifications@librarysimplified.org", + CurrentlyAvailable, + self.api.place_hold, + self.patron, + "pin", + self.pool, + "notifications@librarysimplified.org", ) def test_release_hold_success(self): @@ -970,8 +1111,11 @@ def test_release_hold_success(self): def test_release_hold_not_on_hold(self): pytest.raises( - NotOnHold, self.api.release_hold, - self.patron, "pin", self.pool, + NotOnHold, + self.api.release_hold, + self.patron, + "pin", + self.pool, ) def test_patron_activity(self): @@ -1030,7 +1174,9 @@ def test_patron_activity(self): # One hold. pool2.licenses_owned = 1 - other_patron_loan, ignore = license2.loan_to(self._patron(), end=utc_now() + datetime.timedelta(days=1)) + other_patron_loan, ignore = license2.loan_to( + self._patron(), end=utc_now() + datetime.timedelta(days=1) + ) hold, ignore = pool2.on_hold_to(self.patron) hold.start = utc_now() - datetime.timedelta(days=2) hold.end = hold.start + datetime.timedelta(days=3) @@ -1119,16 +1265,18 @@ def test_checkout_from_external_library(self): # An integration client checks out the book successfully. loan_url = self._str - lsd = json.dumps({ - "status": "ready", - "potential_rights": { - "end": "3017-10-21T11:12:13Z" - }, - "links": [{ - "rel": "self", - "href": loan_url, - }], - }) + lsd = json.dumps( + { + "status": "ready", + "potential_rights": {"end": "3017-10-21T11:12:13Z"}, + "links": [ + { + "rel": "self", + "href": loan_url, + } + ], + } + ) self.api.queue_response(200, content=lsd) loan = self.api.checkout_to_external_library(self.client, self.pool) @@ -1166,20 +1314,24 @@ def test_checkout_from_external_library_with_hold(self): self.pool.licenses_available = 0 self.pool.licenses_reserved = 1 self.pool.patrons_in_hold_queue = 1 - hold, ignore = self.pool.on_hold_to(self.client, start=utc_now() - datetime.timedelta(days=1), position=0) + hold, ignore = self.pool.on_hold_to( + self.client, start=utc_now() - datetime.timedelta(days=1), position=0 + ) # The patron checks out the book. loan_url = self._str - lsd = json.dumps({ - "status": "ready", - "potential_rights": { - "end": "3017-10-21T11:12:13Z" - }, - "links": [{ - "rel": "self", - "href": loan_url, - }], - }) + lsd = json.dumps( + { + "status": "ready", + "potential_rights": {"end": "3017-10-21T11:12:13Z"}, + "links": [ + { + "rel": "self", + "href": loan_url, + } + ], + } + ) self.api.queue_response(200, content=lsd) @@ -1209,16 +1361,22 @@ def test_checkin_from_external_library(self): loan.end = utc_now() + datetime.timedelta(days=3) # The patron returns the book successfully. - lsd = json.dumps({ - "status": "ready", - "links": [{ - "rel": "return", - "href": "http://return", - }], - }) - returned_lsd = json.dumps({ - "status": "returned", - }) + lsd = json.dumps( + { + "status": "ready", + "links": [ + { + "rel": "return", + "href": "http://return", + } + ], + } + ) + returned_lsd = json.dumps( + { + "status": "returned", + } + ) self.api.queue_response(200, content=lsd) self.api.queue_response(200) @@ -1239,17 +1397,19 @@ def test_fulfill_for_external_library(self): loan.external_identifier = self._str loan.end = utc_now() + datetime.timedelta(days=3) - lsd = json.dumps({ - "status": "ready", - "potential_rights": { - "end": "2017-10-21T11:12:13Z" - }, - "links": [{ - "rel": "license", - "href": "http://acsm", - "type": DeliveryMechanism.ADOBE_DRM, - }], - }) + lsd = json.dumps( + { + "status": "ready", + "potential_rights": {"end": "2017-10-21T11:12:13Z"}, + "links": [ + { + "rel": "license", + "href": "http://acsm", + "type": DeliveryMechanism.ADOBE_DRM, + } + ], + } + ) self.api.queue_response(200, content=lsd) fulfillment = self.api.fulfill_for_external_library(self.client, loan, None) @@ -1303,13 +1463,13 @@ def test_import(self): data_source = DataSource.lookup(self._db, "Feedbooks", autocreate=True) collection = MockODLAPI.mock_collection(self._db) collection.external_integration.set_setting( - Collection.DATA_SOURCE_NAME_SETTING, - data_source.name + Collection.DATA_SOURCE_NAME_SETTING, data_source.name ) class MockMetadataClient(object): def canonicalize_author_name(self, identifier, working_display_name): return working_display_name + metadata_client = MockMetadataClient() warrior_time_limited = dict(checkouts=dict(left=52, available=1)) @@ -1319,37 +1479,55 @@ def canonicalize_author_name(self, identifier, working_display_name): midnight_loan_limited_2 = dict(checkouts=dict(left=52, available=1)) everglades_loan = dict(checkouts=dict(left=10, available=5)) poetry_loan = dict(checkouts=dict(left=10, available=5)) - mock_responses = [json.dumps(r) for r in [ - warrior_time_limited, canadianity_loan_limited, canadianity_perpetual, - midnight_loan_limited_1, midnight_loan_limited_2, everglades_loan, poetry_loan - ]] + mock_responses = [ + json.dumps(r) + for r in [ + warrior_time_limited, + canadianity_loan_limited, + canadianity_perpetual, + midnight_loan_limited_1, + midnight_loan_limited_2, + everglades_loan, + poetry_loan, + ] + ] def do_get(url, headers): return 200, {}, mock_responses.pop(0) importer = ODLImporter( - self._db, collection=collection, + self._db, + collection=collection, metadata_client=metadata_client, http_get=do_get, ) - imported_editions, imported_pools, imported_works, failures = ( - importer.import_from_feed(feed) - ) + ( + imported_editions, + imported_pools, + imported_works, + failures, + ) = importer.import_from_feed(feed) self._db.commit() # This importer works the same as the base OPDSImporter, except that # it extracts format information from 'odl:license' tags and creates # LicensePoolDeliveryMechanisms. - # The importer created 6 editions, pools, and works. assert {} == failures assert 6 == len(imported_editions) assert 6 == len(imported_pools) assert 6 == len(imported_works) - [canadianity, everglades, dragons, warrior, blazing, midnight,] = sorted(imported_editions, key=lambda x: x.title) + [ + canadianity, + everglades, + dragons, + warrior, + blazing, + midnight, + ] = sorted(imported_editions, key=lambda x: x.title) assert "The Blazing World" == blazing.title assert "Sun Warrior" == warrior.title assert "Canadianity" == canadianity.title @@ -1358,132 +1536,185 @@ def do_get(url, headers): assert "Rise of the Dragons, Book 1" == dragons.title # This book is open access and has no applicable DRM - [blazing_pool] = [p for p in imported_pools if p.identifier == blazing.primary_identifier] + [blazing_pool] = [ + p for p in imported_pools if p.identifier == blazing.primary_identifier + ] assert True == blazing_pool.open_access [lpdm] = blazing_pool.delivery_mechanisms assert Representation.EPUB_MEDIA_TYPE == lpdm.delivery_mechanism.content_type assert DeliveryMechanism.NO_DRM == lpdm.delivery_mechanism.drm_scheme # # This book has a single 'odl:license' tag. - [warrior_pool] = [p for p in imported_pools if p.identifier == warrior.primary_identifier] + [warrior_pool] = [ + p for p in imported_pools if p.identifier == warrior.primary_identifier + ] assert False == warrior_pool.open_access [lpdm] = warrior_pool.delivery_mechanisms assert Edition.BOOK_MEDIUM == warrior_pool.presentation_edition.medium assert Representation.EPUB_MEDIA_TYPE == lpdm.delivery_mechanism.content_type assert DeliveryMechanism.ADOBE_DRM == lpdm.delivery_mechanism.drm_scheme assert RightsStatus.IN_COPYRIGHT == lpdm.rights_status.uri - assert 52 == warrior_pool.licenses_owned # 52 remaining checkouts in the License Info Document + assert ( + 52 == warrior_pool.licenses_owned + ) # 52 remaining checkouts in the License Info Document assert 1 == warrior_pool.licenses_available [license] = warrior_pool.licenses assert "1" == license.identifier - assert ("https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}" == - license.checkout_url) - assert ("https://license.feedbooks.net/license/status/?uuid=1" == - license.status_url) + assert ( + "https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}" + == license.checkout_url + ) + assert ( + "https://license.feedbooks.net/license/status/?uuid=1" == license.status_url + ) # The original value for 'expires' in the ODL is: # 2019-03-31T03:13:35+02:00 # # As stored in the database, license.expires may not have the # same tzinfo, but it does represent the same point in time. - assert datetime.datetime( - 2019, 3, 31, 3, 13, 35, tzinfo=dateutil.tz.tzoffset("", 3600*2) - ) == license.expires - assert 52 == license.remaining_checkouts # 52 remaining checkouts in the License Info Document + assert ( + datetime.datetime( + 2019, 3, 31, 3, 13, 35, tzinfo=dateutil.tz.tzoffset("", 3600 * 2) + ) + == license.expires + ) + assert ( + 52 == license.remaining_checkouts + ) # 52 remaining checkouts in the License Info Document assert 1 == license.concurrent_checkouts # This item is an open access audiobook. - [everglades_pool] = [p for p in imported_pools if p.identifier == everglades.primary_identifier] + [everglades_pool] = [ + p for p in imported_pools if p.identifier == everglades.primary_identifier + ] assert True == everglades_pool.open_access [lpdm] = everglades_pool.delivery_mechanisms assert Edition.AUDIO_MEDIUM == everglades_pool.presentation_edition.medium - assert Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE == lpdm.delivery_mechanism.content_type + assert ( + Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE + == lpdm.delivery_mechanism.content_type + ) assert DeliveryMechanism.NO_DRM == lpdm.delivery_mechanism.drm_scheme # This is a non-open access audiobook. There is no # tag; the drm_scheme is implied by the value # of . [dragons_pool] = [ - p for p in imported_pools - if p.identifier == dragons.primary_identifier + p for p in imported_pools if p.identifier == dragons.primary_identifier ] assert Edition.AUDIO_MEDIUM == dragons_pool.presentation_edition.medium assert False == dragons_pool.open_access [lpdm] = dragons_pool.delivery_mechanisms - assert Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE == lpdm.delivery_mechanism.content_type - assert DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_DRM == lpdm.delivery_mechanism.drm_scheme + assert ( + Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE + == lpdm.delivery_mechanism.content_type + ) + assert ( + DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_DRM + == lpdm.delivery_mechanism.drm_scheme + ) # This book has two 'odl:license' tags for the same format and drm scheme # (this happens if the library purchases two copies). - [canadianity_pool] = [p for p in imported_pools if p.identifier == canadianity.primary_identifier] + [canadianity_pool] = [ + p for p in imported_pools if p.identifier == canadianity.primary_identifier + ] assert False == canadianity_pool.open_access [lpdm] = canadianity_pool.delivery_mechanisms assert Representation.EPUB_MEDIA_TYPE == lpdm.delivery_mechanism.content_type assert DeliveryMechanism.ADOBE_DRM == lpdm.delivery_mechanism.drm_scheme assert RightsStatus.IN_COPYRIGHT == lpdm.rights_status.uri - assert 40 == canadianity_pool.licenses_owned # 40 remaining checkouts in the License Info Document + assert ( + 40 == canadianity_pool.licenses_owned + ) # 40 remaining checkouts in the License Info Document assert 11 == canadianity_pool.licenses_available - [license1, license2] = sorted(canadianity_pool.licenses, key=lambda x: x.identifier) + [license1, license2] = sorted( + canadianity_pool.licenses, key=lambda x: x.identifier + ) assert "2" == license1.identifier - assert ("https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}" == - license1.checkout_url) - assert ("https://license.feedbooks.net/license/status/?uuid=2" == - license1.status_url) + assert ( + "https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}" + == license1.checkout_url + ) + assert ( + "https://license.feedbooks.net/license/status/?uuid=2" + == license1.status_url + ) assert None == license1.expires assert 40 == license1.remaining_checkouts assert 10 == license1.concurrent_checkouts assert "3" == license2.identifier - assert ("https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}" == - license2.checkout_url) - assert ("https://license.feedbooks.net/license/status/?uuid=3" == - license2.status_url) + assert ( + "https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}" + == license2.checkout_url + ) + assert ( + "https://license.feedbooks.net/license/status/?uuid=3" + == license2.status_url + ) assert None == license2.expires assert None == license2.remaining_checkouts assert 1 == license2.concurrent_checkouts # This book has two 'odl:license' tags, and they have different formats. # TODO: the format+license association is not handled yet. - [midnight_pool] = [p for p in imported_pools if p.identifier == midnight.primary_identifier] + [midnight_pool] = [ + p for p in imported_pools if p.identifier == midnight.primary_identifier + ] assert False == midnight_pool.open_access lpdms = midnight_pool.delivery_mechanisms assert 2 == len(lpdms) - assert (set([Representation.EPUB_MEDIA_TYPE, Representation.PDF_MEDIA_TYPE]) == - set([lpdm.delivery_mechanism.content_type for lpdm in lpdms])) - assert ([DeliveryMechanism.ADOBE_DRM, DeliveryMechanism.ADOBE_DRM] == - [lpdm.delivery_mechanism.drm_scheme for lpdm in lpdms]) - assert ([RightsStatus.IN_COPYRIGHT, RightsStatus.IN_COPYRIGHT] == - [lpdm.rights_status.uri for lpdm in lpdms]) - assert 72 == midnight_pool.licenses_owned # 20 + 52 remaining checkouts in corresponding License Info Documents + assert set( + [Representation.EPUB_MEDIA_TYPE, Representation.PDF_MEDIA_TYPE] + ) == set([lpdm.delivery_mechanism.content_type for lpdm in lpdms]) + assert [DeliveryMechanism.ADOBE_DRM, DeliveryMechanism.ADOBE_DRM] == [ + lpdm.delivery_mechanism.drm_scheme for lpdm in lpdms + ] + assert [RightsStatus.IN_COPYRIGHT, RightsStatus.IN_COPYRIGHT] == [ + lpdm.rights_status.uri for lpdm in lpdms + ] + assert ( + 72 == midnight_pool.licenses_owned + ) # 20 + 52 remaining checkouts in corresponding License Info Documents assert 2 == midnight_pool.licenses_available - [license1, license2] = sorted(midnight_pool.licenses, key=lambda x: x.identifier) + [license1, license2] = sorted( + midnight_pool.licenses, key=lambda x: x.identifier + ) assert "4" == license1.identifier - assert ("https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}" == - license1.checkout_url) - assert ("https://license.feedbooks.net/license/status/?uuid=4" == - license1.status_url) + assert ( + "https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}" + == license1.checkout_url + ) + assert ( + "https://license.feedbooks.net/license/status/?uuid=4" + == license1.status_url + ) assert None == license1.expires assert 20 == license1.remaining_checkouts assert 1 == license1.concurrent_checkouts assert "5" == license2.identifier - assert ("https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}" == - license2.checkout_url) - assert ("https://license.feedbooks.net/license/status/?uuid=5" == - license2.status_url) + assert ( + "https://loan.feedbooks.net/loan/get/{?id,checkout_id,expires,patron_id,notification_url}" + == license2.checkout_url + ) + assert ( + "https://license.feedbooks.net/license/status/?uuid=5" + == license2.status_url + ) assert None == license2.expires assert 52 == license2.remaining_checkouts assert 1 == license2.concurrent_checkouts class TestODLHoldReaper(DatabaseTest, BaseODLTest): - def test_run_once(self): data_source = DataSource.lookup(self._db, "Feedbooks", autocreate=True) collection = MockODLAPI.mock_collection(self._db) collection.external_integration.set_setting( - Collection.DATA_SOURCE_NAME_SETTING, - data_source.name + Collection.DATA_SOURCE_NAME_SETTING, data_source.name ) api = MockODLAPI(self._db, collection) reaper = ODLHoldReaper(self._db, collection, api=api) @@ -1495,19 +1726,29 @@ def test_run_once(self): pool.licenses_owned = 3 pool.licenses_available = 0 pool.licenses_reserved = 3 - expired_hold1, ignore = pool.on_hold_to(self._patron(), end=yesterday, position=0) - expired_hold2, ignore = pool.on_hold_to(self._patron(), end=yesterday, position=0) - expired_hold3, ignore = pool.on_hold_to(self._patron(), end=yesterday, position=0) + expired_hold1, ignore = pool.on_hold_to( + self._patron(), end=yesterday, position=0 + ) + expired_hold2, ignore = pool.on_hold_to( + self._patron(), end=yesterday, position=0 + ) + expired_hold3, ignore = pool.on_hold_to( + self._patron(), end=yesterday, position=0 + ) current_hold, ignore = pool.on_hold_to(self._patron(), position=3) # This hold has an end date in the past, but its position is greater than 0 # so the end date is not reliable. - bad_end_date, ignore = pool.on_hold_to(self._patron(), end=yesterday, position=4) + bad_end_date, ignore = pool.on_hold_to( + self._patron(), end=yesterday, position=4 + ) progress = reaper.run_once(reaper.timestamp().to_data()) # The expired holds have been deleted and the other holds have been updated. assert 2 == self._db.query(Hold).count() - assert [current_hold, bad_end_date] == self._db.query(Hold).order_by(Hold.start).all() + assert [current_hold, bad_end_date] == self._db.query(Hold).order_by( + Hold.start + ).all() assert 0 == current_hold.position assert 0 == bad_end_date.position assert current_hold.end > now @@ -1516,7 +1757,7 @@ def test_run_once(self): assert 2 == pool.licenses_reserved # The TimestampData returned reflects what work was done. - assert 'Holds deleted: 3. License pools updated: 1' == progress.achievements + assert "Holds deleted: 3. License pools updated: 1" == progress.achievements # The TimestampData does not include any timing information -- # that will be applied by run(). @@ -1525,17 +1766,17 @@ def test_run_once(self): class TestSharedODLAPI(DatabaseTest, BaseODLTest): - def setup_method(self): super(TestSharedODLAPI, self).setup_method() self.collection = MockSharedODLAPI.mock_collection(self._db) self.collection.external_integration.set_setting( - Collection.DATA_SOURCE_NAME_SETTING, - "Feedbooks" + Collection.DATA_SOURCE_NAME_SETTING, "Feedbooks" ) self.api = MockSharedODLAPI(self._db, self.collection) self.pool = self._licensepool(None, collection=self.collection) - self.pool.identifier.add_link(Hyperlink.BORROW, self._str, self.collection.data_source) + self.pool.identifier.add_link( + Hyperlink.BORROW, self._str, self.collection.data_source + ) self.patron = self._patron() def test_get(self): @@ -1546,38 +1787,59 @@ def test_get(self): # The library has not registered with the remote collection yet. def do_get(url, headers=None, allowed_response_codes=None): raise Exception("do_get should not be called") - pytest.raises(LibraryAuthorizationFailedException, api._get, - "test url", patron=self.patron, do_get=do_get) + + pytest.raises( + LibraryAuthorizationFailedException, + api._get, + "test url", + patron=self.patron, + do_get=do_get, + ) # Once the library registers, it gets a shared secret that is included # in request headers. ConfigurationSetting.for_library_and_externalintegration( - self._db, ExternalIntegration.PASSWORD, self.patron.library, - self.collection.external_integration).value = "secret" + self._db, + ExternalIntegration.PASSWORD, + self.patron.library, + self.collection.external_integration, + ).value = "secret" + def do_get(url, headers=None, allowed_response_codes=None): assert "test url" == url assert "test header value" == headers.get("test_key") - assert "Bearer " + base64.b64encode("secret") == headers.get("Authorization") + assert "Bearer " + base64.b64encode("secret") == headers.get( + "Authorization" + ) assert ["200"] == allowed_response_codes - api._get("test url", headers=dict(test_key="test header value"), - patron=self.patron, allowed_response_codes=["200"], - do_get=do_get) + + api._get( + "test url", + headers=dict(test_key="test header value"), + patron=self.patron, + allowed_response_codes=["200"], + do_get=do_get, + ) def test_checkout_success(self): response = self.get_data("shared_collection_borrow_success.opds") self.api.queue_response(200, content=response) - loan = self.api.checkout(self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE) + loan = self.api.checkout( + self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE + ) assert self.collection == loan.collection(self._db) assert self.pool.data_source.name == loan.data_source_name assert self.pool.identifier.type == loan.identifier_type assert self.pool.identifier.identifier == loan.identifier assert datetime_utc(2018, 3, 8, 17, 41, 31) == loan.start_date assert datetime_utc(2018, 3, 29, 17, 41, 30) == loan.end_date - assert "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/31" == loan.external_identifier + assert ( + "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/31" + == loan.external_identifier + ) - assert ([self.pool.identifier.links[0].resource.url] == - self.api.requests) + assert [self.pool.identifier.links[0].resource.url] == self.api.requests def test_checkout_from_hold(self): hold, ignore = self.pool.on_hold_to(self.patron, external_identifier=self._str) @@ -1586,31 +1848,48 @@ def test_checkout_from_hold(self): borrow_response = self.get_data("shared_collection_borrow_success.opds") self.api.queue_response(200, content=borrow_response) - loan = self.api.checkout(self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE) + loan = self.api.checkout( + self.patron, "pin", self.pool, Representation.EPUB_MEDIA_TYPE + ) assert self.collection == loan.collection(self._db) assert self.pool.data_source.name == loan.data_source_name assert self.pool.identifier.type == loan.identifier_type assert self.pool.identifier.identifier == loan.identifier assert datetime_utc(2018, 3, 8, 17, 41, 31) == loan.start_date assert datetime_utc(2018, 3, 29, 17, 41, 30) == loan.end_date - assert "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/31" == loan.external_identifier + assert ( + "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/31" + == loan.external_identifier + ) - assert ([hold.external_identifier, - "http://localhost:6500/AL/collections/DPLA%20Exchange/holds/17/borrow"] == - self.api.requests) + assert [ + hold.external_identifier, + "http://localhost:6500/AL/collections/DPLA%20Exchange/holds/17/borrow", + ] == self.api.requests def test_checkout_already_checked_out(self): loan, ignore = self.pool.loan_to(self.patron) - pytest.raises(AlreadyCheckedOut, self.api.checkout, self.patron, "pin", - self.pool, Representation.EPUB_MEDIA_TYPE) + pytest.raises( + AlreadyCheckedOut, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, + ) assert [] == self.api.requests def test_checkout_no_available_copies(self): self.api.queue_response(403) - pytest.raises(NoAvailableCopies, self.api.checkout, self.patron, "pin", - self.pool, Representation.EPUB_MEDIA_TYPE) - assert ([self.pool.identifier.links[0].resource.url] == - self.api.requests) + pytest.raises( + NoAvailableCopies, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, + ) + assert [self.pool.identifier.links[0].resource.url] == self.api.requests def test_checkout_no_licenses(self): self.api.queue_response( @@ -1618,30 +1897,52 @@ def test_checkout_no_licenses(self): headers=NO_LICENSES.response[2], content=NO_LICENSES.response[0], ) - pytest.raises(NoLicenses, self.api.checkout, self.patron, "pin", - self.pool, Representation.EPUB_MEDIA_TYPE) - assert ([self.pool.identifier.links[0].resource.url] == - self.api.requests) + pytest.raises( + NoLicenses, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, + ) + assert [self.pool.identifier.links[0].resource.url] == self.api.requests def test_checkout_from_hold_not_available(self): hold, ignore = self.pool.on_hold_to(self.patron) hold_info_response = self.get_data("shared_collection_hold_info_reserved.opds") self.api.queue_response(200, content=hold_info_response) - pytest.raises(NoAvailableCopies, self.api.checkout, self.patron, "pin", - self.pool, Representation.EPUB_MEDIA_TYPE) + pytest.raises( + NoAvailableCopies, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, + ) assert [hold.external_identifier] == self.api.requests def test_checkout_cannot_loan(self): self.api.queue_response(500) - pytest.raises(CannotLoan, self.api.checkout, self.patron, "pin", - self.pool, Representation.EPUB_MEDIA_TYPE) - assert ([self.pool.identifier.links[0].resource.url] == - self.api.requests) + pytest.raises( + CannotLoan, + self.api.checkout, + self.patron, + "pin", + self.pool, + Representation.EPUB_MEDIA_TYPE, + ) + assert [self.pool.identifier.links[0].resource.url] == self.api.requests # This pool has no borrow link. pool = self._licensepool(None, collection=self.collection) - pytest.raises(CannotLoan, self.api.checkout, self.patron, "pin", - pool, Representation.EPUB_MEDIA_TYPE) + pytest.raises( + CannotLoan, + self.api.checkout, + self.patron, + "pin", + pool, + Representation.EPUB_MEDIA_TYPE, + ) def test_checkin_success(self): loan, ignore = self.pool.loan_to(self.patron, external_identifier=self._str) @@ -1650,9 +1951,10 @@ def test_checkin_success(self): self.api.queue_response(200, content="Deleted") response = self.api.checkin(self.patron, "pin", self.pool) assert True == response - assert ([loan.external_identifier, - "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/33/revoke"] == - self.api.requests) + assert [ + loan.external_identifier, + "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/33/revoke", + ] == self.api.requests def test_checkin_not_checked_out(self): pytest.raises(NotCheckedOut, self.api.checkin, self.patron, "pin", self.pool) @@ -1669,21 +1971,23 @@ def test_checkin_cannot_return(self): pytest.raises(CannotReturn, self.api.checkin, self.patron, "pin", self.pool) assert [loan.external_identifier] == self.api.requests - loan_info_response = self.get_data("shared_collection_loan_info.opds") self.api.queue_response(200, content=loan_info_response) self.api.queue_response(500) pytest.raises(CannotReturn, self.api.checkin, self.patron, "pin", self.pool) - assert ([loan.external_identifier, - "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/33/revoke"] == - self.api.requests[1:]) + assert [ + loan.external_identifier, + "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/33/revoke", + ] == self.api.requests[1:] def test_fulfill_success(self): loan, ignore = self.pool.loan_to(self.patron, external_identifier=self._str) loan_info_response = self.get_data("shared_collection_loan_info.opds") self.api.queue_response(200, content=loan_info_response) self.api.queue_response(200, content="An ACSM file") - fulfillment = self.api.fulfill(self.patron, "pin", self.pool, self.pool.delivery_mechanisms[0]) + fulfillment = self.api.fulfill( + self.patron, "pin", self.pool, self.pool.delivery_mechanisms[0] + ) assert self.collection == fulfillment.collection(self._db) assert self.pool.data_source.name == fulfillment.data_source_name assert self.pool.identifier.type == fulfillment.identifier_type @@ -1692,54 +1996,94 @@ def test_fulfill_success(self): assert b"An ACSM file" == fulfillment.content assert datetime_utc(2018, 3, 29, 17, 44, 11) == fulfillment.content_expires - assert ([loan.external_identifier, - "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/33/fulfill/2"] == - self.api.requests) + assert [ + loan.external_identifier, + "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/33/fulfill/2", + ] == self.api.requests def test_fulfill_not_checked_out(self): - pytest.raises(NotCheckedOut, self.api.fulfill, self.patron, "pin", - self.pool, self.pool.delivery_mechanisms[0]) + pytest.raises( + NotCheckedOut, + self.api.fulfill, + self.patron, + "pin", + self.pool, + self.pool.delivery_mechanisms[0], + ) assert [] == self.api.requests loan, ignore = self.pool.loan_to(self.patron, external_identifier=self._str) self.api.queue_response(404) - pytest.raises(NotCheckedOut, self.api.fulfill, self.patron, "pin", - self.pool, self.pool.delivery_mechanisms[0]) + pytest.raises( + NotCheckedOut, + self.api.fulfill, + self.patron, + "pin", + self.pool, + self.pool.delivery_mechanisms[0], + ) assert [loan.external_identifier] == self.api.requests def test_fulfill_cannot_fulfill(self): loan, ignore = self.pool.loan_to(self.patron, external_identifier=self._str) self.api.queue_response(500) - pytest.raises(CannotFulfill, self.api.fulfill, self.patron, "pin", - self.pool, self.pool.delivery_mechanisms[0]) + pytest.raises( + CannotFulfill, + self.api.fulfill, + self.patron, + "pin", + self.pool, + self.pool.delivery_mechanisms[0], + ) assert [loan.external_identifier] == self.api.requests self.api.queue_response(200, content="not opds") - pytest.raises(CannotFulfill, self.api.fulfill, self.patron, "pin", - self.pool, self.pool.delivery_mechanisms[0]) + pytest.raises( + CannotFulfill, + self.api.fulfill, + self.patron, + "pin", + self.pool, + self.pool.delivery_mechanisms[0], + ) assert [loan.external_identifier] == self.api.requests[1:] loan_info_response = self.get_data("shared_collection_loan_info.opds") self.api.queue_response(200, content=loan_info_response) self.api.queue_response(500) - pytest.raises(CannotFulfill, self.api.fulfill, self.patron, "pin", - self.pool, self.pool.delivery_mechanisms[0]) - assert ([loan.external_identifier, - "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/33/fulfill/2"] == - self.api.requests[2:]) + pytest.raises( + CannotFulfill, + self.api.fulfill, + self.patron, + "pin", + self.pool, + self.pool.delivery_mechanisms[0], + ) + assert [ + loan.external_identifier, + "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/33/fulfill/2", + ] == self.api.requests[2:] def test_fulfill_format_not_available(self): loan, ignore = self.pool.loan_to(self.patron) loan_info_response = self.get_data("shared_collection_loan_info_no_epub.opds") self.api.queue_response(200, content=loan_info_response) - pytest.raises(FormatNotAvailable, self.api.fulfill, self.patron, "pin", - self.pool, self.pool.delivery_mechanisms[0]) + pytest.raises( + FormatNotAvailable, + self.api.fulfill, + self.patron, + "pin", + self.pool, + self.pool.delivery_mechanisms[0], + ) assert [loan.external_identifier] == self.api.requests def test_place_hold_success(self): hold_response = self.get_data("shared_collection_hold_info_reserved.opds") self.api.queue_response(200, content=hold_response) - hold = self.api.place_hold(self.patron, "pin", self.pool, "notifications@librarysimplified.org") + hold = self.api.place_hold( + self.patron, "pin", self.pool, "notifications@librarysimplified.org" + ) assert self.collection == hold.collection(self._db) assert self.pool.data_source.name == hold.data_source_name assert self.pool.identifier.type == hold.identifier_type @@ -1747,15 +2091,23 @@ def test_place_hold_success(self): assert datetime_utc(2018, 3, 8, 18, 50, 18) == hold.start_date assert datetime_utc(2018, 3, 29, 17, 44, 1) == hold.end_date assert 1 == hold.hold_position - assert "http://localhost:6500/AL/collections/DPLA%20Exchange/holds/18" == hold.external_identifier + assert ( + "http://localhost:6500/AL/collections/DPLA%20Exchange/holds/18" + == hold.external_identifier + ) - assert ([self.pool.identifier.links[0].resource.url] == - self.api.requests) + assert [self.pool.identifier.links[0].resource.url] == self.api.requests def test_place_hold_already_checked_out(self): loan, ignore = self.pool.loan_to(self.patron) - pytest.raises(AlreadyCheckedOut, self.api.place_hold, self.patron, "pin", - self.pool, "notification@librarysimplified.org") + pytest.raises( + AlreadyCheckedOut, + self.api.place_hold, + self.patron, + "pin", + self.pool, + "notification@librarysimplified.org", + ) assert [] == self.api.requests def test_release_hold_success(self): @@ -1765,9 +2117,10 @@ def test_release_hold_success(self): self.api.queue_response(200, content="Deleted") response = self.api.release_hold(self.patron, "pin", self.pool) assert True == response - assert ([hold.external_identifier, - "http://localhost:6500/AL/collections/DPLA%20Exchange/holds/18/revoke"] == - self.api.requests) + assert [ + hold.external_identifier, + "http://localhost:6500/AL/collections/DPLA%20Exchange/holds/18/revoke", + ] == self.api.requests def test_release_hold_not_on_hold(self): pytest.raises(NotOnHold, self.api.release_hold, self.patron, "pin", self.pool) @@ -1781,16 +2134,21 @@ def test_release_hold_not_on_hold(self): def test_release_hold_cannot_release_hold(self): hold, ignore = self.pool.on_hold_to(self.patron, external_identifier=self._str) self.api.queue_response(500) - pytest.raises(CannotReleaseHold, self.api.release_hold, self.patron, "pin", self.pool) + pytest.raises( + CannotReleaseHold, self.api.release_hold, self.patron, "pin", self.pool + ) assert [hold.external_identifier] == self.api.requests hold_response = self.get_data("shared_collection_hold_info_reserved.opds") self.api.queue_response(200, content=hold_response) self.api.queue_response(500) - pytest.raises(CannotReleaseHold, self.api.release_hold, self.patron, "pin", self.pool) - assert ([hold.external_identifier, - "http://localhost:6500/AL/collections/DPLA%20Exchange/holds/18/revoke"] == - self.api.requests[1:]) + pytest.raises( + CannotReleaseHold, self.api.release_hold, self.patron, "pin", self.pool + ) + assert [ + hold.external_identifier, + "http://localhost:6500/AL/collections/DPLA%20Exchange/holds/18/revoke", + ] == self.api.requests[1:] def test_patron_activity_success(self): # The patron has one loan, and the remote circ manager returns it. @@ -1844,46 +2202,61 @@ def test_patron_activity_success(self): def test_patron_activity_remote_integration_exception(self): loan, ignore = self.pool.loan_to(self.patron, external_identifier=self._str) self.api.queue_response(500) - pytest.raises(RemoteIntegrationException, self.api.patron_activity, self.patron, "pin") + pytest.raises( + RemoteIntegrationException, self.api.patron_activity, self.patron, "pin" + ) assert [loan.external_identifier] == self.api.requests self._db.delete(loan) hold, ignore = self.pool.on_hold_to(self.patron, external_identifier=self._str) self.api.queue_response(500) - pytest.raises(RemoteIntegrationException, self.api.patron_activity, self.patron, "pin") + pytest.raises( + RemoteIntegrationException, self.api.patron_activity, self.patron, "pin" + ) assert [hold.external_identifier] == self.api.requests[1:] class TestSharedODLImporter(DatabaseTest, BaseODLTest): - def test_get_fulfill_url(self): entry = self.get_data("shared_collection_loan_info.opds") - assert ("http://localhost:6500/AL/collections/DPLA%20Exchange/loans/33/fulfill/2" == - SharedODLImporter.get_fulfill_url(entry, "application/epub+zip", "application/vnd.adobe.adept+xml")) - assert None == SharedODLImporter.get_fulfill_url(entry, "application/pdf", "application/vnd.adobe.adept+xml") - assert None == SharedODLImporter.get_fulfill_url(entry, "application/epub+zip", None) + assert ( + "http://localhost:6500/AL/collections/DPLA%20Exchange/loans/33/fulfill/2" + == SharedODLImporter.get_fulfill_url( + entry, "application/epub+zip", "application/vnd.adobe.adept+xml" + ) + ) + assert None == SharedODLImporter.get_fulfill_url( + entry, "application/pdf", "application/vnd.adobe.adept+xml" + ) + assert None == SharedODLImporter.get_fulfill_url( + entry, "application/epub+zip", None + ) def test_import(self): feed = self.get_data("shared_collection_feed.opds") data_source = DataSource.lookup(self._db, "DPLA Exchange", autocreate=True) collection = MockSharedODLAPI.mock_collection(self._db) collection.external_integration.set_setting( - Collection.DATA_SOURCE_NAME_SETTING, - data_source.name + Collection.DATA_SOURCE_NAME_SETTING, data_source.name ) class MockMetadataClient(object): def canonicalize_author_name(self, identifier, working_display_name): return working_display_name + metadata_client = MockMetadataClient() importer = SharedODLImporter( - self._db, collection=collection, + self._db, + collection=collection, metadata_client=metadata_client, ) - imported_editions, imported_pools, imported_works, failures = ( - importer.import_from_feed(feed) - ) + ( + imported_editions, + imported_pools, + imported_works, + failures, + ) = importer.import_from_feed(feed) # This importer works the same as the base OPDSImporter, except that # it extracts license pool information from acquisition links. @@ -1899,18 +2272,24 @@ def canonicalize_author_name(self, identifier, working_display_name): assert "The Great Gatsby" == gatsby.title # This book is open access. - [gatsby_pool] = [p for p in imported_pools if p.identifier == gatsby.primary_identifier] + [gatsby_pool] = [ + p for p in imported_pools if p.identifier == gatsby.primary_identifier + ] assert True == gatsby_pool.open_access # This pool has two delivery mechanisms, from a borrow link and an open-access link. # Both are DRM-free epubs. lpdms = gatsby_pool.delivery_mechanisms assert 2 == len(lpdms) for lpdm in lpdms: - assert Representation.EPUB_MEDIA_TYPE == lpdm.delivery_mechanism.content_type + assert ( + Representation.EPUB_MEDIA_TYPE == lpdm.delivery_mechanism.content_type + ) assert DeliveryMechanism.NO_DRM == lpdm.delivery_mechanism.drm_scheme # This book is already checked out and has a hold. - [six_months_pool] = [p for p in imported_pools if p.identifier == six_months.primary_identifier] + [six_months_pool] = [ + p for p in imported_pools if p.identifier == six_months.primary_identifier + ] assert False == six_months_pool.open_access assert 1 == six_months_pool.licenses_owned assert 0 == six_months_pool.licenses_available @@ -1919,11 +2298,18 @@ def canonicalize_author_name(self, identifier, working_display_name): assert Representation.EPUB_MEDIA_TYPE == lpdm.delivery_mechanism.content_type assert DeliveryMechanism.ADOBE_DRM == lpdm.delivery_mechanism.drm_scheme assert RightsStatus.IN_COPYRIGHT == lpdm.rights_status.uri - [borrow_link] = [l for l in six_months_pool.identifier.links if l.rel == Hyperlink.BORROW] - assert 'http://localhost:6500/AL/works/URI/http://www.feedbooks.com/item/2493650/borrow' == borrow_link.resource.url + [borrow_link] = [ + l for l in six_months_pool.identifier.links if l.rel == Hyperlink.BORROW + ] + assert ( + "http://localhost:6500/AL/works/URI/http://www.feedbooks.com/item/2493650/borrow" + == borrow_link.resource.url + ) # This book is currently available. - [essex_pool] = [p for p in imported_pools if p.identifier == essex.primary_identifier] + [essex_pool] = [ + p for p in imported_pools if p.identifier == essex.primary_identifier + ] assert False == essex_pool.open_access assert 4 == essex_pool.licenses_owned assert 4 == essex_pool.licenses_available @@ -2071,7 +2457,8 @@ def _get_test_feed(self, licenses: List[TestLicense]) -> str: :return: Test ODL feed """ env = Environment( - loader=FileSystemLoader(self.ODL_TEMPLATE_DIR), autoescape=select_autoescape() + loader=FileSystemLoader(self.ODL_TEMPLATE_DIR), + autoescape=select_autoescape(), ) template = env.get_template(self.ODL_TEMPLATE_FILENAME) feed = template.render(licenses=licenses) @@ -2098,7 +2485,8 @@ def _import_test_feed( ) license_status_response = MagicMock( side_effect=[ - (200, {}, str(license_status) if license_status else "{}") for license_status in license_infos + (200, {}, str(license_status) if license_status else "{}") + for license_status in license_infos ] if license_infos else [(200, {}, {})] @@ -2118,29 +2506,32 @@ def _import_test_feed( class TestODLExpiredItemsReaperSingleLicense(TestODLExpiredItemsReaper): """Class testing that the ODL 1.x reaper correctly processes publications with a single license.""" - @parameterized.expand([ - ( - "expiration_date_in_the_past", - # The license expires 2021-01-01T00:01:00+01:00 that equals to 2010-01-01T00:00:00+00:00, the current time. - # It means the license had already expired at the time of the import. - TestLicense(expires=dateutil.parser.isoparse("2021-01-01T00:01:00+01:00")) - ), - ( - "total_checkouts_is_zero", - TestLicense(total_checkouts=0) - ), - ( - "remaining_checkouts_is_zero", - TestLicense(total_checkouts=10, concurrent_checkouts=5), - TestLicenseInfo(remaining_checkouts=0, available_concurrent_checkouts=0) - ) - ]) + @parameterized.expand( + [ + ( + "expiration_date_in_the_past", + # The license expires 2021-01-01T00:01:00+01:00 that equals to 2010-01-01T00:00:00+00:00, the current time. + # It means the license had already expired at the time of the import. + TestLicense( + expires=dateutil.parser.isoparse("2021-01-01T00:01:00+01:00") + ), + ), + ("total_checkouts_is_zero", TestLicense(total_checkouts=0)), + ( + "remaining_checkouts_is_zero", + TestLicense(total_checkouts=10, concurrent_checkouts=5), + TestLicenseInfo( + remaining_checkouts=0, available_concurrent_checkouts=0 + ), + ), + ] + ) @freeze_time("2021-01-01T00:00:00+00:00") def test_odl_importer_skips_expired_licenses( self, _, test_license: TestLicense, - test_license_info: Optional[TestLicenseInfo] = None + test_license_info: Optional[TestLicenseInfo] = None, ) -> None: """Ensure ODLImporter skips expired licenses and does not count them in the total number of available licenses. @@ -2151,8 +2542,7 @@ def test_odl_importer_skips_expired_licenses( """ # 1.1. Import the test feed with an expired ODL license. imported_editions, imported_pools, imported_works = self._import_test_feed( - [test_license], - [test_license_info] + [test_license], [test_license_info] ) # Commit to expire the SQLAlchemy cache. @@ -2228,7 +2618,9 @@ def test_odl_reaper_removes_expired_licenses(self): assert imported_pool.licenses_available == available_concurrent_checkouts # 4. Expire the license. - with patch("core.model.License.is_expired", new_callable=PropertyMock) as is_expired: + with patch( + "core.model.License.is_expired", new_callable=PropertyMock + ) as is_expired: is_expired.return_value = True # 5.1. Run ODLExpiredItemsReaper again. This time it should remove the expired license. @@ -2265,27 +2657,27 @@ def test_odl_importer_skips_expired_licenses(self): available_concurrent_checkouts = 5 imported_editions, imported_pools, imported_works = self._import_test_feed( [ - TestLicense( # Expired - total_checkouts=10, # (expiry date in the past) + TestLicense( # Expired + total_checkouts=10, # (expiry date in the past) concurrent_checkouts=5, expires=datetime_helpers.utc_now() - datetime.timedelta(days=1), ), - TestLicense( # Expired - total_checkouts=0, # (total_checkouts is 0) + TestLicense( # Expired + total_checkouts=0, # (total_checkouts is 0) concurrent_checkouts=0, expires=datetime_helpers.utc_now() + datetime.timedelta(days=1), ), - TestLicense( # Expired - total_checkouts=10, # (remaining_checkout is 0) + TestLicense( # Expired + total_checkouts=10, # (remaining_checkout is 0) concurrent_checkouts=5, expires=datetime_helpers.utc_now() + datetime.timedelta(days=1), ), - TestLicense( # Valid + TestLicense( # Valid total_checkouts=10, concurrent_checkouts=5, expires=datetime_helpers.utc_now() + datetime.timedelta(days=2), ), - TestLicense( # Valid + TestLicense( # Valid total_checkouts=10, concurrent_checkouts=5, expires=datetime_helpers.utc_now() + datetime.timedelta(weeks=12), @@ -2293,17 +2685,16 @@ def test_odl_importer_skips_expired_licenses(self): ], [ TestLicenseInfo( - remaining_checkouts=0, - available_concurrent_checkouts=0 + remaining_checkouts=0, available_concurrent_checkouts=0 ), TestLicenseInfo( remaining_checkouts=remaining_checkouts, - available_concurrent_checkouts=available_concurrent_checkouts + available_concurrent_checkouts=available_concurrent_checkouts, ), TestLicenseInfo( remaining_checkouts=remaining_checkouts, - available_concurrent_checkouts=available_concurrent_checkouts - ) + available_concurrent_checkouts=available_concurrent_checkouts, + ), ], ) @@ -2349,16 +2740,16 @@ def test_odl_reaper_removes_expired_licenses(self): [ TestLicenseInfo( remaining_checkouts=total_checkouts, - available_concurrent_checkouts=available_concurrent_checkouts + available_concurrent_checkouts=available_concurrent_checkouts, ), TestLicenseInfo( remaining_checkouts=remaining_checkouts, - available_concurrent_checkouts=available_concurrent_checkouts + available_concurrent_checkouts=available_concurrent_checkouts, ), TestLicenseInfo( remaining_checkouts=remaining_checkouts, - available_concurrent_checkouts=available_concurrent_checkouts - ) + available_concurrent_checkouts=available_concurrent_checkouts, + ), ], ) diff --git a/tests/test_odl2.py b/tests/test_odl2.py index cadf2c0d88..c3563b820a 100644 --- a/tests/test_odl2.py +++ b/tests/test_odl2.py @@ -3,8 +3,6 @@ import os import requests_mock -from api.odl import ODLImporter -from api.odl2 import ODL2API, ODL2APIConfiguration, ODL2ExpiredItemsReaper, ODL2Importer from freezegun import freeze_time from mock import MagicMock from webpub_manifest_parser.core.ast import PresentationMetadata @@ -14,6 +12,8 @@ ODL_PUBLICATION_MUST_CONTAIN_EITHER_LICENSES_OR_OA_ACQUISITION_LINK_ERROR, ) +from api.odl import ODLImporter +from api.odl2 import ODL2API, ODL2APIConfiguration, ODL2ExpiredItemsReaper, ODL2Importer from core.coverage import CoverageFailure from core.model import ( Contribution, @@ -223,8 +223,10 @@ def test(self): == moby_dick_license.checkout_url ) assert "http://www.example.com/status/294024" == moby_dick_license.status_url - assert datetime.datetime(2016, 4, 25, 10, 25, 21, tzinfo=datetime.timezone.utc) \ - == moby_dick_license.expires + assert ( + datetime.datetime(2016, 4, 25, 10, 25, 21, tzinfo=datetime.timezone.utc) + == moby_dick_license.expires + ) assert 10 == moby_dick_license.remaining_checkouts assert 10 == moby_dick_license.concurrent_checkouts @@ -249,11 +251,13 @@ def test(self): assert isinstance(huck_finn_failure, CoverageFailure) assert "9781234567897" == huck_finn_failure.obj.identifier - huck_finn_semantic_error = ODL_PUBLICATION_MUST_CONTAIN_EITHER_LICENSES_OR_OA_ACQUISITION_LINK_ERROR( - node=ODLPublication( - metadata=PresentationMetadata(identifier="urn:isbn:9781234567897") - ), - node_property=None, + huck_finn_semantic_error = ( + ODL_PUBLICATION_MUST_CONTAIN_EITHER_LICENSES_OR_OA_ACQUISITION_LINK_ERROR( + node=ODLPublication( + metadata=PresentationMetadata(identifier="urn:isbn:9781234567897") + ), + node_property=None, + ) ) assert str(huck_finn_semantic_error) == huck_finn_failure.exception @@ -262,7 +266,9 @@ class TestODL2ExpiredItemsReaper(TestODLExpiredItemsReaper): """Base class for all ODL 2.x reaper tests.""" ODL_PROTOCOL = ODL2API.NAME - ODL_TEMPLATE_DIR = os.path.join(TestODLExpiredItemsReaper.base_path, "files", "odl2") + ODL_TEMPLATE_DIR = os.path.join( + TestODLExpiredItemsReaper.base_path, "files", "odl2" + ) ODL_TEMPLATE_FILENAME = "feed_template.json.jinja" ODL_REAPER_CLASS = ODL2ExpiredItemsReaper @@ -290,9 +296,13 @@ def _create_importer(self, collection, http_get): return importer -class TestODL2ExpiredItemsReaperSingleLicense(TestODL2ExpiredItemsReaper, TestODLExpiredItemsReaperSingleLicense): +class TestODL2ExpiredItemsReaperSingleLicense( + TestODL2ExpiredItemsReaper, TestODLExpiredItemsReaperSingleLicense +): """Class testing that the ODL 2.x reaper correctly processes publications with a single license.""" -class TestODL2ExpiredItemsReaperMultipleLicense(TestODL2ExpiredItemsReaper, TestODLExpiredItemsReaperMultipleLicense): +class TestODL2ExpiredItemsReaperMultipleLicense( + TestODL2ExpiredItemsReaper, TestODLExpiredItemsReaperMultipleLicense +): """Class testing that the ODL 2.x reaper correctly processes publications with multiple licenses.""" diff --git a/tests/test_onix.py b/tests/test_onix.py index 2faa94d2fa..a39c7fc7fc 100644 --- a/tests/test_onix.py +++ b/tests/test_onix.py @@ -5,17 +5,13 @@ from api.onix import ONIXExtractor from core.classifier import Classifier from core.metadata_layer import CirculationData -from core.model import ( - Classification, - Edition, - Identifier, - LicensePool) +from core.model import Classification, Edition, Identifier, LicensePool from core.util.datetime_helpers import datetime_utc + from . import sample_data class TestONIXExtractor(object): - def sample_data(self, filename): return sample_data(filename, "onix") @@ -50,36 +46,41 @@ def test_parser(self): assert 2017 == record.issued.year assert 1 == len(record.links) - assert "the essential democratic values of diversity and free expression" in record.links[0].content + assert ( + "the essential democratic values of diversity and free expression" + in record.links[0].content + ) record = metadata_records[1] assert Edition.AUDIO_MEDIUM == record.medium assert "The Test Corporation" == record.contributors[0].display_name assert "Test Corporation, The" == record.contributors[0].sort_name - @parameterized.expand([ - ( - 'limited_usage_status', - 'onix_3_usage_constraints_example.xml', - 20 - ), - ( - 'unlimited_usage_status', - 'onix_3_usage_constraints_with_unlimited_usage_status.xml', - LicensePool.UNLIMITED_ACCESS - ), - ( - 'wrong_usage_unit', - 'onix_3_usage_constraints_example_with_day_usage_unit.xml', - LicensePool.UNLIMITED_ACCESS - ) - ]) - def test_parse_parses_correctly_onix_3_usage_constraints(self, _, file_name, licenses_number): + @parameterized.expand( + [ + ("limited_usage_status", "onix_3_usage_constraints_example.xml", 20), + ( + "unlimited_usage_status", + "onix_3_usage_constraints_with_unlimited_usage_status.xml", + LicensePool.UNLIMITED_ACCESS, + ), + ( + "wrong_usage_unit", + "onix_3_usage_constraints_example_with_day_usage_unit.xml", + LicensePool.UNLIMITED_ACCESS, + ), + ] + ) + def test_parse_parses_correctly_onix_3_usage_constraints( + self, _, file_name, licenses_number + ): # Arrange file = self.sample_data(file_name) # Act - metadata_records = ONIXExtractor().parse(BytesIO(file), 'ONIX 3 Usage Constraints Example') + metadata_records = ONIXExtractor().parse( + BytesIO(file), "ONIX 3 Usage Constraints Example" + ) # Assert assert len(metadata_records) == 1 diff --git a/tests/test_opds.py b/tests/test_opds.py index 337d956531..967638d453 100644 --- a/tests/test_opds.py +++ b/tests/test_opds.py @@ -1,25 +1,37 @@ -from collections import defaultdict import contextlib import datetime -import dateutil +import json import os import re -import json +from collections import defaultdict +from pdb import set_trace + +import dateutil +import feedparser +import jwt import pytest from lxml import etree from mock import create_autospec -import feedparser -from core.testing import DatabaseTest -from pdb import set_trace -from core.analytics import Analytics -from core.lane import ( - FacetsWithEntryPoint, - Lane, - WorkList, + +from api.adobe_vendor_id import AuthdataUtility +from api.circulation import BaseCirculationAPI, CirculationAPI, FulfillmentInfo +from api.config import Configuration, temp_config +from api.lanes import ContributorLane, CrawlableCustomListBasedLane +from api.novelist import NoveListAPI +from api.opds import ( + CirculationManagerAnnotator, + LibraryAnnotator, + LibraryLoanAndHoldAnnotator, + SharedCollectionAnnotator, + SharedCollectionLoanAndHoldAnnotator, ) +from api.testing import VendorIDTest +from core.analytics import Analytics +from core.classifier import Classifier, Fantasy, Urban_Fantasy +from core.entrypoint import AudiobooksEntryPoint, EverythingEntryPoint +from core.external_search import MockExternalSearchIndex, WorkSearchResult +from core.lane import FacetsWithEntryPoint, Lane, WorkList from core.model import ( - create, - get_one_or_create, CirculationEvent, ConfigurationSetting, Contributor, @@ -33,67 +45,16 @@ RightsStatus, SessionManager, Work, + create, + get_one_or_create, ) - -from core.classifier import ( - Classifier, - Fantasy, - Urban_Fantasy -) - -from core.entrypoint import ( - AudiobooksEntryPoint, - EverythingEntryPoint, -) -from core.external_search import ( - MockExternalSearchIndex, - WorkSearchResult, -) -from core.util.datetime_helpers import ( - datetime_utc, - utc_now, -) -from core.util.opds_writer import ( - AtomFeed, - OPDSFeed, -) - -from core.opds import ( - AcquisitionFeed, - TestAnnotator, - UnfulfillableWork, -) - -from core.opds_import import ( - OPDSXMLParser -) -from core.util.flask_util import ( - OPDSEntryResponse, - OPDSFeedResponse, -) +from core.opds import AcquisitionFeed, TestAnnotator, UnfulfillableWork +from core.opds_import import OPDSXMLParser +from core.testing import DatabaseTest +from core.util.datetime_helpers import datetime_utc, utc_now +from core.util.flask_util import OPDSEntryResponse, OPDSFeedResponse +from core.util.opds_writer import AtomFeed, OPDSFeed from core.util.string_helpers import base64 -from api.circulation import ( - BaseCirculationAPI, - CirculationAPI, - FulfillmentInfo, -) -from api.config import ( - Configuration, - temp_config, -) -from api.opds import ( - CirculationManagerAnnotator, - LibraryAnnotator, - SharedCollectionAnnotator, - LibraryLoanAndHoldAnnotator, - SharedCollectionLoanAndHoldAnnotator, -) - -from api.testing import VendorIDTest -from api.adobe_vendor_id import AuthdataUtility -from api.novelist import NoveListAPI -from api.lanes import ContributorLane, CrawlableCustomListBasedLane -import jwt _strftime = AtomFeed._strftime @@ -104,7 +65,8 @@ def setup_method(self): self.work = self._work(with_open_access_download=True) self.lane = self._lane(display_name="Fantasy") self.annotator = CirculationManagerAnnotator( - self.lane, test_mode=True, + self.lane, + test_mode=True, ) def test_open_access_link(self): @@ -120,34 +82,34 @@ def test_open_access_link(self): lpdm.resource.representation = None lpdm.resource.url = "http://foo.com/thefile.epub" link_tag = self.annotator.open_access_link(pool, lpdm) - assert lpdm.resource.url == link_tag.get('href') + assert lpdm.resource.url == link_tag.get("href") # The dcterms:rights attribute may provide a more detailed # explanation of the book's copyright status. - rights = link_tag.attrib['{http://purl.org/dc/terms/}rights'] + rights = link_tag.attrib["{http://purl.org/dc/terms/}rights"] assert lpdm.rights_status.uri == rights # If we have a CDN set up for open-access links, the CDN hostname # replaces the original hostname. with temp_config() as config: config[Configuration.INTEGRATIONS][ExternalIntegration.CDN] = { - 'foo.com' : 'https://cdn.com/' + "foo.com": "https://cdn.com/" } link_tag = self.annotator.open_access_link(pool, lpdm) - link_url = link_tag.get('href') + link_url = link_tag.get("href") assert "https://cdn.com/thefile.epub" == link_url # If the Resource has a Representation, the public URL is used # instead of the original Resource URL. lpdm.resource.representation = representation link_tag = self.annotator.open_access_link(pool, lpdm) - assert representation.public_url == link_tag.get('href') + assert representation.public_url == link_tag.get("href") # If there is no Representation, the Resource's original URL is used. lpdm.resource.representation = None link_tag = self.annotator.open_access_link(pool, lpdm) - assert lpdm.resource.url == link_tag.get('href') + assert lpdm.resource.url == link_tag.get("href") def test_default_lane_url(self): default_lane_url = self.annotator.default_lane_url() @@ -174,7 +136,8 @@ def test_visible_delivery_mechanisms(self): # Create an annotator that hides PDFs. no_pdf = CirculationManagerAnnotator( - self.lane, hidden_content_types=["application/pdf"], + self.lane, + hidden_content_types=["application/pdf"], test_mode=True, ) @@ -184,7 +147,8 @@ def test_visible_delivery_mechanisms(self): # Create an annotator that hides EPUBs. no_epub = CirculationManagerAnnotator( - self.lane, hidden_content_types=["application/epub+zip"], + self.lane, + hidden_content_types=["application/epub+zip"], test_mode=True, ) @@ -200,8 +164,7 @@ def test_rights_attributes(self): # attribute to the URI associated with the RightsStatus. lp = self._licensepool(None) [lpdm] = lp.delivery_mechanisms - assert ({"{http://purl.org/dc/terms/}rights":lpdm.rights_status.uri} == - m(lpdm)) + assert {"{http://purl.org/dc/terms/}rights": lpdm.rights_status.uri} == m(lpdm) # If any link in the chain is broken, rights_attributes returns # an empty dictionary. @@ -234,7 +197,7 @@ def entry_for(work): return entry entry = entry_for(work) - assert '2018-02-04' in entry.get("updated") + assert "2018-02-04" in entry.get("updated") # If the work passed in is a WorkSearchResult that indicates # the search index found a later 'update time', then the later @@ -246,19 +209,20 @@ def __init__(self, last_update): # Store the time the way we get it from ElasticSearch -- # as a single-element list containing seconds since epoch. self.last_update = [ - (last_update-datetime_utc(1970, 1, 1)).total_seconds() + (last_update - datetime_utc(1970, 1, 1)).total_seconds() ] + hit = MockHit(datetime_utc(2018, 2, 5)) result = WorkSearchResult(work, hit) entry = entry_for(result) - assert '2018-02-05' in entry.get("updated") + assert "2018-02-05" in entry.get("updated") # Any 'update time' provided by ElasticSearch is used even if # it's clearly earlier than Work.last_update_time. hit = MockHit(datetime_utc(2017, 1, 1)) result._hit = hit entry = entry_for(result) - assert '2017-01-01' in entry.get("updated") + assert "2017-01-01" in entry.get("updated") def test__single_entry_response(self): # Test the helper method that makes OPDSEntryResponse objects. @@ -271,12 +235,12 @@ def test__single_entry_response(self): annotator = TestAnnotator() response = m(self._db, work, annotator, url) assert isinstance(response, OPDSEntryResponse) - assert '%s' % work.title in response.get_data(as_text=True) + assert "%s" % work.title in response.get_data(as_text=True) # By default, the representation is private but can be cached # by the recipient. assert True == response.private - assert 30*60 == response.max_age + assert 30 * 60 == response.max_age # Test the case where we override the defaults. response = m(self._db, work, annotator, url, max_age=12, private=False) @@ -290,8 +254,8 @@ def test__single_entry_response(self): # Instead of an entry based on the Work, we get an empty feed. assert isinstance(response, OPDSFeedResponse) response_data = response.get_data(as_text=True) - assert 'Unknown work' in response_data - assert '' not in response_data + assert "Unknown work" in response_data + assert "" not in response_data # Since it's an error message, the representation is private # and not to be cached. @@ -304,14 +268,16 @@ def setup_method(self): super(TestLibraryAnnotator, self).setup_method() self.work = self._work(with_open_access_download=True) - parent = self._lane( - display_name="Fiction", languages=["eng"], fiction=True - ) + parent = self._lane(display_name="Fiction", languages=["eng"], fiction=True) self.lane = self._lane(display_name="Fantasy", languages=["eng"]) self.lane.add_genre(Fantasy.name) self.lane.parent = parent self.annotator = LibraryAnnotator( - None, self.lane, self._default_library, test_mode=True, top_level_title="Test Top Level Title" + None, + self.lane, + self._default_library, + test_mode=True, + top_level_title="Test Top Level Title", ) # Initialize library with Adobe Vendor ID details @@ -321,12 +287,10 @@ def setup_method(self): # A ContributorLane to test code that handles it differently. self.contributor, ignore = self._contributor("Someone") self.contributor_lane = ContributorLane( - self._default_library, self.contributor, languages=["eng"], - audiences=None + self._default_library, self.contributor, languages=["eng"], audiences=None ) def test__hidden_content_types(self): - def f(value): """Set the default library's HIDDEN_CONTENT_TYPES setting to a specific value and see what _hidden_content_types @@ -337,9 +301,7 @@ def f(value): return LibraryAnnotator._hidden_content_types(library) # When the value is not set at all, no content types are hidden. - assert ( - [] == - list(LibraryAnnotator._hidden_content_types(self._default_library))) + assert [] == list(LibraryAnnotator._hidden_content_types(self._default_library)) # Now set various values and see what happens. assert [] == f(None) @@ -347,10 +309,8 @@ def f(value): assert [] == f(json.dumps([])) assert ["text/html"] == f("text/html") assert ["text/html"] == f(json.dumps("text/html")) - assert ["text/html"] == f(json.dumps({"text/html" : "some value"})) - assert ( - ["text/html", "text/plain"] == - f(json.dumps(["text/html", "text/plain"]))) + assert ["text/html"] == f(json.dumps({"text/html": "some value"})) + assert ["text/html", "text/plain"] == f(json.dumps(["text/html", "text/plain"])) def test_add_configuration_links(self): mock_feed = [] @@ -360,9 +320,9 @@ def test_add_configuration_links(self): LibraryAnnotator.COPYRIGHT: "http://copyright/", LibraryAnnotator.ABOUT: "http://about/", LibraryAnnotator.LICENSE: "http://license/", - Configuration.HELP_EMAIL : "help@me", - Configuration.HELP_WEB : "http://help/", - Configuration.HELP_URI : "uri:help", + Configuration.HELP_EMAIL: "help@me", + Configuration.HELP_WEB: "http://help/", + Configuration.HELP_URI: "uri:help", } # Set up configuration settings for links. @@ -385,25 +345,24 @@ def test_add_configuration_links(self): # They are the links we'd expect. links = {} for link in mock_feed: - rel = link.attrib['rel'] - href = link.attrib['href'] - if rel == 'help' or rel == 'related': - continue # Tested below + rel = link.attrib["rel"] + href = link.attrib["href"] + if rel == "help" or rel == "related": + continue # Tested below # Check that the configuration value made it into the link. assert href == link_config[rel] - assert "text/html" == link.attrib['type'] + assert "text/html" == link.attrib["type"] # There are three help links using different protocols. - help_links = [x.attrib['href'] for x in mock_feed - if x.attrib['rel'] == 'help'] - assert (set(["mailto:help@me", "http://help/", "uri:help"]) == - set(help_links)) + help_links = [x.attrib["href"] for x in mock_feed if x.attrib["rel"] == "help"] + assert set(["mailto:help@me", "http://help/", "uri:help"]) == set(help_links) # There are two navigation links. - navigation_links = [x for x in mock_feed if x.attrib['rel'] == 'related'] + navigation_links = [x for x in mock_feed if x.attrib["rel"] == "related"] assert set(["navigation"]) == set([x.attrib["role"] for x in navigation_links]) - assert (set(["http://example.com/1", "http://example.com/2"]) == - set([x.attrib["href"] for x in navigation_links])) + assert set(["http://example.com/1", "http://example.com/2"]) == set( + [x.attrib["href"] for x in navigation_links] + ) assert set(["one", "two"]) == set([x.attrib["title"] for x in navigation_links]) def test_top_level_title(self): @@ -411,15 +370,11 @@ def test_top_level_title(self): def test_group_uri_with_flattened_lane(self): spanish_lane = self._lane(display_name="Spanish", languages=["spa"]) - flat_spanish_lane = dict({ - "lane": spanish_lane, - "label": "All Spanish", - "link_to_list_feed": True - }) + flat_spanish_lane = dict( + {"lane": spanish_lane, "label": "All Spanish", "link_to_list_feed": True} + ) spanish_work = self._work( - title="Spanish Book", - with_license_pool=True, - language="spa" + title="Spanish Book", with_license_pool=True, language="spa" ) lp = spanish_work.license_pools[0] self.annotator.lanes_by_work[spanish_work].append(flat_spanish_lane) @@ -454,15 +409,21 @@ def test_lane_url(self): assert groups_url == self.annotator.groups_url(fantasy_lane_with_sublanes) groups_url = self.annotator.lane_url(fantasy_lane_with_sublanes, facets=facets) - assert groups_url == self.annotator.groups_url(fantasy_lane_with_sublanes, facets=facets) + assert groups_url == self.annotator.groups_url( + fantasy_lane_with_sublanes, facets=facets + ) feed_url = self.annotator.lane_url(fantasy_lane_without_sublanes) assert feed_url == self.annotator.feed_url(fantasy_lane_without_sublanes) feed_url = self.annotator.lane_url(fantasy_lane_without_sublanes, facets=facets) - assert feed_url == self.annotator.feed_url(fantasy_lane_without_sublanes, facets=facets) + assert feed_url == self.annotator.feed_url( + fantasy_lane_without_sublanes, facets=facets + ) - def test_fulfill_link_issues_only_open_access_links_when_library_does_not_identify_patrons(self): + def test_fulfill_link_issues_only_open_access_links_when_library_does_not_identify_patrons( + self, + ): # This library doesn't identify patrons. self.annotator.identifies_patrons = False @@ -470,15 +431,12 @@ def test_fulfill_link_issues_only_open_access_links_when_library_does_not_identi # Because of this, normal fulfillment links are not generated. [pool] = self.work.license_pools [lpdm] = pool.delivery_mechanisms - assert (None == - self.annotator.fulfill_link(pool, None, lpdm)) + assert None == self.annotator.fulfill_link(pool, None, lpdm) # However, fulfillment links _can_ be generated with the # 'open-access' link relation. - link = self.annotator.fulfill_link( - pool, None, lpdm, OPDSFeed.OPEN_ACCESS_REL - ) - assert OPDSFeed.OPEN_ACCESS_REL == link.attrib['rel'] + link = self.annotator.fulfill_link(pool, None, lpdm, OPDSFeed.OPEN_ACCESS_REL) + assert OPDSFeed.OPEN_ACCESS_REL == link.attrib["rel"] def test_fulfill_link_includes_device_registration_tags(self): """Verify that when Adobe Vendor ID delegation is included, the @@ -501,9 +459,7 @@ def test_fulfill_link_includes_device_registration_tags(self): # The fulfill link for non-Adobe DRM does not # include the drm:licensor tag. - link = self.annotator.fulfill_link( - pool, loan, other_delivery_mechanism - ) + link = self.annotator.fulfill_link(pool, loan, other_delivery_mechanism) for child in link: assert child.tag != "{http://librarysimplified.org/terms/drm}licensor" @@ -512,20 +468,19 @@ def test_fulfill_link_includes_device_registration_tags(self): # The fulfill link for Adobe DRM includes information # on how to get an Adobe ID in the drm:licensor tag. - link = self.annotator.fulfill_link( - pool, loan, adobe_delivery_mechanism - ) + link = self.annotator.fulfill_link(pool, loan, adobe_delivery_mechanism) licensor = link[-1] - assert ("{http://librarysimplified.org/terms/drm}licensor" == - licensor.tag) + assert "{http://librarysimplified.org/terms/drm}licensor" == licensor.tag # An Adobe ID-specific identifier has been created for the patron. - [adobe_id_identifier] = [x for x in patron.credentials - if x not in old_credentials] - assert (AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER == - adobe_id_identifier.type) - assert (DataSource.INTERNAL_PROCESSING == - adobe_id_identifier.data_source.name) + [adobe_id_identifier] = [ + x for x in patron.credentials if x not in old_credentials + ] + assert ( + AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER + == adobe_id_identifier.type + ) + assert DataSource.INTERNAL_PROCESSING == adobe_id_identifier.data_source.name assert None == adobe_id_identifier.expires # The drm:licensor tag is the one we get by calling @@ -549,14 +504,14 @@ def test_adobe_id_tags_when_vendor_id_configured(self): self.initialize_adobe(library) patron_identifier = "patron identifier" [element] = self.annotator.adobe_id_tags(patron_identifier) - assert '{http://librarysimplified.org/terms/drm}licensor' == element.tag + assert "{http://librarysimplified.org/terms/drm}licensor" == element.tag - key = '{http://librarysimplified.org/terms/drm}vendor' + key = "{http://librarysimplified.org/terms/drm}vendor" assert self.adobe_vendor_id.username == element.attrib[key] [token, device_management_link] = element - assert '{http://librarysimplified.org/terms/drm}clientToken' == token.tag + assert "{http://librarysimplified.org/terms/drm}clientToken" == token.tag # token.text is a token which we can decode, since we know # the secret. token = token.text @@ -568,13 +523,14 @@ def test_adobe_id_tags_when_vendor_id_configured(self): assert (expected_url, patron_identifier) == decoded assert "link" == device_management_link.tag - assert ("http://librarysimplified.org/terms/drm/rel/devices" == - device_management_link.attrib['rel']) + assert ( + "http://librarysimplified.org/terms/drm/rel/devices" + == device_management_link.attrib["rel"] + ) expect_url = self.annotator.url_for( - 'adobe_drm_devices', library_short_name=library.short_name, - _external=True + "adobe_drm_devices", library_short_name=library.short_name, _external=True ) - assert expect_url == device_management_link.attrib['href'] + assert expect_url == device_management_link.attrib["href"] # If we call adobe_id_tags again we'll get a distinct tag # object that renders to the same XML. @@ -588,8 +544,7 @@ def test_adobe_id_tags_when_vendor_id_configured(self): # Delete one setting from the existing integration to check # this. setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, ExternalIntegration.USERNAME, library, - self.registry + self._db, ExternalIntegration.USERNAME, library, self.registry ) self._db.delete(setting) assert [] == self.annotator.adobe_id_tags("new identifier") @@ -628,7 +583,9 @@ def test_feed_url(self): # A QueryGeneratedLane. self.annotator.lane = self.contributor_lane - feed_url_contributor = self.annotator.feed_url(self.contributor_lane, dict(), dict()) + feed_url_contributor = self.annotator.feed_url( + self.contributor_lane, dict(), dict() + ) assert self.contributor_lane.ROUTE in feed_url_contributor assert self.contributor_lane.contributor_key in feed_url_contributor assert self._default_library.name in feed_url_contributor @@ -660,14 +617,18 @@ def test_facet_url(self): def test_alternate_link_is_permalink(self): work = self._work(with_open_access_download=True) works = self._db.query(Work) - annotator = LibraryAnnotator(None, self.lane, self._default_library, test_mode=True) + annotator = LibraryAnnotator( + None, self.lane, self._default_library, test_mode=True + ) pool = annotator.active_licensepool_for(work) feed = self.get_parsed_feed([work]) - [entry] = feed['entries'] - assert entry['id'] == pool.identifier.urn + [entry] = feed["entries"] + assert entry["id"] == pool.identifier.urn - [(alternate, type)] = [(x['href'], x['type']) for x in entry['links'] if x['rel'] == 'alternate'] + [(alternate, type)] = [ + (x["href"], x["type"]) for x in entry["links"] if x["rel"] == "alternate" + ] permalink, permalink_type = self.annotator.permalink_for( work, pool, pool.identifier ) @@ -677,7 +638,7 @@ def test_alternate_link_is_permalink(self): # Make sure we are using the 'permalink' controller -- we were using # 'work' and that was wrong. - assert '/host/permalink' in permalink + assert "/host/permalink" in permalink def test_annotate_work_entry(self): lane = self._lane() @@ -694,22 +655,23 @@ def test_annotate_work_entry(self): linksets = [] for auth in (True, False): annotator = LibraryAnnotator( - None, lane, self._default_library, test_mode=True, - library_identifies_patrons=auth + None, + lane, + self._default_library, + test_mode=True, + library_identifies_patrons=auth, ) feed = AcquisitionFeed(self._db, "test", "url", [], annotator) entry = feed._make_entry_xml(work, edition) - annotator.annotate_work_entry( - work, pool, edition, identifier, feed, entry - ) + annotator.annotate_work_entry(work, pool, edition, identifier, feed, entry) parsed = feedparser.parse(etree.tostring(entry)) - [entry_parsed] = parsed['entries'] - linksets.append(set([x['rel'] for x in entry_parsed['links']])) + [entry_parsed] = parsed["entries"] + linksets.append(set([x["rel"] for x in entry_parsed["links"]])) with_auth, no_auth = linksets # Some links are present no matter what. - for expect in ['alternate', 'issues', 'related']: + for expect in ["alternate", "issues", "related"]: assert expect in with_auth assert expect in no_auth @@ -717,8 +679,8 @@ def test_annotate_work_entry(self): # links -- one to borrow the book and one to annotate the # book. for expect in [ - 'http://www.w3.org/ns/oa#annotationservice', - 'http://opds-spec.org/acquisition/borrow' + "http://www.w3.org/ns/oa#annotationservice", + "http://opds-spec.org/acquisition/borrow", ]: assert expect in with_auth assert expect not in no_auth @@ -730,48 +692,53 @@ def test_annotate_work_entry(self): identifier = edition.primary_identifier annotator = LibraryAnnotator( - None, lane, self._default_library, test_mode=True, - library_identifies_patrons=True + None, + lane, + self._default_library, + test_mode=True, + library_identifies_patrons=True, ) feed = AcquisitionFeed(self._db, "test", "url", [], annotator) entry = feed._make_entry_xml(work, edition) - annotator.annotate_work_entry( - work, None, edition, identifier, feed, entry - ) + annotator.annotate_work_entry(work, None, edition, identifier, feed, entry) parsed = feedparser.parse(etree.tostring(entry)) - [entry_parsed] = parsed['entries'] - links = set([x['rel'] for x in entry_parsed['links']]) + [entry_parsed] = parsed["entries"] + links = set([x["rel"] for x in entry_parsed["links"]]) # These links are still present. - for expect in ['alternate', 'issues', 'related', 'http://www.w3.org/ns/oa#annotationservice']: + for expect in [ + "alternate", + "issues", + "related", + "http://www.w3.org/ns/oa#annotationservice", + ]: assert expect in links # But the borrow link is gone. - assert 'http://opds-spec.org/acquisition/borrow' not in links + assert "http://opds-spec.org/acquisition/borrow" not in links # There are no links to create analytics events for this title, # because the library has no analytics configured. - open_book_rel = 'http://librarysimplified.org/terms/rel/analytics/open-book' + open_book_rel = "http://librarysimplified.org/terms/rel/analytics/open-book" assert open_book_rel not in links # If analytics are configured, a link is added to # create an 'open_book' analytics event for this title. Analytics.GLOBAL_ENABLED = True entry = feed._make_entry_xml(work, edition) - annotator.annotate_work_entry( - work, None, edition, identifier, feed, entry - ) + annotator.annotate_work_entry(work, None, edition, identifier, feed, entry) parsed = feedparser.parse(etree.tostring(entry)) - [entry_parsed] = parsed['entries'] - [analytics_link] = [x['href'] for x in entry_parsed['links'] - if x['rel'] == open_book_rel] + [entry_parsed] = parsed["entries"] + [analytics_link] = [ + x["href"] for x in entry_parsed["links"] if x["rel"] == open_book_rel + ] expect = annotator.url_for( - 'track_analytics_event', + "track_analytics_event", identifier_type=identifier.type, identifier=identifier.identifier, event_type=CirculationEvent.OPEN_BOOK, library_short_name=self._default_library.short_name, - _external=True + _external=True, ) assert expect == analytics_link @@ -780,54 +747,59 @@ def test_annotate_feed(self): linksets = [] for auth in (True, False): annotator = LibraryAnnotator( - None, lane, self._default_library, test_mode=True, - library_identifies_patrons=auth + None, + lane, + self._default_library, + test_mode=True, + library_identifies_patrons=auth, ) feed = AcquisitionFeed(self._db, "test", "url", [], annotator) annotator.annotate_feed(feed, lane) parsed = feedparser.parse(str(feed)) - linksets.append([x['rel'] for x in parsed['feed']['links']]) + linksets.append([x["rel"] for x in parsed["feed"]["links"]]) with_auth, without_auth = linksets # There's always a self link, a search link, and an auth # document link. - for rel in ( - 'self', 'search', 'http://opds-spec.org/auth/document' - ): + for rel in ("self", "search", "http://opds-spec.org/auth/document"): assert rel in with_auth assert rel in without_auth # But there's only a bookshelf link and an annotation link # when patron authentication is enabled. for rel in ( - 'http://opds-spec.org/shelf', - 'http://www.w3.org/ns/oa#annotationservice' + "http://opds-spec.org/shelf", + "http://www.w3.org/ns/oa#annotationservice", ): assert rel in with_auth assert rel not in without_auth - def get_parsed_feed(self, works, lane=None, **kwargs): if not lane: lane = self._lane(display_name="Main Lane") feed = AcquisitionFeed( - self._db, "test", "url", works, - LibraryAnnotator(None, lane, self._default_library, test_mode=True, - **kwargs) + self._db, + "test", + "url", + works, + LibraryAnnotator( + None, lane, self._default_library, test_mode=True, **kwargs + ), ) return feedparser.parse(str(feed)) - def assert_link_on_entry(self, entry, link_type=None, rels=None, - partials_by_rel=None + def assert_link_on_entry( + self, entry, link_type=None, rels=None, partials_by_rel=None ): """Asserts that a link with a certain 'rel' value exists on a given feed or entry, as well as its link 'type' value and parts of its 'href' value. """ + def get_link_by_rel(rel): try: - [link] = [x for x in entry['links'] if x['rel']==rel] + [link] = [x for x in entry["links"] if x["rel"] == rel] except ValueError as e: raise AssertionError if link_type: @@ -849,7 +821,7 @@ def test_work_entry_includes_problem_reporting_link(self): work = self._work(with_open_access_download=True) feed = self.get_parsed_feed([work]) [entry] = feed.entries - expected_rel_and_partial = { 'issues' : '/report' } + expected_rel_and_partial = {"issues": "/report"} self.assert_link_on_entry(entry, partials_by_rel=expected_rel_and_partial) def test_work_entry_includes_open_access_or_borrow_link(self): @@ -864,30 +836,32 @@ def test_work_entry_includes_open_access_or_borrow_link(self): self.assert_link_on_entry(licensed_entry, rels=[OPDSFeed.BORROW_REL]) def test_language_and_audience_key_from_work(self): - work = self._work(language='eng', audience=Classifier.AUDIENCE_CHILDREN) + work = self._work(language="eng", audience=Classifier.AUDIENCE_CHILDREN) result = self.annotator.language_and_audience_key_from_work(work) - assert ('eng', 'Children') == result + assert ("eng", "Children") == result - work = self._work(language='fre', audience=Classifier.AUDIENCE_YOUNG_ADULT) + work = self._work(language="fre", audience=Classifier.AUDIENCE_YOUNG_ADULT) result = self.annotator.language_and_audience_key_from_work(work) - assert ('fre', 'All+Ages,Children,Young+Adult') == result + assert ("fre", "All+Ages,Children,Young+Adult") == result - work = self._work(language='spa', audience=Classifier.AUDIENCE_ADULT) + work = self._work(language="spa", audience=Classifier.AUDIENCE_ADULT) result = self.annotator.language_and_audience_key_from_work(work) - assert ('spa', 'Adult,Adults+Only,All+Ages,Children,Young+Adult') == result + assert ("spa", "Adult,Adults+Only,All+Ages,Children,Young+Adult") == result work = self._work(audience=Classifier.AUDIENCE_ADULTS_ONLY) result = self.annotator.language_and_audience_key_from_work(work) - assert ('eng', 'Adult,Adults+Only,All+Ages,Children,Young+Adult') == result + assert ("eng", "Adult,Adults+Only,All+Ages,Children,Young+Adult") == result work = self._work(audience=Classifier.AUDIENCE_RESEARCH) result = self.annotator.language_and_audience_key_from_work(work) - assert ('eng', 'Adult,Adults+Only,All+Ages,Children,Research,Young+Adult') == result + assert ( + "eng", + "Adult,Adults+Only,All+Ages,Children,Research,Young+Adult", + ) == result work = self._work(audience=Classifier.AUDIENCE_ALL_AGES) result = self.annotator.language_and_audience_key_from_work(work) - assert ('eng', 'All+Ages,Children') == result - + assert ("eng", "All+Ages,Children") == result def test_work_entry_includes_contributor_links(self): """ContributorLane links are added to works with contributors""" @@ -896,27 +870,28 @@ def test_work_entry_includes_contributor_links(self): feed = self.get_parsed_feed([work]) [entry] = feed.entries - expected_rel_and_partial = dict(contributor='/contributor') + expected_rel_and_partial = dict(contributor="/contributor") self.assert_link_on_entry( - entry, link_type=OPDSFeed.ACQUISITION_FEED_TYPE, + entry, + link_type=OPDSFeed.ACQUISITION_FEED_TYPE, partials_by_rel=expected_rel_and_partial, ) # When there are two authors, they each get a contributor link. - work.presentation_edition.add_contributor('Oprah', Contributor.AUTHOR_ROLE) + work.presentation_edition.add_contributor("Oprah", Contributor.AUTHOR_ROLE) work.calculate_presentation( PresentationCalculationPolicy(regenerate_opds_entries=True), - MockExternalSearchIndex() + MockExternalSearchIndex(), ) [entry] = self.get_parsed_feed([work]).entries - contributor_links = [l for l in entry.links if l.rel == 'contributor'] + contributor_links = [l for l in entry.links if l.rel == "contributor"] assert 2 == len(contributor_links) contributor_links.sort(key=lambda l: l.href) for l in contributor_links: assert l.type == OPDSFeed.ACQUISITION_FEED_TYPE - assert '/contributor' in l.href + assert "/contributor" in l.href assert contributor1.sort_name in contributor_links[0].href - assert 'Oprah' in contributor_links[1].href + assert "Oprah" in contributor_links[1].href # When there's no author, there's no contributor link. self._db.delete(work.presentation_edition.contributions[0]) @@ -924,30 +899,30 @@ def test_work_entry_includes_contributor_links(self): self._db.commit() work.calculate_presentation( PresentationCalculationPolicy(regenerate_opds_entries=True), - MockExternalSearchIndex() + MockExternalSearchIndex(), ) [entry] = self.get_parsed_feed([work]).entries - assert [] == [l for l in entry.links if l.rel=='contributor'] + assert [] == [l for l in entry.links if l.rel == "contributor"] def test_work_entry_includes_series_link(self): - """A series lane link is added to the work entry when its in a series - """ + """A series lane link is added to the work entry when its in a series""" work = self._work( - with_open_access_download=True, series='Serious Cereals Series' + with_open_access_download=True, series="Serious Cereals Series" ) feed = self.get_parsed_feed([work]) [entry] = feed.entries - expected_rel_and_partial = dict(series='/series') + expected_rel_and_partial = dict(series="/series") self.assert_link_on_entry( - entry, link_type=OPDSFeed.ACQUISITION_FEED_TYPE, - partials_by_rel=expected_rel_and_partial + entry, + link_type=OPDSFeed.ACQUISITION_FEED_TYPE, + partials_by_rel=expected_rel_and_partial, ) # When there's no series, there's no series link. work = self._work(with_open_access_download=True) feed = self.get_parsed_feed([work]) [entry] = feed.entries - assert [] == [l for l in entry.links if l.rel=='series'] + assert [] == [l for l in entry.links if l.rel == "series"] def test_work_entry_includes_recommendations_link(self): work = self._work(with_open_access_download=True) @@ -955,29 +930,33 @@ def test_work_entry_includes_recommendations_link(self): # If NoveList Select isn't configured, there's no recommendations link. feed = self.get_parsed_feed([work]) [entry] = feed.entries - assert [] == [l for l in entry.links if l.rel=='recommendations'] + assert [] == [l for l in entry.links if l.rel == "recommendations"] # There's a recommendation link when configuration is found, though! NoveListAPI.IS_CONFIGURED = None self._external_integration( ExternalIntegration.NOVELIST, - goal=ExternalIntegration.METADATA_GOAL, username='library', - password='sure', libraries=[self._default_library], + goal=ExternalIntegration.METADATA_GOAL, + username="library", + password="sure", + libraries=[self._default_library], ) feed = self.get_parsed_feed([work]) [entry] = feed.entries - expected_rel_and_partial = dict(recommendations='/recommendations') + expected_rel_and_partial = dict(recommendations="/recommendations") self.assert_link_on_entry( - entry, link_type=OPDSFeed.ACQUISITION_FEED_TYPE, - partials_by_rel=expected_rel_and_partial) + entry, + link_type=OPDSFeed.ACQUISITION_FEED_TYPE, + partials_by_rel=expected_rel_and_partial, + ) def test_work_entry_includes_annotations_link(self): work = self._work(with_open_access_download=True) identifier_str = work.license_pools[0].identifier.identifier - uri_parts = ['/annotations', identifier_str] - annotation_rel = 'http://www.w3.org/ns/oa#annotationservice' - rel_with_partials = { annotation_rel : uri_parts } + uri_parts = ["/annotations", identifier_str] + annotation_rel = "http://www.w3.org/ns/oa#annotationservice" + rel_with_partials = {annotation_rel: uri_parts} feed = self.get_parsed_feed([work]) [entry] = feed.entries @@ -985,11 +964,9 @@ def test_work_entry_includes_annotations_link(self): # If the library does not authenticate patrons, no link to the # annotation service is provided. - feed = self.get_parsed_feed( - [work], library_identifies_patrons=False - ) + feed = self.get_parsed_feed([work], library_identifies_patrons=False) [entry] = feed.entries - assert annotation_rel not in [x['rel'] for x in entry['links']] + assert annotation_rel not in [x["rel"] for x in entry["links"]] def test_active_loan_feed(self): self.initialize_adobe(self._default_library) @@ -1021,20 +998,25 @@ def test_active_loan_feed(self): # No entries in the feed... raw = str(response) feed = feedparser.parse(raw) - assert 0 == len(feed['entries']) + assert 0 == len(feed["entries"]) # ... but we have a link to the User Profile Management # Protocol endpoint... - links = feed['feed']['links'] + links = feed["feed"]["links"] [upmp_link] = [ - x for x in links - if x['rel'] == 'http://librarysimplified.org/terms/rel/user-profile' + x + for x in links + if x["rel"] == "http://librarysimplified.org/terms/rel/user-profile" ] - annotator = cls(None, None, library=patron.library, patron=patron, test_mode=True) + annotator = cls( + None, None, library=patron.library, patron=patron, test_mode=True + ) expect_url = annotator.url_for( - 'patron_profile', library_short_name=patron.library.short_name, _external=True + "patron_profile", + library_short_name=patron.library.short_name, + _external=True, ) - assert expect_url == upmp_link['href'] + assert expect_url == upmp_link["href"] # ... and we have DRM licensing information. tree = etree.fromstring(response.get_data(as_text=True)) @@ -1045,26 +1027,31 @@ def test_active_loan_feed(self): # The DRM licensing information includes the Adobe vendor ID # and the patron's patron identifier for Adobe purposes. - assert (self.adobe_vendor_id.username == - licensor.attrib['{http://librarysimplified.org/terms/drm}vendor']) + assert ( + self.adobe_vendor_id.username + == licensor.attrib["{http://librarysimplified.org/terms/drm}vendor"] + ) [client_token, device_management_link] = licensor expected = ConfigurationSetting.for_library_and_externalintegration( self._db, ExternalIntegration.USERNAME, self._default_library, self.registry ).value.upper() assert client_token.text.startswith(expected) assert adobe_patron_identifier in client_token.text - assert ("{http://www.w3.org/2005/Atom}link" == - device_management_link.tag) - assert ("http://librarysimplified.org/terms/drm/rel/devices" == - device_management_link.attrib['rel']) + assert "{http://www.w3.org/2005/Atom}link" == device_management_link.tag + assert ( + "http://librarysimplified.org/terms/drm/rel/devices" + == device_management_link.attrib["rel"] + ) # Unlike other places this tag shows up, we use the # 'scheme' attribute to explicitly state that this # tag is talking about an ACS licensing # scheme. Since we're in a and not a to a # specific book, that context would otherwise be lost. - assert ('http://librarysimplified.org/terms/drm/scheme/ACS' == - licensor.attrib['{http://librarysimplified.org/terms/drm}scheme']) + assert ( + "http://librarysimplified.org/terms/drm/scheme/ACS" + == licensor.attrib["{http://librarysimplified.org/terms/drm}scheme"] + ) # Since we're taking a round trip to and from OPDS, which only # represents times with second precision, generate the current @@ -1084,16 +1071,17 @@ def test_active_loan_feed(self): # Get the feed. feed_obj = LibraryLoanAndHoldAnnotator.active_loans_for( - None, patron, test_mode=True) + None, patron, test_mode=True + ) raw = str(feed_obj) feed = feedparser.parse(raw) # The only entries in the feed is the work currently out on loan # to this patron. - assert 2 == len(feed['entries']) - e1, e2 = sorted(feed['entries'], key=lambda x: x['title']) - assert work1.title == e1['title'] - assert work2.title == e2['title'] + assert 2 == len(feed["entries"]) + e1, e2 = sorted(feed["entries"], key=lambda x: x["title"]) + assert work1.title == e1["title"] + assert work2.title == e2["title"] # Make sure that the start and end dates from the loan are present # in an child of the acquisition link. @@ -1104,44 +1092,53 @@ def test_active_loan_feed(self): ) assert 2 == len(acquisitions) - availabilities = [ - parser._xpath1(x, "opds:availability") for x in acquisitions - ] + availabilities = [parser._xpath1(x, "opds:availability") for x in acquisitions] # One of these availability tags has 'since' but not 'until'. # The other one has both. - [no_until] = [x for x in availabilities if 'until' not in x.attrib] - assert now == dateutil.parser.parse(no_until.attrib['since']) + [no_until] = [x for x in availabilities if "until" not in x.attrib] + assert now == dateutil.parser.parse(no_until.attrib["since"]) - [has_until] = [x for x in availabilities if 'until' in x.attrib] - assert now == dateutil.parser.parse(has_until.attrib['since']) - assert tomorrow == dateutil.parser.parse(has_until.attrib['until']) + [has_until] = [x for x in availabilities if "until" in x.attrib] + assert now == dateutil.parser.parse(has_until.attrib["since"]) + assert tomorrow == dateutil.parser.parse(has_until.attrib["until"]) def test_loan_feed_includes_patron(self): patron = self._patron() - patron.username = 'bellhooks' - patron.authorization_identifier = '987654321' + patron.username = "bellhooks" + patron.authorization_identifier = "987654321" feed_obj = LibraryLoanAndHoldAnnotator.active_loans_for( - None, patron, test_mode=True) + None, patron, test_mode=True + ) raw = str(feed_obj) - feed_details = feedparser.parse(raw)['feed'] + feed_details = feedparser.parse(raw)["feed"] assert "simplified:authorizationIdentifier" in raw assert "simplified:username" in raw - assert patron.username == feed_details['simplified_patron']['simplified:username'] - assert '987654321' == feed_details['simplified_patron']['simplified:authorizationidentifier'] + assert ( + patron.username == feed_details["simplified_patron"]["simplified:username"] + ) + assert ( + "987654321" + == feed_details["simplified_patron"]["simplified:authorizationidentifier"] + ) def test_loans_feed_includes_annotations_link(self): patron = self._patron() feed_obj = LibraryLoanAndHoldAnnotator.active_loans_for( - None, patron, test_mode=True) + None, patron, test_mode=True + ) raw = str(feed_obj) - feed = feedparser.parse(raw)['feed'] - links = feed['links'] + feed = feedparser.parse(raw)["feed"] + links = feed["links"] - [annotations_link] = [x for x in links if x['rel'].lower() == "http://www.w3.org/ns/oa#annotationService".lower()] - assert '/annotations' in annotations_link['href'] + [annotations_link] = [ + x + for x in links + if x["rel"].lower() == "http://www.w3.org/ns/oa#annotationService".lower() + ] + assert "/annotations" in annotations_link["href"] def test_active_loan_feed_ignores_inconsistent_local_data(self): patron = self._patron() @@ -1159,10 +1156,11 @@ def test_active_loan_feed_ignores_inconsistent_local_data(self): # We can still get a feed... feed_obj = LibraryLoanAndHoldAnnotator.active_loans_for( - None, patron, test_mode=True) + None, patron, test_mode=True + ) # ...but it's empty. - assert '' not in str(feed_obj) + assert "" not in str(feed_obj) def test_acquisition_feed_includes_license_information(self): work = self._work(with_open_access_download=True) @@ -1175,9 +1173,7 @@ def test_acquisition_feed_includes_license_information(self): pool.licenses_available = 50 pool.patrons_in_hold_queue = 25 - feed = AcquisitionFeed( - self._db, "title", "url", [work], self.annotator - ) + feed = AcquisitionFeed(self._db, "title", "url", [work], self.annotator) u = str(feed) holds_re = re.compile('', re.S) assert holds_re.search(u) is not None @@ -1196,46 +1192,58 @@ def test_loans_feed_includes_fulfill_links(self): pool.open_access = False mech1 = pool.delivery_mechanisms[0] mech2 = pool.set_delivery_mechanism( - Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM, - RightsStatus.IN_COPYRIGHT, None + Representation.PDF_MEDIA_TYPE, + DeliveryMechanism.ADOBE_DRM, + RightsStatus.IN_COPYRIGHT, + None, ) streaming_mech = pool.set_delivery_mechanism( - DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, DeliveryMechanism.OVERDRIVE_DRM, - RightsStatus.IN_COPYRIGHT, None + DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, + DeliveryMechanism.OVERDRIVE_DRM, + RightsStatus.IN_COPYRIGHT, + None, ) now = utc_now() loan, ignore = pool.loan_to(patron, start=now) feed_obj = LibraryLoanAndHoldAnnotator.active_loans_for( - None, patron, test_mode=True) + None, patron, test_mode=True + ) raw = str(feed_obj) - entries = feedparser.parse(raw)['entries'] + entries = feedparser.parse(raw)["entries"] assert 1 == len(entries) - links = entries[0]['links'] + links = entries[0]["links"] # Before we fulfill the loan, there are fulfill links for all three mechanisms. - fulfill_links = [link for link in links if link['rel'] == "http://opds-spec.org/acquisition"] + fulfill_links = [ + link for link in links if link["rel"] == "http://opds-spec.org/acquisition" + ] assert 3 == len(fulfill_links) - assert (set([mech1.delivery_mechanism.drm_scheme_media_type, mech2.delivery_mechanism.drm_scheme_media_type, - OPDSFeed.ENTRY_TYPE]) == - set([link['type'] for link in fulfill_links])) + assert ( + set( + [ + mech1.delivery_mechanism.drm_scheme_media_type, + mech2.delivery_mechanism.drm_scheme_media_type, + OPDSFeed.ENTRY_TYPE, + ] + ) + == set([link["type"] for link in fulfill_links]) + ) # If one of the content types is hidden, the corresponding # delivery mechanism does not have a link. - setting = self._default_library.setting( - Configuration.HIDDEN_CONTENT_TYPES - ) + setting = self._default_library.setting(Configuration.HIDDEN_CONTENT_TYPES) setting.value = json.dumps([mech1.delivery_mechanism.content_type]) feed_obj = LibraryLoanAndHoldAnnotator.active_loans_for( None, patron, test_mode=True ) - assert (set([mech2.delivery_mechanism.drm_scheme_media_type, - OPDSFeed.ENTRY_TYPE]) == - set([link['type'] for link in fulfill_links])) + assert set( + [mech2.delivery_mechanism.drm_scheme_media_type, OPDSFeed.ENTRY_TYPE] + ) == set([link["type"] for link in fulfill_links]) setting.value = None # When the loan is fulfilled, there are only fulfill links for that mechanism @@ -1243,22 +1251,27 @@ def test_loans_feed_includes_fulfill_links(self): loan.fulfillment = mech1 feed_obj = LibraryLoanAndHoldAnnotator.active_loans_for( - None, patron, test_mode=True) + None, patron, test_mode=True + ) raw = str(feed_obj) - entries = feedparser.parse(raw)['entries'] + entries = feedparser.parse(raw)["entries"] assert 1 == len(entries) - links = entries[0]['links'] + links = entries[0]["links"] - fulfill_links = [link for link in links if link['rel'] == "http://opds-spec.org/acquisition"] + fulfill_links = [ + link for link in links if link["rel"] == "http://opds-spec.org/acquisition" + ] assert 2 == len(fulfill_links) - assert (set([mech1.delivery_mechanism.drm_scheme_media_type, - OPDSFeed.ENTRY_TYPE]) == - set([link['type'] for link in fulfill_links])) + assert set( + [mech1.delivery_mechanism.drm_scheme_media_type, OPDSFeed.ENTRY_TYPE] + ) == set([link["type"] for link in fulfill_links]) - def test_incomplete_catalog_entry_contains_an_alternate_link_to_the_complete_entry(self): + def test_incomplete_catalog_entry_contains_an_alternate_link_to_the_complete_entry( + self, + ): circulation = create_autospec(spec=CirculationAPI) circulation.library = self._default_library work = self._work(with_license_pool=True, with_open_access_download=False) @@ -1269,13 +1282,17 @@ def test_incomplete_catalog_entry_contains_an_alternate_link_to_the_complete_ent ) raw = str(feed_obj) - entries = feedparser.parse(raw)['entries'] + entries = feedparser.parse(raw)["entries"] assert 1 == len(entries) - links = entries[0]['links'] + links = entries[0]["links"] # We want to make sure that an incomplete catalog entry contains an alternate link to the complete entry. - alternate_links = [link for link in links if link['type'] == OPDSFeed.ENTRY_TYPE and link['rel'] == 'alternate'] + alternate_links = [ + link + for link in links + if link["type"] == OPDSFeed.ENTRY_TYPE and link["rel"] == "alternate" + ] assert 1 == len(alternate_links) def test_complete_catalog_entry_with_fulfillment_link_contains_self_link(self): @@ -1291,22 +1308,30 @@ def test_complete_catalog_entry_with_fulfillment_link_contains_self_link(self): ) raw = str(feed_obj) - entries = feedparser.parse(raw)['entries'] + entries = feedparser.parse(raw)["entries"] assert 1 == len(entries) - links = entries[0]['links'] + links = entries[0]["links"] # We want to make sure that a complete catalog entry contains an alternate link # because it's required by some clients (for example, an Android version of SimplyE). - alternate_links = [link for link in links if link['type'] == OPDSFeed.ENTRY_TYPE and link['rel'] == 'alternate'] + alternate_links = [ + link + for link in links + if link["type"] == OPDSFeed.ENTRY_TYPE and link["rel"] == "alternate" + ] assert 1 == len(alternate_links) # We want to make sure that the complete catalog entry contains a self link. - self_links = [link for link in links if link['type'] == OPDSFeed.ENTRY_TYPE and link['rel'] == 'self'] + self_links = [ + link + for link in links + if link["type"] == OPDSFeed.ENTRY_TYPE and link["rel"] == "self" + ] assert 1 == len(self_links) # We want to make sure that alternate and self links are the same. - assert alternate_links[0]['href'] == self_links[0]['href'] + assert alternate_links[0]["href"] == self_links[0]["href"] def test_complete_catalog_entry_with_fulfillment_info_contains_self_link(self): patron = self._patron() @@ -1323,7 +1348,7 @@ def test_complete_catalog_entry_with_fulfillment_info_contains_self_link(self): "http://link", Representation.EPUB_MEDIA_TYPE, None, - None + None, ) feed_obj = LibraryLoanAndHoldAnnotator.single_item_feed( @@ -1331,22 +1356,30 @@ def test_complete_catalog_entry_with_fulfillment_info_contains_self_link(self): ) raw = str(feed_obj) - entries = feedparser.parse(raw)['entries'] + entries = feedparser.parse(raw)["entries"] assert 1 == len(entries) - links = entries[0]['links'] + links = entries[0]["links"] # We want to make sure that a complete catalog entry contains an alternate link # because it's required by some clients (for example, an Android version of SimplyE). - alternate_links = [link for link in links if link['type'] == OPDSFeed.ENTRY_TYPE and link['rel'] == 'alternate'] + alternate_links = [ + link + for link in links + if link["type"] == OPDSFeed.ENTRY_TYPE and link["rel"] == "alternate" + ] assert 1 == len(alternate_links) # We want to make sure that the complete catalog entry contains a self link. - self_links = [link for link in links if link['type'] == OPDSFeed.ENTRY_TYPE and link['rel'] == 'self'] + self_links = [ + link + for link in links + if link["type"] == OPDSFeed.ENTRY_TYPE and link["rel"] == "self" + ] assert 1 == len(self_links) # We want to make sure that alternate and self links are the same. - assert alternate_links[0]['href'] == self_links[0]['href'] + assert alternate_links[0]["href"] == self_links[0]["href"] def test_fulfill_feed(self): patron = self._patron() @@ -1355,36 +1388,46 @@ def test_fulfill_feed(self): pool = work.license_pools[0] pool.open_access = False streaming_mech = pool.set_delivery_mechanism( - DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, DeliveryMechanism.OVERDRIVE_DRM, - RightsStatus.IN_COPYRIGHT, None + DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, + DeliveryMechanism.OVERDRIVE_DRM, + RightsStatus.IN_COPYRIGHT, + None, ) now = utc_now() loan, ignore = pool.loan_to(patron, start=now) fulfillment = FulfillmentInfo( - pool.collection, pool.data_source.name, - pool.identifier.type, pool.identifier.identifier, + pool.collection, + pool.data_source.name, + pool.identifier.type, + pool.identifier.identifier, "http://streaming_link", Representation.TEXT_HTML_MEDIA_TYPE + DeliveryMechanism.STREAMING_PROFILE, - None, None) + None, + None, + ) response = LibraryLoanAndHoldAnnotator.single_item_feed( None, loan, fulfillment, test_mode=True ) raw = response.get_data(as_text=True) - entries = feedparser.parse(raw)['entries'] + entries = feedparser.parse(raw)["entries"] assert 1 == len(entries) - links = entries[0]['links'] + links = entries[0]["links"] # The feed for a single fulfillment only includes one fulfill link. - fulfill_links = [link for link in links if link['rel'] == "http://opds-spec.org/acquisition"] + fulfill_links = [ + link for link in links if link["rel"] == "http://opds-spec.org/acquisition" + ] assert 1 == len(fulfill_links) - assert (Representation.TEXT_HTML_MEDIA_TYPE + DeliveryMechanism.STREAMING_PROFILE == - fulfill_links[0]['type']) - assert "http://streaming_link" == fulfill_links[0]['href'] + assert ( + Representation.TEXT_HTML_MEDIA_TYPE + DeliveryMechanism.STREAMING_PROFILE + == fulfill_links[0]["type"] + ) + assert "http://streaming_link" == fulfill_links[0]["href"] def test_drm_device_registration_feed_tags(self): """Check that drm_device_registration_feed_tags returns @@ -1392,15 +1435,18 @@ def test_drm_device_registration_feed_tags(self): set. """ self.initialize_adobe(self._default_library) - annotator = LibraryLoanAndHoldAnnotator(None, None, self._default_library, test_mode=True) + annotator = LibraryLoanAndHoldAnnotator( + None, None, self._default_library, test_mode=True + ) patron = self._patron() [feed_tag] = annotator.drm_device_registration_feed_tags(patron) [generic_tag] = annotator.adobe_id_tags(patron) # The feed-level tag has the drm:scheme attribute set. - key = '{http://librarysimplified.org/terms/drm}scheme' - assert ("http://librarysimplified.org/terms/drm/scheme/ACS" == - feed_tag.attrib[key]) + key = "{http://librarysimplified.org/terms/drm}scheme" + assert ( + "http://librarysimplified.org/terms/drm/scheme/ACS" == feed_tag.attrib[key] + ) # If we remove that attribute, the feed-level tag is the same as the # generic tag. @@ -1410,35 +1456,37 @@ def test_drm_device_registration_feed_tags(self): def test_borrow_link_raises_unfulfillable_work(self): edition, pool = self._edition(with_license_pool=True) kindle_mechanism = pool.set_delivery_mechanism( - DeliveryMechanism.KINDLE_CONTENT_TYPE, DeliveryMechanism.KINDLE_DRM, - RightsStatus.IN_COPYRIGHT, None) + DeliveryMechanism.KINDLE_CONTENT_TYPE, + DeliveryMechanism.KINDLE_DRM, + RightsStatus.IN_COPYRIGHT, + None, + ) epub_mechanism = pool.set_delivery_mechanism( - Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM, - RightsStatus.IN_COPYRIGHT, None) + Representation.EPUB_MEDIA_TYPE, + DeliveryMechanism.ADOBE_DRM, + RightsStatus.IN_COPYRIGHT, + None, + ) data_source_name = pool.data_source.name identifier = pool.identifier - annotator = LibraryLoanAndHoldAnnotator(None, None, self._default_library, test_mode=True) + annotator = LibraryLoanAndHoldAnnotator( + None, None, self._default_library, test_mode=True + ) # If there's no way to fulfill the book, borrow_link raises # UnfulfillableWork. - pytest.raises( - UnfulfillableWork, - annotator.borrow_link, - pool, None, []) + pytest.raises(UnfulfillableWork, annotator.borrow_link, pool, None, []) pytest.raises( - UnfulfillableWork, - annotator.borrow_link, - pool, None, [kindle_mechanism]) + UnfulfillableWork, annotator.borrow_link, pool, None, [kindle_mechanism] + ) # If there's a fulfillable mechanism, everything's fine. link = annotator.borrow_link(pool, None, [epub_mechanism]) assert link != None - link = annotator.borrow_link( - pool, None, [epub_mechanism, kindle_mechanism] - ) + link = annotator.borrow_link(pool, None, [epub_mechanism, kindle_mechanism]) assert link != None def test_feed_includes_lane_links(self): @@ -1449,12 +1497,12 @@ def annotated_links(lane, annotator): feed = AcquisitionFeed(self._db, "test", "url", [], annotator) annotator.annotate_feed(feed, lane) raw = str(feed) - parsed = feedparser.parse(raw)['feed'] - links = parsed['links'] + parsed = feedparser.parse(raw)["feed"] + links = parsed["links"] d = defaultdict(list) for link in links: - d[link['rel'].lower()].append(link['href']) + d[link["rel"].lower()].append(link["href"]) return d # When an EntryPoint is explicitly selected, it shows up in the @@ -1464,17 +1512,17 @@ def annotated_links(lane, annotator): annotator = LibraryAnnotator( None, lane, self._default_library, test_mode=True, facets=facets ) - [url] = annotated_links(lane, annotator)['search'] - assert '/lane_search' in url - assert 'entrypoint=%s' % AudiobooksEntryPoint.INTERNAL_NAME in url + [url] = annotated_links(lane, annotator)["search"] + assert "/lane_search" in url + assert "entrypoint=%s" % AudiobooksEntryPoint.INTERNAL_NAME in url assert str(lane.id) in url # When the selected EntryPoint is a default, it's not used -- # instead, we search everything. annotator.facets.entrypoint_is_default = True links = annotated_links(lane, annotator) - [url] = links['search'] - assert 'entrypoint=%s' % EverythingEntryPoint.INTERNAL_NAME in url + [url] = links["search"] + assert "entrypoint=%s" % EverythingEntryPoint.INTERNAL_NAME in url # This lane isn't based on a custom list, so there's no crawlable link. assert [] == links["http://opds-spec.org/crawlable"] @@ -1489,15 +1537,15 @@ def annotated_links(lane, annotator): # A lane based on a single list gets a crawlable link. lane.customlists = [list1] links = annotated_links(lane, annotator) - [crawlable] = links['http://opds-spec.org/crawlable'] - assert '/crawlable_list_feed' in crawlable + [crawlable] = links["http://opds-spec.org/crawlable"] + assert "/crawlable_list_feed" in crawlable assert str(list1.name) in crawlable def test_acquisition_links(self): - annotator = LibraryLoanAndHoldAnnotator(None, None, self._default_library, test_mode=True) - feed = AcquisitionFeed( - self._db, "test", "url", [], annotator + annotator = LibraryLoanAndHoldAnnotator( + None, None, self._default_library, test_mode=True ) + feed = AcquisitionFeed(self._db, "test", "url", [], annotator) patron = self._patron() @@ -1514,39 +1562,47 @@ def test_acquisition_links(self): # Hold on a licensed book. work3 = self._work(with_license_pool=True) - hold, ignore = work3.license_pools[0].on_hold_to(patron, start=now, end=tomorrow) + hold, ignore = work3.license_pools[0].on_hold_to( + patron, start=now, end=tomorrow + ) # Book with no loans or holds yet. work4 = self._work(with_license_pool=True) loan1_links = annotator.acquisition_links( - loan1.license_pool, loan1, None, None, feed, loan1.license_pool.identifier) + loan1.license_pool, loan1, None, None, feed, loan1.license_pool.identifier + ) # Fulfill, open access, and revoke. - [revoke, fulfill, open_access] = sorted(loan1_links, key=lambda x: x.attrib.get("rel")) - assert 'revoke_loan_or_hold' in revoke.attrib.get("href") - assert 'http://librarysimplified.org/terms/rel/revoke' == revoke.attrib.get("rel") + [revoke, fulfill, open_access] = sorted( + loan1_links, key=lambda x: x.attrib.get("rel") + ) + assert "revoke_loan_or_hold" in revoke.attrib.get("href") + assert "http://librarysimplified.org/terms/rel/revoke" == revoke.attrib.get( + "rel" + ) assert "fulfill" in fulfill.attrib.get("href") - assert 'http://opds-spec.org/acquisition' == fulfill.attrib.get("rel") - assert 'fulfill' in open_access.attrib.get("href") - assert 'http://opds-spec.org/acquisition/open-access' == open_access.attrib.get("rel") + assert "http://opds-spec.org/acquisition" == fulfill.attrib.get("rel") + assert "fulfill" in open_access.attrib.get("href") + assert "http://opds-spec.org/acquisition/open-access" == open_access.attrib.get( + "rel" + ) loan2_links = annotator.acquisition_links( - loan2.license_pool, loan2, None, None, feed, - loan2.license_pool.identifier + loan2.license_pool, loan2, None, None, feed, loan2.license_pool.identifier ) # Fulfill and revoke. [revoke, fulfill] = sorted(loan2_links, key=lambda x: x.attrib.get("rel")) - assert 'revoke_loan_or_hold' in revoke.attrib.get("href") - assert 'http://librarysimplified.org/terms/rel/revoke' == revoke.attrib.get("rel") + assert "revoke_loan_or_hold" in revoke.attrib.get("href") + assert "http://librarysimplified.org/terms/rel/revoke" == revoke.attrib.get( + "rel" + ) assert "fulfill" in fulfill.attrib.get("href") - assert 'http://opds-spec.org/acquisition' == fulfill.attrib.get("rel") + assert "http://opds-spec.org/acquisition" == fulfill.attrib.get("rel") # If a book is ready to be fulfilled, but the library has # hidden all of its available content types, the fulfill link does # not show up -- only the revoke link. - hidden = self._default_library.setting( - Configuration.HIDDEN_CONTENT_TYPES - ) + hidden = self._default_library.setting(Configuration.HIDDEN_CONTENT_TYPES) available_types = [ lpdm.delivery_mechanism.content_type for lpdm in loan2.license_pool.delivery_mechanisms @@ -1559,38 +1615,50 @@ def test_acquisition_links(self): None, None, self._default_library, test_mode=True ) loan2_links = annotator_with_hidden_types.acquisition_links( - loan2.license_pool, loan2, None, None, feed, - loan2.license_pool.identifier + loan2.license_pool, loan2, None, None, feed, loan2.license_pool.identifier ) [revoke] = loan2_links - assert ('http://librarysimplified.org/terms/rel/revoke' == - revoke.attrib.get("rel")) + assert "http://librarysimplified.org/terms/rel/revoke" == revoke.attrib.get( + "rel" + ) # Un-hide the content types so the test can continue. hidden.value = None hold_links = annotator.acquisition_links( - hold.license_pool, None, hold, None, feed, hold.license_pool.identifier) + hold.license_pool, None, hold, None, feed, hold.license_pool.identifier + ) # Borrow and revoke. [revoke, borrow] = sorted(hold_links, key=lambda x: x.attrib.get("rel")) - assert 'revoke_loan_or_hold' in revoke.attrib.get("href") - assert 'http://librarysimplified.org/terms/rel/revoke' == revoke.attrib.get("rel") + assert "revoke_loan_or_hold" in revoke.attrib.get("href") + assert "http://librarysimplified.org/terms/rel/revoke" == revoke.attrib.get( + "rel" + ) assert "borrow" in borrow.attrib.get("href") - assert 'http://opds-spec.org/acquisition/borrow' == borrow.attrib.get("rel") + assert "http://opds-spec.org/acquisition/borrow" == borrow.attrib.get("rel") work4_links = annotator.acquisition_links( - work4.license_pools[0], None, None, None, feed, work4.license_pools[0].identifier) + work4.license_pools[0], + None, + None, + None, + feed, + work4.license_pools[0].identifier, + ) # Borrow only. [borrow] = work4_links assert "borrow" in borrow.attrib.get("href") - assert 'http://opds-spec.org/acquisition/borrow' == borrow.attrib.get("rel") + assert "http://opds-spec.org/acquisition/borrow" == borrow.attrib.get("rel") # If patron authentication is turned off for the library, then # only open-access links are displayed. annotator.identifies_patrons = False [open_access] = annotator.acquisition_links( - loan1.license_pool, loan1, None, None, feed, loan1.license_pool.identifier) - assert 'http://opds-spec.org/acquisition/open-access' == open_access.attrib.get("rel") + loan1.license_pool, loan1, None, None, feed, loan1.license_pool.identifier + ) + assert "http://opds-spec.org/acquisition/open-access" == open_access.attrib.get( + "rel" + ) # This may include links with the open-access relation for # non-open-access works that are available without @@ -1601,40 +1669,51 @@ def test_acquisition_links(self): [lpdm4] = lp4.delivery_mechanisms lpdm4.set_rights_status(RightsStatus.IN_COPYRIGHT) [not_open_access] = annotator.acquisition_links( - lp4, None, None, None, feed, lp4.identifier, - direct_fulfillment_delivery_mechanisms=[lpdm4] + lp4, + None, + None, + None, + feed, + lp4.identifier, + direct_fulfillment_delivery_mechanisms=[lpdm4], ) # The link relation is OPDS 'open-access', which just means the # book can be downloaded with no hassle. - assert 'http://opds-spec.org/acquisition/open-access' == not_open_access.attrib.get("rel") + assert ( + "http://opds-spec.org/acquisition/open-access" + == not_open_access.attrib.get("rel") + ) # The dcterms:rights attribute provides a more detailed # explanation of the book's copyright status -- note that it's # not "open access" in the typical sense. - rights = not_open_access.attrib['{http://purl.org/dc/terms/}rights'] + rights = not_open_access.attrib["{http://purl.org/dc/terms/}rights"] assert RightsStatus.IN_COPYRIGHT == rights # Hold links are absent even when there are active holds in the # database -- there is no way to distinguish one patron from # another so the concept of a 'hold' is meaningless. hold_links = annotator.acquisition_links( - hold.license_pool, None, hold, None, feed, hold.license_pool.identifier) + hold.license_pool, None, hold, None, feed, hold.license_pool.identifier + ) assert [] == hold_links def test_acquisition_links_multiple_links(self): - annotator = LibraryLoanAndHoldAnnotator(None, None, self._default_library, test_mode=True) - feed = AcquisitionFeed( - self._db, "test", "url", [], annotator + annotator = LibraryLoanAndHoldAnnotator( + None, None, self._default_library, test_mode=True ) + feed = AcquisitionFeed(self._db, "test", "url", [], annotator) # This book has two delivery mechanisms work = self._work(with_license_pool=True) [pool] = work.license_pools [mech1] = pool.delivery_mechanisms mech2 = pool.set_delivery_mechanism( - Representation.PDF_MEDIA_TYPE, DeliveryMechanism.NO_DRM, - RightsStatus.IN_COPYRIGHT, None + Representation.PDF_MEDIA_TYPE, + DeliveryMechanism.NO_DRM, + RightsStatus.IN_COPYRIGHT, + None, ) # The vendor API for LicensePools of this type requires that a @@ -1645,8 +1724,7 @@ class MockAPI(object): # This means that two different acquisition links will be # generated -- one for each delivery mechanism. links = annotator.acquisition_links( - pool, None, None, None, feed, pool.identifier, - mock_api=MockAPI() + pool, None, None, None, feed, pool.identifier, mock_api=MockAPI() ) assert 2 == len(links) @@ -1656,7 +1734,7 @@ class MockAPI(object): # Instead of sorting, which may be wrong if the id is greater than 10 # due to how double digits are sorted, extract the links associated # with the expected delivery mechanism. - if mech1_param in links[0].attrib['href']: + if mech1_param in links[0].attrib["href"]: [mech1_link, mech2_link] = links else: [mech2_link, mech1_link] = links @@ -1664,18 +1742,16 @@ class MockAPI(object): indirects = [] for link in [mech1_link, mech2_link]: # Both links should have the same subtags. - [availability, copies, holds, indirect] = sorted( - link, key=lambda x: x.tag - ) - assert availability.tag.endswith('availability') - assert copies.tag.endswith('copies') - assert holds.tag.endswith('holds') - assert indirect.tag.endswith('indirectAcquisition') + [availability, copies, holds, indirect] = sorted(link, key=lambda x: x.tag) + assert availability.tag.endswith("availability") + assert copies.tag.endswith("copies") + assert holds.tag.endswith("holds") + assert indirect.tag.endswith("indirectAcquisition") indirects.append(indirect) # The target of the top-level link is different. - assert mech1_param in mech1_link.attrib['href'] - assert mech2_param in mech2_link.attrib['href'] + assert mech1_param in mech1_link.attrib["href"] + assert mech2_param in mech2_link.attrib["href"] # So is the media type seen in the indirectAcquisition subtag. [mech1_indirect, mech2_indirect] = indirects @@ -1683,11 +1759,11 @@ class MockAPI(object): # The first delivery mechanism (created when the Work was created) # uses Adobe DRM, so that shows up as the first indirect acquisition # type. - assert mech1.delivery_mechanism.drm_scheme == mech1_indirect.attrib['type'] + assert mech1.delivery_mechanism.drm_scheme == mech1_indirect.attrib["type"] # The second delivery mechanism doesn't use DRM, so the content # type shows up as the first (and only) indirect acquisition type. - assert mech2.delivery_mechanism.content_type == mech2_indirect.attrib['type'] + assert mech2.delivery_mechanism.content_type == mech2_indirect.attrib["type"] # If we configure the library to hide one of the content types, # we end up with only one link -- the one for the delivery @@ -1699,22 +1775,17 @@ class MockAPI(object): None, None, self._default_library, test_mode=True ) [link] = annotator.acquisition_links( - pool, None, None, None, feed, pool.identifier, - mock_api=MockAPI() + pool, None, None, None, feed, pool.identifier, mock_api=MockAPI() ) - [availability, copies, holds, indirect] = sorted( - link, key=lambda x: x.tag - ) - assert mech2.delivery_mechanism.content_type == indirect.attrib['type'] + [availability, copies, holds, indirect] = sorted(link, key=lambda x: x.tag) + assert mech2.delivery_mechanism.content_type == indirect.attrib["type"] class TestLibraryLoanAndHoldAnnotator(DatabaseTest): - def test_single_item_feed(self): # Test the generation of single-item OPDS feeds for loans (with and # without fulfillment) and holds. class MockAnnotator(LibraryLoanAndHoldAnnotator): - def url_for(self, controller, **kwargs): self.url_for_called_with = (controller, kwargs) return "a URL" @@ -1731,8 +1802,7 @@ def test_annotator(item, fulfillment=None): test_mode = object() feed_class = object() result = MockAnnotator.single_item_feed( - circulation, item, fulfillment, test_mode, feed_class, - extra_arg="value" + circulation, item, fulfillment, test_mode, feed_class, extra_arg="value" ) # The final result is a MockAnnotator object. This isn't @@ -1754,20 +1824,17 @@ def test_annotator(item, fulfillment=None): url_call = result.url_for_called_with controller_name, kwargs = url_call assert "loan_or_hold_detail" == controller_name - assert (self._default_library.short_name == - kwargs.pop('library_short_name')) - assert pool.identifier.type == kwargs.pop('identifier_type') - assert pool.identifier.identifier == kwargs.pop('identifier') - assert True == kwargs.pop('_external') + assert self._default_library.short_name == kwargs.pop("library_short_name") + assert pool.identifier.type == kwargs.pop("identifier_type") + assert pool.identifier.identifier == kwargs.pop("identifier") + assert True == kwargs.pop("_external") assert {} == kwargs # The return value of that was the string "a URL". We then # passed that into _single_entry_response, along with # `item` and a number of arguments that we made up. response_call = result._single_entry_response_called_with - (_db, _work, annotator, url, _feed_class), kwargs = ( - response_call - ) + (_db, _work, annotator, url, _feed_class), kwargs = response_call assert self._db == _db assert work == _work assert result == annotator @@ -1776,7 +1843,7 @@ def test_annotator(item, fulfillment=None): # The only keyword argument is an extra argument propagated from # the single_item_feed call. - assert 'value' == kwargs.pop('extra_arg') + assert "value" == kwargs.pop("extra_arg") # Return the MockAnnotator for further examination. return result @@ -1794,7 +1861,7 @@ def test_annotator(item, fulfillment=None): # Everything tested by test_annotator happened, but _also_, # when the annotator was created, the Loan was stored in # active_loans_by_work. - assert {work : loan} == annotator.active_loans_by_work + assert {work: loan} == annotator.active_loans_by_work # Since we passed in a loan rather than a hold, # active_holds_by_work is empty. @@ -1807,13 +1874,13 @@ def test_annotator(item, fulfillment=None): # Now try it again, but give the loan a fulfillment. fulfillment = object() annotator = test_annotator(loan, fulfillment) - assert {work : loan} == annotator.active_loans_by_work - assert {work : fulfillment} == annotator.active_fulfillments_by_work + assert {work: loan} == annotator.active_loans_by_work + assert {work: fulfillment} == annotator.active_fulfillments_by_work # Finally, try it with a hold. hold, ignore = pool.on_hold_to(patron) annotator = test_annotator(hold) - assert {work : hold} == annotator.active_holds_by_work + assert {work: hold} == annotator.active_holds_by_work assert {} == annotator.active_loans_by_work assert {} == annotator.active_fulfillments_by_work @@ -1825,7 +1892,9 @@ def setup_method(self): self.collection = self._collection() self.lane = self._lane(display_name="Fantasy") self.annotator = SharedCollectionAnnotator( - self.collection, self.lane, test_mode=True, + self.collection, + self.lane, + test_mode=True, ) def test_top_level_title(self): @@ -1844,21 +1913,25 @@ def get_parsed_feed(self, works, lane=None): if not lane: lane = self._lane(display_name="Main Lane") feed = AcquisitionFeed( - self._db, "test", "url", works, - SharedCollectionAnnotator(self.collection, lane, test_mode=True) + self._db, + "test", + "url", + works, + SharedCollectionAnnotator(self.collection, lane, test_mode=True), ) return feedparser.parse(str(feed)) - def assert_link_on_entry(self, entry, link_type=None, rels=None, - partials_by_rel=None + def assert_link_on_entry( + self, entry, link_type=None, rels=None, partials_by_rel=None ): """Asserts that a link with a certain 'rel' value exists on a given feed or entry, as well as its link 'type' value and parts of its 'href' value. """ + def get_link_by_rel(rel, should_exist=True): try: - [link] = [x for x in entry['links'] if x['rel']==rel] + [link] = [x for x in entry["links"] if x["rel"] == rel] except ValueError as e: raise AssertionError if link_type: @@ -1883,7 +1956,7 @@ def test_work_entry_includes_updated(self): feed = self.get_parsed_feed([work]) [entry] = feed.entries - assert '2018-02-04' in entry.get("updated") + assert "2018-02-04" in entry.get("updated") def test_work_entry_includes_open_access_or_borrow_link(self): open_access_work = self._work(with_open_access_download=True) @@ -1893,55 +1966,65 @@ def test_work_entry_includes_open_access_or_borrow_link(self): feed = self.get_parsed_feed([open_access_work, licensed_work]) [open_access_entry, licensed_entry] = feed.entries - self.assert_link_on_entry(open_access_entry, rels=[Hyperlink.OPEN_ACCESS_DOWNLOAD]) + self.assert_link_on_entry( + open_access_entry, rels=[Hyperlink.OPEN_ACCESS_DOWNLOAD] + ) self.assert_link_on_entry(licensed_entry, rels=[OPDSFeed.BORROW_REL]) # The open access entry shouldn't have a borrow link, and the licensed entry # shouldn't have an open access link. - links = [x for x in open_access_entry['links'] if x['rel']==OPDSFeed.BORROW_REL] + links = [ + x for x in open_access_entry["links"] if x["rel"] == OPDSFeed.BORROW_REL + ] assert 0 == len(links) - links = [x for x in licensed_entry['links'] if x['rel']==Hyperlink.OPEN_ACCESS_DOWNLOAD] + links = [ + x + for x in licensed_entry["links"] + if x["rel"] == Hyperlink.OPEN_ACCESS_DOWNLOAD + ] assert 0 == len(links) def test_borrow_link_raises_unfulfillable_work(self): edition, pool = self._edition(with_license_pool=True) kindle_mechanism = pool.set_delivery_mechanism( - DeliveryMechanism.KINDLE_CONTENT_TYPE, DeliveryMechanism.KINDLE_DRM, - RightsStatus.IN_COPYRIGHT, None) + DeliveryMechanism.KINDLE_CONTENT_TYPE, + DeliveryMechanism.KINDLE_DRM, + RightsStatus.IN_COPYRIGHT, + None, + ) epub_mechanism = pool.set_delivery_mechanism( - Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM, - RightsStatus.IN_COPYRIGHT, None) + Representation.EPUB_MEDIA_TYPE, + DeliveryMechanism.ADOBE_DRM, + RightsStatus.IN_COPYRIGHT, + None, + ) data_source_name = pool.data_source.name identifier = pool.identifier - annotator = SharedCollectionLoanAndHoldAnnotator(self.collection, None, test_mode=True) + annotator = SharedCollectionLoanAndHoldAnnotator( + self.collection, None, test_mode=True + ) # If there's no way to fulfill the book, borrow_link raises # UnfulfillableWork. - pytest.raises( - UnfulfillableWork, - annotator.borrow_link, - pool, None, []) + pytest.raises(UnfulfillableWork, annotator.borrow_link, pool, None, []) pytest.raises( - UnfulfillableWork, - annotator.borrow_link, - pool, None, [kindle_mechanism]) + UnfulfillableWork, annotator.borrow_link, pool, None, [kindle_mechanism] + ) # If there's a fulfillable mechanism, everything's fine. link = annotator.borrow_link(pool, None, [epub_mechanism]) assert link != None - link = annotator.borrow_link( - pool, None, [epub_mechanism, kindle_mechanism] - ) + link = annotator.borrow_link(pool, None, [epub_mechanism, kindle_mechanism]) assert link != None def test_acquisition_links(self): - annotator = SharedCollectionLoanAndHoldAnnotator(self.collection, None, test_mode=True) - feed = AcquisitionFeed( - self._db, "test", "url", [], annotator + annotator = SharedCollectionLoanAndHoldAnnotator( + self.collection, None, test_mode=True ) + feed = AcquisitionFeed(self._db, "test", "url", [], annotator) client = self._integration_client() @@ -1958,58 +2041,80 @@ def test_acquisition_links(self): # Hold on a licensed book. work3 = self._work(with_license_pool=True) - hold, ignore = work3.license_pools[0].on_hold_to(client, start=now, end=tomorrow) + hold, ignore = work3.license_pools[0].on_hold_to( + client, start=now, end=tomorrow + ) # Book with no loans or holds yet. work4 = self._work(with_license_pool=True) loan1_links = annotator.acquisition_links( - loan1.license_pool, loan1, None, None, feed, loan1.license_pool.identifier) + loan1.license_pool, loan1, None, None, feed, loan1.license_pool.identifier + ) # Fulfill, open access, revoke, and loan info. - [revoke, fulfill, open_access, info] = sorted(loan1_links, key=lambda x: x.attrib.get("rel")) - assert 'shared_collection_revoke_loan' in revoke.attrib.get("href") - assert 'http://librarysimplified.org/terms/rel/revoke' == revoke.attrib.get("rel") + [revoke, fulfill, open_access, info] = sorted( + loan1_links, key=lambda x: x.attrib.get("rel") + ) + assert "shared_collection_revoke_loan" in revoke.attrib.get("href") + assert "http://librarysimplified.org/terms/rel/revoke" == revoke.attrib.get( + "rel" + ) assert "shared_collection_fulfill" in fulfill.attrib.get("href") - assert 'http://opds-spec.org/acquisition' == fulfill.attrib.get("rel") - assert work1.license_pools[0].delivery_mechanisms[0].resource.representation.mirror_url == open_access.attrib.get("href") - assert 'http://opds-spec.org/acquisition/open-access' == open_access.attrib.get("rel") - assert 'shared_collection_loan_info' in info.attrib.get("href") + assert "http://opds-spec.org/acquisition" == fulfill.attrib.get("rel") + assert work1.license_pools[0].delivery_mechanisms[ + 0 + ].resource.representation.mirror_url == open_access.attrib.get("href") + assert "http://opds-spec.org/acquisition/open-access" == open_access.attrib.get( + "rel" + ) + assert "shared_collection_loan_info" in info.attrib.get("href") assert "self" == info.attrib.get("rel") loan2_links = annotator.acquisition_links( - loan2.license_pool, loan2, None, None, feed, loan2.license_pool.identifier) + loan2.license_pool, loan2, None, None, feed, loan2.license_pool.identifier + ) # Fulfill, revoke, and loan info. [revoke, fulfill, info] = sorted(loan2_links, key=lambda x: x.attrib.get("rel")) - assert 'shared_collection_revoke_loan' in revoke.attrib.get("href") - assert 'http://librarysimplified.org/terms/rel/revoke' == revoke.attrib.get("rel") + assert "shared_collection_revoke_loan" in revoke.attrib.get("href") + assert "http://librarysimplified.org/terms/rel/revoke" == revoke.attrib.get( + "rel" + ) assert "shared_collection_fulfill" in fulfill.attrib.get("href") - assert 'http://opds-spec.org/acquisition' == fulfill.attrib.get("rel") - assert 'shared_collection_loan_info' in info.attrib.get("href") + assert "http://opds-spec.org/acquisition" == fulfill.attrib.get("rel") + assert "shared_collection_loan_info" in info.attrib.get("href") assert "self" == info.attrib.get("rel") hold_links = annotator.acquisition_links( - hold.license_pool, None, hold, None, feed, hold.license_pool.identifier) + hold.license_pool, None, hold, None, feed, hold.license_pool.identifier + ) # Borrow, revoke, and hold info. [revoke, borrow, info] = sorted(hold_links, key=lambda x: x.attrib.get("rel")) - assert 'shared_collection_revoke_hold' in revoke.attrib.get("href") - assert 'http://librarysimplified.org/terms/rel/revoke' == revoke.attrib.get("rel") + assert "shared_collection_revoke_hold" in revoke.attrib.get("href") + assert "http://librarysimplified.org/terms/rel/revoke" == revoke.attrib.get( + "rel" + ) assert "shared_collection_borrow" in borrow.attrib.get("href") - assert 'http://opds-spec.org/acquisition/borrow' == borrow.attrib.get("rel") - assert 'shared_collection_hold_info' in info.attrib.get("href") + assert "http://opds-spec.org/acquisition/borrow" == borrow.attrib.get("rel") + assert "shared_collection_hold_info" in info.attrib.get("href") assert "self" == info.attrib.get("rel") work4_links = annotator.acquisition_links( - work4.license_pools[0], None, None, None, feed, work4.license_pools[0].identifier) + work4.license_pools[0], + None, + None, + None, + feed, + work4.license_pools[0].identifier, + ) # Borrow only. [borrow] = work4_links assert "shared_collection_borrow" in borrow.attrib.get("href") - assert 'http://opds-spec.org/acquisition/borrow' == borrow.attrib.get("rel") + assert "http://opds-spec.org/acquisition/borrow" == borrow.attrib.get("rel") def test_single_item_feed(self): # Test the generation of single-item OPDS feeds for loans (with and # without fulfillment) and holds. class MockAnnotator(SharedCollectionLoanAndHoldAnnotator): - def url_for(self, controller, **kwargs): self.url_for_called_with = (controller, kwargs) return "a URL" @@ -2019,15 +2124,18 @@ def _single_entry_response(self, *args, **kwargs): # Return the annotator itself so we can look at it. return self - def test_annotator(item, fulfillment, expect_route, - expect_route_kwargs): + def test_annotator(item, fulfillment, expect_route, expect_route_kwargs): # Call MockAnnotator.single_item_feed with certain arguments # and make some general assertions about the return value. test_mode = object() feed_class = object() result = MockAnnotator.single_item_feed( - self.collection, item, fulfillment, test_mode, feed_class, - extra_arg="value" + self.collection, + item, + fulfillment, + test_mode, + feed_class, + extra_arg="value", ) # The final result is a MockAnnotator object. This isn't @@ -2054,17 +2162,15 @@ def test_annotator(item, fulfillment, expect_route, # Apart from a few keyword arguments that are always the same, # the keyword arguments are the ones we expect. - assert self.collection.name == route_kwargs.pop('collection_name') - assert True == route_kwargs.pop('_external') + assert self.collection.name == route_kwargs.pop("collection_name") + assert True == route_kwargs.pop("_external") assert expect_route_kwargs == route_kwargs # The return value of that was the string "a URL". We then # passed that into _single_entry_response, along with # `item` and a number of arguments that we made up. response_call = result._single_entry_response_called_with - (_db, _work, annotator, url, _feed_class), kwargs = ( - response_call - ) + (_db, _work, annotator, url, _feed_class), kwargs = response_call assert self._db == _db assert work == _work assert result == annotator @@ -2073,7 +2179,7 @@ def test_annotator(item, fulfillment, expect_route, # The only keyword argument is an extra argument propagated from # the single_item_feed call. - assert 'value' == kwargs.pop('extra_arg') + assert "value" == kwargs.pop("extra_arg") # Return the MockAnnotator for further examination. return result @@ -2087,14 +2193,16 @@ def test_annotator(item, fulfillment, expect_route, # First, let's ask for a single-item feed for a loan. annotator = test_annotator( - loan, None, expect_route="shared_collection_loan_info", - expect_route_kwargs=dict(loan_id=loan.id) + loan, + None, + expect_route="shared_collection_loan_info", + expect_route_kwargs=dict(loan_id=loan.id), ) # Everything tested by test_annotator happened, but _also_, # when the annotator was created, the Loan was stored in # active_loans_by_work. - assert {work : loan} == annotator.active_loans_by_work + assert {work: loan} == annotator.active_loans_by_work # Since we passed in a loan rather than a hold, # active_holds_by_work is empty. @@ -2107,18 +2215,22 @@ def test_annotator(item, fulfillment, expect_route, # Now try it again, but give the loan a fulfillment. fulfillment = object() annotator = test_annotator( - loan, fulfillment, expect_route="shared_collection_loan_info", - expect_route_kwargs=dict(loan_id=loan.id) + loan, + fulfillment, + expect_route="shared_collection_loan_info", + expect_route_kwargs=dict(loan_id=loan.id), ) - assert {work : loan} == annotator.active_loans_by_work - assert {work : fulfillment} == annotator.active_fulfillments_by_work + assert {work: loan} == annotator.active_loans_by_work + assert {work: fulfillment} == annotator.active_fulfillments_by_work # Finally, try it with a hold. hold, ignore = pool.on_hold_to(patron) annotator = test_annotator( - hold, None, expect_route="shared_collection_hold_info", - expect_route_kwargs=dict(hold_id=hold.id) + hold, + None, + expect_route="shared_collection_hold_info", + expect_route_kwargs=dict(hold_id=hold.id), ) - assert {work : hold} == annotator.active_holds_by_work + assert {work: hold} == annotator.active_holds_by_work assert {} == annotator.active_loans_by_work assert {} == annotator.active_fulfillments_by_work diff --git a/tests/test_opds_for_distributors.py b/tests/test_opds_for_distributors.py index 15827395bd..c78ea83e3b 100644 --- a/tests/test_opds_for_distributors.py +++ b/tests/test_opds_for_distributors.py @@ -1,21 +1,17 @@ -import pytest import datetime -import os import json +import os + +import pytest +from api.circulation_exceptions import * from api.opds_for_distributors import ( + MockOPDSForDistributorsAPI, OPDSForDistributorsAPI, OPDSForDistributorsImporter, OPDSForDistributorsReaperMonitor, - MockOPDSForDistributorsAPI, -) -from api.circulation_exceptions import * -from core.testing import DatabaseTest -from core.metadata_layer import ( - CirculationData, - LinkData, - TimestampData, ) +from core.metadata_layer import CirculationData, LinkData, TimestampData from core.model import ( Collection, Credential, @@ -29,6 +25,7 @@ Representation, RightsStatus, ) +from core.testing import DatabaseTest from core.util.datetime_helpers import utc_now from core.util.opds_writer import OPDSFeed @@ -42,21 +39,23 @@ def get_data(cls, filename): path = os.path.join(cls.resource_path, filename) return open(path).read() -class TestOPDSForDistributorsAPI(DatabaseTest): +class TestOPDSForDistributorsAPI(DatabaseTest): def setup_method(self): super(TestOPDSForDistributorsAPI, self).setup_method() self.collection = MockOPDSForDistributorsAPI.mock_collection(self._db) self.api = MockOPDSForDistributorsAPI(self._db, self.collection) def test_external_integration(self): - assert (self.collection.external_integration == - self.api.external_integration(self._db)) + assert self.collection.external_integration == self.api.external_integration( + self._db + ) def test__run_self_tests(self): """The self-test for OPDSForDistributorsAPI just tries to negotiate a fulfillment token. """ + class Mock(OPDSForDistributorsAPI): def __init__(self): pass @@ -126,15 +125,21 @@ def test_get_token_success(self): # document and authenticate url. feed = '' self.api.queue_response(200, content=feed) - auth_doc = json.dumps({ - "authentication": [{ - "type": "http://opds-spec.org/auth/oauth/client_credentials", - "links": [{ - "rel": "authenticate", - "href": "http://authenticate", - }] - }] - }) + auth_doc = json.dumps( + { + "authentication": [ + { + "type": "http://opds-spec.org/auth/oauth/client_credentials", + "links": [ + { + "rel": "authenticate", + "href": "http://authenticate", + } + ], + } + ] + } + ) self.api.queue_response(200, content=auth_doc) token = self._str token_response = json.dumps({"access_token": token, "expires_in": 60}) @@ -160,15 +165,21 @@ def test_get_token_success(self): self.api = MockOPDSForDistributorsAPI(self._db, self.collection) # This feed requires authentication and returns the auth document. - auth_doc = json.dumps({ - "authentication": [{ - "type": "http://opds-spec.org/auth/oauth/client_credentials", - "links": [{ - "rel": "authenticate", - "href": "http://authenticate", - }] - }] - }) + auth_doc = json.dumps( + { + "authentication": [ + { + "type": "http://opds-spec.org/auth/oauth/client_credentials", + "links": [ + { + "rel": "authenticate", + "href": "http://authenticate", + } + ], + } + ] + } + ) self.api.queue_response(401, content=auth_doc) token = self._str token_response = json.dumps({"access_token": token, "expires_in": 60}) @@ -177,49 +188,67 @@ def test_get_token_success(self): assert token == self.api._get_token(self._db).credential def test_get_token_errors(self): - no_auth_document = '' + no_auth_document = "" self.api.queue_response(200, content=no_auth_document) with pytest.raises(LibraryAuthorizationFailedException) as excinfo: self.api._get_token(self._db) - assert "No authentication document link found in http://opds" in str(excinfo.value) + assert "No authentication document link found in http://opds" in str( + excinfo.value + ) feed = '' self.api.queue_response(200, content=feed) - auth_doc_without_client_credentials = json.dumps({ - "authentication": [] - }) + auth_doc_without_client_credentials = json.dumps({"authentication": []}) self.api.queue_response(200, content=auth_doc_without_client_credentials) with pytest.raises(LibraryAuthorizationFailedException) as excinfo: self.api._get_token(self._db) - assert "Could not find any credential-based authentication mechanisms in http://authdoc" in str(excinfo.value) + assert ( + "Could not find any credential-based authentication mechanisms in http://authdoc" + in str(excinfo.value) + ) self.api.queue_response(200, content=feed) - auth_doc_without_links = json.dumps({ - "authentication": [{ - "type": "http://opds-spec.org/auth/oauth/client_credentials", - }] - }) + auth_doc_without_links = json.dumps( + { + "authentication": [ + { + "type": "http://opds-spec.org/auth/oauth/client_credentials", + } + ] + } + ) self.api.queue_response(200, content=auth_doc_without_links) with pytest.raises(LibraryAuthorizationFailedException) as excinfo: self.api._get_token(self._db) - assert "Could not find any authentication links in http://authdoc" in str(excinfo.value) + assert "Could not find any authentication links in http://authdoc" in str( + excinfo.value + ) self.api.queue_response(200, content=feed) - auth_doc = json.dumps({ - "authentication": [{ - "type": "http://opds-spec.org/auth/oauth/client_credentials", - "links": [{ - "rel": "authenticate", - "href": "http://authenticate", - }] - }] - }) + auth_doc = json.dumps( + { + "authentication": [ + { + "type": "http://opds-spec.org/auth/oauth/client_credentials", + "links": [ + { + "rel": "authenticate", + "href": "http://authenticate", + } + ], + } + ] + } + ) self.api.queue_response(200, content=auth_doc) token_response = json.dumps({"error": "unexpected error"}) self.api.queue_response(200, content=token_response) with pytest.raises(LibraryAuthorizationFailedException) as excinfo: self.api._get_token(self._db) - assert 'Document retrieved from http://authenticate is not a bearer token: {"error": "unexpected error"}' in str(excinfo.value) + assert ( + 'Document retrieved from http://authenticate is not a bearer token: {"error": "unexpected error"}' + in str(excinfo.value) + ) def test_checkin(self): # The patron has two loans, one from this API's collection and @@ -265,7 +294,9 @@ def test_checkout(self): collection=self.collection, ) - loan_info = self.api.checkout(patron, "1234", pool, Representation.EPUB_MEDIA_TYPE) + loan_info = self.api.checkout( + patron, "1234", pool, Representation.EPUB_MEDIA_TYPE + ) assert self.collection.id == loan_info.collection_id assert data_source.name == loan_info.data_source_name assert Identifier.URI == loan_info.identifier_type @@ -290,14 +321,21 @@ def test_fulfill(self): ) # This pool doesn't have an acquisition link, so # we can't fulfill it yet. - pytest.raises(CannotFulfill, self.api.fulfill, - patron, "1234", pool, Representation.EPUB_MEDIA_TYPE) + pytest.raises( + CannotFulfill, + self.api.fulfill, + patron, + "1234", + pool, + Representation.EPUB_MEDIA_TYPE, + ) # Set up an epub acquisition link for the pool. url = self._url link, ignore = pool.identifier.add_link( Hyperlink.GENERIC_OPDS_ACQUISITION, - url, data_source, + url, + data_source, Representation.EPUB_MEDIA_TYPE, ) pool.set_delivery_mechanism( @@ -315,7 +353,9 @@ def test_fulfill(self): self.api.queue_response(200, content=token_response) fulfillment_time = utc_now() - fulfillment_info = self.api.fulfill(patron, "1234", pool, Representation.EPUB_MEDIA_TYPE) + fulfillment_info = self.api.fulfill( + patron, "1234", pool, Representation.EPUB_MEDIA_TYPE + ) assert self.collection.id == fulfillment_info.collection_id assert data_source.name == fulfillment_info.data_source_name assert Identifier.URI == fulfillment_info.identifier_type @@ -324,23 +364,21 @@ def test_fulfill(self): assert DeliveryMechanism.BEARER_TOKEN == fulfillment_info.content_type bearer_token_document = json.loads(fulfillment_info.content) - expires_in = bearer_token_document['expires_in'] + expires_in = bearer_token_document["expires_in"] assert expires_in < 60 - assert "Bearer" == bearer_token_document['token_type'] - assert "token" == bearer_token_document['access_token'] - assert url == bearer_token_document['location'] + assert "Bearer" == bearer_token_document["token_type"] + assert "token" == bearer_token_document["access_token"] + assert url == bearer_token_document["location"] # The FulfillmentInfo's content_expires is approximately the # time you get if you add the number of seconds until the # bearer token expires to the time at which the title was # originally fulfilled. - expect_expiration = fulfillment_time + datetime.timedelta( - seconds=expires_in + expect_expiration = fulfillment_time + datetime.timedelta(seconds=expires_in) + assert ( + abs((fulfillment_info.content_expires - expect_expiration).total_seconds()) + < 5 ) - assert abs( - (fulfillment_info.content_expires-expect_expiration).total_seconds() - ) < 5 - def test_patron_activity(self): # The patron has two loans from this API's collection and @@ -378,33 +416,38 @@ def test_patron_activity(self): [l1, l2] = activity assert l1.collection_id == self.collection.id assert l2.collection_id == self.collection.id - assert (set([l1.identifier, l2.identifier]) == - set([p1.identifier.identifier, p2.identifier.identifier])) + assert set([l1.identifier, l2.identifier]) == set( + [p1.identifier.identifier, p2.identifier.identifier] + ) -class TestOPDSForDistributorsImporter(DatabaseTest, BaseOPDSForDistributorsTest): +class TestOPDSForDistributorsImporter(DatabaseTest, BaseOPDSForDistributorsTest): def test_import(self): feed = self.get_data("biblioboard_mini_feed.opds") data_source = DataSource.lookup(self._db, "Biblioboard", autocreate=True) collection = MockOPDSForDistributorsAPI.mock_collection(self._db) collection.external_integration.set_setting( - Collection.DATA_SOURCE_NAME_SETTING, - data_source.name + Collection.DATA_SOURCE_NAME_SETTING, data_source.name ) class MockMetadataClient(object): def canonicalize_author_name(self, identifier, working_display_name): return working_display_name + metadata_client = MockMetadataClient() importer = OPDSForDistributorsImporter( - self._db, collection=collection, + self._db, + collection=collection, metadata_client=metadata_client, ) - imported_editions, imported_pools, imported_works, failures = ( - importer.import_from_feed(feed) - ) + ( + imported_editions, + imported_pools, + imported_works, + failures, + ) = importer.import_from_feed(feed) # This importer works the same as the base OPDSImporter, except that # it adds delivery mechanisms for books with epub acquisition links @@ -421,26 +464,45 @@ def canonicalize_author_name(self, identifier, working_display_name): for pool in [camelot_pool, southern_pool]: assert False == pool.open_access - assert RightsStatus.IN_COPYRIGHT == pool.delivery_mechanisms[0].rights_status.uri - assert Representation.EPUB_MEDIA_TYPE == pool.delivery_mechanisms[0].delivery_mechanism.content_type - assert DeliveryMechanism.BEARER_TOKEN == pool.delivery_mechanisms[0].delivery_mechanism.drm_scheme + assert ( + RightsStatus.IN_COPYRIGHT + == pool.delivery_mechanisms[0].rights_status.uri + ) + assert ( + Representation.EPUB_MEDIA_TYPE + == pool.delivery_mechanisms[0].delivery_mechanism.content_type + ) + assert ( + DeliveryMechanism.BEARER_TOKEN + == pool.delivery_mechanisms[0].delivery_mechanism.drm_scheme + ) assert 1 == pool.licenses_owned assert 1 == pool.licenses_available assert (pool.work.last_update_time - now).total_seconds() <= 2 - [camelot_acquisition_link] = [l for l in camelot_pool.identifier.links - if l.rel == Hyperlink.GENERIC_OPDS_ACQUISITION - and l.resource.representation.media_type == Representation.EPUB_MEDIA_TYPE] + [camelot_acquisition_link] = [ + l + for l in camelot_pool.identifier.links + if l.rel == Hyperlink.GENERIC_OPDS_ACQUISITION + and l.resource.representation.media_type == Representation.EPUB_MEDIA_TYPE + ] camelot_acquisition_url = camelot_acquisition_link.resource.representation.url - assert ("https://library.biblioboard.com/ext/api/media/04377e87-ab69-41c8-a2a4-812d55dc0952/assets/content.epub" == - camelot_acquisition_url) + assert ( + "https://library.biblioboard.com/ext/api/media/04377e87-ab69-41c8-a2a4-812d55dc0952/assets/content.epub" + == camelot_acquisition_url + ) - [southern_acquisition_link] = [l for l in southern_pool.identifier.links - if l.rel == Hyperlink.GENERIC_OPDS_ACQUISITION - and l.resource.representation.media_type == Representation.EPUB_MEDIA_TYPE] + [southern_acquisition_link] = [ + l + for l in southern_pool.identifier.links + if l.rel == Hyperlink.GENERIC_OPDS_ACQUISITION + and l.resource.representation.media_type == Representation.EPUB_MEDIA_TYPE + ] southern_acquisition_url = southern_acquisition_link.resource.representation.url - assert ("https://library.biblioboard.com/ext/api/media/04da95cd-6cfc-4e82-810f-121d418b6963/assets/content.epub" == - southern_acquisition_url) + assert ( + "https://library.biblioboard.com/ext/api/media/04da95cd-6cfc-4e82-810f-121d418b6963/assets/content.epub" + == southern_acquisition_url + ) def test__add_format_data(self): @@ -485,26 +547,22 @@ def test__add_format_data(self): class TestOPDSForDistributorsReaperMonitor(DatabaseTest, BaseOPDSForDistributorsTest): - def test_reaper(self): feed = self.get_data("biblioboard_mini_feed.opds") class MockOPDSForDistributorsReaperMonitor(OPDSForDistributorsReaperMonitor): """An OPDSForDistributorsReaperMonitor that overrides _get.""" + def _get(self, url, headers): - return ( - 200, {'content-type': OPDSFeed.ACQUISITION_FEED_TYPE}, feed - ) + return (200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, feed) data_source = DataSource.lookup(self._db, "Biblioboard", autocreate=True) collection = MockOPDSForDistributorsAPI.mock_collection(self._db) collection.external_integration.set_setting( - Collection.DATA_SOURCE_NAME_SETTING, - data_source.name + Collection.DATA_SOURCE_NAME_SETTING, data_source.name ) monitor = MockOPDSForDistributorsReaperMonitor( - self._db, collection, OPDSForDistributorsImporter, - metadata_client=object() + self._db, collection, OPDSForDistributorsImporter, metadata_client=object() ) # There's a license pool in the database that isn't in the feed anymore. diff --git a/tests/test_overdrive.py b/tests/test_overdrive.py index ba5e3edb2b..66a7d5ae25 100644 --- a/tests/test_overdrive.py +++ b/tests/test_overdrive.py @@ -1,11 +1,15 @@ # encoding: utf-8 -import pytest -import pkgutil import json -from datetime import ( - timedelta, -) +import pkgutil import random +from datetime import timedelta + +import pytest + +from api.authenticator import BasicAuthenticationProvider +from api.circulation import CirculationAPI, FulfillmentInfo, HoldInfo, LoanInfo +from api.circulation_exceptions import * +from api.config import Configuration, temp_config from api.overdrive import ( MockOverdriveAPI, NewTitlesOverdriveCollectionMonitor, @@ -14,30 +18,12 @@ OverdriveCollectionReaper, OverdriveFormatSweep, OverdriveManifestFulfillmentInfo, - RecentOverdriveCollectionMonitor -) - -from api.authenticator import BasicAuthenticationProvider -from api.circulation import ( - CirculationAPI, - FulfillmentInfo, - HoldInfo, - LoanInfo, -) -from api.circulation_exceptions import * -from api.config import Configuration - -from core.testing import DatabaseTest -from core.util.datetime_helpers import ( - datetime_utc, - utc_now, + RecentOverdriveCollectionMonitor, ) -from . import sample_data - from core.metadata_layer import TimestampData from core.model import ( - Collection, CirculationEvent, + Collection, ConfigurationSetting, DataSource, DeliveryMechanism, @@ -49,27 +35,25 @@ Representation, RightsStatus, ) -from core.testing import ( - DummyHTTPClient, - MockRequestsResponse, -) +from core.testing import DatabaseTest, DummyHTTPClient, MockRequestsResponse +from core.util.datetime_helpers import datetime_utc, utc_now -from api.config import temp_config +from . import sample_data -class OverdriveAPITest(DatabaseTest): +class OverdriveAPITest(DatabaseTest): def setup_method(self): super(OverdriveAPITest, self).setup_method() library = self._default_library self.collection = MockOverdriveAPI.mock_collection(self._db) self.circulation = CirculationAPI( - self._db, library, api_map={ExternalIntegration.OVERDRIVE:MockOverdriveAPI} + self._db, library, api_map={ExternalIntegration.OVERDRIVE: MockOverdriveAPI} ) self.api = self.circulation.api_for_collection[self.collection.id] @classmethod def sample_data(self, filename): - return sample_data(filename, 'overdrive') + return sample_data(filename, "overdrive") @classmethod def sample_json(self, filename): @@ -85,11 +69,12 @@ def error_message(self, error_code, message=None, token=None): data = dict(errorCode=error_code, message=message, token=token) return json.dumps(data) -class TestOverdriveAPI(OverdriveAPITest): +class TestOverdriveAPI(OverdriveAPITest): def test_external_integration(self): - assert (self.collection.external_integration == - self.api.external_integration(self._db)) + assert self.collection.external_integration == self.api.external_integration( + self._db + ) def test_lock_in_format(self): # Verify which formats do or don't need to be locked in before @@ -98,9 +83,8 @@ def test_lock_in_format(self): # Streaming and manifest-based formats are exempt; all # other formats need lock-in. - exempt = ( - list(self.api.STREAMING_FORMATS) + - list(self.api.MANIFEST_INTERNAL_FORMATS) + exempt = list(self.api.STREAMING_FORMATS) + list( + self.api.MANIFEST_INTERNAL_FORMATS ) for i in self.api.FORMATS: if i not in exempt: @@ -117,12 +101,14 @@ class Mock(MockOverdriveAPI): # First we will call check_creds() to get a fresh credential. mock_credential = object() + def check_creds(self, force_refresh=False): self.check_creds_called_with = force_refresh return self.mock_credential # Then we will call get_advantage_accounts(). mock_advantage_accounts = [object(), object()] + def get_advantage_accounts(self): return self.mock_advantage_accounts @@ -136,6 +122,7 @@ def get(self, url, extra_headers, exception_on_401=False): # the credentials of that library's test patron. mock_patron_credential = object() get_patron_credential_called_with = [] + def get_patron_credential(self, patron, pin): self.get_patron_credential_called_with.append((patron, pin)) return self.mock_patron_credential @@ -150,7 +137,7 @@ def get_patron_credential(self, patron, pin): integration = self._external_integration( "api.simple_authentication", ExternalIntegration.PATRON_AUTH_GOAL, - libraries=[with_default_patron] + libraries=[with_default_patron], ) p = BasicAuthenticationProvider integration.setting(p.TEST_IDENTIFIER).value = "username1" @@ -158,39 +145,47 @@ def get_patron_credential(self, patron, pin): # Now that everything is set up, run the self-test. api = Mock(self._db, self.collection) - results = sorted( - api._run_self_tests(self._db), key=lambda x: x.name - ) - [no_patron_credential, default_patron_credential, - global_privileges, collection_size, advantage] = results + results = sorted(api._run_self_tests(self._db), key=lambda x: x.name) + [ + no_patron_credential, + default_patron_credential, + global_privileges, + collection_size, + advantage, + ] = results # Verify that each test method was called and returned the # expected SelfTestResult object. - assert ('Checking global Client Authentication privileges' == - global_privileges.name) + assert ( + "Checking global Client Authentication privileges" == global_privileges.name + ) assert True == global_privileges.success assert api.mock_credential == global_privileges.result - assert 'Looking up Overdrive Advantage accounts' == advantage.name + assert "Looking up Overdrive Advantage accounts" == advantage.name assert True == advantage.success - assert 'Found 2 Overdrive Advantage account(s).' == advantage.result + assert "Found 2 Overdrive Advantage account(s)." == advantage.result - assert 'Counting size of collection' == collection_size.name + assert "Counting size of collection" == collection_size.name assert True == collection_size.success - assert '2010 item(s) in collection' == collection_size.result + assert "2010 item(s) in collection" == collection_size.result url, headers, error_on_401 = api.get_called_with assert api._all_products_link == url assert ( - "Acquiring test patron credentials for library %s" % no_default_patron.name == - no_patron_credential.name) + "Acquiring test patron credentials for library %s" % no_default_patron.name + == no_patron_credential.name + ) assert False == no_patron_credential.success - assert ("Library has no test patron configured." == - str(no_patron_credential.exception)) + assert "Library has no test patron configured." == str( + no_patron_credential.exception + ) assert ( - "Checking Patron Authentication privileges, using test patron for library %s" % with_default_patron.name == - default_patron_credential.name) + "Checking Patron Authentication privileges, using test patron for library %s" + % with_default_patron.name + == default_patron_credential.name + ) assert True == default_patron_credential.success assert api.mock_patron_credential == default_patron_credential.result @@ -208,8 +203,10 @@ def test_run_self_tests_short_circuit(self): This probably doesn't matter much, because if check_creds doesn't work we won't be able to instantiate the OverdriveAPI class. """ + def explode(*args, **kwargs): raise Exception("Failure!") + self.api.check_creds = explode # Only one test will be run. @@ -221,54 +218,50 @@ def test_default_notification_email_address(self): previously given by the patron to Overdrive for the purpose of notifications. """ - ignore, patron_with_email = self.sample_json( - "patron_info.json" - ) + ignore, patron_with_email = self.sample_json("patron_info.json") self.api.queue_response(200, content=patron_with_email) patron = self._patron() # The site default for notification emails will never be used. configuration_setting = ConfigurationSetting.for_library( - Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, - self._default_library + Configuration.DEFAULT_NOTIFICATION_EMAIL_ADDRESS, self._default_library ) configuration_setting.value = "notifications@example.com" # If the patron has used a particular email address to put # books on hold, use that email address, not the site default. - assert ("foo@bar.com" == - self.api.default_notification_email_address(patron, 'pin')) + assert "foo@bar.com" == self.api.default_notification_email_address( + patron, "pin" + ) # If the patron's email address according to Overdrive _is_ # the site default, it is ignored. This can only happen if # this patron placed a hold using an older version of the # circulation manager. - patron_with_email['lastHoldEmail'] = configuration_setting.value + patron_with_email["lastHoldEmail"] = configuration_setting.value self.api.queue_response(200, content=patron_with_email) - assert (None == - self.api.default_notification_email_address(patron, 'pin')) + assert None == self.api.default_notification_email_address(patron, "pin") # If the patron has never before put an Overdrive book on # hold, their JSON object has no `lastHoldEmail` key. In this # case we return None -- again, ignoring the site default. patron_with_no_email = dict(patron_with_email) - del patron_with_no_email['lastHoldEmail'] + del patron_with_no_email["lastHoldEmail"] self.api.queue_response(200, content=patron_with_no_email) - assert (None == - self.api.default_notification_email_address(patron, 'pin')) + assert None == self.api.default_notification_email_address(patron, "pin") # If there's an error getting the information from Overdrive, # we return None. self.api.queue_response(404) - assert (None == - self.api.default_notification_email_address(patron, 'pin')) + assert None == self.api.default_notification_email_address(patron, "pin") def test_scope_string(self): # scope_string() puts the website ID of the Overdrive # integration and the ILS name associated with the library # into the form expected by Overdrive. expect = "websiteid:%s authorizationname:%s" % ( - self.api.website_id.decode("utf-8"), self.api.ils_name(self._default_library) + self.api.website_id.decode("utf-8"), + self.api.ils_name(self._default_library), ) assert expect == self.api.scope_string(self._default_library) @@ -312,13 +305,13 @@ def _process_checkout_error(self, patron, pin, licensepool, data): # Verify that a good-looking patron request went out. endpoint, ignore, kwargs = api.requests.pop() assert endpoint.endswith("/me/checkouts") - assert patron == kwargs.pop('_patron') - extra_headers = kwargs.pop('extra_headers') + assert patron == kwargs.pop("_patron") + extra_headers = kwargs.pop("extra_headers") assert {"Content-Type": "application/json"} == extra_headers - data = json.loads(kwargs.pop('data')) - assert ( - {'fields': [{'name': 'reserveId', 'value': pool.identifier.identifier}]} == - data) + data = json.loads(kwargs.pop("data")) + assert { + "fields": [{"name": "reserveId", "value": pool.identifier.identifier}] + } == data # The API response was passed into extract_expiration_date. # @@ -345,19 +338,27 @@ def _process_checkout_error(self, patron, pin, licensepool, data): with pytest.raises(Exception) as excinfo: api.checkout(patron, pin, pool, "internal format is ignored") assert "exception in _process_checkout_error" in str(excinfo.value) - assert (patron, pin, pool, "some data") == api._process_checkout_error_called_with.pop() + assert ( + patron, + pin, + pool, + "some data", + ) == api._process_checkout_error_called_with.pop() # However, if _process_checkout_error is able to recover from # the error and ends up returning something, the return value # is propagated from checkout(). api.PROCESS_CHECKOUT_ERROR_RESULT = "Actually, I was able to recover" api.queue_response(400, content=api_response) + assert "Actually, I was able to recover" == api.checkout( + patron, pin, pool, "internal format is ignored" + ) assert ( - "Actually, I was able to recover" == - api.checkout( - patron, pin, pool, "internal format is ignored" - )) - assert (patron, pin, pool, "some data") == api._process_checkout_error_called_with.pop() + patron, + pin, + pool, + "some data", + ) == api._process_checkout_error_called_with.pop() def test__process_checkout_error(self): # Verify that _process_checkout_error handles common API-side errors, @@ -415,10 +416,14 @@ def with_error_code(code): assert "Unknown Error" in str(excinfo.value) # Some known errors become specific subclasses of CannotLoan. - pytest.raises(PatronLoanLimitReached, with_error_code, - "PatronHasExceededCheckoutLimit") - pytest.raises(PatronLoanLimitReached, with_error_code, - "PatronHasExceededCheckoutLimit_ForCPC") + pytest.raises( + PatronLoanLimitReached, with_error_code, "PatronHasExceededCheckoutLimit" + ) + pytest.raises( + PatronLoanLimitReached, + with_error_code, + "PatronHasExceededCheckoutLimit_ForCPC", + ) # There are two cases where we need to make follow-up API # requests as the result of a failure during the loan process. @@ -426,8 +431,7 @@ def with_error_code(code): # First, if the error is "NoCopiesAvailable", we know we have # out-of-date availability information and we need to call # update_licensepool before raising NoAvailbleCopies(). - pytest.raises(NoAvailableCopies, with_error_code, - "NoCopiesAvailable") + pytest.raises(NoAvailableCopies, with_error_code, "NoCopiesAvailable") assert identifier.identifier == api.update_licensepool_called_with.pop() # If the error is "TitleAlreadyCheckedOut", then the problem @@ -438,8 +442,7 @@ def with_error_code(code): loan = with_error_code("TitleAlreadyCheckedOut") # get_loan was called with the patron's details. - assert ((patron, pin, identifier.identifier) == - api.get_loan_called_with.pop()) + assert (patron, pin, identifier.identifier) == api.get_loan_called_with.pop() # extract_expiration_date was called on the return value of get_loan. assert api.MOCK_LOAN == api.extract_expiration_date_called_with.pop() @@ -458,7 +461,9 @@ def test_extract_expiration_date(self): m = OverdriveAPI.extract_expiration_date # Success - assert datetime_utc(2020, 1, 2, 3, 4, 5) == m(dict(expires="2020-01-02T03:04:05Z")) + assert datetime_utc(2020, 1, 2, 3, 4, 5) == m( + dict(expires="2020-01-02T03:04:05Z") + ) # Various failure cases. assert None == m(dict(expiresPresent=False)) @@ -478,9 +483,7 @@ def __init__(self, *args, **kwargs): self.DEFAULT_NOTIFICATION_EMAIL_ADDRESS = None def default_notification_email_address(self, patron, pin): - self.default_notification_email_address_called_with = ( - patron, pin - ) + self.default_notification_email_address_called_with = (patron, pin) return self.DEFAULT_NOTIFICATION_EMAIL_ADDRESS def fill_out_form(self, **form_fields): @@ -493,11 +496,12 @@ def patron_request(self, *args, **kwargs): self.patron_request_called_with = (args, kwargs) return "A mock response" - def process_place_hold_response( - self, response, patron, pin, licensepool - ): + def process_place_hold_response(self, response, patron, pin, licensepool): self.process_place_hold_response_called_with = ( - response, patron, pin, licensepool + response, + patron, + pin, + licensepool, ) return "OK, I processed it." @@ -525,15 +529,17 @@ def process_place_hold_response( # patron_request was called with the filled-out form and other # information necessary to authenticate the request. args, kwargs = api.patron_request_called_with - assert ((patron, pin, api.HOLDS_ENDPOINT, 'headers', 'filled-out form') == - args) + assert (patron, pin, api.HOLDS_ENDPOINT, "headers", "filled-out form") == args assert {} == kwargs # Finally, process_place_hold_response was called on # the return value of patron_request assert ( - ("A mock response", patron, pin, pool) == - api.process_place_hold_response_called_with) + "A mock response", + patron, + pin, + pool, + ) == api.process_place_hold_response_called_with assert "OK, I processed it." == response # Now we need to test two more cases. @@ -564,6 +570,7 @@ def test_process_place_hold_response(self): # to a HOLDS_ENDPOINT request. ignore, successful_hold = self.sample_json("successful_hold.json") + class Mock(MockOverdriveAPI): def get_hold(self, patron, pin, overdrive_id): # Return a sample hold representation rather than @@ -583,12 +590,9 @@ def process_error_response(message): return api.process_place_hold_response(response, None, None, None) # Some error messages result in specific CirculationExceptions. + pytest.raises(CannotRenew, process_error_response, "NotWithinRenewalWindow") pytest.raises( - CannotRenew, process_error_response, "NotWithinRenewalWindow" - ) - pytest.raises( - PatronHoldLimitReached, process_error_response, - "PatronExceededHoldLimit" + PatronHoldLimitReached, process_error_response, "PatronExceededHoldLimit" ) # An unrecognized error message results in a generic @@ -604,8 +608,7 @@ def process_error_response(message): # (which shouldn't happen in real life). response = MockRequestsResponse(999) pytest.raises( - CannotHold, api.process_place_hold_response, - response, None, None, None + CannotHold, api.process_place_hold_response, response, None, None, None ) # At this point patron and book details become important -- @@ -632,14 +635,11 @@ def assert_correct_holdinfo(x): # on hold. already_on_hold = dict(errorCode="AlreadyOnWaitList") response = MockRequestsResponse(400, content=already_on_hold) - result = api.process_place_hold_response( - response, patron, pin, licensepool - ) + result = api.process_place_hold_response(response, patron, pin, licensepool) # get_hold() was called with the arguments we expect. identifier = licensepool.identifier - assert ((patron, pin, identifier.identifier) == - api.get_hold_called_with) + assert (patron, pin, identifier.identifier) == api.get_hold_called_with # The result was converted into a HoldInfo object. The # effective result is exactly as if we had successfully put @@ -650,9 +650,7 @@ def assert_correct_holdinfo(x): # there is. api.get_hold_called_with = None response = MockRequestsResponse(200, content=successful_hold) - result = api.process_place_hold_response( - response, patron, pin, licensepool - ) + result = api.process_place_hold_response(response, patron, pin, licensepool) assert_correct_holdinfo(result) # Here, get_hold was _not_ called, because the hold didn't @@ -660,7 +658,6 @@ def assert_correct_holdinfo(x): assert None == api.get_hold_called_with def test_checkin(self): - class Mock(MockOverdriveAPI): EARLY_RETURN_SUCCESS = False @@ -680,8 +677,7 @@ def patron_request(self, *args, **kwargs): patron = self._patron() pin = object() expect_url = overdrive.endpoint( - overdrive.CHECKOUT_ENDPOINT, - overdrive_id=pool.identifier.identifier + overdrive.CHECKOUT_ENDPOINT, overdrive_id=pool.identifier.identifier ) def assert_no_early_return(): @@ -740,7 +736,6 @@ def assert_no_early_return(): assert None == overdrive.patron_request_call def test_perform_early_return(self): - class Mock(MockOverdriveAPI): EARLY_RETURN_URL = "http://early-return/" @@ -792,19 +787,22 @@ def _extract_early_return_url(self, *args): # # get_fulfillment_link was called with appropriate arguments. assert ( - (patron, pin, pool.identifier.identifier, 'ebook-epub-adobe') == - overdrive.get_fulfillment_link_call) + patron, + pin, + pool.identifier.identifier, + "ebook-epub-adobe", + ) == overdrive.get_fulfillment_link_call # The URL returned by that method was 'requested'. - assert 'http://fulfillment/' == http.requests.pop(0) + assert "http://fulfillment/" == http.requests.pop(0) # The resulting URL was passed into _extract_early_return_url. assert ( - ('http://fulfill-this-book/?or=return-early',) == - overdrive._extract_early_return_url_call) + "http://fulfill-this-book/?or=return-early", + ) == overdrive._extract_early_return_url_call # Then the URL returned by _that_ method was 'requested'. - assert 'http://early-return/' == http.requests.pop(0) + assert "http://early-return/" == http.requests.pop(0) # If no early return URL can be extracted from the fulfillment URL, # perform_early_return has no effect. @@ -812,18 +810,17 @@ def _extract_early_return_url(self, *args): overdrive._extract_early_return_url_call = None overdrive.EARLY_RETURN_URL = None http.responses.append( - MockRequestsResponse( - 302, dict(location="http://fulfill-this-book/") - ) + MockRequestsResponse(302, dict(location="http://fulfill-this-book/")) ) success = overdrive.perform_early_return(patron, pin, loan, http.do_get) assert False == success # extract_early_return_url_call was called, but since it returned # None, no second HTTP request was made. - assert 'http://fulfillment/' == http.requests.pop(0) - assert (("http://fulfill-this-book/",) == - overdrive._extract_early_return_url_call) + assert "http://fulfillment/" == http.requests.pop(0) + assert ( + "http://fulfill-this-book/", + ) == overdrive._extract_early_return_url_call assert [] == http.requests # If we can't map the delivery mechanism to one of Overdrive's @@ -845,9 +842,7 @@ def _extract_early_return_url(self, *args): 302, dict(location="http://fulfill-this-book/?or=return-early") ) ) - http.responses.append( - MockRequestsResponse(401, content="Unauthorized!") - ) + http.responses.append(MockRequestsResponse(401, content="Unauthorized!")) success = overdrive.perform_early_return(patron, pin, loan, http.do_get) assert False == success @@ -858,54 +853,57 @@ def test_extract_early_return_url(self): assert None == m(None) # This is based on a real Overdrive early return URL. - has_early_return = 'https://openepub-gk.cdn.overdrive.com/OpenEPUBStore1/1577-1/%7B5880F6D0-48AC-44DE-8BF1-FD1CE62E97A8%7DFzr418.epub?e=1518753718&loanExpirationDate=2018-03-01T17%3a12%3a33Z&loanEarlyReturnUrl=https%3a%2f%2fnotifications-ofs.contentreserve.com%2fEarlyReturn%2fnypl%2f037-1374147-00279%2f5480F6E1-48F3-00DE-96C1-FD3CE32D94FD-312%3fh%3dVgvxBQHdQxtsbgb43AH6%252bEmpni9LoffkPczNiUz7%252b10%253d&sourceId=nypl&h=j7nGk7qxE71X2ZcdLw%2bqa04jqEw%3d' - assert 'https://notifications-ofs.contentreserve.com/EarlyReturn/nypl/037-1374147-00279/5480F6E1-48F3-00DE-96C1-FD3CE32D94FD-312?h=VgvxBQHdQxtsbgb43AH6%2bEmpni9LoffkPczNiUz7%2b10%3d' == m(has_early_return) + has_early_return = "https://openepub-gk.cdn.overdrive.com/OpenEPUBStore1/1577-1/%7B5880F6D0-48AC-44DE-8BF1-FD1CE62E97A8%7DFzr418.epub?e=1518753718&loanExpirationDate=2018-03-01T17%3a12%3a33Z&loanEarlyReturnUrl=https%3a%2f%2fnotifications-ofs.contentreserve.com%2fEarlyReturn%2fnypl%2f037-1374147-00279%2f5480F6E1-48F3-00DE-96C1-FD3CE32D94FD-312%3fh%3dVgvxBQHdQxtsbgb43AH6%252bEmpni9LoffkPczNiUz7%252b10%253d&sourceId=nypl&h=j7nGk7qxE71X2ZcdLw%2bqa04jqEw%3d" + assert ( + "https://notifications-ofs.contentreserve.com/EarlyReturn/nypl/037-1374147-00279/5480F6E1-48F3-00DE-96C1-FD3CE32D94FD-312?h=VgvxBQHdQxtsbgb43AH6%2bEmpni9LoffkPczNiUz7%2b10%3d" + == m(has_early_return) + ) def test_place_hold_raises_exception_if_patron_over_hold_limit(self): over_hold_limit = self.error_message( "PatronExceededHoldLimit", - "Patron cannot place any more holds, already has maximum holds placed." + "Patron cannot place any more holds, already has maximum holds placed.", ) edition, pool = self._edition( identifier_type=Identifier.OVERDRIVE_ID, data_source_name=DataSource.OVERDRIVE, - with_license_pool=True + with_license_pool=True, ) self.api.queue_response(400, content=over_hold_limit) pytest.raises( PatronHoldLimitReached, - self.api.place_hold, self._patron(), 'pin', pool, - notification_email_address='foo@bar.com' + self.api.place_hold, + self._patron(), + "pin", + pool, + notification_email_address="foo@bar.com", ) def test_place_hold_looks_up_notification_address(self): edition, pool = self._edition( identifier_type=Identifier.OVERDRIVE_ID, data_source_name=DataSource.OVERDRIVE, - with_license_pool=True + with_license_pool=True, ) # The first request we make will be to get patron info, # so that we know that the most recent email address used # to put a book on hold is foo@bar.com. - ignore, patron_with_email = self.sample_json( - "patron_info.json" - ) + ignore, patron_with_email = self.sample_json("patron_info.json") # The second request we make will be to put a book on hold, # and when we do so we will ask for the notification to be # sent to foo@bar.com. - ignore, successful_hold = self.sample_json( - "successful_hold.json" - ) + ignore, successful_hold = self.sample_json("successful_hold.json") self.api.queue_response(200, content=patron_with_email) self.api.queue_response(200, content=successful_hold) with temp_config() as config: - config['default_notification_email_address'] = "notifications@example.com" - hold = self.api.place_hold(self._patron(), 'pin', pool, - notification_email_address=None) + config["default_notification_email_address"] = "notifications@example.com" + hold = self.api.place_hold( + self._patron(), "pin", pool, notification_email_address=None + ) # The book was placed on hold. assert 1 == hold.hold_position @@ -921,7 +919,8 @@ def test_fulfill_returns_fulfillmentinfo_if_returned_by_get_fulfillment_link(sel # If get_fulfillment_link returns a FulfillmentInfo, it is returned # immediately and the rest of fulfill() does not run. - fulfillment = FulfillmentInfo(self.collection, *[None]*7) + fulfillment = FulfillmentInfo(self.collection, *[None] * 7) + class MockAPI(OverdriveAPI): def get_fulfillment_link(*args, **kwargs): return fulfillment @@ -937,17 +936,19 @@ def test_fulfill_raises_exception_and_updates_formats_for_outdated_format(self): edition, pool = self._edition( identifier_type=Identifier.OVERDRIVE_ID, data_source_name=DataSource.OVERDRIVE, - with_license_pool=True + with_license_pool=True, ) # This pool has a format that's no longer available from overdrive. - pool.set_delivery_mechanism(Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM, - RightsStatus.IN_COPYRIGHT, None) - - ignore, loan = self.sample_json( - "single_loan.json" + pool.set_delivery_mechanism( + Representation.PDF_MEDIA_TYPE, + DeliveryMechanism.ADOBE_DRM, + RightsStatus.IN_COPYRIGHT, + None, ) + ignore, loan = self.sample_json("single_loan.json") + ignore, lock_in_format_not_available = self.sample_json( "lock_in_format_not_available.json" ) @@ -960,18 +961,18 @@ def test_fulfill_raises_exception_and_updates_formats_for_outdated_format(self): pytest.raises( FormatNotAvailable, self.api.get_fulfillment_link, - self._patron(), 'pin', pool.identifier.identifier, - 'ebook-epub-adobe' + self._patron(), + "pin", + pool.identifier.identifier, + "ebook-epub-adobe", ) # Fulfill will also update the formats. - ignore, bibliographic = self.sample_json( - "bibliographic_information.json" - ) + ignore, bibliographic = self.sample_json("bibliographic_information.json") # To avoid a mismatch, make it look like the information is # for the correct Identifier. - bibliographic['id'] = pool.identifier.identifier + bibliographic["id"] = pool.identifier.identifier # If we have the LicensePool available (as opposed to just the # identifier), we will get the loan, try to lock in the @@ -983,16 +984,34 @@ def test_fulfill_raises_exception_and_updates_formats_for_outdated_format(self): pytest.raises( FormatNotAvailable, self.api.fulfill, - self._patron(), 'pin', pool, - 'ebook-epub-adobe' + self._patron(), + "pin", + pool, + "ebook-epub-adobe", ) # The delivery mechanisms have been updated. assert 4 == len(pool.delivery_mechanisms) - assert (set([MediaTypes.EPUB_MEDIA_TYPE, DeliveryMechanism.KINDLE_CONTENT_TYPE, DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, MediaTypes.OVERDRIVE_EBOOK_MANIFEST_MEDIA_TYPE]) == - set([lpdm.delivery_mechanism.content_type for lpdm in pool.delivery_mechanisms])) - assert (set([DeliveryMechanism.ADOBE_DRM, DeliveryMechanism.KINDLE_DRM, DeliveryMechanism.LIBBY_DRM, DeliveryMechanism.STREAMING_DRM]) == - set([lpdm.delivery_mechanism.drm_scheme for lpdm in pool.delivery_mechanisms])) + assert set( + [ + MediaTypes.EPUB_MEDIA_TYPE, + DeliveryMechanism.KINDLE_CONTENT_TYPE, + DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, + MediaTypes.OVERDRIVE_EBOOK_MANIFEST_MEDIA_TYPE, + ] + ) == set( + [lpdm.delivery_mechanism.content_type for lpdm in pool.delivery_mechanisms] + ) + assert set( + [ + DeliveryMechanism.ADOBE_DRM, + DeliveryMechanism.KINDLE_DRM, + DeliveryMechanism.LIBBY_DRM, + DeliveryMechanism.STREAMING_DRM, + ] + ) == set( + [lpdm.delivery_mechanism.drm_scheme for lpdm in pool.delivery_mechanisms] + ) def test_get_fulfillment_link_from_download_link(self): patron = self._patron() @@ -1003,9 +1022,13 @@ def test_get_fulfillment_link_from_download_link(self): self.api.queue_response(200, content=streaming_fulfill_link) - href, type = self.api.get_fulfillment_link_from_download_link(patron, '1234', "http://download-link", fulfill_url="http://fulfill") - assert ("https://fulfill.contentreserve.com/PerfectLife9780345530967.epub-sample.overdrive.com?RetailerID=nypl&Expires=1469825647&Token=dd0e19b4-eb70-439d-8c50-a65201060f4c&Signature=asl67/G154KeeUsL1mHPwEbZfgc=" == - href) + href, type = self.api.get_fulfillment_link_from_download_link( + patron, "1234", "http://download-link", fulfill_url="http://fulfill" + ) + assert ( + "https://fulfill.contentreserve.com/PerfectLife9780345530967.epub-sample.overdrive.com?RetailerID=nypl&Expires=1469825647&Token=dd0e19b4-eb70-439d-8c50-a65201060f4c&Signature=asl67/G154KeeUsL1mHPwEbZfgc=" + == href + ) assert "text/html" == type def test_get_fulfillment_link_returns_fulfillmentinfo_for_manifest_format(self): @@ -1017,16 +1040,14 @@ def test_get_fulfillment_link_returns_fulfillmentinfo_for_manifest_format(self): # To keep things simple, our mock API will always return the same # fulfillment link. loan_info = {"isFormatLockedIn": False} - class MockAPI(MockOverdriveAPI): + class MockAPI(MockOverdriveAPI): def get_loan(self, patron, pin, overdrive_id): self.get_loan_called_with = (patron, pin, overdrive_id) return loan_info def get_download_link(self, loan, format_type, error_url): - self.get_download_link_called_with = ( - loan, format_type, error_url - ) + self.get_download_link_called_with = (loan, format_type, error_url) return "http://fulfillment-link/" def get_fulfillment_link_from_download_link(self, *args, **kwargs): @@ -1038,14 +1059,14 @@ def get_fulfillment_link_from_download_link(self, *args, **kwargs): # Randomly choose one of the formats that must be fulfilled as # a link to a manifest. - overdrive_format = random.choice( - list(OverdriveAPI.MANIFEST_INTERNAL_FORMATS) - ) + overdrive_format = random.choice(list(OverdriveAPI.MANIFEST_INTERNAL_FORMATS)) # Get the fulfillment link. patron = self._patron() fulfillmentinfo = api.get_fulfillment_link( - patron, '1234', "http://download-link", + patron, + "1234", + "http://download-link", overdrive_format, ) assert isinstance(fulfillmentinfo, OverdriveManifestFulfillmentInfo) @@ -1054,7 +1075,7 @@ def get_fulfillment_link_from_download_link(self, *args, **kwargs): # let's see how we got there. # First, our mocked get_loan() was called. - assert (patron, '1234', 'http://download-link') == api.get_loan_called_with + assert (patron, "1234", "http://download-link") == api.get_loan_called_with # It returned a dictionary that contained no information # except isFormatLockedIn: false. @@ -1064,8 +1085,10 @@ def get_fulfillment_link_from_download_link(self, *args, **kwargs): # loan info was passed into our mocked get_download_link. assert ( - (loan_info, overdrive_format, api.DEFAULT_ERROR_URL) == - api.get_download_link_called_with) + loan_info, + overdrive_format, + api.DEFAULT_ERROR_URL, + ) == api.get_download_link_called_with # Since the manifest formats cannot be retrieved by the # circulation manager, the result of get_download_link was @@ -1080,22 +1103,24 @@ def test_update_formats(self): edition, pool = self._edition( data_source_name=DataSource.OVERDRIVE, identifier_type=Identifier.OVERDRIVE_ID, - with_license_pool=True + with_license_pool=True, ) edition.medium = Edition.PERIODICAL_MEDIUM # Add the bad delivery mechanism. - pool.set_delivery_mechanism(Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM, - RightsStatus.IN_COPYRIGHT, None) + pool.set_delivery_mechanism( + Representation.PDF_MEDIA_TYPE, + DeliveryMechanism.ADOBE_DRM, + RightsStatus.IN_COPYRIGHT, + None, + ) # Prepare the bibliographic information. - ignore, bibliographic = self.sample_json( - "bibliographic_information.json" - ) + ignore, bibliographic = self.sample_json("bibliographic_information.json") # To avoid a mismatch, make it look like the information is # for the new pool's Identifier. - bibliographic['id'] = pool.identifier.identifier + bibliographic["id"] = pool.identifier.identifier self.api.queue_response(200, content=bibliographic) @@ -1103,10 +1128,26 @@ def test_update_formats(self): # The delivery mechanisms have been updated. assert 4 == len(pool.delivery_mechanisms) - assert (set([MediaTypes.EPUB_MEDIA_TYPE, DeliveryMechanism.KINDLE_CONTENT_TYPE, DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, MediaTypes.OVERDRIVE_EBOOK_MANIFEST_MEDIA_TYPE]) == - set([lpdm.delivery_mechanism.content_type for lpdm in pool.delivery_mechanisms])) - assert (set([DeliveryMechanism.ADOBE_DRM, DeliveryMechanism.KINDLE_DRM, DeliveryMechanism.LIBBY_DRM, DeliveryMechanism.STREAMING_DRM]) == - set([lpdm.delivery_mechanism.drm_scheme for lpdm in pool.delivery_mechanisms])) + assert set( + [ + MediaTypes.EPUB_MEDIA_TYPE, + DeliveryMechanism.KINDLE_CONTENT_TYPE, + DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, + MediaTypes.OVERDRIVE_EBOOK_MANIFEST_MEDIA_TYPE, + ] + ) == set( + [lpdm.delivery_mechanism.content_type for lpdm in pool.delivery_mechanisms] + ) + assert set( + [ + DeliveryMechanism.ADOBE_DRM, + DeliveryMechanism.KINDLE_DRM, + DeliveryMechanism.LIBBY_DRM, + DeliveryMechanism.STREAMING_DRM, + ] + ) == set( + [lpdm.delivery_mechanism.drm_scheme for lpdm in pool.delivery_mechanisms] + ) # The Edition's medium has been corrected. assert Edition.BOOK_MEDIUM == edition.medium @@ -1120,7 +1161,7 @@ def test_update_availability(self): identifier_type=Identifier.OVERDRIVE_ID, data_source_name=DataSource.OVERDRIVE, with_license_pool=True, - collection=self.collection + collection=self.collection, ) # We have never checked the circulation information for this @@ -1137,14 +1178,12 @@ def test_update_availability(self): ) # Since this is the first time we've seen this book, # we'll also be updating the bibliographic information. - ignore, bibliographic = self.sample_json( - "bibliographic_information.json" - ) + ignore, bibliographic = self.sample_json("bibliographic_information.json") # To avoid a mismatch, make it look like the information is # for the new pool's Identifier. - availability['id'] = pool.identifier.identifier - bibliographic['id'] = pool.identifier.identifier + availability["id"] = pool.identifier.identifier + bibliographic["id"] = pool.identifier.identifier self.api.queue_response(200, content=availability) self.api.queue_response(200, content=bibliographic) @@ -1176,10 +1215,10 @@ def test_circulation_lookup(self): request_url, ignore1, ignore2 = self.api.requests.pop() expect_url = self.api.endpoint( - self.api.AVAILABILITY_ENDPOINT, - collection_token=self.api.collection_token, - product_id="an-identifier" - ) + self.api.AVAILABILITY_ENDPOINT, + collection_token=self.api.collection_token, + product_id="an-identifier", + ) assert request_url == expect_url assert "/v2/collections" in request_url @@ -1203,9 +1242,7 @@ def test_circulation_lookup(self): def test_update_licensepool_error(self): # Create an identifier. - identifier = self._identifier( - identifier_type=Identifier.OVERDRIVE_ID - ) + identifier = self._identifier(identifier_type=Identifier.OVERDRIVE_ID) ignore, availability = self.sample_json( "overdrive_availability_information.json" ) @@ -1218,12 +1255,8 @@ def test_update_licensepool_not_found(self): # If the Overdrive API says a book is not found in the # collection, that's treated as useful information, not an error. # Create an identifier. - identifier = self._identifier( - identifier_type=Identifier.OVERDRIVE_ID - ) - ignore, not_found = self.sample_json( - "overdrive_availability_not_found.json" - ) + identifier = self._identifier(identifier_type=Identifier.OVERDRIVE_ID) + ignore, not_found = self.sample_json("overdrive_availability_not_found.json") # Queue the 'not found' response twice -- once for the circulation # lookup and once for the metadata lookup. @@ -1238,23 +1271,19 @@ def test_update_licensepool_not_found(self): def test_update_licensepool_provides_bibliographic_coverage(self): # Create an identifier. - identifier = self._identifier( - identifier_type=Identifier.OVERDRIVE_ID - ) + identifier = self._identifier(identifier_type=Identifier.OVERDRIVE_ID) # Prepare bibliographic and availability information # for this identifier. ignore, availability = self.sample_json( "overdrive_availability_information.json" ) - ignore, bibliographic = self.sample_json( - "bibliographic_information.json" - ) + ignore, bibliographic = self.sample_json("bibliographic_information.json") # To avoid a mismatch, make it look like the information is # for the newly created Identifier. - availability['id'] = identifier.identifier - bibliographic['id'] = identifier.identifier + availability["id"] = identifier.identifier + bibliographic["id"] = identifier.identifier self.api.queue_response(200, content=availability) self.api.queue_response(200, content=bibliographic) @@ -1266,22 +1295,22 @@ def test_update_licensepool_provides_bibliographic_coverage(self): # create an Edition and a presentation-ready Work. pool, was_new, changed = self.api.update_licensepool(identifier.identifier) assert True == was_new - assert availability['copiesOwned'] == pool.licenses_owned + assert availability["copiesOwned"] == pool.licenses_owned edition = pool.presentation_edition assert "Ancillary Justice" == edition.title assert True == pool.work.presentation_ready assert pool.work.cover_thumbnail_url.startswith( - 'http://images.contentreserve.com/' + "http://images.contentreserve.com/" ) # The book has been run through the bibliographic coverage # provider. coverage = [ - x for x in identifier.coverage_records - if x.operation is None - and x.data_source.name == DataSource.OVERDRIVE + x + for x in identifier.coverage_records + if x.operation is None and x.data_source.name == DataSource.OVERDRIVE ] assert 1 == len(coverage) @@ -1290,8 +1319,11 @@ def test_update_licensepool_provides_bibliographic_coverage(self): self._db.delete(pool.work) self._db.commit() pool, is_new = LicensePool.for_foreign_id( - self._db, DataSource.OVERDRIVE, Identifier.OVERDRIVE_ID, identifier.identifier, - collection=self.collection + self._db, + DataSource.OVERDRIVE, + Identifier.OVERDRIVE_ID, + identifier.identifier, + collection=self.collection, ) assert not pool.work self.api.queue_response(200, content=availability) @@ -1304,18 +1336,18 @@ def test_update_new_licensepool(self): data, raw = self.sample_json("overdrive_availability_information.json") # Create an identifier - identifier = self._identifier( - identifier_type=Identifier.OVERDRIVE_ID - ) + identifier = self._identifier(identifier_type=Identifier.OVERDRIVE_ID) # Make it look like the availability information is for the # newly created Identifier. - raw['reserveId'] = identifier.identifier + raw["reserveId"] = identifier.identifier pool, was_new = LicensePool.for_foreign_id( - self._db, DataSource.OVERDRIVE, - identifier.type, identifier.identifier, - collection=self.collection + self._db, + DataSource.OVERDRIVE, + identifier.type, + identifier.identifier, + collection=self.collection, ) pool, was_new, changed = self.api.update_licensepool_with_book_info( @@ -1326,10 +1358,10 @@ def test_update_new_licensepool(self): self._db.commit() - assert raw['copiesOwned'] == pool.licenses_owned - assert raw['copiesAvailable'] == pool.licenses_available + assert raw["copiesOwned"] == pool.licenses_owned + assert raw["copiesAvailable"] == pool.licenses_available assert 0 == pool.licenses_reserved - assert raw['numberOfHolds'] == pool.patrons_in_hold_queue + assert raw["numberOfHolds"] == pool.patrons_in_hold_queue def test_update_existing_licensepool(self): data, raw = self.sample_json("overdrive_availability_information.json") @@ -1338,12 +1370,12 @@ def test_update_existing_licensepool(self): wr, pool = self._edition( data_source_name=DataSource.OVERDRIVE, identifier_type=Identifier.OVERDRIVE_ID, - with_license_pool=True + with_license_pool=True, ) # Make it look like the availability information is for the # newly created LicensePool. - raw['id'] = pool.identifier.identifier + raw["id"] = pool.identifier.identifier wr.title = "The real title." assert 1 == pool.licenses_owned @@ -1360,12 +1392,14 @@ def test_update_existing_licensepool(self): # The title didn't change to that title given in the availability # information, because we already set a title for that work. assert "The real title." == wr.title - assert raw['copiesOwned'] == pool.licenses_owned - assert raw['copiesAvailable'] == pool.licenses_available + assert raw["copiesOwned"] == pool.licenses_owned + assert raw["copiesAvailable"] == pool.licenses_available assert 0 == pool.licenses_reserved - assert raw['numberOfHolds'] == pool.patrons_in_hold_queue + assert raw["numberOfHolds"] == pool.patrons_in_hold_queue - def test_update_new_licensepool_when_same_book_has_pool_in_different_collection(self): + def test_update_new_licensepool_when_same_book_has_pool_in_different_collection( + self, + ): old_edition, old_pool = self._edition( data_source_name=DataSource.OVERDRIVE, identifier_type=Identifier.OVERDRIVE_ID, @@ -1379,12 +1413,14 @@ def test_update_new_licensepool_when_same_book_has_pool_in_different_collection( # Make it look like the availability information is for the # old pool's Identifier. identifier = old_pool.identifier - raw['id'] = identifier.identifier + raw["id"] = identifier.identifier new_pool, was_new = LicensePool.for_foreign_id( - self._db, DataSource.OVERDRIVE, - identifier.type, identifier.identifier, - collection=collection + self._db, + DataSource.OVERDRIVE, + identifier.type, + identifier.identifier, + collection=collection, ) # The new pool doesn't have a presentation edition yet, # but it will be updated to share the old pool's edition. @@ -1400,14 +1436,15 @@ def test_update_new_licensepool_when_same_book_has_pool_in_different_collection( def test_update_licensepool_with_holds(self): data, raw = self.sample_json("overdrive_availability_information_holds.json") - identifier = self._identifier( - identifier_type=Identifier.OVERDRIVE_ID - ) - raw['id'] = identifier.identifier + identifier = self._identifier(identifier_type=Identifier.OVERDRIVE_ID) + raw["id"] = identifier.identifier license_pool, is_new = LicensePool.for_foreign_id( - self._db, DataSource.OVERDRIVE, identifier.type, - identifier.identifier, collection=self._default_collection + self._db, + DataSource.OVERDRIVE, + identifier.type, + identifier.identifier, + collection=self._default_collection, ) pool, was_new, changed = self.api.update_licensepool_with_book_info( raw, license_pool, is_new @@ -1420,7 +1457,7 @@ def test_refresh_patron_access_token(self): when refreshing a patron access token. """ patron = self._patron() - patron.authorization_identifier = 'barcode' + patron.authorization_identifier = "barcode" credential = self._credential(patron=patron) data, raw = self.sample_json("patron_token.json") @@ -1437,48 +1474,48 @@ def test_refresh_patron_access_token(self): with_pin, without_pin = self.api.access_token_requests url, payload, headers, kwargs = with_pin assert "https://oauth-patron.overdrive.com/patrontoken" == url - assert "barcode" == payload['username'] + assert "barcode" == payload["username"] expect_scope = "websiteid:%s authorizationname:%s" % ( - self.api.website_id.decode("utf-8"), self.api.ils_name(patron.library) + self.api.website_id.decode("utf-8"), + self.api.ils_name(patron.library), ) - assert expect_scope == payload['scope'] - assert "a pin" == payload['password'] - assert not 'password_required' in payload + assert expect_scope == payload["scope"] + assert "a pin" == payload["password"] + assert not "password_required" in payload url, payload, headers, kwargs = without_pin assert "https://oauth-patron.overdrive.com/patrontoken" == url - assert "barcode" == payload['username'] - assert expect_scope == payload['scope'] - assert "false" == payload['password_required'] - assert "[ignore]" == payload['password'] + assert "barcode" == payload["username"] + assert expect_scope == payload["scope"] + assert "false" == payload["password_required"] + assert "[ignore]" == payload["password"] class TestOverdriveAPICredentials(OverdriveAPITest): - def test_patron_correct_credentials_for_multiple_overdrive_collections(self): # Verify that the correct credential will be used # when a library has more than one OverDrive collection. def _optional_value(self, obj, key): - return obj.get(key, 'none') + return obj.get(key, "none") - def _make_token(scope, username, password, grant_type='password'): - return '%s|%s|%s|%s' % (grant_type, scope, username, password) + def _make_token(scope, username, password, grant_type="password"): + return "%s|%s|%s|%s" % (grant_type, scope, username, password) class MockAPI(MockOverdriveAPI): - def token_post(self, url, payload, headers={}, **kwargs): url = self.endpoint(url) self.access_token_requests.append((url, payload, headers, kwargs)) token = _make_token( - _optional_value(self, payload, 'scope'), - _optional_value(self, payload, 'username'), - _optional_value(self, payload, 'password'), - grant_type=_optional_value(self, payload, 'grant_type'), + _optional_value(self, payload, "scope"), + _optional_value(self, payload, "username"), + _optional_value(self, payload, "password"), + grant_type=_optional_value(self, payload, "grant_type"), ) response = self.mock_access_token_response(token) from core.util.http import HTTP + return HTTP._process_response(url, response, **kwargs) library = self._default_library @@ -1493,65 +1530,90 @@ def token_post(self, url, payload, headers={}, **kwargs): # library has membership. library_collection_properties = [ dict( - library=library, name="Test OD Collection 1", - client_key="client_key_1", client_secret="client_secret_1", - library_id="lib_id_1", website_id="ws_id_1", ils_name="lib1_coll1_ils" + library=library, + name="Test OD Collection 1", + client_key="client_key_1", + client_secret="client_secret_1", + library_id="lib_id_1", + website_id="ws_id_1", + ils_name="lib1_coll1_ils", ), dict( - library=library, name="Test OD Collection 2", - client_key="client_key_2", client_secret="client_secret_2", - library_id="lib_id_2", website_id="ws_id_2", ils_name="lib1_coll2_ils" - ) + library=library, + name="Test OD Collection 2", + client_key="client_key_2", + client_secret="client_secret_2", + library_id="lib_id_2", + website_id="ws_id_2", + ils_name="lib1_coll2_ils", + ), ] # These are the credentials we'll expect for each of our collections. expected_credentials = { - props['name']: _make_token( - 'websiteid:%s authorizationname:%s' % (props['website_id'], props['ils_name']), - patron.authorization_identifier, pin, + props["name"]: _make_token( + "websiteid:%s authorizationname:%s" + % (props["website_id"], props["ils_name"]), + patron.authorization_identifier, + pin, ) for props in library_collection_properties } # Add the collections. - collections = [MockAPI.mock_collection(self._db, **props) - for props in library_collection_properties] + collections = [ + MockAPI.mock_collection(self._db, **props) + for props in library_collection_properties + ] circulation = CirculationAPI( self._db, library, api_map={ExternalIntegration.OVERDRIVE: MockAPI} ) - od_apis = {api.collection.name: api - for api in list(circulation.api_for_collection.values())} + od_apis = { + api.collection.name: api + for api in list(circulation.api_for_collection.values()) + } # Ensure that we have the correct number of OverDrive collections. assert len(library_collection_properties) == len(od_apis) # Verify that the expected credentials match what we got. - for name in list(expected_credentials.keys()) + list(reversed(list(expected_credentials.keys()))): + for name in list(expected_credentials.keys()) + list( + reversed(list(expected_credentials.keys())) + ): credential = od_apis[name].get_patron_credential(patron, pin) assert expected_credentials[name] == credential.credential class TestExtractData(OverdriveAPITest): - def test_get_download_link(self): data, json = self.sample_json("checkout_response_locked_in_format.json") url = MockOverdriveAPI.get_download_link( - json, "ebook-epub-adobe", "http://foo.com/") - assert "http://patron.api.overdrive.com/v1/patrons/me/checkouts/76C1B7D0-17F4-4C05-8397-C66C17411584/formats/ebook-epub-adobe/downloadlink?errorpageurl=http://foo.com/" == url + json, "ebook-epub-adobe", "http://foo.com/" + ) + assert ( + "http://patron.api.overdrive.com/v1/patrons/me/checkouts/76C1B7D0-17F4-4C05-8397-C66C17411584/formats/ebook-epub-adobe/downloadlink?errorpageurl=http://foo.com/" + == url + ) pytest.raises( NoAcceptableFormat, MockOverdriveAPI.get_download_link, - json, "no-such-format", "http://foo.com/" + json, + "no-such-format", + "http://foo.com/", ) - def test_get_download_link_raises_exception_if_loan_fulfilled_on_incompatible_platform(self): + def test_get_download_link_raises_exception_if_loan_fulfilled_on_incompatible_platform( + self, + ): data, json = self.sample_json("checkout_response_book_fulfilled_on_kindle.json") pytest.raises( FulfilledOnIncompatiblePlatform, MockOverdriveAPI.get_download_link, - json, "ebook-epub-adobe", "http://foo.com/" + json, + "ebook-epub-adobe", + "http://foo.com/", ) def test_get_download_link_for_manifest_format(self): @@ -1561,7 +1623,7 @@ def test_get_download_link_for_manifest_format(self): # This is part of the URL from `json` that we expect # get_download_link to use as a base. - base_url = 'http://patron.api.overdrive.com/v1/patrons/me/checkouts/98EA8135-52C0-4480-9C0E-1D0779670D4A/formats/ebook-overdrive/downloadlink' + base_url = "http://patron.api.overdrive.com/v1/patrons/me/checkouts/98EA8135-52C0-4480-9C0E-1D0779670D4A/formats/ebook-overdrive/downloadlink" # First, let's ask for the streaming format. link = MockOverdriveAPI.get_download_link( @@ -1571,8 +1633,9 @@ def test_get_download_link_for_manifest_format(self): # The base URL is returned, with {errorpageurl} filled in and # {odreadauthurl} left for other code to fill in. assert ( - base_url + "?errorpageurl=http://foo.com/&odreadauthurl={odreadauthurl}" == - link) + base_url + "?errorpageurl=http://foo.com/&odreadauthurl={odreadauthurl}" + == link + ) # Now let's ask for the manifest format. link = MockOverdriveAPI.get_download_link( @@ -1581,7 +1644,7 @@ def test_get_download_link_for_manifest_format(self): # The {errorpageurl} and {odreadauthurl} parameters # have been removed, and contentfile=true has been appended. - assert base_url + '?contentfile=true' == link + assert base_url + "?contentfile=true" == link def test_extract_download_link(self): # Verify that extract_download_link can or cannot find a @@ -1589,10 +1652,12 @@ def test_extract_download_link(self): class Mock(OverdriveAPI): called_with = None + @classmethod def make_direct_download_link(cls, download_link): cls.called_with = download_link return "http://manifest/" + m = Mock.extract_download_link error_url = "http://error/" @@ -1603,27 +1668,21 @@ def make_direct_download_link(cls, download_link): assert "No linkTemplates for format (unknown)" in str(excinfo.value) # Here we know the name, but there are no link templates. - no_templates = dict(formatType='someformat') + no_templates = dict(formatType="someformat") with pytest.raises(IOError) as excinfo: m(no_templates, error_url) assert "No linkTemplates for format someformat" in str(excinfo.value) # Here there's a link template structure, but no downloadLink # inside. - no_download_link = dict( - formatType='someformat', - linkTemplates=dict() - ) + no_download_link = dict(formatType="someformat", linkTemplates=dict()) with pytest.raises(IOError) as excinfo: m(no_download_link, error_url) assert "No downloadLink for format someformat" in str(excinfo.value) # Here there's a downloadLink structure, but no href inside. href_is_missing = dict( - formatType='someformat', - linkTemplates=dict( - downloadLink=dict() - ) + formatType="someformat", linkTemplates=dict(downloadLink=dict()) ) with pytest.raises(IOError) as excinfo: m(href_is_missing, error_url) @@ -1634,12 +1693,10 @@ def make_direct_download_link(cls, download_link): # or not we want to return a link to the manifest file. working = dict( - formatType='someformat', + formatType="someformat", linkTemplates=dict( - downloadLink=dict( - href='http://download/?errorpageurl={errorpageurl}' - ) - ) + downloadLink=dict(href="http://download/?errorpageurl={errorpageurl}") + ), ) # If we don't want a manifest, make_direct_download_link is @@ -1648,14 +1705,12 @@ def make_direct_download_link(cls, download_link): assert None == Mock.called_with # The errorpageurl template is filled in. - assert ("http://download/?errorpageurl=http://error/" == - do_not_fetch_manifest) + assert "http://download/?errorpageurl=http://error/" == do_not_fetch_manifest # If we do want a manifest, make_direct_download_link is called # without errorpageurl being affected. do_fetch_manifest = m(working, error_url, fetch_manifest=True) - assert ("http://download/?errorpageurl={errorpageurl}" == - Mock.called_with) + assert "http://download/?errorpageurl={errorpageurl}" == Mock.called_with assert "http://manifest/" == do_fetch_manifest def test_make_direct_download_link(self): @@ -1665,30 +1720,40 @@ def test_make_direct_download_link(self): base = "http://overdrive/downloadlink" m = OverdriveAPI.make_direct_download_link assert base + "?contentfile=true" == m(base) - assert (base + "?contentfile=true" == - m(base + "?odreadauthurl={odreadauthurl}")) - assert (base + "?other=other&contentfile=true" == - m(base + "?odreadauthurl={odreadauthurl}&other=other")) + assert base + "?contentfile=true" == m(base + "?odreadauthurl={odreadauthurl}") + assert base + "?other=other&contentfile=true" == m( + base + "?odreadauthurl={odreadauthurl}&other=other" + ) def test_extract_data_from_checkout_resource(self): data, json = self.sample_json("checkout_response_locked_in_format.json") expires, url = MockOverdriveAPI.extract_data_from_checkout_response( - json, "ebook-epub-adobe", "http://foo.com/") + json, "ebook-epub-adobe", "http://foo.com/" + ) assert 2013 == expires.year assert 10 == expires.month assert 4 == expires.day - assert "http://patron.api.overdrive.com/v1/patrons/me/checkouts/76C1B7D0-17F4-4C05-8397-C66C17411584/formats/ebook-epub-adobe/downloadlink?errorpageurl=http://foo.com/" == url + assert ( + "http://patron.api.overdrive.com/v1/patrons/me/checkouts/76C1B7D0-17F4-4C05-8397-C66C17411584/formats/ebook-epub-adobe/downloadlink?errorpageurl=http://foo.com/" + == url + ) def test_process_checkout_data(self): - data, json = self.sample_json("shelf_with_book_already_fulfilled_on_kindle.json") + data, json = self.sample_json( + "shelf_with_book_already_fulfilled_on_kindle.json" + ) [on_kindle, not_on_kindle] = json["checkouts"] # The book already fulfilled on Kindle doesn't get turned into # LoanInfo at all. - assert None == MockOverdriveAPI.process_checkout_data(on_kindle, self.collection) + assert None == MockOverdriveAPI.process_checkout_data( + on_kindle, self.collection + ) # The book not yet fulfilled does show up as a LoanInfo. - loan_info = MockOverdriveAPI.process_checkout_data(not_on_kindle, self.collection) + loan_info = MockOverdriveAPI.process_checkout_data( + not_on_kindle, self.collection + ) assert "2fadd2ac-a8ec-4938-a369-4c3260e8922b" == loan_info.identifier # Since there are two usable formats (Adobe EPUB and Adobe @@ -1697,8 +1762,12 @@ def test_process_checkout_data(self): # A book that's on loan and locked to a specific format has a # DeliveryMechanismInfo associated with that format. - data, format_locked_in = self.sample_json("checkout_response_locked_in_format.json") - loan_info = MockOverdriveAPI.process_checkout_data(format_locked_in, self.collection) + data, format_locked_in = self.sample_json( + "checkout_response_locked_in_format.json" + ) + loan_info = MockOverdriveAPI.process_checkout_data( + format_locked_in, self.collection + ) delivery = loan_info.locked_to assert Representation.EPUB_MEDIA_TYPE == delivery.content_type assert DeliveryMechanism.ADOBE_DRM == delivery.drm_scheme @@ -1707,8 +1776,12 @@ def test_process_checkout_data(self): # EPUB has not yet been made, but as far as we're concerned, # Adobe EPUB is the only *usable* format, so it's effectively # locked. - data, no_format_locked_in = self.sample_json("checkout_response_no_format_locked_in.json") - loan_info = MockOverdriveAPI.process_checkout_data(no_format_locked_in, self.collection) + data, no_format_locked_in = self.sample_json( + "checkout_response_no_format_locked_in.json" + ) + loan_info = MockOverdriveAPI.process_checkout_data( + no_format_locked_in, self.collection + ) assert loan_info != None delivery = loan_info.locked_to assert Representation.EPUB_MEDIA_TYPE == delivery.content_type @@ -1718,10 +1791,12 @@ def test_process_checkout_data(self): # LoanInfo with appropriate FulfillmentInfo. The calling code # would then decide whether or not to show the loan. -class TestSyncBookshelf(OverdriveAPITest): +class TestSyncBookshelf(OverdriveAPITest): def test_sync_bookshelf_creates_local_loans(self): - loans_data, json_loans = self.sample_json("shelf_with_some_checked_out_books.json") + loans_data, json_loans = self.sample_json( + "shelf_with_some_checked_out_books.json" + ) holds_data, json_holds = self.sample_json("no_holds.json") self.api.queue_response(200, content=loans_data) @@ -1736,13 +1811,18 @@ def test_sync_bookshelf_creates_local_loans(self): # We have created previously unknown LicensePools and # Identifiers. - identifiers = [loan.license_pool.identifier.identifier - for loan in loans] - assert (sorted(['a5a3d737-34d4-4d69-aad8-eba4e46019a3', - '99409f99-45a5-4238-9e10-98d1435cde04', - '993e4b33-823c-40af-8f61-cac54e1cba5d', - 'a2ec6f3a-ebfe-4c95-9638-2cb13be8de5a']) == - sorted(identifiers)) + identifiers = [loan.license_pool.identifier.identifier for loan in loans] + assert ( + sorted( + [ + "a5a3d737-34d4-4d69-aad8-eba4e46019a3", + "99409f99-45a5-4238-9e10-98d1435cde04", + "993e4b33-823c-40af-8f61-cac54e1cba5d", + "a2ec6f3a-ebfe-4c95-9638-2cb13be8de5a", + ] + ) + == sorted(identifiers) + ) # We have recorded a new DeliveryMechanism associated with # each loan. @@ -1750,17 +1830,13 @@ def test_sync_bookshelf_creates_local_loans(self): for loan in loans: if loan.fulfillment: mechanism = loan.fulfillment.delivery_mechanism - mechanisms.append( - (mechanism.content_type, mechanism.drm_scheme) - ) - assert ( - [ - (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM), - (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), - (Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), - (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), - ] == - mechanisms) + mechanisms.append((mechanism.content_type, mechanism.drm_scheme)) + assert [ + (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM), + (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), + (Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), + (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM), + ] == mechanisms # There are no holds. assert [] == holds @@ -1774,7 +1850,9 @@ def test_sync_bookshelf_creates_local_loans(self): assert loans.sort() == patron.loans.sort() def test_sync_bookshelf_removes_loans_not_present_on_remote(self): - loans_data, json_loans = self.sample_json("shelf_with_some_checked_out_books.json") + loans_data, json_loans = self.sample_json( + "shelf_with_some_checked_out_books.json" + ) holds_data, json_holds = self.sample_json("no_holds.json") self.api.queue_response(200, content=loans_data) @@ -1784,11 +1862,12 @@ def test_sync_bookshelf_removes_loans_not_present_on_remote(self): patron = self._patron() overdrive_edition, new = self._edition( data_source_name=DataSource.OVERDRIVE, - with_license_pool=True, collection=self.collection + with_license_pool=True, + collection=self.collection, ) [pool] = overdrive_edition.license_pools overdrive_loan, new = pool.loan_to(patron) - yesterday = utc_now()- timedelta(days=1) + yesterday = utc_now() - timedelta(days=1) overdrive_loan.start = yesterday # Sync with Overdrive, and the loan not present in the sample @@ -1801,11 +1880,14 @@ def test_sync_bookshelf_removes_loans_not_present_on_remote(self): def test_sync_bookshelf_ignores_loans_from_other_sources(self): patron = self._patron() - gutenberg, new = self._edition(data_source_name=DataSource.GUTENBERG, - with_license_pool=True) + gutenberg, new = self._edition( + data_source_name=DataSource.GUTENBERG, with_license_pool=True + ) [pool] = gutenberg.license_pools gutenberg_loan, new = pool.loan_to(patron) - loans_data, json_loans = self.sample_json("shelf_with_some_checked_out_books.json") + loans_data, json_loans = self.sample_json( + "shelf_with_some_checked_out_books.json" + ) holds_data, json_holds = self.sample_json("no_holds.json") # Overdrive doesn't know about the Gutenberg loan, but it was @@ -1847,12 +1929,11 @@ def test_sync_bookshelf_removes_holds_not_present_on_remote(self): overdrive_edition, new = self._edition( data_source_name=DataSource.OVERDRIVE, with_license_pool=True, - collection=self.collection + collection=self.collection, ) [pool] = overdrive_edition.license_pools overdrive_hold, new = pool.on_hold_to(patron) - self.api.queue_response(200, content=loans_data) self.api.queue_response(200, content=holds_data) @@ -1874,7 +1955,7 @@ def test_sync_bookshelf_ignores_holds_from_other_collections(self): overdrive, new = self._edition( data_source_name=DataSource.OVERDRIVE, with_license_pool=True, - collection=self._collection() + collection=self._collection(), ) [pool] = overdrive.license_pools overdrive_hold, new = pool.on_hold_to(patron) @@ -1890,30 +1971,30 @@ def test_sync_bookshelf_ignores_holds_from_other_collections(self): class TestOverdriveManifestFulfillmentInfo(OverdriveAPITest): - def test_as_response(self): # An OverdriveManifestFulfillmentInfo just links the client # directly to the manifest file, bypassing normal FulfillmentInfo # processing. info = OverdriveManifestFulfillmentInfo( - self._default_collection, "http://content-link/", - "abcd-efgh", "scope string" + self._default_collection, + "http://content-link/", + "abcd-efgh", + "scope string", ) response = info.as_response assert 302 == response.status_code assert "" == response.get_data(as_text=True) headers = response.headers - assert "text/plain" == headers['Content-Type'] + assert "text/plain" == headers["Content-Type"] # These are the important headers; the location of the manifest file # and the scope necessary to initiate Patron Authentication for # it. - assert "scope string" == headers['X-Overdrive-Scope'] - assert "http://content-link/" == headers['Location'] + assert "scope string" == headers["X-Overdrive-Scope"] + assert "http://content-link/" == headers["Location"] class TestOverdriveCirculationMonitor(OverdriveAPITest): - def test_run(self): # An end-to-end test verifying that this Monitor manages its # state across multiple runs. @@ -1938,7 +2019,7 @@ def catch_up_from(self, start, cutoff, progress): # # (This isn't how the Overdrive collection is initially # populated, BTW -- that's NewTitlesOverdriveCollectionMonitor.) - self.time_eq(start, now-monitor.OVERLAP) + self.time_eq(start, now - monitor.OVERLAP) self.time_eq(cutoff, now) timestamp = monitor.timestamp() assert start == timestamp.start @@ -1949,7 +2030,7 @@ def catch_up_from(self, start, cutoff, progress): monitor.run() new_start, new_cutoff, new_progress = monitor.catch_up_from_called_with now = utc_now() - assert new_start == cutoff-monitor.OVERLAP + assert new_start == cutoff - monitor.OVERLAP self.time_eq(new_cutoff, now) def test_catch_up_from(self): @@ -1976,7 +2057,7 @@ def update_licensepool(self, book_id): class MockAnalytics(object): def __init__(self, _db): - self._db= _db + self._db = _db self.events = [] def collect_event(self, *args): @@ -1986,6 +2067,7 @@ class MockMonitor(OverdriveCirculationMonitor): recently_changed_ids_called_with = None should_stop_calls = [] + def recently_changed_ids(self, start, cutoff): self.recently_changed_ids_called_with = (start, cutoff) return [1, 2, None, 3, 4] @@ -1999,8 +2081,9 @@ def should_stop(self, start, book, is_changed): return True return False - monitor = MockMonitor(self._db, self.collection, api_class=MockAPI, - analytics_class=MockAnalytics) + monitor = MockMonitor( + self._db, self.collection, api_class=MockAPI, analytics_class=MockAnalytics + ) api = monitor.api # A MockAnalytics object was created and is ready to receive analytics @@ -2035,17 +2118,17 @@ def should_stop(self, start, book, is_changed): # update_licensepool on the first three valid 'books'. The # mock API delivered the first three LicensePools from the # queue. - assert [(1, lp1),(2, lp2),(3, lp3)] == api.update_licensepool_calls + assert [(1, lp1), (2, lp2), (3, lp3)] == api.update_licensepool_calls # After each book was processed, should_stop was called, using # the LicensePool, the start date, plus information about # whether the LicensePool was changed (or created) during # update_licensepool(). - assert ( - [(start, 1, True), - (start, 2, False), - (start, 3, True)] == - monitor.should_stop_calls) + assert [ + (start, 1, True), + (start, 2, False), + (start, 3, True), + ] == monitor.should_stop_calls # should_stop returned True on the third call, and at that # point we gave up. @@ -2074,11 +2157,11 @@ def should_stop(self, start, book, is_changed): class TestNewTitlesOverdriveCollectionMonitor(OverdriveAPITest): - def test_recently_changed_ids(self): class MockAPI(object): def __init__(self, *args, **kwargs): pass + def all_ids(self): return "all of the ids" @@ -2103,19 +2186,22 @@ def test_should_stop(self): # should keep going. start = datetime_utc(2018, 1, 1) assert False == m(start, {}, object()) - assert False == m(start, {'date_added': None}, object()) - assert False == m(start, {'date_added': "Not a date"}, object()) + assert False == m(start, {"date_added": None}, object()) + assert False == m(start, {"date_added": "Not a date"}, object()) # Here, we're actually comparing real dates, using the date # format found in the Overdrive API. A date that's after the # `start` date means we should keep going backwards. A date before # the `start` date means we should stop. - assert False == m(start, {'date_added': '2019-07-12T11:06:38.157+01:00'}, object()) - assert True == m(start, {'date_added': '2017-07-12T11:06:38.157-04:00'}, object()) + assert False == m( + start, {"date_added": "2019-07-12T11:06:38.157+01:00"}, object() + ) + assert True == m( + start, {"date_added": "2017-07-12T11:06:38.157-04:00"}, object() + ) class TestNewTitlesOverdriveCollectionMonitor(OverdriveAPITest): - def test_should_stop(self): monitor = RecentOverdriveCollectionMonitor( self._db, self.collection, api_class=MockOverdriveAPI @@ -2137,19 +2223,21 @@ def test_should_stop(self): # When we're at the limit, and another book comes along that hasn't # been changed, _then_ we decide to stop. - monitor.consecutive_unchanged_books = monitor.MAXIMUM_CONSECUTIVE_UNCHANGED_BOOKS + monitor.consecutive_unchanged_books = ( + monitor.MAXIMUM_CONSECUTIVE_UNCHANGED_BOOKS + ) assert True == m(object(), object(), False) - assert (monitor.MAXIMUM_CONSECUTIVE_UNCHANGED_BOOKS+1 == - monitor.consecutive_unchanged_books) + assert ( + monitor.MAXIMUM_CONSECUTIVE_UNCHANGED_BOOKS + 1 + == monitor.consecutive_unchanged_books + ) class TestOverdriveFormatSweep(OverdriveAPITest): - def test_process_item(self): # Validate the standard CollectionMonitor interface. monitor = OverdriveFormatSweep( - self._db, self.collection, - api_class=MockOverdriveAPI + self._db, self.collection, api_class=MockOverdriveAPI ) monitor.api.queue_collection_token() # We're not testing that the work actually gets done (that's @@ -2165,13 +2253,11 @@ def test_process_item_multiple_licence_pools(self): class MockApi(MockOverdriveAPI): update_format_calls = 0 + def update_formats(self, licensepool): self.update_format_calls += 1 - monitor = OverdriveFormatSweep( - self._db, self.collection, - api_class=MockApi - ) + monitor = OverdriveFormatSweep(self._db, self.collection, api_class=MockApi) monitor.api.queue_collection_token() monitor.api.queue_response(404) @@ -2187,10 +2273,8 @@ def update_formats(self, licensepool): class TestReaper(OverdriveAPITest): - def test_instantiate(self): # Validate the standard CollectionMonitor interface. monitor = OverdriveCollectionReaper( - self._db, self.collection, - api_class=MockOverdriveAPI + self._db, self.collection, api_class=MockOverdriveAPI ) diff --git a/tests/test_patron_utility.py b/tests/test_patron_utility.py index 9d8b2336d1..4e04e93b59 100644 --- a/tests/test_patron_utility.py +++ b/tests/test_patron_utility.py @@ -1,23 +1,20 @@ import datetime -import dateutil from decimal import Decimal +import dateutil import pytest -from core.testing import ( - DatabaseTest, -) -from api.config import Configuration, temp_config from api.authenticator import PatronData -from api.util.patron import PatronUtility from api.circulation_exceptions import * -from core.util.datetime_helpers import utc_now +from api.config import Configuration, temp_config +from api.util.patron import PatronUtility from core.model import ConfigurationSetting +from core.testing import DatabaseTest from core.util import MoneyUtility +from core.util.datetime_helpers import utc_now class TestPatronUtility(DatabaseTest): - def test_needs_external_sync(self): """Test the method that encapsulates the determination of whether or not a patron needs to have their account @@ -84,8 +81,7 @@ def test_has_borrowing_privileges(self): patron.authorization_expires = one_day_ago assert False == PatronUtility.has_borrowing_privileges(patron) pytest.raises( - AuthorizationExpired, - PatronUtility.assert_borrowing_privileges, patron + AuthorizationExpired, PatronUtility.assert_borrowing_privileges, patron ) patron.authorization_expires = None @@ -96,12 +92,10 @@ class Mock(PatronUtility): def has_excess_fines(cls, patron): cls.called_with = patron return True + assert False == Mock.has_borrowing_privileges(patron) assert patron == Mock.called_with - pytest.raises( - OutstandingFines, - Mock.assert_borrowing_privileges, patron - ) + pytest.raises(OutstandingFines, Mock.assert_borrowing_privileges, patron) # Even if the circulation manager is not configured to know # what "excessive fines" are, the authentication mechanism @@ -109,8 +103,7 @@ def has_excess_fines(cls, patron): # patron's block_reason. patron.block_reason = PatronData.EXCESSIVE_FINES pytest.raises( - OutstandingFines, - PatronUtility.assert_borrowing_privileges, patron + OutstandingFines, PatronUtility.assert_borrowing_privileges, patron ) # If your card is blocked for any reason you lose borrowing @@ -118,8 +111,7 @@ def has_excess_fines(cls, patron): patron.block_reason = "some reason" assert False == PatronUtility.has_borrowing_privileges(patron) pytest.raises( - AuthorizationBlocked, - PatronUtility.assert_borrowing_privileges, patron + AuthorizationBlocked, PatronUtility.assert_borrowing_privileges, patron ) patron.block_reason = None @@ -131,8 +123,7 @@ def test_has_excess_fines(self): # If you accrue excessive fines you lose borrowing privileges. setting = ConfigurationSetting.for_library( - Configuration.MAX_OUTSTANDING_FINES, - self._default_library + Configuration.MAX_OUTSTANDING_FINES, self._default_library ) # Verify that all these tests work no matter what data type has been stored in @@ -142,20 +133,22 @@ def test_has_excess_fines(self): # Test cases where the patron's fines exceed a well-defined limit, # or when any amount of fines is too much. - for max_fines in ( - ["$0.50", "0.5", .5] + # well-defined limit - ["$0", "$0.00", "0", 0] # any fines is too much - ): + for max_fines in ["$0.50", "0.5", 0.5] + [ # well-defined limit + "$0", + "$0.00", + "0", + 0, + ]: # any fines is too much setting.value = max_fines assert True == PatronUtility.has_excess_fines(patron) # Test cases where the patron's fines are below a # well-defined limit, or where fines are ignored # altogether. - for max_fines in ( - ["$100", 100] + # well-defined-limit - [None, ""] # fines ignored - ): + for max_fines in ["$100", 100] + [ # well-defined-limit + None, + "", + ]: # fines ignored setting.value = max_fines assert False == PatronUtility.has_excess_fines(patron) diff --git a/tests/test_registry.py b/tests/test_registry.py index 33cfc807b3..9644da757c 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -1,38 +1,24 @@ -import pytest - +import base64 import json -from Crypto.PublicKey import RSA -from Crypto.Cipher import PKCS1_OAEP import os -from core.testing import ( - DatabaseTest -) -from core.testing import ( - DummyHTTPClient, - MockRequestsResponse, -) -from core.util.http import HTTP -from core.util.problem_detail import ( - ProblemDetail, - JSON_MEDIA_TYPE as PROBLEM_DETAIL_JSON_MEDIA_TYPE, -) -from core.model import ( - ConfigurationSetting, - ExternalIntegration, -) from pdb import set_trace -import base64 + +import pytest +from Crypto.Cipher import PKCS1_OAEP +from Crypto.PublicKey import RSA + from api.adobe_vendor_id import AuthdataUtility from api.config import Configuration from api.problem_details import * -from api.registry import ( - RemoteRegistry, - Registration, - LibraryRegistrationScript, -) +from api.registry import LibraryRegistrationScript, Registration, RemoteRegistry +from core.model import ConfigurationSetting, ExternalIntegration +from core.testing import DatabaseTest, DummyHTTPClient, MockRequestsResponse +from core.util.http import HTTP +from core.util.problem_detail import JSON_MEDIA_TYPE as PROBLEM_DETAIL_JSON_MEDIA_TYPE +from core.util.problem_detail import ProblemDetail -class TestRemoteRegistry(DatabaseTest): +class TestRemoteRegistry(DatabaseTest): def setup_method(self): super(TestRemoteRegistry, self).setup_method() @@ -52,9 +38,7 @@ def test_for_integration_id(self): """ m = RemoteRegistry.for_integration_id - registry = m( - self._db, self.integration.id, ExternalIntegration.DISCOVERY_GOAL - ) + registry = m(self._db, self.integration.id, ExternalIntegration.DISCOVERY_GOAL) assert isinstance(registry, RemoteRegistry) assert self.integration == registry.integration @@ -148,17 +132,16 @@ def _extract_catalog_information(self, response): # method is the return value of fetch_catalog. client.queue_requests_response(200, content="A root catalog") [queued] = client.responses - assert ("Essential information" == - registry.fetch_catalog("custom catalog URL", do_get=client.do_get)) + assert "Essential information" == registry.fetch_catalog( + "custom catalog URL", do_get=client.do_get + ) assert "custom catalog URL" == client.requests.pop() def test__extract_catalog_information(self): # Test our ability to extract a registration link and an # Adobe Vendor ID from an OPDS 1 or OPDS 2 catalog. def extract(document, type=RemoteRegistry.OPDS_2_TYPE): - response = MockRequestsResponse( - 200, { "Content-Type" : type }, document - ) + response = MockRequestsResponse(200, {"Content-Type": type}, document) return RemoteRegistry._extract_catalog_information(response) def assert_no_link(*args, **kwargs): @@ -167,12 +150,14 @@ def assert_no_link(*args, **kwargs): """ result = extract(*args, **kwargs) assert REMOTE_INTEGRATION_FAILED.uri == result.uri - assert ("The service at http://url/ did not provide a register link." == - result.detail) + assert ( + "The service at http://url/ did not provide a register link." + == result.detail + ) # OPDS 2 feed with link and Adobe Vendor ID. - link = { 'rel': 'register', 'href': 'register url' } - metadata = { 'adobe_vendor_id': 'vendorid' } + link = {"rel": "register", "href": "register url"} + metadata = {"adobe_vendor_id": "vendorid"} feed = json.dumps(dict(links=[link], metadata=metadata)) assert ("register url", "vendorid") == extract(feed) @@ -186,18 +171,18 @@ def assert_no_link(*args, **kwargs): # OPDS 1 feed with link. feed = '' - assert (("register url", None) == - extract(feed, RemoteRegistry.OPDS_1_PREFIX + ";foo")) + assert ("register url", None) == extract( + feed, RemoteRegistry.OPDS_1_PREFIX + ";foo" + ) # OPDS 1 feed with no link. - feed = '' + feed = "" assert_no_link(feed, RemoteRegistry.OPDS_1_PREFIX + ";foo") # Non-OPDS document. result = extract("plain text here", "text/plain") assert REMOTE_INTEGRATION_FAILED.uri == result.uri - assert ("The service at http://url/ did not return OPDS." == - result.detail) + assert "The service at http://url/ did not return OPDS." == result.detail def test_fetch_registration_document(self): # Test our ability to retrieve terms-of-service information @@ -227,6 +212,7 @@ def fetch_catalog(self, do_get): # get the registration document. client = DummyHTTPClient() client.responses.append(REMOTE_INTEGRATION_FAILED) + class Mock(RemoteRegistry): def fetch_catalog(self, do_get): return "http://register-here/", "vendor id" @@ -272,8 +258,7 @@ def test__extract_registration_information(self): def data_link(data, type="text/html"): encoded = base64.b64encode(data.encode("utf-8")).decode("utf-8") return dict( - rel="terms-of-service", - href="data:%s;base64,%s" % (type, encoded) + rel="terms-of-service", href="data:%s;base64,%s" % (type, encoded) ) class Mock(RemoteRegistry): @@ -285,22 +270,21 @@ def _decode_data_url(cls, url): def extract(document, type=RemoteRegistry.OPDS_2_TYPE): if type == RemoteRegistry.OPDS_2_TYPE: document = json.dumps(dict(links=document)) - response = MockRequestsResponse( - 200, { "Content-Type" : type }, document - ) + response = MockRequestsResponse(200, {"Content-Type": type}, document) return Mock._extract_registration_information(response) # OPDS 2 feed with TOS in http: and data: links. - tos_link = dict(rel='terms-of-service', href='http://tos/') + tos_link = dict(rel="terms-of-service", href="http://tos/") tos_data = data_link("

    Some HTML

    ") - assert (("http://tos/", "Decoded:

    Some HTML

    ") == - extract([tos_link, tos_data])) + assert ("http://tos/", "Decoded:

    Some HTML

    ") == extract( + [tos_link, tos_data] + ) # At this point it's clear that the data: URL found in # `tos_data` was run through `_decode_data()`. This gives us # permission to test all the fiddly bits of `_decode_data` in # isolation, below. - assert tos_data['href'] == Mock.decoded + assert tos_data["href"] == Mock.decoded # OPDS 2 feed with http: link only. assert ("http://tos/", None) == extract([tos_link]) @@ -313,18 +297,19 @@ def extract(document, type=RemoteRegistry.OPDS_2_TYPE): # OPDS 1 feed with link. feed = '' - assert (("http://tos/", None) == - extract(feed, RemoteRegistry.OPDS_1_PREFIX + ";foo")) + assert ("http://tos/", None) == extract( + feed, RemoteRegistry.OPDS_1_PREFIX + ";foo" + ) # OPDS 1 feed with no link. - feed = '' + feed = "" assert (None, None) == extract(feed, RemoteRegistry.OPDS_1_PREFIX + ";foo") # Non-OPDS document. assert (None, None) == extract("plain text here", "text/plain") # Unrecognized URI schemes are ignored. - ftp_link = dict(rel='terms-of-service', href='ftp://tos/') + ftp_link = dict(rel="terms-of-service", href="ftp://tos/") assert (None, None) == extract([ftp_link]) def test__decode_data_url(self): @@ -378,7 +363,6 @@ def data_url(data, type="text/html"): class TestRegistration(DatabaseTest): - def setup_method(self): super(TestRegistration, self).setup_method() @@ -396,10 +380,10 @@ def test_constructor(self): assert self.registry == reg.registry assert self._default_library == reg.library - settings = [x for x in reg.integration.settings - if x.library is not None] - assert (set([reg.status_field, reg.stage_field, reg.web_client_field]) == - set(settings)) + settings = [x for x in reg.integration.settings if x.library is not None] + assert set([reg.status_field, reg.stage_field, reg.web_client_field]) == set( + settings + ) assert Registration.FAILURE_STATUS == reg.status_field.value assert Registration.TESTING_STAGE == reg.stage_field.value assert None == reg.web_client_field.value @@ -429,8 +413,9 @@ def _find(key): store the default value. """ values = [ - x for x in self.registration.integration.settings - if x.library and x.key==key + x + for x in self.registration.integration.settings + if x.library and x.key == key ] if len(values) == 1: return values[0] @@ -458,7 +443,6 @@ def test_push(self): # Test the other methods orchestrated by the push() method. class MockRegistry(RemoteRegistry): - def fetch_catalog(self, catalog_url, do_get): # Pretend to fetch a root catalog and extract a # registration URL from it. @@ -466,7 +450,6 @@ def fetch_catalog(self, catalog_url, do_get): return "register_url", "vendor_id" class MockRegistration(Registration): - def _create_registration_payload(self, url_for, stage): self.payload_ingredients = (url_for, stage) return dict(payload="this is it") @@ -476,18 +459,21 @@ def _create_registration_headers(self): return dict(Header="Value") def _send_registration_request( - self, register_url, headers, payload, do_post + self, register_url, headers, payload, do_post ): self._send_registration_request_called_with = ( - register_url, headers, payload, do_post - ) - return MockRequestsResponse( - 200, content=json.dumps("you did it!") + register_url, + headers, + payload, + do_post, ) + return MockRequestsResponse(200, content=json.dumps("you did it!")) def _process_registration_result(self, catalog, encryptor, stage): self._process_registration_result_called_with = ( - catalog, encryptor, stage + catalog, + encryptor, + stage, ) return "all done!" @@ -503,10 +489,9 @@ def _process_registration_result(self, catalog, encryptor, stage): catalog_url = "http://catalog/" do_get = object() do_post = object() + def push(): - return registration.push( - stage, url_for, catalog_url, do_get, do_post - ) + return registration.push(stage, url_for, catalog_url, do_get, do_post) result = push() expect = "Library %s has no key pair set." % library.short_name @@ -533,10 +518,11 @@ def push(): # The vendor ID was set as a ConfigurationSetting on # the ExternalIntegration associated with this registry. assert ( - "vendor_id" == - ConfigurationSetting.for_externalintegration( + "vendor_id" + == ConfigurationSetting.for_externalintegration( AuthdataUtility.VENDOR_ID_KEY, self.integration - ).value) + ).value + ) # _create_registration_payload was called to create the body # of the registration request. @@ -550,9 +536,11 @@ def push(): # payload to "register_url", the registration URL we got earlier. results = registration._send_registration_request_called_with assert ( - ("register_url", {"Header": "Value"}, dict(payload="this is it"), - do_post) == - results) + "register_url", + {"Header": "Value"}, + dict(payload="this is it"), + do_post, + ) == results # Finally, the return value of that method was loaded as JSON # and passed into _process_registration_result, along with @@ -570,8 +558,7 @@ def push(): "no such stage", url_for, catalog_url, do_get, do_post ) assert INVALID_INPUT.uri == result.uri - assert ("'no such stage' is not a valid registration stage" == - result.detail) + assert "'no such stage' is not a valid registration stage" == result.detail # Now in reverse order, let's replace the mocked methods so # that they return ProblemDetail documents. This tests that if @@ -585,14 +572,14 @@ def fail(*args, **kwargs): return INVALID_REGISTRATION.detailed( "could not process registration result" ) + registration._process_registration_result = fail problem = cause_problem() assert "could not process registration result" == problem.detail def fail(*args, **kwargs): - return INVALID_REGISTRATION.detailed( - "could not send registration request" - ) + return INVALID_REGISTRATION.detailed("could not send registration request") + registration._send_registration_request = fail problem = cause_problem() assert "could not send registration request" == problem.detail @@ -601,12 +588,14 @@ def fail(*args, **kwargs): return INVALID_REGISTRATION.detailed( "could not create registration payload" ) + registration._create_registration_payload = fail problem = cause_problem() assert "could not create registration payload" == problem.detail def fail(*args, **kwargs): return INVALID_REGISTRATION.detailed("could not fetch catalog") + registry.fetch_catalog = fail problem = cause_problem() assert "could not fetch catalog" == problem.detail @@ -632,8 +621,8 @@ def url_for(controller, library_short_name): ConfigurationSetting.for_library( Configuration.CONFIGURATION_CONTACT_EMAIL, self.registration.library, - ).value=contact - expect_payload['contact'] = contact + ).value = contact + expect_payload["contact"] = contact assert expect_payload == m(url_for, stage) def test_create_registration_headers(self): @@ -645,10 +634,12 @@ def test_create_registration_headers(self): # If a shared secret is configured, it shows up as part of # the Authorization header. setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, ExternalIntegration.PASSWORD, self.registration.library, - self.registration.registry.integration - ).value="a secret" - expect_headers['Authorization'] = 'Bearer a secret' + self._db, + ExternalIntegration.PASSWORD, + self.registration.library, + self.registration.registry.integration, + ).value = "a secret" + expect_headers["Authorization"] = "Bearer a secret" assert expect_headers == m() def test__send_registration_request(self): @@ -670,14 +661,15 @@ def do_post(self, url, payload, **kwargs): result = m(url, headers, payload, mock.do_post) assert mock.response == result called_with = mock.called_with - assert (called_with == - (url, payload, - dict( - headers=headers, - timeout=60, - allowed_response_codes=["2xx", "3xx", "400", "401"] - ) - )) + assert called_with == ( + url, + payload, + dict( + headers=headers, + timeout=60, + allowed_response_codes=["2xx", "3xx", "400", "401"], + ), + ) # Most error handling is expected to be handled by do_post # raising an exception, but certain responses get special @@ -686,21 +678,20 @@ def do_post(self, url, payload, **kwargs): # The remote sends a 401 response with a problem detail. mock = Mock( MockRequestsResponse( - 401, { "Content-Type": PROBLEM_DETAIL_JSON_MEDIA_TYPE }, - content=json.dumps(dict(detail="this is a problem detail")) + 401, + {"Content-Type": PROBLEM_DETAIL_JSON_MEDIA_TYPE}, + content=json.dumps(dict(detail="this is a problem detail")), ) ) result = m(url, headers, payload, mock.do_post) assert isinstance(result, ProblemDetail) assert REMOTE_INTEGRATION_FAILED.uri == result.uri - assert ('Remote service returned: "this is a problem detail"' == - result.detail) + assert 'Remote service returned: "this is a problem detail"' == result.detail # The remote sends some other kind of 401 response. mock = Mock( MockRequestsResponse( - 401, { "Content-Type": "text/html" }, - content="log in why don't you" + 401, {"Content-Type": "text/html"}, content="log in why don't you" ) ) result = m(url, headers, payload, mock.do_post) @@ -738,7 +729,10 @@ def test__process_registration_result(self): # Result must be a dictionary. result = m("not a dictionary", None, None) assert INTEGRATION_ERROR.uri == result.uri - assert "Remote service served 'not a dictionary', which I can't make sense of as an OPDS document." == result.detail + assert ( + "Remote service served 'not a dictionary', which I can't make sense of as an OPDS document." + == result.detail + ) # When the result is empty, the registration is marked as successful. new_stage = "new stage" @@ -776,7 +770,10 @@ def _decrypt_shared_secret(self, encryptor, shared_secret): assert "👉 cleartext 👈" == reg.setting(ExternalIntegration.PASSWORD).value # Web client URL is set. - assert "http://web/library" == reg.setting(reg.LIBRARY_REGISTRATION_WEB_CLIENT).value + assert ( + "http://web/library" + == reg.setting(reg.LIBRARY_REGISTRATION_WEB_CLIENT).value + ) assert "another new stage" == reg.stage_field.value @@ -784,6 +781,7 @@ def _decrypt_shared_secret(self, encryptor, shared_secret): class Mock(Registration): def _decrypt_shared_secret(self, encryptor, shared_secret): return SHARED_SECRET_DECRYPTION_ERROR + reg = Mock(self.registry, self._default_library) result = reg._process_registration_result( catalog, encryptor, "another new stage" @@ -792,15 +790,14 @@ def _decrypt_shared_secret(self, encryptor, shared_secret): class TestLibraryRegistrationScript(DatabaseTest): - def setup_method(self): """Make sure there's a base URL for url_for to use.""" super(TestLibraryRegistrationScript, self).setup_method() def test_do_run(self): - class Mock(LibraryRegistrationScript): processed = [] + def process_library(self, *args): self.processed.append(args) @@ -809,13 +806,16 @@ def process_library(self, *args): base_url_setting = ConfigurationSetting.sitewide( self._db, Configuration.BASE_URL_KEY ) - base_url_setting.value = 'http://test-circulation-manager/' + base_url_setting.value = "http://test-circulation-manager/" library = self._default_library library2 = self._library() - cmd_args = [library.short_name, "--stage=testing", - "--registry-url=http://registry/"] + cmd_args = [ + library.short_name, + "--stage=testing", + "--registry-url=http://registry/", + ] app = script.do_run(cmd_args=cmd_args, in_unit_test=True) # One library was processed. @@ -842,8 +842,7 @@ def process_library(self, *args): app = script.do_run(cmd_args=[], in_unit_test=True) # Every library was processed. - assert (set([library, library2]) == - set([x[0].library for x in script.processed])) + assert set([library, library2]) == set([x[0].library for x in script.processed]) for i in script.processed: # Since no stage was provided, each library was registered @@ -852,9 +851,7 @@ def process_library(self, *args): # Every library was registered with the default # library registry. - assert ( - RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_URL == - i[0].integration.url) + assert RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_URL == i[0].integration.url def test_process_library(self): """Test the things that might happen when process_library is called.""" @@ -870,6 +867,7 @@ class Success(Registration): def push(self, stage, url_for): self.pushed = (stage, url_for) return True + registration = Success(registry, library) stage = object() @@ -897,6 +895,7 @@ def push(self, stage, url_for): class FailsWithProblemDetail(Registration): def push(self, stage, url_for): return INVALID_INPUT.detailed("oops") + registration = FailsWithProblemDetail(registry, library) result = script.process_library(registration, stage, url_for) @@ -905,4 +904,3 @@ def push(self, stage, url_for): # actually running the script will see it. assert INVALID_INPUT.uri == result.uri assert "oops" == result.detail - diff --git a/tests/test_routes.py b/tests/test_routes.py index 7b98140010..ce67e5fe51 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -1,28 +1,26 @@ import contextlib import logging -import pytest import flask +import pytest from flask import Response from werkzeug.exceptions import MethodNotAllowed -from core.app_server import ErrorHandler - -from api import app -from api import routes -from api.opds import CirculationManagerAnnotator +from api import app, routes from api.controller import CirculationManager -from api.routes import ( - exception_handler, - h as error_handler_object, -) +from api.opds import CirculationManagerAnnotator +from api.routes import exception_handler +from api.routes import h as error_handler_object +from core.app_server import ErrorHandler from .test_controller import ControllerTest + class MockApp(object): """Pretends to be a Flask application with a configured CirculationManager. """ + def __init__(self): self.manager = MockManager() @@ -37,12 +35,12 @@ def __init__(self): self.patron_web_domains = set(["http://patron/web"]) def __getattr__(self, controller_name): - return self._cache.setdefault( - controller_name, MockController(controller_name) - ) + return self._cache.setdefault(controller_name, MockController(controller_name)) + class MockControllerMethod(object): """Pretends to be one of the methods of a controller class.""" + def __init__(self, controller, name): """Constructor. @@ -66,9 +64,8 @@ def __call__(self, *args, **kwargs): return response def __repr__(self): - return "" % ( - self.controller.name, self.name - ) + return "" % (self.controller.name, self.name) + class MockController(MockControllerMethod): """Pretends to be a controller. @@ -76,6 +73,7 @@ class MockController(MockControllerMethod): A controller has methods, but it may also be called _as_ a method, so this class subclasses MockControllerMethod. """ + AUTHENTICATED_PATRON = "i am a mock patron" def __init__(self, name): @@ -87,7 +85,7 @@ def __init__(self, name): # If this controller were to be called as a method, the method # name would be __call__, not the name of the controller. - self.callable_name = '__call__' + self.callable_name = "__call__" self._cache = {} self.authenticated = False @@ -101,8 +99,7 @@ def authenticated_patron_from_request(self): return self.AUTHENTICATED_PATRON else: return Response( - "authenticated_patron_from_request called without authorizing", - 401 + "authenticated_patron_from_request called without authorizing", 401 ) def __getattr__(self, method_name): @@ -116,7 +113,7 @@ def __repr__(self): class RouteTestFixtures(object): - def request(self, url, method='GET'): + def request(self, url, method="GET"): """Simulate a request to a URL without triggering any code outside routes.py. """ @@ -135,7 +132,7 @@ def assert_request_calls(self, url, method, *args, **kwargs): the given controller `method` was called with the given `args` and `kwargs`. """ - http_method = kwargs.pop('http_method', 'GET') + http_method = kwargs.pop("http_method", "GET") response = self.request(url, http_method) assert response.method == method assert response.method.args == args @@ -153,7 +150,9 @@ def assert_request_calls(self, url, method, *args, **kwargs): # the mock method. This might remove the need to call the # mock method at all. - def assert_request_calls_method_using_identifier(self, url, method, *args, **kwargs): + def assert_request_calls_method_using_identifier( + self, url, method, *args, **kwargs + ): # Call an assertion method several times, using different # types of identifier in the URL, to make sure the identifier # is always passed through correctly. @@ -161,22 +160,23 @@ def assert_request_calls_method_using_identifier(self, url, method, *args, **kwa # The url must contain the string '' standing in # for the place where an identifier should be plugged in, and # the *args list must include the string ''. - authenticated = kwargs.pop('authenticated', False) + authenticated = kwargs.pop("authenticated", False) if authenticated: assertion_method = self.assert_authenticated_request_calls else: assertion_method = self.assert_request_calls - assert '' in url + assert "" in url args = list(args) - identifier_index = args.index('') + identifier_index = args.index("") for identifier in ( - '', 'an/identifier/', 'http://an-identifier/', 'http://an-identifier', + "", + "an/identifier/", + "http://an-identifier/", + "http://an-identifier", ): - modified_url = url.replace('', identifier) + modified_url = url.replace("", identifier) args[identifier_index] = identifier - assertion_method( - modified_url, method, *args, **kwargs - ) + assertion_method(modified_url, method, *args, **kwargs) def assert_authenticated_request_calls(self, url, method, *args, **kwargs): """First verify that an unauthenticated request fails. Then make an @@ -185,12 +185,14 @@ def assert_authenticated_request_calls(self, url, method, *args, **kwargs): """ authentication_required = kwargs.pop("authentication_required", True) - http_method = kwargs.pop('http_method', 'GET') + http_method = kwargs.pop("http_method", "GET") response = self.request(url, http_method) if authentication_required: assert 401 == response.status_code - assert ("authenticated_patron_from_request called without authorizing" == - response.get_data(as_text=True)) + assert ( + "authenticated_patron_from_request called without authorizing" + == response.get_data(as_text=True) + ) else: assert 200 == response.status_code @@ -198,7 +200,7 @@ def assert_authenticated_request_calls(self, url, method, *args, **kwargs): # will succeed, and try again. self.manager.index_controller.authenticated = True try: - kwargs['http_method'] = http_method + kwargs["http_method"] = http_method self.assert_request_calls(url, method, *args, **kwargs) finally: # Un-set authentication for the benefit of future @@ -212,12 +214,12 @@ def assert_supported_methods(self, url, *methods): # The simplest way to do this seems to be to try each of the # other potential methods and verify that MethodNotAllowed is # raised each time. - check = set(['GET', 'POST', 'PUT', 'DELETE']) - set(methods) + check = set(["GET", "POST", "PUT", "DELETE"]) - set(methods) # Treat HEAD specially. Any controller that supports GET # automatically supports HEAD. So we only assert that HEAD # fails if the method supports neither GET nor HEAD. - if 'GET' not in methods and 'HEAD' not in methods: - check.add('HEAD') + if "GET" not in methods and "HEAD" not in methods: + check.add("HEAD") for method in check: logging.debug("MethodNotAllowed should be raised on %s", method) pytest.raises(MethodNotAllowed, self.request, url, method) @@ -260,11 +262,11 @@ def setup_method(self): self.routes = routes self.manager = app.manager self.original_app = self.routes.app - self.resolver = self.original_app.url_map.bind('', '/') + self.resolver = self.original_app.url_map.bind("", "/") # For convenience, set self.controller to a specific controller # whose routes are being tested. - controller_name = getattr(self, 'CONTROLLER_NAME', None) + controller_name = getattr(self, "CONTROLLER_NAME", None) if controller_name: self.controller = getattr(self.manager, controller_name) @@ -295,21 +297,21 @@ class TestIndex(RouteTest): CONTROLLER_NAME = "index_controller" def test_index(self): - for url in '/', '': + for url in "/", "": self.assert_request_calls(url, self.controller) def test_authentication_document(self): - url = '/authentication_document' + url = "/authentication_document" self.assert_request_calls(url, self.controller.authentication_document) def test_public_key_document(self): - url = '/public_key_document' + url = "/public_key_document" self.assert_request_calls(url, self.controller.public_key_document) class TestOPDSFeed(RouteTest): - CONTROLLER_NAME = 'opds_feeds' + CONTROLLER_NAME = "opds_feeds" def test_acquisition_groups(self): # An incoming lane identifier is passed in to the groups() @@ -317,143 +319,137 @@ def test_acquisition_groups(self): method = self.controller.groups self.assert_request_calls("/groups", method, None) self.assert_request_calls( - "/groups/", method, '' + "/groups/", method, "" ) def test_feed(self): # An incoming lane identifier is passed in to the feed() # method. - url = '/feed' + url = "/feed" self.assert_request_calls(url, self.controller.feed, None) - url = '/feed/' - self.assert_request_calls( - url, self.controller.feed, '' - ) + url = "/feed/" + self.assert_request_calls(url, self.controller.feed, "") def test_navigation_feed(self): # An incoming lane identifier is passed in to the navigation_feed() # method. - url = '/navigation' + url = "/navigation" self.assert_request_calls(url, self.controller.navigation, None) - url = '/navigation/' - self.assert_request_calls( - url, self.controller.navigation, '' - ) + url = "/navigation/" + self.assert_request_calls(url, self.controller.navigation, "") def test_crawlable_library_feed(self): - url = '/crawlable' + url = "/crawlable" self.assert_request_calls(url, self.controller.crawlable_library_feed) def test_crawlable_list_feed(self): - url = '/lists//crawlable' + url = "/lists//crawlable" self.assert_request_calls( - url, self.controller.crawlable_list_feed, '' + url, self.controller.crawlable_list_feed, "" ) def test_crawlable_collection_feed(self): - url = '/collections//crawlable' + url = "/collections//crawlable" self.assert_request_calls( - url, self.manager.opds_feeds.crawlable_collection_feed, - '' + url, self.manager.opds_feeds.crawlable_collection_feed, "" ) def test_lane_search(self): - url = '/search' + url = "/search" self.assert_request_calls(url, self.controller.search, None) - url = '/search/' - self.assert_request_calls( - url, self.controller.search, "" - ) + url = "/search/" + self.assert_request_calls(url, self.controller.search, "") def test_qa_feed(self): - url = '/feed/qa' + url = "/feed/qa" self.assert_authenticated_request_calls(url, self.controller.qa_feed) def test_qa_series_feed(self): - url = '/feed/qa/series' + url = "/feed/qa/series" self.assert_authenticated_request_calls(url, self.controller.qa_series_feed) class TestMARCRecord(RouteTest): - CONTROLLER_NAME = 'marc_records' + CONTROLLER_NAME = "marc_records" def test_marc_page(self): url = "/marc" self.assert_request_calls(url, self.controller.download_page) + class TestSharedCollection(RouteTest): - CONTROLLER_NAME = 'shared_collection_controller' + CONTROLLER_NAME = "shared_collection_controller" def test_shared_collection_info(self): - url = '/collections/' - self.assert_request_calls( - url, self.controller.info, '' - ) + url = "/collections/" + self.assert_request_calls(url, self.controller.info, "") def test_shared_collection_register(self): - url = '/collections//register' + url = "/collections//register" self.assert_request_calls( - url, self.controller.register, '', - http_method='POST' + url, self.controller.register, "", http_method="POST" ) - self.assert_supported_methods(url, 'POST') + self.assert_supported_methods(url, "POST") def test_shared_collection_borrow_identifier(self): - url = '/collections////borrow' + url = "/collections////borrow" self.assert_request_calls_method_using_identifier( - url, self.controller.borrow, '', - '', "", None + url, + self.controller.borrow, + "", + "", + "", + None, ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_shared_collection_borrow_hold_id(self): - url = '/collections//holds//borrow' + url = "/collections//holds//borrow" self.assert_request_calls( - url, self.controller.borrow, '', None, None, - '' + url, self.controller.borrow, "", None, None, "" ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_shared_collection_loan_info(self): - url = '/collections//loans/' + url = "/collections//loans/" self.assert_request_calls( - url, self.controller.loan_info, '', '' + url, self.controller.loan_info, "", "" ) def test_shared_collection_revoke_loan(self): - url = '/collections//loans//revoke' + url = "/collections//loans//revoke" self.assert_request_calls( - url, self.controller.revoke_loan, '', '' + url, self.controller.revoke_loan, "", "" ) def test_shared_collection_fulfill_no_mechanism(self): - url = '/collections//loans//fulfill' + url = "/collections//loans//fulfill" self.assert_request_calls( - url, self.controller.fulfill, '', '', - None + url, self.controller.fulfill, "", "", None ) def test_shared_collection_fulfill_with_mechanism(self): - url = '/collections//loans//fulfill/' + url = "/collections//loans//fulfill/" self.assert_request_calls( - url, self.controller.fulfill, '', '', - '' + url, + self.controller.fulfill, + "", + "", + "", ) def test_shared_collection_hold_info(self): - url = '/collections//holds/' + url = "/collections//holds/" self.assert_request_calls( - url, self.controller.hold_info, '', - '' + url, self.controller.hold_info, "", "" ) def test_shared_collection_revoke_hold(self): - url = '/collections//holds//revoke' + url = "/collections//holds//revoke" self.assert_request_calls( - url, self.controller.revoke_hold, '', - '' + url, self.controller.revoke_hold, "", "" ) @@ -462,9 +458,10 @@ class TestProfileController(RouteTest): CONTROLLER_NAME = "profiles" def test_patron_profile(self): - url = '/patrons/me' + url = "/patrons/me" self.assert_authenticated_request_calls( - url, self.controller.protocol, + url, + self.controller.protocol, ) @@ -473,69 +470,81 @@ class TestLoansController(RouteTest): CONTROLLER_NAME = "loans" def test_active_loans(self): - url = '/loans' + url = "/loans" self.assert_authenticated_request_calls( - url, self.controller.sync, + url, + self.controller.sync, ) - self.assert_supported_methods(url, 'GET', 'HEAD') + self.assert_supported_methods(url, "GET", "HEAD") def test_borrow(self): - url = '/works///borrow' + url = "/works///borrow" self.assert_request_calls_method_using_identifier( - url, self.controller.borrow, - "", "", None, - authenticated=True + url, + self.controller.borrow, + "", + "", + None, + authenticated=True, ) - self.assert_supported_methods(url, 'GET', 'PUT') + self.assert_supported_methods(url, "GET", "PUT") - url = '/works///borrow/' + url = "/works///borrow/" self.assert_request_calls_method_using_identifier( - url, self.controller.borrow, - "", "", "", - authenticated=True + url, + self.controller.borrow, + "", + "", + "", + authenticated=True, ) - self.assert_supported_methods(url, 'GET', 'PUT') + self.assert_supported_methods(url, "GET", "PUT") def test_fulfill(self): # fulfill does *not* require authentication, because this # controller is how a no-authentication library fulfills # open-access titles. - url = '/works//fulfill' + url = "/works//fulfill" self.assert_request_calls( url, self.controller.fulfill, "", None, None ) - url = '/works//fulfill/' + url = "/works//fulfill/" self.assert_request_calls( - url, self.controller.fulfill, "", - "", None + url, self.controller.fulfill, "", "", None ) - url = '/works//fulfill//' + url = "/works//fulfill//" self.assert_request_calls( - url, self.controller.fulfill, "", - "", "" + url, + self.controller.fulfill, + "", + "", + "", ) def test_revoke_loan_or_hold(self): - url = '/loans//revoke' + url = "/loans//revoke" self.assert_authenticated_request_calls( - url, self.controller.revoke, '' + url, self.controller.revoke, "" ) # TODO: DELETE shouldn't be in here, but "DELETE # /loans//revoke" is interpreted as an attempt # to match /loans//, the # method tested directly below, which does support DELETE. - self.assert_supported_methods(url, 'GET', 'PUT', 'DELETE') + self.assert_supported_methods(url, "GET", "PUT", "DELETE") def test_loan_or_hold_detail(self): - url = '/loans//' + url = "/loans//" self.assert_request_calls_method_using_identifier( - url, self.controller.detail, - "", "", authenticated=True + url, + self.controller.detail, + "", + "", + authenticated=True, ) - self.assert_supported_methods(url, 'GET', 'DELETE') + self.assert_supported_methods(url, "GET", "DELETE") class TestAnnotationsController(RouteTest): @@ -543,27 +552,27 @@ class TestAnnotationsController(RouteTest): CONTROLLER_NAME = "annotations" def test_annotations(self): - url = '/annotations/' - self.assert_authenticated_request_calls( - url, self.controller.container - ) - self.assert_supported_methods(url, 'HEAD', 'GET', 'POST') + url = "/annotations/" + self.assert_authenticated_request_calls(url, self.controller.container) + self.assert_supported_methods(url, "HEAD", "GET", "POST") def test_annotation_detail(self): - url = '/annotations/' + url = "/annotations/" self.assert_authenticated_request_calls( - url, self.controller.detail, '' + url, self.controller.detail, "" ) - self.assert_supported_methods(url, 'HEAD', 'GET', 'DELETE') + self.assert_supported_methods(url, "HEAD", "GET", "DELETE") def test_annotations_for_work(self): - url = '/annotations//' + url = "/annotations//" self.assert_request_calls_method_using_identifier( - url, self.controller.container_for_work, - '', "", - authenticated=True + url, + self.controller.container_for_work, + "", + "", + authenticated=True, ) - self.assert_supported_methods(url, 'GET') + self.assert_supported_methods(url, "GET") class TestURNLookupController(RouteTest): @@ -571,8 +580,8 @@ class TestURNLookupController(RouteTest): CONTROLLER_NAME = "urn_lookup" def test_work(self): - url = '/works' - self.assert_request_calls(url, self.controller.work_lookup, 'work') + url = "/works" + self.assert_request_calls(url, self.controller.work_lookup, "work") class TestWorkController(RouteTest): @@ -580,86 +589,90 @@ class TestWorkController(RouteTest): CONTROLLER_NAME = "work_controller" def test_contributor(self): - url = '/works/contributor/' + url = "/works/contributor/" self.assert_request_calls( url, self.controller.contributor, "", None, None ) def test_contributor_language(self): - url = '/works/contributor//' + url = "/works/contributor//" self.assert_request_calls( - url, self.controller.contributor, - "", "", None + url, self.controller.contributor, "", "", None ) def test_contributor_language_audience(self): - url = '/works/contributor///' + url = "/works/contributor///" self.assert_request_calls( - url, self.controller.contributor, - "", "", "" + url, + self.controller.contributor, + "", + "", + "", ) def test_series(self): - url = '/works/series/' + url = "/works/series/" self.assert_request_calls( url, self.controller.series, "", None, None ) def test_series_language(self): - url = '/works/series//' + url = "/works/series//" self.assert_request_calls( url, self.controller.series, "", "", None ) def test_series_language_audience(self): - url = '/works/series///' + url = "/works/series///" self.assert_request_calls( - url, self.controller.series, "", "", - "" + url, self.controller.series, "", "", "" ) def test_permalink(self): - url = '/works//' + url = "/works//" self.assert_request_calls_method_using_identifier( - url, self.controller.permalink, - "", "" + url, self.controller.permalink, "", "" ) def test_recommendations(self): - url = '/works///recommendations' + url = "/works///recommendations" self.assert_request_calls_method_using_identifier( - url, self.controller.recommendations, - "", "" + url, self.controller.recommendations, "", "" ) def test_related_books(self): - url = '/works///related_books' + url = "/works///related_books" self.assert_request_calls_method_using_identifier( url, self.controller.related, "", "" ) def test_report(self): - url = '/works///report' + url = "/works///report" self.assert_request_calls_method_using_identifier( - url, self.controller.report, - "", "", + url, + self.controller.report, + "", + "", ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") class TestAnalyticsController(RouteTest): CONTROLLER_NAME = "analytics_controller" def test_track_analytics_event(self): - url = '/analytics///' + url = "/analytics///" # This controller can be called either authenticated or # unauthenticated. self.assert_request_calls_method_using_identifier( - url, self.controller.track_event, - "", "", "", + url, + self.controller.track_event, + "", + "", + "", authenticated=True, - authentication_required=False + authentication_required=False, ) @@ -668,31 +681,33 @@ class TestAdobeVendorID(RouteTest): CONTROLLER_NAME = "adobe_vendor_id" def test_adobe_vendor_id_get_token(self): - url = '/AdobeAuth/authdata' + url = "/AdobeAuth/authdata" self.assert_authenticated_request_calls( - url, self.controller.create_authdata_handler, - self.controller.AUTHENTICATED_PATRON + url, + self.controller.create_authdata_handler, + self.controller.AUTHENTICATED_PATRON, ) # TODO: test what happens when vendor ID is not configured. def test_adobe_vendor_id_signin(self): - url = '/AdobeAuth/SignIn' + url = "/AdobeAuth/SignIn" self.assert_request_calls( - url, self.controller.signin_handler, http_method='POST' + url, self.controller.signin_handler, http_method="POST" ) - self.assert_supported_methods(url, 'POST') + self.assert_supported_methods(url, "POST") def test_adobe_vendor_id_accountinfo(self): - url = '/AdobeAuth/AccountInfo' + url = "/AdobeAuth/AccountInfo" self.assert_request_calls( - url, self.controller.userinfo_handler, http_method='POST' + url, self.controller.userinfo_handler, http_method="POST" ) - self.assert_supported_methods(url, 'POST') + self.assert_supported_methods(url, "POST") def test_adobe_vendor_id_status(self): - url = '/AdobeAuth/Status' + url = "/AdobeAuth/Status" self.assert_request_calls( - url, self.controller.status_handler, + url, + self.controller.status_handler, ) @@ -700,19 +715,18 @@ class TestAdobeDeviceManagement(RouteTest): CONTROLLER_NAME = "adobe_device_management" def test_adobe_drm_devices(self): - url = '/AdobeAuth/devices' + url = "/AdobeAuth/devices" self.assert_authenticated_request_calls( url, self.controller.device_id_list_handler ) - self.assert_supported_methods(url, 'GET', 'POST') + self.assert_supported_methods(url, "GET", "POST") def test_adobe_drm_device(self): - url = '/AdobeAuth/devices/' + url = "/AdobeAuth/devices/" self.assert_authenticated_request_calls( - url, self.controller.device_id_handler, "", - http_method='DELETE' + url, self.controller.device_id_handler, "", http_method="DELETE" ) - self.assert_supported_methods(url, 'DELETE') + self.assert_supported_methods(url, "DELETE") class TestOAuthController(RouteTest): @@ -722,14 +736,14 @@ class TestOAuthController(RouteTest): CONTROLLER_NAME = "oauth_controller" def test_oauth_authenticate(self): - url = '/oauth_authenticate' + url = "/oauth_authenticate" _db = self.manager._db self.assert_request_calls( url, self.controller.oauth_authentication_redirect, {}, _db ) def test_oauth_callback(self): - url = '/oauth_callback' + url = "/oauth_callback" _db = self.manager._db self.assert_request_calls( url, self.controller.oauth_authentication_callback, _db, {} @@ -740,18 +754,16 @@ class TestODLNotificationController(RouteTest): CONTROLLER_NAME = "odl_notification_controller" def test_odl_notify(self): - url = '/odl_notify/' - self.assert_request_calls( - url, self.controller.notify, "" - ) - self.assert_supported_methods(url, 'GET', 'POST') + url = "/odl_notify/" + self.assert_request_calls(url, self.controller.notify, "") + self.assert_supported_methods(url, "GET", "POST") class TestHeartbeatController(RouteTest): CONTROLLER_NAME = "heartbeat" def test_heartbeat(self): - url = '/heartbeat' + url = "/heartbeat" self.assert_request_calls(url, self.controller.heartbeat) @@ -769,7 +781,6 @@ def test_health_check(self): class TestExceptionHandler(RouteTest): - def test_exception_handling(self): # The exception handler deals with most exceptions by running them # through ErrorHandler.handle() @@ -782,6 +793,7 @@ class MockErrorHandler(object): def handle(self, exception): self.handled = exception return Response("handled it", 500) + routes.h = MockErrorHandler() # Simulate a request that causes an unhandled exception. diff --git a/tests/test_scripts.py b/tests/test_scripts.py index b5e9e9f973..0b4e711e1b 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -1,45 +1,24 @@ -import pytest import contextlib import datetime -import flask import json from io import StringIO +import flask +import pytest + from api.adobe_vendor_id import ( AdobeVendorIDModel, AuthdataUtility, ShortClientTokenLibraryConfigurationScript, ) from api.authenticator import BasicAuthenticationProvider - -from api.config import ( - temp_config, - Configuration, -) - -from api.novelist import ( - NoveListAPI -) - -from core.entrypoint import ( - AudiobooksEntryPoint, - EbooksEntryPoint, - EntryPoint, -) - -from core.external_search import ( - MockExternalSearchIndex, - mock_search_index, -) - -from core.lane import ( - Lane, - Facets, - FeaturedFacets, - Pagination, - WorkList, -) - +from api.config import Configuration, temp_config +from api.marc import LibraryAnnotator as MARCLibraryAnnotator +from api.novelist import NoveListAPI +from core.entrypoint import AudiobooksEntryPoint, EbooksEntryPoint, EntryPoint +from core.external_search import MockExternalSearchIndex, mock_search_index +from core.lane import Facets, FeaturedFacets, Lane, Pagination, WorkList +from core.marc import MARCExporter from core.metadata_layer import ( CirculationData, IdentifierData, @@ -47,63 +26,48 @@ Metadata, ReplacementPolicy, ) - +from core.mirror import MirrorUploader from core.model import ( CachedFeed, CachedMARCFile, ConfigurationSetting, - create, Credential, DataSource, DeliveryMechanism, + EditionConstants, ExternalIntegration, - get_one, Hyperlink, Identifier, LicensePool, Representation, RightsStatus, Timestamp, - EditionConstants) + create, + get_one, +) from core.model.configuration import ExternalIntegrationLink - from core.opds import AcquisitionFeed - from core.s3 import MockS3Uploader - -from core.mirror import MirrorUploader - -from core.marc import MARCExporter from core.scripts import CollectionType - -from core.util.flask_util import ( - Response, - OPDSFeedResponse -) - -from api.marc import LibraryAnnotator as MARCLibraryAnnotator - -from core.testing import ( - DatabaseTest, -) - +from core.testing import DatabaseTest +from core.util.datetime_helpers import datetime_utc, utc_now +from core.util.flask_util import OPDSFeedResponse, Response from scripts import ( AdobeAccountIDResetScript, - CacheRepresentationPerLane, CacheFacetListsPerLane, - CacheOPDSGroupFeedPerLane, CacheMARCFiles, + CacheOPDSGroupFeedPerLane, + CacheRepresentationPerLane, DirectoryImportScript, GenerateShortTokenScript, InstanceInitializationScript, LanguageListScript, - NovelistSnapshotScript, LocalAnalyticsExportScript, + NovelistSnapshotScript, ) -from core.util.datetime_helpers import datetime_utc, utc_now -class TestAdobeAccountIDResetScript(DatabaseTest): +class TestAdobeAccountIDResetScript(DatabaseTest): def test_process_patron(self): patron = self._patron() @@ -119,14 +83,14 @@ def set_value(credential): # Create two Credentials that will be deleted and one that will be # left alone. - for type in (AdobeVendorIDModel.VENDOR_ID_UUID_TOKEN_TYPE, - AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, - "Some other type" + for type in ( + AdobeVendorIDModel.VENDOR_ID_UUID_TOKEN_TYPE, + AuthdataUtility.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, + "Some other type", ): credential = Credential.lookup( - self._db, data_source, type, patron, - set_value, True + self._db, data_source, type, patron, set_value, True ) assert 3 == len(patron.credentials) @@ -151,59 +115,62 @@ def set_value(credential): class TestLaneScript(DatabaseTest): - def setup_method(self): super(TestLaneScript, self).setup_method() base_url_setting = ConfigurationSetting.sitewide( - self._db, Configuration.BASE_URL_KEY) - base_url_setting.value = 'http://test-circulation-manager/' + self._db, Configuration.BASE_URL_KEY + ) + base_url_setting.value = "http://test-circulation-manager/" for k, v in [ - (Configuration.LARGE_COLLECTION_LANGUAGES, []), - (Configuration.SMALL_COLLECTION_LANGUAGES, []), - (Configuration.TINY_COLLECTION_LANGUAGES, ['eng', 'fre']) + (Configuration.LARGE_COLLECTION_LANGUAGES, []), + (Configuration.SMALL_COLLECTION_LANGUAGES, []), + (Configuration.TINY_COLLECTION_LANGUAGES, ["eng", "fre"]), ]: ConfigurationSetting.for_library( - k, self._default_library).value = json.dumps(v) + k, self._default_library + ).value = json.dumps(v) class TestCacheRepresentationPerLane(TestLaneScript): - def test_should_process_lane(self): # Test that should_process_lane respects any specified # language restrictions. script = CacheRepresentationPerLane( - self._db, ["--language=fre", "--language=English", "--language=none", "--min-depth=0"], - manager=object() + self._db, + [ + "--language=fre", + "--language=English", + "--language=none", + "--min-depth=0", + ], + manager=object(), ) - assert ['fre', 'eng'] == script.languages + assert ["fre", "eng"] == script.languages - english_lane = self._lane(languages=['eng']) + english_lane = self._lane(languages=["eng"]) assert True == script.should_process_lane(english_lane) - no_english_lane = self._lane(languages=['spa','fre']) + no_english_lane = self._lane(languages=["spa", "fre"]) assert True == script.should_process_lane(no_english_lane) - no_english_or_french_lane = self._lane(languages=['spa']) + no_english_or_french_lane = self._lane(languages=["spa"]) assert False == script.should_process_lane(no_english_or_french_lane) # Test that should_process_lane respects maximum depth # restrictions. script = CacheRepresentationPerLane( - self._db, ["--max-depth=0", "--min-depth=0"], - manager=object() + self._db, ["--max-depth=0", "--min-depth=0"], manager=object() ) assert 0 == script.max_depth child = self._lane(display_name="sublane") parent = self._lane(display_name="parent") - parent.sublanes=[child] + parent.sublanes = [child] assert True == script.should_process_lane(parent) assert False == script.should_process_lane(child) - script = CacheRepresentationPerLane( - self._db, ["--min-depth=1"], testing=True - ) + script = CacheRepresentationPerLane(self._db, ["--min-depth=1"], testing=True) assert 1 == script.min_depth assert False == script.should_process_lane(parent) assert True == script.should_process_lane(child) @@ -213,7 +180,6 @@ def test_process_lane(self): # combination of items yielded by facets() and pagination(). class MockFacets(object): - def __init__(self, query): self.query = query @@ -228,6 +194,7 @@ def query_string(self): class Mock(CacheRepresentationPerLane): generated = [] + def do_generate(self, lane, facets, pagination): value = (lane, facets, pagination) response = Response("mock response") @@ -256,50 +223,39 @@ def pagination(self, lane): def test_default_facets(self): # By default, do_generate will only be called once, with facets=None. - script = CacheRepresentationPerLane( - self._db, manager=object(), cmd_args=[] - ) + script = CacheRepresentationPerLane(self._db, manager=object(), cmd_args=[]) assert [None] == list(script.facets(object())) def test_default_pagination(self): # By default, do_generate will only be called once, with pagination=None. - script = CacheRepresentationPerLane( - self._db, manager=object(), cmd_args=[] - ) + script = CacheRepresentationPerLane(self._db, manager=object(), cmd_args=[]) assert [None] == list(script.pagination(object())) class TestCacheFacetListsPerLane(TestLaneScript): - def test_arguments(self): # Verify that command-line arguments become attributes of # the CacheFacetListsPerLane object. script = CacheFacetListsPerLane( - self._db, ["--order=title", "--order=added"], - manager=object() + self._db, ["--order=title", "--order=added"], manager=object() ) - assert ['title', 'added'] == script.orders + assert ["title", "added"] == script.orders script = CacheFacetListsPerLane( - self._db, ["--availability=all", "--availability=always"], - manager=object() + self._db, ["--availability=all", "--availability=always"], manager=object() ) - assert ['all', 'always'] == script.availabilities + assert ["all", "always"] == script.availabilities script = CacheFacetListsPerLane( - self._db, ["--collection=main", "--collection=full"], - manager=object() + self._db, ["--collection=main", "--collection=full"], manager=object() ) - assert ['main', 'full'] == script.collections + assert ["main", "full"] == script.collections script = CacheFacetListsPerLane( - self._db, ["--entrypoint=Audio", "--entrypoint=Book"], - manager=object() + self._db, ["--entrypoint=Audio", "--entrypoint=Book"], manager=object() ) - assert ['Audio', 'Book'] == script.entrypoints + assert ["Audio", "Book"] == script.entrypoints - script = CacheFacetListsPerLane( - self._db, ['--pages=1'], manager=object() - ) + script = CacheFacetListsPerLane(self._db, ["--pages=1"], manager=object()) assert 1 == script.pages def test_facets(self): @@ -308,8 +264,9 @@ def test_facets(self): script = CacheFacetListsPerLane(self._db, manager=object(), cmd_args=[]) script.orders = [Facets.ORDER_TITLE, Facets.ORDER_AUTHOR, "nonsense"] script.entrypoints = [ - AudiobooksEntryPoint.INTERNAL_NAME, "nonsense", - EbooksEntryPoint.INTERNAL_NAME + AudiobooksEntryPoint.INTERNAL_NAME, + "nonsense", + EbooksEntryPoint.INTERNAL_NAME, ] script.availabilities = [Facets.AVAILABLE_NOW, "nonsense"] script.collections = [Facets.COLLECTION_FULL, "nonsense"] @@ -362,6 +319,7 @@ def test_do_generate(self): # is called with the right arguments. class MockAcquisitionFeed(object): called_with = None + @classmethod def page(cls, **kwargs): cls.called_with = kwargs @@ -380,45 +338,42 @@ def page(cls, **kwargs): assert "here's your feed" == result args = MockAcquisitionFeed.called_with - assert self._db == args['_db'] - assert lane == args['worklist'] - assert lane.display_name == args['title'] - assert 0 == args['max_age'] + assert self._db == args["_db"] + assert lane == args["worklist"] + assert lane.display_name == args["title"] + assert 0 == args["max_age"] # The Pagination object was passed into # MockAcquisitionFeed.page, and it was also used to make the # feed URL (see below). - assert pagination == args['pagination'] + assert pagination == args["pagination"] # The Facets object was passed into # MockAcquisitionFeed.page, and it was also used to make # the feed URL and to create the feed annotator. - assert facets == args['facets'] - annotator = args['annotator'] + assert facets == args["facets"] + annotator = args["annotator"] assert facets == annotator.facets - assert ( - args['url'] == - annotator.feed_url(lane, facets=facets, pagination=pagination)) + assert args["url"] == annotator.feed_url( + lane, facets=facets, pagination=pagination + ) # Try again without mocking AcquisitionFeed, to verify that # we get a Flask Response containing an OPDS feed. response = script.do_generate(lane, facets, pagination) assert isinstance(response, OPDSFeedResponse) assert AcquisitionFeed.ACQUISITION_FEED_TYPE == response.content_type - assert response.get_data(as_text=True).startswith(' last_week - class TestInstanceInitializationScript(DatabaseTest): - def test_run(self): # If the database has been initialized -- which happened @@ -707,6 +667,7 @@ def test_run(self): class Mock(InstanceInitializationScript): def do_run(self): raise Exception("I'll never be called.") + Mock().run() # If the database has not been initialized, run() will detect @@ -718,6 +679,7 @@ def do_run(self): # and do_run() will be called. class Mock(InstanceInitializationScript): TEST_SQL = "select * from nosuchtable" + def do_run(self, *args, **kwargs): self.was_run = True @@ -725,21 +687,23 @@ def do_run(self, *args, **kwargs): script.run() assert True == script.was_run - def test_do_run(self): # Normally, do_run is only called by run() if the database has # not yet meen initialized. But we can test it by calling it # directly. timestamp = get_one( - self._db, Timestamp, service="Database Migration", - service_type=Timestamp.SCRIPT_TYPE + self._db, + Timestamp, + service="Database Migration", + service_type=Timestamp.SCRIPT_TYPE, ) assert None == timestamp # Remove all secret keys, should they exist, before running the # script. secret_keys = self._db.query(ConfigurationSetting).filter( - ConfigurationSetting.key==Configuration.SECRET_KEY) + ConfigurationSetting.key == Configuration.SECRET_KEY + ) [self._db.delete(secret_key) for secret_key in secret_keys] script = InstanceInitializationScript(_db=self._db) @@ -747,23 +711,24 @@ def test_do_run(self): # It initializes the database. timestamp = get_one( - self._db, Timestamp, service="Database Migration", - service_type=Timestamp.SCRIPT_TYPE + self._db, + Timestamp, + service="Database Migration", + service_type=Timestamp.SCRIPT_TYPE, ) assert timestamp # It creates a secret key. assert 1 == secret_keys.count() - assert ( - secret_keys.one().value == - ConfigurationSetting.sitewide_secret(self._db, Configuration.SECRET_KEY)) + assert secret_keys.one().value == ConfigurationSetting.sitewide_secret( + self._db, Configuration.SECRET_KEY + ) class TestLanguageListScript(DatabaseTest): - def test_languages(self): """Test the method that gives this script the bulk of its output.""" - english = self._work(language='eng', with_open_access_download=True) + english = self._work(language="eng", with_open_access_download=True) tagalog = self._work(language="tgl", with_license_pool=True) [pool] = tagalog.license_pools self._add_generic_delivery_mechanism(pool) @@ -776,60 +741,59 @@ def test_languages(self): class TestShortClientTokenLibraryConfigurationScript(DatabaseTest): - def setup_method(self): super(TestShortClientTokenLibraryConfigurationScript, self).setup_method() - self._default_library.setting( - Configuration.WEBSITE_URL - ).value = "http://foo/" + self._default_library.setting(Configuration.WEBSITE_URL).value = "http://foo/" self.script = ShortClientTokenLibraryConfigurationScript(self._db) def test_identify_library_by_url(self): with pytest.raises(Exception) as excinfo: - self.script.set_secret(self._db, "http://bar/", "vendorid", "libraryname", "secret", None) - assert "Could not locate library with URL http://bar/. Available URLs: http://foo/" in str(excinfo.value) + self.script.set_secret( + self._db, "http://bar/", "vendorid", "libraryname", "secret", None + ) + assert ( + "Could not locate library with URL http://bar/. Available URLs: http://foo/" + in str(excinfo.value) + ) def test_set_secret(self): assert [] == self._default_library.integrations output = StringIO() self.script.set_secret( - self._db, "http://foo/", "vendorid", "libraryname", "secret", - output + self._db, "http://foo/", "vendorid", "libraryname", "secret", output ) assert ( - 'Current Short Client Token configuration for http://foo/:\n Vendor ID: vendorid\n Library name: libraryname\n Shared secret: secret\n' == - output.getvalue()) + "Current Short Client Token configuration for http://foo/:\n Vendor ID: vendorid\n Library name: libraryname\n Shared secret: secret\n" + == output.getvalue() + ) [integration] = self._default_library.integrations - assert ( - [('password', 'secret'), ('username', 'libraryname'), - ('vendor_id', 'vendorid')] == - sorted((x.key, x.value) for x in integration.settings)) + assert [ + ("password", "secret"), + ("username", "libraryname"), + ("vendor_id", "vendorid"), + ] == sorted((x.key, x.value) for x in integration.settings) # We can modify an existing configuration. output = StringIO() self.script.set_secret( - self._db, "http://foo/", "newid", "newname", "newsecret", - output + self._db, "http://foo/", "newid", "newname", "newsecret", output ) - expect = 'Current Short Client Token configuration for http://foo/:\n Vendor ID: newid\n Library name: newname\n Shared secret: newsecret\n' + expect = "Current Short Client Token configuration for http://foo/:\n Vendor ID: newid\n Library name: newname\n Shared secret: newsecret\n" assert expect == output.getvalue() expect_settings = [ - ('password', 'newsecret'), ('username', 'newname'), - ('vendor_id', 'newid') + ("password", "newsecret"), + ("username", "newname"), + ("vendor_id", "newid"), ] - assert (expect_settings == - sorted((x.key, x.value) for x in integration.settings)) + assert expect_settings == sorted((x.key, x.value) for x in integration.settings) # We can also just check on the existing configuration without # changing anything. output = StringIO() - self.script.set_secret( - self._db, "http://foo/", None, None, None, output - ) + self.script.set_secret(self._db, "http://foo/", None, None, None, output) assert expect == output.getvalue() - assert (expect_settings == - sorted((x.key, x.value) for x in integration.settings)) + assert expect_settings == sorted((x.key, x.value) for x in integration.settings) class MockDirectoryImportScript(DirectoryImportScript): @@ -842,13 +806,10 @@ def __init__(self, _db, mock_filesystem={}): def _locate_file(self, identifier, directory, extensions, file_type): self._locate_file_args = (identifier, directory, extensions, file_type) - return self.mock_filesystem.get( - directory, (None, None, None) - ) + return self.mock_filesystem.get(directory, (None, None, None)) class TestDirectoryImportScript(DatabaseTest): - def test_do_run(self): # Calling do_run with command-line arguments parses the # arguments and calls run_with_arguments. @@ -868,23 +829,21 @@ def run_with_arguments(self, *args, **kwargs): "--ebook-directory=ebooks", "--rights-uri=rights", "--dry-run", - "--default-medium-type={0}".format(EditionConstants.AUDIO_MEDIUM) + "--default-medium-type={0}".format(EditionConstants.AUDIO_MEDIUM), ] ) - assert ( - { - 'collection_name': 'coll1', - 'collection_type': CollectionType.OPEN_ACCESS, - 'data_source_name': 'ds1', - 'metadata_file': 'metadata', - 'metadata_format': 'marc', - 'cover_directory': 'covers', - 'ebook_directory': 'ebooks', - 'rights_uri': 'rights', - 'dry_run': True, - 'default_medium_type': EditionConstants.AUDIO_MEDIUM - } == - script.ran_with) + assert { + "collection_name": "coll1", + "collection_type": CollectionType.OPEN_ACCESS, + "data_source_name": "ds1", + "metadata_file": "metadata", + "metadata_format": "marc", + "cover_directory": "covers", + "ebook_directory": "ebooks", + "rights_uri": "rights", + "dry_run": True, + "default_medium_type": EditionConstants.AUDIO_MEDIUM, + } == script.ran_with def test_run_with_arguments(self): @@ -897,6 +856,7 @@ def test_run_with_arguments(self): class Mock(DirectoryImportScript): """Mock the methods called by run_with_arguments.""" + def __init__(self, _db): super(DirectoryImportScript, self).__init__(_db) self.load_collection_calls = [] @@ -919,7 +879,7 @@ def work_from_metadata(self, *args): # Make a change to a model object so we can track when the # session is committed. - self._default_collection.name = 'changed' + self._default_collection.name = "changed" script = Mock(self._db) basic_args = [ @@ -930,22 +890,28 @@ def work_from_metadata(self, *args): "marc", "cover directory", "ebook directory", - "rights URI" + "rights URI", ] - script.run_with_arguments(*(basic_args + [True] + [EditionConstants.BOOK_MEDIUM])) + script.run_with_arguments( + *(basic_args + [True] + [EditionConstants.BOOK_MEDIUM]) + ) # load_collection was called with the collection and data source names. - assert ( - [('collection name', CollectionType.OPEN_ACCESS, 'data source name')] == - script.load_collection_calls) + assert [ + ("collection name", CollectionType.OPEN_ACCESS, "data source name") + ] == script.load_collection_calls # load_metadata was called with the metadata file and data source name. - assert [('metadata file', 'marc', 'data source name', EditionConstants.BOOK_MEDIUM)] == script.load_metadata_calls + assert [ + ("metadata file", "marc", "data source name", EditionConstants.BOOK_MEDIUM) + ] == script.load_metadata_calls # work_from_metadata was called twice, once on each metadata # object. - [(coll1, t1, o1, policy1, c1, e1, r1), - (coll2, t2, o2, policy2, c2, e2, r2)] = script.work_from_metadata_calls + [ + (coll1, t1, o1, policy1, c1, e1, r1), + (coll2, t2, o2, policy2, c2, e2, r2), + ] = script.work_from_metadata_calls assert coll1 == self._default_collection assert coll1 == coll2 @@ -953,10 +919,10 @@ def work_from_metadata(self, *args): assert o1 == metadata1 assert o2 == metadata2 - assert c1 == 'cover directory' + assert c1 == "cover directory" assert c1 == c2 - assert e1 == 'ebook directory' + assert e1 == "ebook directory" assert e1 == e2 assert "rights URI" == r1 @@ -977,8 +943,10 @@ def work_from_metadata(self, *args): # This time, the ReplacementPolicy has a mirror set # appropriately. - [(coll1, t1, o1, policy1, c1, e1, r1), - (coll1, t2, o2, policy2, c2, e2, r2)] = script.work_from_metadata_calls + [ + (coll1, t1, o1, policy1, c1, e1, r1), + (coll1, t2, o2, policy2, c2, e2, r2), + ] = script.work_from_metadata_calls for policy in policy1, policy2: assert mirrors == policy.mirrors @@ -990,7 +958,8 @@ def test_load_collection_setting_mirrors(self): # Calling load_collection does not create a new collection. script = DirectoryImportScript(self._db) collection, mirrors = script.load_collection( - "New collection", CollectionType.OPEN_ACCESS, "data source name") + "New collection", CollectionType.OPEN_ACCESS, "data source name" + ) assert None == collection assert None == mirrors @@ -999,7 +968,8 @@ def test_load_collection_setting_mirrors(self): ) collection, mirrors = script.load_collection( - "some collection", CollectionType.OPEN_ACCESS, "data source name") + "some collection", CollectionType.OPEN_ACCESS, "data source name" + ) # No covers or books mirrors were created beforehand for this collection # so nothing is returned. @@ -1008,36 +978,44 @@ def test_load_collection_setting_mirrors(self): # Both mirrors need to set up or else nothing is returned. storage1 = self._external_integration( - ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL, - username="name", password="password" + ExternalIntegration.S3, + ExternalIntegration.STORAGE_GOAL, + username="name", + password="password", ) external_integration_link = self._external_integration_link( integration=existing_collection.external_integration, other_integration=storage1, - purpose=ExternalIntegrationLink.COVERS + purpose=ExternalIntegrationLink.COVERS, ) collection, mirrors = script.load_collection( - "some collection", CollectionType.OPEN_ACCESS, "data source name") + "some collection", CollectionType.OPEN_ACCESS, "data source name" + ) assert None == collection assert None == mirrors # Create another storage and assign it for the books mirror storage2 = self._external_integration( - ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL, - username="name", password="password" + ExternalIntegration.S3, + ExternalIntegration.STORAGE_GOAL, + username="name", + password="password", ) external_integration_link = self._external_integration_link( integration=existing_collection.external_integration, other_integration=storage2, - purpose=ExternalIntegrationLink.OPEN_ACCESS_BOOKS + purpose=ExternalIntegrationLink.OPEN_ACCESS_BOOKS, ) collection, mirrors = script.load_collection( - "some collection", CollectionType.OPEN_ACCESS, "data source name") + "some collection", CollectionType.OPEN_ACCESS, "data source name" + ) assert collection == existing_collection assert isinstance(mirrors[ExternalIntegrationLink.COVERS], MirrorUploader) - assert isinstance(mirrors[ExternalIntegrationLink.OPEN_ACCESS_BOOKS], MirrorUploader) + assert isinstance( + mirrors[ExternalIntegrationLink.OPEN_ACCESS_BOOKS], MirrorUploader + ) def test_work_from_metadata(self): # Validate the ability to create a new Work from appropriate metadata. @@ -1046,6 +1024,7 @@ class Mock(MockDirectoryImportScript): """In this test we need to verify that annotate_metadata was called but did nothing. """ + def annotate_metadata(self, collection_type, metadata, *args, **kwargs): metadata.annotated = True return super(Mock, self).annotate_metadata( @@ -1055,14 +1034,12 @@ def annotate_metadata(self, collection_type, metadata, *args, **kwargs): identifier = IdentifierData(Identifier.GUTENBERG_ID, "1003") identifier_obj, ignore = identifier.load(self._db) metadata = Metadata( - DataSource.GUTENBERG, - primary_identifier=identifier, - title="A book" + DataSource.GUTENBERG, primary_identifier=identifier, title="A book" ) metadata.annotated = False datasource = DataSource.lookup(self._db, DataSource.GUTENBERG) policy = ReplacementPolicy.from_license_source(self._db) - mirrors = dict(books_mirror=MockS3Uploader(),covers_mirror=MockS3Uploader()) + mirrors = dict(books_mirror=MockS3Uploader(), covers_mirror=MockS3Uploader()) mirror_type_books = ExternalIntegrationLink.OPEN_ACCESS_BOOKS mirror_type_covers = ExternalIntegrationLink.COVERS policy.mirrors = mirrors @@ -1072,27 +1049,31 @@ def annotate_metadata(self, collection_type, metadata, *args, **kwargs): # disk' and thus no way to actually get the book. collection = self._default_collection collection_type = CollectionType.OPEN_ACCESS - shared_args = (collection_type, metadata, policy, - "cover directory", "ebook directory", RightsStatus.CC0) + shared_args = ( + collection_type, + metadata, + policy, + "cover directory", + "ebook directory", + RightsStatus.CC0, + ) # args = (collection, *shared_args) script = Mock(self._db) assert None == script.work_from_metadata(collection, *shared_args) assert True == metadata.annotated # Now let's try it with some files 'on disk'. - with open(self.sample_cover_path('test-book-cover.png'), "rb") as fh: + with open(self.sample_cover_path("test-book-cover.png"), "rb") as fh: image = fh.read() mock_filesystem = { - 'cover directory' : ( - 'cover.jpg', Representation.JPEG_MEDIA_TYPE, image + "cover directory": ("cover.jpg", Representation.JPEG_MEDIA_TYPE, image), + "ebook directory": ( + "book.epub", + Representation.EPUB_MEDIA_TYPE, + "I'm an EPUB.", ), - 'ebook directory' : ( - 'book.epub', Representation.EPUB_MEDIA_TYPE, "I'm an EPUB." - ) } - script = MockDirectoryImportScript( - self._db, mock_filesystem=mock_filesystem - ) + script = MockDirectoryImportScript(self._db, mock_filesystem=mock_filesystem) work, licensepool_for_work = script.work_from_metadata(collection, *shared_args) # Get the edition that was created for this book. It should have @@ -1104,21 +1085,25 @@ def annotate_metadata(self, collection_type, metadata, *args, **kwargs): # thumbnail. assert "A book" == work.title assert ( - work.cover_full_url == - 'https://test-cover-bucket.s3.amazonaws.com/Gutenberg/Gutenberg%20ID/1003/1003.jpg') + work.cover_full_url + == "https://test-cover-bucket.s3.amazonaws.com/Gutenberg/Gutenberg%20ID/1003/1003.jpg" + ) assert ( - work.cover_thumbnail_url == - 'https://test-cover-bucket.s3.amazonaws.com/scaled/300/Gutenberg/Gutenberg%20ID/1003/1003.png') + work.cover_thumbnail_url + == "https://test-cover-bucket.s3.amazonaws.com/scaled/300/Gutenberg/Gutenberg%20ID/1003/1003.png" + ) assert 1 == len(work.license_pools) assert 1 == len(edition.license_pools) - assert 1 == len([lp for lp in edition.license_pools if lp.collection == collection]) + assert 1 == len( + [lp for lp in edition.license_pools if lp.collection == collection] + ) [pool] = work.license_pools assert licensepool_for_work == pool assert ( - pool.open_access_download_url == - 'https://test-content-bucket.s3.amazonaws.com/Gutenberg/Gutenberg%20ID/1003/A%20book.epub') - assert (RightsStatus.CC0 == - pool.delivery_mechanisms[0].rights_status.uri) + pool.open_access_download_url + == "https://test-content-bucket.s3.amazonaws.com/Gutenberg/Gutenberg%20ID/1003/A%20book.epub" + ) + assert RightsStatus.CC0 == pool.delivery_mechanisms[0].rights_status.uri # The two mock S3Uploaders have records of 'uploading' all these files # to S3. The "books" mirror has the epubs and the "covers" mirror @@ -1138,8 +1123,10 @@ def annotate_metadata(self, collection_type, metadata, *args, **kwargs): # the same metadata. # Even though there will be two license pools associated with the # work's presentation edition, the call should be successful. - collection2 = self._collection('second collection') - work2, licensepool_for_work2 = script.work_from_metadata(collection2, *shared_args) + collection2 = self._collection("second collection") + work2, licensepool_for_work2 = script.work_from_metadata( + collection2, *shared_args + ) # The presentation edition should be the same for both works. edition2 = work2.presentation_edition @@ -1153,7 +1140,9 @@ def annotate_metadata(self, collection_type, metadata, *args, **kwargs): # one for each collection. assert 2 == len(work2.license_pools) assert 2 == len(edition2.license_pools) - assert 1 == len([lp for lp in edition2.license_pools if lp.collection == collection2]) + assert 1 == len( + [lp for lp in edition2.license_pools if lp.collection == collection2] + ) def test_annotate_metadata(self): """Verify that annotate_metadata calls load_circulation_data @@ -1165,6 +1154,7 @@ class MockNoCirculationData(DirectoryImportScript): """Do nothing when load_circulation_data is called. Explode if load_cover_link is called. """ + def load_circulation_data(self, *args): self.load_circulation_data_args = args return None @@ -1177,9 +1167,7 @@ def load_cover_link(self, *args): identifier = IdentifierData(Identifier.GUTENBERG_ID, "11111") identifier_obj, ignore = identifier.load(self._db) metadata = Metadata( - title=self._str, - data_source=gutenberg, - primary_identifier=identifier + title=self._str, data_source=gutenberg, primary_identifier=identifier ) mirrors = object() policy = ReplacementPolicy(mirrors=mirrors) @@ -1194,22 +1182,20 @@ def load_cover_link(self, *args): policy, cover_directory, ebook_directory, - rights_uri + rights_uri, ) script.annotate_metadata(*args) # load_circulation_data was called. assert ( - ( - collection_type, - identifier_obj, - gutenberg, - ebook_directory, - mirrors, - metadata.title, - rights_uri - ) == - script.load_circulation_data_args) + collection_type, + identifier_obj, + gutenberg, + ebook_directory, + mirrors, + metadata.title, + rights_uri, + ) == script.load_circulation_data_args # But because load_circulation_data returned None, # metadata.circulation_data was not modified and @@ -1222,6 +1208,7 @@ class MockNoCoverLink(DirectoryImportScript): """Return an object when load_circulation_data is called. Do nothing when load_cover_link is called. """ + def load_circulation_data(self, *args): return "Some circulation data" @@ -1238,8 +1225,11 @@ def load_cover_link(self, *args): # load_cover_link was called. assert ( - (identifier_obj, gutenberg, cover_directory, mirrors) == - script.load_cover_link_args) + identifier_obj, + gutenberg, + cover_directory, + mirrors, + ) == script.load_cover_link_args # But since it provided no cover link, metadata.links was empty. assert [] == metadata.links @@ -1249,6 +1239,7 @@ class MockWithCoverLink(DirectoryImportScript): """Mock success for both load_circulation_data and load_cover_link. """ + def load_circulation_data(self, *args): return "Some circulation data" @@ -1260,7 +1251,7 @@ def load_cover_link(self, *args): script.annotate_metadata(*args) assert "Some circulation data" == metadata.circulation - assert ['A cover link'] == metadata.links + assert ["A cover link"] == metadata.links def test_load_circulation_data(self): # Create a directory import script with an empty mock filesystem. @@ -1268,7 +1259,7 @@ def test_load_circulation_data(self): identifier = self._identifier(Identifier.GUTENBERG_ID, "2345") gutenberg = DataSource.lookup(self._db, DataSource.GUTENBERG) - mirrors = dict(books_mirror=MockS3Uploader(),covers_mirror=None) + mirrors = dict(books_mirror=MockS3Uploader(), covers_mirror=None) args = ( CollectionType.OPEN_ACCESS, identifier, @@ -1276,7 +1267,7 @@ def test_load_circulation_data(self): "ebooks", mirrors, "Name of book", - "rights URI" + "rights URI", ) # There is nothing on the mock filesystem, so in this case @@ -1285,15 +1276,15 @@ def test_load_circulation_data(self): # But we tried. assert ( - ('2345', 'ebooks', Representation.COMMON_EBOOK_EXTENSIONS, - 'ebook file') == - script._locate_file_args) + "2345", + "ebooks", + Representation.COMMON_EBOOK_EXTENSIONS, + "ebook file", + ) == script._locate_file_args # Try another script that has a populated mock filesystem. mock_filesystem = { - 'ebooks' : ( - 'book.epub', Representation.EPUB_MEDIA_TYPE, "I'm an EPUB." - ) + "ebooks": ("book.epub", Representation.EPUB_MEDIA_TYPE, "I'm an EPUB.") } script = MockDirectoryImportScript(self._db, mock_filesystem) @@ -1309,8 +1300,9 @@ def test_load_circulation_data(self): [link] = circulation.links assert Hyperlink.OPEN_ACCESS_DOWNLOAD == link.rel assert ( - link.href == - 'https://test-content-bucket.s3.amazonaws.com/Gutenberg/Gutenberg%20ID/2345/Name%20of%20book.epub') + link.href + == "https://test-content-bucket.s3.amazonaws.com/Gutenberg/Gutenberg%20ID/2345/Name%20of%20book.epub" + ) assert Representation.EPUB_MEDIA_TYPE == link.media_type assert "I'm an EPUB." == link.content @@ -1327,7 +1319,7 @@ def test_load_cover_link(self): identifier = self._identifier(Identifier.GUTENBERG_ID, "2345") gutenberg = DataSource.lookup(self._db, DataSource.GUTENBERG) - mirrors = dict(covers_mirror=MockS3Uploader(),books_mirror=None) + mirrors = dict(covers_mirror=MockS3Uploader(), books_mirror=None) args = (identifier, gutenberg, "covers", mirrors) # There is nothing on the mock filesystem, so in this case @@ -1336,22 +1328,23 @@ def test_load_cover_link(self): # But we tried. assert ( - ('2345', 'covers', Representation.COMMON_IMAGE_EXTENSIONS, - 'cover image') == - script._locate_file_args) + "2345", + "covers", + Representation.COMMON_IMAGE_EXTENSIONS, + "cover image", + ) == script._locate_file_args # Try another script that has a populated mock filesystem. mock_filesystem = { - 'covers' : ( - 'acover.jpeg', Representation.JPEG_MEDIA_TYPE, "I'm an image." - ) + "covers": ("acover.jpeg", Representation.JPEG_MEDIA_TYPE, "I'm an image.") } script = MockDirectoryImportScript(self._db, mock_filesystem) link = script.load_cover_link(*args) assert Hyperlink.IMAGE == link.rel assert ( - link.href == - 'https://test-cover-bucket.s3.amazonaws.com/Gutenberg/Gutenberg%20ID/2345/2345.jpg') + link.href + == "https://test-cover-bucket.s3.amazonaws.com/Gutenberg/Gutenberg%20ID/2345/2345.jpg" + ) assert Representation.JPEG_MEDIA_TYPE == link.media_type assert "I'm an image." == link.content @@ -1360,15 +1353,15 @@ def test_locate_file(self): to find files on a mock filesystem. """ # Create a mock filesystem with a single file. - mock_filesystem = { - "directory/thefile.JPEG" : "The contents" - } + mock_filesystem = {"directory/thefile.JPEG": "The contents"} + def mock_exists(path): return path in mock_filesystem @contextlib.contextmanager def mock_open(path, mode="r"): yield StringIO(mock_filesystem[path]) + mock_filesystem_operations = mock_exists, mock_open def assert_not_found(base_filename, directory, extensions): @@ -1376,8 +1369,11 @@ def assert_not_found(base_filename, directory, extensions): _locate_file() does not find anything. """ result = DirectoryImportScript._locate_file( - base_filename, directory, extensions, file_type="some file", - mock_filesystem_operations=mock_filesystem_operations + base_filename, + directory, + extensions, + file_type="some file", + mock_filesystem_operations=mock_filesystem_operations, ) assert (None, None, None) == result @@ -1386,36 +1382,39 @@ def assert_found(base_filename, directory, extensions): finds and loads the single file on the mock filesystem.. """ result = DirectoryImportScript._locate_file( - base_filename, directory, extensions, file_type="some file", - mock_filesystem_operations=mock_filesystem_operations + base_filename, + directory, + extensions, + file_type="some file", + mock_filesystem_operations=mock_filesystem_operations, ) assert ( - ("thefile.JPEG", Representation.JPEG_MEDIA_TYPE, - "The contents") == - result) + "thefile.JPEG", + Representation.JPEG_MEDIA_TYPE, + "The contents", + ) == result # As long as the file and directory match we have some flexibility # regarding the extensions we look for. - assert_found('thefile', 'directory', ['.jpeg']) - assert_found('thefile', 'directory', ['.JPEG']) - assert_found('thefile', 'directory', ['jpeg']) - assert_found('thefile', 'directory', ['JPEG']) - assert_found('thefile', 'directory', ['.another-extension', '.jpeg']) + assert_found("thefile", "directory", [".jpeg"]) + assert_found("thefile", "directory", [".JPEG"]) + assert_found("thefile", "directory", ["jpeg"]) + assert_found("thefile", "directory", ["JPEG"]) + assert_found("thefile", "directory", [".another-extension", ".jpeg"]) # But file, directory, and (flexible) extension must all match. - assert_not_found('anotherfile', 'directory', ['.jpeg']) - assert_not_found('thefile', 'another_directory', ['.jpeg']) - assert_not_found('thefile', 'directory', ['.another-extension']) - assert_not_found('thefile', 'directory', []) + assert_not_found("anotherfile", "directory", [".jpeg"]) + assert_not_found("thefile", "another_directory", [".jpeg"]) + assert_not_found("thefile", "directory", [".another-extension"]) + assert_not_found("thefile", "directory", []) -class TestNovelistSnapshotScript(DatabaseTest): +class TestNovelistSnapshotScript(DatabaseTest): def mockNoveListAPI(self, *args, **kwargs): self.called_with = (args, kwargs) def test_do_run(self): - """Test that NovelistSnapshotScript.do_run() calls the NoveList api. - """ + """Test that NovelistSnapshotScript.do_run() calls the NoveList api.""" class MockNovelistSnapshotScript(NovelistSnapshotScript): pass @@ -1434,27 +1433,24 @@ class MockNovelistSnapshotScript(NovelistSnapshotScript): NoveListAPI.from_config = oldNovelistConfig -class TestLocalAnalyticsExportScript(DatabaseTest): +class TestLocalAnalyticsExportScript(DatabaseTest): def test_do_run(self): - class MockLocalAnalyticsExporter(object): def export(self, _db, start, end): self.called_with = [start, end] return "test" output = StringIO() - cmd_args = ['--start=20190820', '--end=20190827'] + cmd_args = ["--start=20190820", "--end=20190827"] exporter = MockLocalAnalyticsExporter() script = LocalAnalyticsExportScript() - script.do_run( - output=output, cmd_args=cmd_args, - exporter=exporter) + script.do_run(output=output, cmd_args=cmd_args, exporter=exporter) assert "test" == output.getvalue() - assert ['20190820', '20190827'] == exporter.called_with + assert ["20190820", "20190827"] == exporter.called_with -class TestGenerateShortTokenScript(DatabaseTest): +class TestGenerateShortTokenScript(DatabaseTest): @pytest.fixture def script(self): return GenerateShortTokenScript() @@ -1477,21 +1473,22 @@ def authdata(self, monkeypatch): @pytest.fixture def patron(self, authdata): - patron = self._patron(external_identifier='test') - patron.authorization_identifier = 'test' + patron = self._patron(external_identifier="test") + patron.authorization_identifier = "test" adobe_credential = self._credential( data_source_name=DataSource.INTERNAL_PROCESSING, patron=patron, - type=authdata.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER) - adobe_credential.credential = '1234567' + type=authdata.ADOBE_ACCOUNT_ID_PATRON_IDENTIFIER, + ) + adobe_credential.credential = "1234567" return patron @pytest.fixture def authentication_provider(self): - barcode = '12345' - pin = 'abcd' + barcode = "12345" + pin = "abcd" integration = self._external_integration( - 'api.simple_authentication', goal=ExternalIntegration.PATRON_AUTH_GOAL + "api.simple_authentication", goal=ExternalIntegration.PATRON_AUTH_GOAL ) self._default_library.integrations.append(integration) integration.setting(BasicAuthenticationProvider.TEST_IDENTIFIER).value = barcode @@ -1500,74 +1497,88 @@ def authentication_provider(self): def test_run_days(self, script, output, authdata, patron): # Test with --days - cmd_args = ['--barcode={}'.format(patron.authorization_identifier), '--days=2', self._default_library.short_name] - script.do_run( - _db=self._db, - output=output, cmd_args=cmd_args, - authdata=authdata) - assert output.getvalue().split('\n') == [ - 'Vendor ID: The Vendor ID', - 'Token: YOU|1620345600|1234567|ZP45vhpfs3fHREvFkDDVgDAmhoD699elFD3PGaZu7yo@', - 'Username: YOU|1620345600|1234567', - 'Password: ZP45vhpfs3fHREvFkDDVgDAmhoD699elFD3PGaZu7yo@', - '' + cmd_args = [ + "--barcode={}".format(patron.authorization_identifier), + "--days=2", + self._default_library.short_name, + ] + script.do_run(_db=self._db, output=output, cmd_args=cmd_args, authdata=authdata) + assert output.getvalue().split("\n") == [ + "Vendor ID: The Vendor ID", + "Token: YOU|1620345600|1234567|ZP45vhpfs3fHREvFkDDVgDAmhoD699elFD3PGaZu7yo@", + "Username: YOU|1620345600|1234567", + "Password: ZP45vhpfs3fHREvFkDDVgDAmhoD699elFD3PGaZu7yo@", + "", ] def test_run_minutes(self, script, output, authdata, patron): # Test with --minutes - cmd_args = ['--barcode={}'.format(patron.authorization_identifier), '--minutes=20', self._default_library.short_name] - script.do_run( - _db=self._db, - output=output, cmd_args=cmd_args, - authdata=authdata) - assert output.getvalue().split('\n')[2] == 'Username: YOU|1620174000|1234567' + cmd_args = [ + "--barcode={}".format(patron.authorization_identifier), + "--minutes=20", + self._default_library.short_name, + ] + script.do_run(_db=self._db, output=output, cmd_args=cmd_args, authdata=authdata) + assert output.getvalue().split("\n")[2] == "Username: YOU|1620174000|1234567" def test_run_hours(self, script, output, authdata, patron): # Test with --hours - cmd_args = ['--barcode={}'.format(patron.authorization_identifier), '--hours=4', self._default_library.short_name] - script.do_run( - _db=self._db, - output=output, cmd_args=cmd_args, - authdata=authdata) - assert output.getvalue().split('\n')[2] == 'Username: YOU|1620187200|1234567' + cmd_args = [ + "--barcode={}".format(patron.authorization_identifier), + "--hours=4", + self._default_library.short_name, + ] + script.do_run(_db=self._db, output=output, cmd_args=cmd_args, authdata=authdata) + assert output.getvalue().split("\n")[2] == "Username: YOU|1620187200|1234567" def test_no_registry(self, script, output, patron): - cmd_args = ['--barcode={}'.format(patron.authorization_identifier), '--minutes=20', self._default_library.short_name] + cmd_args = [ + "--barcode={}".format(patron.authorization_identifier), + "--minutes=20", + self._default_library.short_name, + ] with pytest.raises(SystemExit) as pytest_exit: - script.do_run( - _db=self._db, - output=output, cmd_args=cmd_args) + script.do_run(_db=self._db, output=output, cmd_args=cmd_args) assert pytest_exit.value.code == -1 assert "Library not registered with library registry" in output.getvalue() def test_no_patron_auth_method(self, script, output): # Test running when the patron does not exist - cmd_args = ['--barcode={}'.format('1234567'), '--hours=4', self._default_library.short_name] + cmd_args = [ + "--barcode={}".format("1234567"), + "--hours=4", + self._default_library.short_name, + ] with pytest.raises(SystemExit) as pytest_exit: - script.do_run( - _db=self._db, - output=output, cmd_args=cmd_args) + script.do_run(_db=self._db, output=output, cmd_args=cmd_args) assert pytest_exit.value.code == -1 assert "No methods to authenticate patron found" in output.getvalue() def test_patron_auth(self, script, output, authdata, authentication_provider): barcode, pin = authentication_provider # Test running when the patron does not exist - cmd_args = ['--barcode={}'.format(barcode), '--pin={}'.format(pin), '--hours=4', self._default_library.short_name] - script.do_run( - _db=self._db, - output=output, cmd_args=cmd_args, - authdata=authdata) + cmd_args = [ + "--barcode={}".format(barcode), + "--pin={}".format(pin), + "--hours=4", + self._default_library.short_name, + ] + script.do_run(_db=self._db, output=output, cmd_args=cmd_args, authdata=authdata) assert "Token: YOU|1620187200" in output.getvalue() - def test_patron_auth_no_patron(self, script, output, authdata, authentication_provider): - barcode = 'nonexistent' + def test_patron_auth_no_patron( + self, script, output, authdata, authentication_provider + ): + barcode = "nonexistent" # Test running when the patron does not exist - cmd_args = ['--barcode={}'.format(barcode), '--hours=4', self._default_library.short_name] + cmd_args = [ + "--barcode={}".format(barcode), + "--hours=4", + self._default_library.short_name, + ] with pytest.raises(SystemExit) as pytest_exit: script.do_run( - _db=self._db, - output=output, cmd_args=cmd_args, - authdata=authdata) + _db=self._db, output=output, cmd_args=cmd_args, authdata=authdata + ) assert pytest_exit.value.code == -1 assert "Patron not found" in output.getvalue() diff --git a/tests/test_selftest.py b/tests/test_selftest.py index 4bd18d1fc3..afdd91cc58 100644 --- a/tests/test_selftest.py +++ b/tests/test_selftest.py @@ -2,27 +2,21 @@ import datetime from io import StringIO -from core.testing import DatabaseTest -from core.model import ( - ExternalIntegration, -) -from core.opds_import import ( - OPDSImportMonitor, -) from api.authenticator import BasicAuthenticationProvider from api.circulation import CirculationAPI +from api.feedbooks import FeedbooksImportMonitor from api.selftest import ( HasCollectionSelfTests, HasSelfTests, RunSelfTestsScript, SelfTestResult, ) -from api.feedbooks import ( - FeedbooksImportMonitor, -) +from core.model import ExternalIntegration +from core.opds_import import OPDSImportMonitor +from core.testing import DatabaseTest -class TestHasSelfTests(DatabaseTest): +class TestHasSelfTests(DatabaseTest): def test_default_patrons(self): """Some self-tests must run with a patron's credentials. The default_patrons() method finds the default Patron for every @@ -36,11 +30,13 @@ def test_default_patrons(self): [result] = h.default_patrons(not_in_library) assert "Acquiring test patron credentials." == result.name assert False == result.success - assert ("Collection is not associated with any libraries." == - str(result.exception)) + assert "Collection is not associated with any libraries." == str( + result.exception + ) assert ( - "Add the collection to a library that has a patron authentication service." == - result.exception.debug_message) + "Add the collection to a library that has a patron authentication service." + == result.exception.debug_message + ) # This collection is in two libraries. collection = self._default_collection @@ -51,8 +47,9 @@ def test_default_patrons(self): # This library has a default patorn set up. integration = self._external_integration( - "api.simple_authentication", ExternalIntegration.PATRON_AUTH_GOAL, - libraries=[self._default_library] + "api.simple_authentication", + ExternalIntegration.PATRON_AUTH_GOAL, + libraries=[self._default_library], ) p = BasicAuthenticationProvider integration.setting(p.TEST_IDENTIFIER).value = "username1" @@ -71,10 +68,14 @@ def test_default_patrons(self): # a test patron. assert False == failure.success assert ( - "Acquiring test patron credentials for library %s" % no_default_patron.name == - failure.name) + "Acquiring test patron credentials for library %s" % no_default_patron.name + == failure.name + ) assert "Library has no test patron configured." == str(failure.exception) - assert "You can specify a test patron when you configure the library's patron authentication service." == failure.exception.debug_message + assert ( + "You can specify a test patron when you configure the library's patron authentication service." + == failure.exception.debug_message + ) # The test patron for the library that has one was looked up, # and the test can proceed using this patron. @@ -85,7 +86,6 @@ def test_default_patrons(self): class TestRunSelfTestsScript(DatabaseTest): - def test_do_run(self): library1 = self._default_library library2 = self._library(name="library2") @@ -96,6 +96,7 @@ class MockParsed(object): class MockScript(RunSelfTestsScript): tested = [] + def parse_command_line(self, *args, **kwargs): parsed = MockParsed() parsed.libraries = [library1, library2] @@ -107,8 +108,10 @@ def test_collection(self, collection, api_map): script = MockScript(self._db, out) script.do_run() # Both libraries were tested. - assert (out.getvalue() == - "Testing %s\nTesting %s\n" % (library1.name, library2.name)) + assert out.getvalue() == "Testing %s\nTesting %s\n" % ( + library1.name, + library2.name, + ) # The default library is the only one with a collection; # test_collection() was called on that collection. @@ -133,15 +136,20 @@ def test_collection(self, collection, api_map): class MockScript2(MockScript): def test_collection(self, collection, api_map): raise Exception("blah") + out = StringIO() script = MockScript2(self._db, out) script.do_run() - assert (out.getvalue() == - "Testing %s\n Exception while running self-test: 'blah'\nTesting %s\n" % (library1.name, library2.name)) + assert ( + out.getvalue() + == "Testing %s\n Exception while running self-test: 'blah'\nTesting %s\n" + % (library1.name, library2.name) + ) def test_test_collection(self): class MockScript(RunSelfTestsScript): processed = [] + def process_result(self, result): self.processed.append(result) @@ -152,8 +160,10 @@ def process_result(self, result): out = StringIO() script = MockScript(self._db, out) script.test_collection(collection, api_map={}) - assert (out.getvalue() == - ' Cannot find a self-test for %s, ignoring.\n' % collection.name) + assert ( + out.getvalue() + == " Cannot find a self-test for %s, ignoring.\n" % collection.name + ) # If the api_map does map the colelction's protocol to a # HasSelfTests class, the class's run_self_tests class method @@ -170,15 +180,19 @@ def run_self_tests(cls, _db, constructor_method, *constructor_args): script = MockScript(self._db, out) protocol = self._default_collection.protocol script.test_collection( - collection, api_map={protocol:MockHasSelfTests}, - extra_args={MockHasSelfTests:["an extra arg"]} + collection, + api_map={protocol: MockHasSelfTests}, + extra_args={MockHasSelfTests: ["an extra arg"]}, ) # run_self_tests() was called with the correct arguments, # including the extra one. assert (self._db, None) == MockHasSelfTests.run_self_tests_called_with - assert ((self._db, collection, "an extra arg") == - MockHasSelfTests.run_self_tests_constructor_args) + assert ( + self._db, + collection, + "an extra arg", + ) == MockHasSelfTests.run_self_tests_constructor_args # Each result was run through process_result(). assert ["result 1", "result 2"] == script.processed @@ -193,8 +207,7 @@ def test_process_result(self): out = StringIO() script = RunSelfTestsScript(self._db, out) script.process_result(success) - assert (out.getvalue() == - ' SUCCESS i succeeded (1.5sec)\n Result: a result\n') + assert out.getvalue() == " SUCCESS i succeeded (1.5sec)\n Result: a result\n" # Test a failed test that raised an exception. failure = SelfTestResult("i failed") @@ -203,12 +216,10 @@ def test_process_result(self): out = StringIO() script = RunSelfTestsScript(self._db, out) script.process_result(failure) - assert (out.getvalue() == - " FAILURE i failed (0.0sec)\n Exception: 'bah'\n") + assert out.getvalue() == " FAILURE i failed (0.0sec)\n Exception: 'bah'\n" class TestHasCollectionSelfTests(DatabaseTest): - def test__run_self_tests(self): # Verify that _run_self_tests calls all the test methods # we want it to. @@ -230,8 +241,10 @@ def test__no_delivery_mechanisms_test(self): # There's one LicensePool, and it has a delivery mechanism, # so a string is returned. pool = self._licensepool(None) + class Mock(HasCollectionSelfTests): collection = self._default_collection + hastests = Mock() result = hastests._no_delivery_mechanisms_test() success = "All titles in this collection have delivery mechanisms." @@ -243,8 +256,7 @@ class Mock(HasCollectionSelfTests): # Now a list of strings is returned, one for each problematic # book. [result] = hastests._no_delivery_mechanisms_test() - assert ("[title unknown] (ID: %s)" % pool.identifier.identifier == - result) + assert "[title unknown] (ID: %s)" % pool.identifier.identifier == result # Change the LicensePool so it has no owned licenses. # Now the book is no longer considered problematic, diff --git a/tests/test_shared_collection.py b/tests/test_shared_collection.py index b56fb9fd94..97e5a3bc0e 100644 --- a/tests/test_shared_collection.py +++ b/tests/test_shared_collection.py @@ -1,16 +1,16 @@ -import pytest +import base64 import json + import flask -from Crypto.PublicKey import RSA +import pytest from Crypto.Cipher import PKCS1_OAEP +from Crypto.PublicKey import RSA +from api.circulation import FulfillmentInfo from api.circulation_exceptions import * -from api.shared_collection import ( - SharedCollectionAPI, - BaseSharedCollectionAPI, -) -from core.config import CannotLoadConfiguration from api.odl import ODLAPI +from api.shared_collection import BaseSharedCollectionAPI, SharedCollectionAPI +from core.config import CannotLoadConfiguration from core.model import ( ConfigurationSetting, Hold, @@ -19,11 +19,8 @@ create, get_one, ) -import base64 -from api.circulation import FulfillmentInfo +from core.testing import DatabaseTest, MockRequestsResponse -from core.testing import DatabaseTest -from core.testing import MockRequestsResponse class MockAPI(BaseSharedCollectionAPI): def __init__(self, _db, collection): @@ -47,19 +44,18 @@ def fulfill_for_external_library(self, client, loan, mechanism): def release_hold_from_external_library(self, client, hold): self.released_holds.append((client, hold)) -class TestSharedCollectionAPI(DatabaseTest): +class TestSharedCollectionAPI(DatabaseTest): def setup_method(self): super(TestSharedCollectionAPI, self).setup_method() self.collection = self._collection(protocol="Mock") self.shared_collection = SharedCollectionAPI( - self._db, api_map = { - "Mock" : MockAPI - } + self._db, api_map={"Mock": MockAPI} ) self.api = self.shared_collection.api(self.collection) ConfigurationSetting.for_externalintegration( - BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, self.collection.external_integration + BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, + self.collection.external_integration, ).value = json.dumps(["http://library.org"]) self.client, ignore = IntegrationClient.register(self._db, "http://library.org") edition, self.pool = self._edition( @@ -72,10 +68,8 @@ class MisconfiguredAPI(object): def __init__(self, _db, collection): raise CannotLoadConfiguration("doomed!") - api_map = { self._default_collection.protocol: MisconfiguredAPI } - shared_collection = SharedCollectionAPI( - self._db, api_map=api_map - ) + api_map = {self._default_collection.protocol: MisconfiguredAPI} + shared_collection = SharedCollectionAPI(self._db, api_map=api_map) # Although the SharedCollectionAPI was created, it has no functioning # APIs. assert {} == shared_collection.api_for_collection @@ -105,109 +99,219 @@ def test_api_for_collection(self): def test_register(self): # An auth document URL is required to register. - pytest.raises(InvalidInputException, self.shared_collection.register, - self.collection, None) + pytest.raises( + InvalidInputException, + self.shared_collection.register, + self.collection, + None, + ) # If the url doesn't return a valid auth document, there's an exception. auth_response = "not json" + def do_get(*args, **kwargs): return MockRequestsResponse(200, content=auth_response) - pytest.raises(RemoteInitiatedServerError, self.shared_collection.register, - self.collection, "http://library.org/auth", do_get=do_get) + + pytest.raises( + RemoteInitiatedServerError, + self.shared_collection.register, + self.collection, + "http://library.org/auth", + do_get=do_get, + ) # The auth document also must have a link to the library's catalog. auth_response = json.dumps({"links": []}) - pytest.raises(RemoteInitiatedServerError, self.shared_collection.register, - self.collection, "http://library.org/auth", do_get=do_get) + pytest.raises( + RemoteInitiatedServerError, + self.shared_collection.register, + self.collection, + "http://library.org/auth", + do_get=do_get, + ) # If no external library URLs are configured, no one can register. - auth_response = json.dumps({"links": [{"href": "http://library.org", "rel": "start"}]}) + auth_response = json.dumps( + {"links": [{"href": "http://library.org", "rel": "start"}]} + ) ConfigurationSetting.for_externalintegration( - BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, self.collection.external_integration + BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, + self.collection.external_integration, ).value = None - pytest.raises(AuthorizationFailedException, self.shared_collection.register, - self.collection, "http://library.org/auth", do_get=do_get) + pytest.raises( + AuthorizationFailedException, + self.shared_collection.register, + self.collection, + "http://library.org/auth", + do_get=do_get, + ) # If the library's URL isn't in the configuration, it can't register. - auth_response = json.dumps({"links": [{"href": "http://differentlibrary.org", "rel": "start"}]}) + auth_response = json.dumps( + {"links": [{"href": "http://differentlibrary.org", "rel": "start"}]} + ) ConfigurationSetting.for_externalintegration( - BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, self.collection.external_integration + BaseSharedCollectionAPI.EXTERNAL_LIBRARY_URLS, + self.collection.external_integration, ).value = json.dumps(["http://library.org"]) - pytest.raises(AuthorizationFailedException, self.shared_collection.register, - self.collection, "http://differentlibrary.org/auth", do_get=do_get) + pytest.raises( + AuthorizationFailedException, + self.shared_collection.register, + self.collection, + "http://differentlibrary.org/auth", + do_get=do_get, + ) # Or if the public key is missing from the auth document. - auth_response = json.dumps({"links": [{"href": "http://library.org", "rel": "start"}]}) - pytest.raises(RemoteInitiatedServerError, self.shared_collection.register, - self.collection, "http://library.org/auth", do_get=do_get) - - auth_response = json.dumps({"public_key": { "type": "not RSA", "value": "123" }, - "links": [{"href": "http://library.org", "rel": "start"}]}) - pytest.raises(RemoteInitiatedServerError, self.shared_collection.register, - self.collection, "http://library.org/auth", do_get=do_get) + auth_response = json.dumps( + {"links": [{"href": "http://library.org", "rel": "start"}]} + ) + pytest.raises( + RemoteInitiatedServerError, + self.shared_collection.register, + self.collection, + "http://library.org/auth", + do_get=do_get, + ) - auth_response = json.dumps({"public_key": { "type": "RSA" }, - "links": [{"href": "http://library.org", "rel": "start"}]}) - pytest.raises(RemoteInitiatedServerError, self.shared_collection.register, - self.collection, "http://library.org/auth", do_get=do_get) + auth_response = json.dumps( + { + "public_key": {"type": "not RSA", "value": "123"}, + "links": [{"href": "http://library.org", "rel": "start"}], + } + ) + pytest.raises( + RemoteInitiatedServerError, + self.shared_collection.register, + self.collection, + "http://library.org/auth", + do_get=do_get, + ) + auth_response = json.dumps( + { + "public_key": {"type": "RSA"}, + "links": [{"href": "http://library.org", "rel": "start"}], + } + ) + pytest.raises( + RemoteInitiatedServerError, + self.shared_collection.register, + self.collection, + "http://library.org/auth", + do_get=do_get, + ) # Here's an auth document with a valid key. key = RSA.generate(2048) public_key = key.publickey().exportKey().decode("utf-8") encryptor = PKCS1_OAEP.new(key) - auth_response = json.dumps({"public_key": { "type": "RSA", "value": public_key }, - "links": [{"href": "http://library.org", "rel": "start"}]}) - response = self.shared_collection.register(self.collection, "http://library.org/auth", do_get=do_get) + auth_response = json.dumps( + { + "public_key": {"type": "RSA", "value": public_key}, + "links": [{"href": "http://library.org", "rel": "start"}], + } + ) + response = self.shared_collection.register( + self.collection, "http://library.org/auth", do_get=do_get + ) # An IntegrationClient has been created. - client = get_one(self._db, IntegrationClient, url=IntegrationClient.normalize_url("http://library.org/")) - decrypted_secret = encryptor.decrypt(base64.b64decode(response.get("metadata", {}).get("shared_secret"))) + client = get_one( + self._db, + IntegrationClient, + url=IntegrationClient.normalize_url("http://library.org/"), + ) + decrypted_secret = encryptor.decrypt( + base64.b64decode(response.get("metadata", {}).get("shared_secret")) + ) assert client.shared_secret == decrypted_secret.decode("utf-8") def test_borrow(self): # This client is registered, but isn't one of the allowed URLs for the collection # (maybe it was registered for a different shared collection). - other_client, ignore = IntegrationClient.register(self._db, "http://other_library.org") + other_client, ignore = IntegrationClient.register( + self._db, "http://other_library.org" + ) # Trying to borrow raises an exception. - pytest.raises(AuthorizationFailedException, self.shared_collection.borrow, - self.collection, other_client, self.pool) + pytest.raises( + AuthorizationFailedException, + self.shared_collection.borrow, + self.collection, + other_client, + self.pool, + ) # A client that's registered with the collection can borrow. self.shared_collection.borrow(self.collection, self.client, self.pool) assert [(self.client, self.pool)] == self.api.checkouts # If the client's checking out an existing hold, the hold must be for that client. - hold, ignore = create(self._db, Hold, integration_client=other_client, license_pool=self.pool) - pytest.raises(CannotLoan, self.shared_collection.borrow, - self.collection, self.client, self.pool, hold=hold) + hold, ignore = create( + self._db, Hold, integration_client=other_client, license_pool=self.pool + ) + pytest.raises( + CannotLoan, + self.shared_collection.borrow, + self.collection, + self.client, + self.pool, + hold=hold, + ) hold.integration_client = self.client - self.shared_collection.borrow(self.collection, self.client, self.pool, hold=hold) + self.shared_collection.borrow( + self.collection, self.client, self.pool, hold=hold + ) assert [(self.client, self.pool)] == self.api.checkouts[1:] def test_revoke_loan(self): - other_client, ignore = IntegrationClient.register(self._db, "http://other_library.org") - loan, ignore = create(self._db, Loan, integration_client=other_client, license_pool=self.pool) - pytest.raises(NotCheckedOut, self.shared_collection.revoke_loan, - self.collection, self.client, loan) + other_client, ignore = IntegrationClient.register( + self._db, "http://other_library.org" + ) + loan, ignore = create( + self._db, Loan, integration_client=other_client, license_pool=self.pool + ) + pytest.raises( + NotCheckedOut, + self.shared_collection.revoke_loan, + self.collection, + self.client, + loan, + ) loan.integration_client = self.client self.shared_collection.revoke_loan(self.collection, self.client, loan) assert [(self.client, loan)] == self.api.returns def test_fulfill(self): - other_client, ignore = IntegrationClient.register(self._db, "http://other_library.org") - loan, ignore = create(self._db, Loan, integration_client=other_client, license_pool=self.pool) - pytest.raises(CannotFulfill, self.shared_collection.fulfill, - self.collection, self.client, loan, self.delivery_mechanism) + other_client, ignore = IntegrationClient.register( + self._db, "http://other_library.org" + ) + loan, ignore = create( + self._db, Loan, integration_client=other_client, license_pool=self.pool + ) + pytest.raises( + CannotFulfill, + self.shared_collection.fulfill, + self.collection, + self.client, + loan, + self.delivery_mechanism, + ) loan.integration_client = self.client # If the API does not return content or a content link, the loan can't be fulfilled. - pytest.raises(CannotFulfill, self.shared_collection.fulfill, - self.collection, self.client, loan, self.delivery_mechanism) + pytest.raises( + CannotFulfill, + self.shared_collection.fulfill, + self.collection, + self.client, + loan, + self.delivery_mechanism, + ) assert [(self.client, loan, self.delivery_mechanism)] == self.api.fulfills self.api.fulfillment = FulfillmentInfo( @@ -220,16 +324,27 @@ def test_fulfill(self): None, None, ) - fulfillment = self.shared_collection.fulfill(self.collection, self.client, loan, self.delivery_mechanism) + fulfillment = self.shared_collection.fulfill( + self.collection, self.client, loan, self.delivery_mechanism + ) assert [(self.client, loan, self.delivery_mechanism)] == self.api.fulfills[1:] assert self.delivery_mechanism == loan.fulfillment def test_revoke_hold(self): - other_client, ignore = IntegrationClient.register(self._db, "http://other_library.org") - hold, ignore = create(self._db, Hold, integration_client=other_client, license_pool=self.pool) + other_client, ignore = IntegrationClient.register( + self._db, "http://other_library.org" + ) + hold, ignore = create( + self._db, Hold, integration_client=other_client, license_pool=self.pool + ) - pytest.raises(CannotReleaseHold, self.shared_collection.revoke_hold, - self.collection, self.client, hold) + pytest.raises( + CannotReleaseHold, + self.shared_collection.revoke_hold, + self.collection, + self.client, + hold, + ) hold.integration_client = self.client self.shared_collection.revoke_hold(self.collection, self.client, hold) diff --git a/tests/test_simple_auth.py b/tests/test_simple_auth.py index 65940d22f3..1b6fdf44d0 100644 --- a/tests/test_simple_auth.py +++ b/tests/test_simple_auth.py @@ -1,20 +1,14 @@ -import pytest import json -from api.authenticator import PatronData - -from api.simple_authentication import ( - SimpleAuthenticationProvider, -) - -from api.config import ( - CannotLoadConfiguration, -) +import pytest +from api.authenticator import PatronData +from api.config import CannotLoadConfiguration +from api.simple_authentication import SimpleAuthenticationProvider from core.testing import DatabaseTest -class TestSimpleAuth(DatabaseTest): +class TestSimpleAuth(DatabaseTest): def test_simple(self): p = SimpleAuthenticationProvider integration = self._external_integration(self._str) @@ -61,7 +55,7 @@ def test_no_password_authentication(self): user = provider.remote_authenticate("barcode", None) assert isinstance(user, PatronData) - user2 = provider.remote_authenticate("barcode", '') + user2 = provider.remote_authenticate("barcode", "") assert user2.authorization_identifier == user.authorization_identifier # If you provide any password, you're out. @@ -77,7 +71,9 @@ def test_additional_identifiers(self): integration.setting(p.TEST_IDENTIFIER).value = "barcode" integration.setting(p.TEST_PASSWORD).value = "pass" - integration.setting(p.ADDITIONAL_TEST_IDENTIFIERS).value = json.dumps(["a", "b", "c"]) + integration.setting(p.ADDITIONAL_TEST_IDENTIFIERS).value = json.dumps( + ["a", "b", "c"] + ) provider = p(self._default_library, integration) assert None == provider.remote_authenticate("a", None) @@ -110,20 +106,20 @@ def test_generate_patrondata(self): m = SimpleAuthenticationProvider.generate_patrondata - #Pass in numeric barcode as identifier + # Pass in numeric barcode as identifier result = m("1234") assert result.permanent_id == "1234_id" - assert result.authorization_identifier == '1234' + assert result.authorization_identifier == "1234" assert result.personal_name == "PersonalName1234" - assert result.username == '1234_username' + assert result.username == "1234_username" assert result.neighborhood == None - #Pass in username as identifier + # Pass in username as identifier result = m("1234_username") assert result.permanent_id == "1234_id" - assert result.authorization_identifier == '1234' + assert result.authorization_identifier == "1234" assert result.personal_name == "PersonalName1234" - assert result.username == '1234_username' + assert result.username == "1234_username" assert result.neighborhood == None # Pass in a neighborhood. @@ -141,18 +137,18 @@ def test__remote_patron_lookup(self): patron = self._patron() patron.authorization_identifier = "barcode" - #Returns None if nothing is passed in + # Returns None if nothing is passed in assert provider._remote_patron_lookup(None) == None - #Returns a patron if a patron is passed in and something is found + # Returns a patron if a patron is passed in and something is found result = provider._remote_patron_lookup(patron) assert result.permanent_id == "barcode_id" - #Returns None if no patron is found + # Returns None if no patron is found patron.authorization_identifier = "wrong barcode" result = provider._remote_patron_lookup(patron) assert result == None - #Returns a patron if a PatronData object is passed in and something is found + # Returns a patron if a PatronData object is passed in and something is found result = provider._remote_patron_lookup(patron_data) assert result.permanent_id == "barcode_id" diff --git a/tox.ini b/tox.ini index ac689a4fe4..1db677b79f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py{36,37,38,39}-docker +envlist = isort, black, py{36,37,38,39}-docker skipsdist = true [testenv] @@ -21,6 +21,47 @@ allowlist_externals = docker: docker python +[testenv:isort] +description = Run isort (linter) +skip_install = True +basepython = python3.8 +deps = isort==5.10.0 +setenv = ISORT_ARGS=--check-only +commands_pre = +commands = isort {env:ISORT_ARGS:} . + +[testenv:isort-reformat] +description = {[testenv:isort]description} and reformat +skip_install = {[testenv:isort]skip_install} +basepython = {[testenv:isort]basepython} +deps = {[testenv:isort]deps} +commands_pre = {[testenv:isort]commands_pre} +commands = {[testenv:isort]commands} + +[testenv:black] +description = Run Black (linter) +skip_install = True +basepython = python3.8 +deps = black==21.10b0 +commands_pre = +setenv = BLACK_LINT_ARGS=--check +commands = black {env:BLACK_LINT_ARGS:} . + +[testenv:black-reformat] +description = {[testenv:black]description} and reformat +skip_install = {[testenv:black]skip_install} +basepython = {[testenv:black]basepython} +deps = {[testenv:black]deps} +commands_pre = +commands = {[testenv:black]commands} + +[testenv:flake8] +description = Run Flake8 (linter) +skip_install = True +deps = flake8 +commands_pre = +commands = flake8 . + [docker:db-circ] image = postgres:12 environment = @@ -29,6 +70,9 @@ environment = POSTGRES_DB=simplified_circulation_test ports = 9005:5432/tcp +healthcheck_cmd = pg_isready +healthcheck_interval = 5 +healthcheck_retries = 10 [docker:es-circ] image = elasticsearch:6.8.6 @@ -48,3 +92,11 @@ python = timeout_method = thread timeout = 600 testpaths = tests + +[flake8] +max-line-length = 120 +extend-ignore = E203, E501, E711, E712 + +[isort] +profile = black +known_first_party = core,api